Improve Hooks
This commit is contained in:
		
							parent
							
								
									95a6539331
								
							
						
					
					
						commit
						befef0c9a9
					
				| @ -10,7 +10,7 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func SaveBeforeAssociations(db *gorm.DB) { | func SaveBeforeAssociations(db *gorm.DB) { | ||||||
| 	if db.Statement.Schema != nil { | 	if db.Error == nil && db.Statement.Schema != nil { | ||||||
| 		selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) | 		selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) | ||||||
| 
 | 
 | ||||||
| 		// Save Belongs To associations
 | 		// Save Belongs To associations
 | ||||||
| @ -83,7 +83,7 @@ func SaveBeforeAssociations(db *gorm.DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func SaveAfterAssociations(db *gorm.DB) { | func SaveAfterAssociations(db *gorm.DB) { | ||||||
| 	if db.Statement.Schema != nil { | 	if db.Error == nil && db.Statement.Schema != nil { | ||||||
| 		selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) | 		selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) | ||||||
| 
 | 
 | ||||||
| 		// Save Has One associations
 | 		// Save Has One associations
 | ||||||
|  | |||||||
| @ -9,20 +9,21 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func BeforeCreate(db *gorm.DB) { | func BeforeCreate(db *gorm.DB) { | ||||||
| 	if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { | 	if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { | ||||||
|  | 		tx := db.Session(&gorm.Session{}) | ||||||
| 		callMethod := func(value interface{}) bool { | 		callMethod := func(value interface{}) bool { | ||||||
| 			var ok bool | 			var ok bool | ||||||
| 			if db.Statement.Schema.BeforeSave { | 			if db.Statement.Schema.BeforeSave { | ||||||
| 				if i, ok := value.(gorm.BeforeSaveInterface); ok { | 				if i, ok := value.(gorm.BeforeSaveInterface); ok { | ||||||
| 					ok = true | 					ok = true | ||||||
| 					i.BeforeSave(db) | 					db.AddError(i.BeforeSave(tx)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			if db.Statement.Schema.BeforeCreate { | 			if db.Statement.Schema.BeforeCreate { | ||||||
| 				if i, ok := value.(gorm.BeforeCreateInterface); ok { | 				if i, ok := value.(gorm.BeforeCreateInterface); ok { | ||||||
| 					ok = true | 					ok = true | ||||||
| 					i.BeforeCreate(db) | 					db.AddError(i.BeforeCreate(tx)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 			return ok | 			return ok | ||||||
| @ -31,7 +32,7 @@ func BeforeCreate(db *gorm.DB) { | |||||||
| 		if ok := callMethod(db.Statement.Dest); !ok { | 		if ok := callMethod(db.Statement.Dest); !ok { | ||||||
| 			switch db.Statement.ReflectValue.Kind() { | 			switch db.Statement.ReflectValue.Kind() { | ||||||
| 			case reflect.Slice, reflect.Array: | 			case reflect.Slice, reflect.Array: | ||||||
| 				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { | 				for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||||
| 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | ||||||
| 				} | 				} | ||||||
| 			case reflect.Struct: | 			case reflect.Struct: | ||||||
| @ -46,48 +47,127 @@ func Create(config *Config) func(db *gorm.DB) { | |||||||
| 		return CreateWithReturning | 		return CreateWithReturning | ||||||
| 	} else { | 	} else { | ||||||
| 		return func(db *gorm.DB) { | 		return func(db *gorm.DB) { | ||||||
| 			if db.Statement.Schema != nil && !db.Statement.Unscoped { | 			if db.Error == nil { | ||||||
| 				for _, c := range db.Statement.Schema.CreateClauses { | 				if db.Statement.Schema != nil && !db.Statement.Unscoped { | ||||||
| 					db.Statement.AddClause(c) | 					for _, c := range db.Statement.Schema.CreateClauses { | ||||||
| 				} | 						db.Statement.AddClause(c) | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			if db.Statement.SQL.String() == "" { |  | ||||||
| 				db.Statement.AddClauseIfNotExists(clause.Insert{ |  | ||||||
| 					Table: clause.Table{Name: db.Statement.Table}, |  | ||||||
| 				}) |  | ||||||
| 				db.Statement.AddClause(ConvertToCreateValues(db.Statement)) |  | ||||||
| 
 |  | ||||||
| 				db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) |  | ||||||
| 
 |  | ||||||
| 			if err == nil { |  | ||||||
| 				if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { |  | ||||||
| 					if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { |  | ||||||
| 						if insertID, err := result.LastInsertId(); err == nil { |  | ||||||
| 							switch db.Statement.ReflectValue.Kind() { |  | ||||||
| 							case reflect.Slice, reflect.Array: |  | ||||||
| 								if config.LastInsertIDReversed { |  | ||||||
| 									for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { |  | ||||||
| 										db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) |  | ||||||
| 										insertID-- |  | ||||||
| 									} |  | ||||||
| 								} else { |  | ||||||
| 									for i := 0; i < db.Statement.ReflectValue.Len(); i++ { |  | ||||||
| 										db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) |  | ||||||
| 										insertID++ |  | ||||||
| 									} |  | ||||||
| 								} |  | ||||||
| 							case reflect.Struct: |  | ||||||
| 								db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) |  | ||||||
| 							} |  | ||||||
| 						} else { |  | ||||||
| 							db.AddError(err) |  | ||||||
| 						} |  | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
|  | 
 | ||||||
|  | 				if db.Statement.SQL.String() == "" { | ||||||
|  | 					db.Statement.AddClauseIfNotExists(clause.Insert{ | ||||||
|  | 						Table: clause.Table{Name: db.Statement.Table}, | ||||||
|  | 					}) | ||||||
|  | 					db.Statement.AddClause(ConvertToCreateValues(db.Statement)) | ||||||
|  | 
 | ||||||
|  | 					db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
|  | 
 | ||||||
|  | 				if err == nil { | ||||||
|  | 					if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { | ||||||
|  | 						if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { | ||||||
|  | 							if insertID, err := result.LastInsertId(); err == nil { | ||||||
|  | 								switch db.Statement.ReflectValue.Kind() { | ||||||
|  | 								case reflect.Slice, reflect.Array: | ||||||
|  | 									if config.LastInsertIDReversed { | ||||||
|  | 										for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { | ||||||
|  | 											db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) | ||||||
|  | 											insertID-- | ||||||
|  | 										} | ||||||
|  | 									} else { | ||||||
|  | 										for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||||
|  | 											db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) | ||||||
|  | 											insertID++ | ||||||
|  | 										} | ||||||
|  | 									} | ||||||
|  | 								case reflect.Struct: | ||||||
|  | 									db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) | ||||||
|  | 								} | ||||||
|  | 							} else { | ||||||
|  | 								db.AddError(err) | ||||||
|  | 							} | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 					db.RowsAffected, _ = result.RowsAffected() | ||||||
|  | 				} else { | ||||||
|  | 					db.AddError(err) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func CreateWithReturning(db *gorm.DB) { | ||||||
|  | 	if db.Error == nil { | ||||||
|  | 		if db.Statement.Schema != nil && !db.Statement.Unscoped { | ||||||
|  | 			for _, c := range db.Statement.Schema.CreateClauses { | ||||||
|  | 				db.Statement.AddClause(c) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if db.Statement.SQL.String() == "" { | ||||||
|  | 			db.Statement.AddClauseIfNotExists(clause.Insert{ | ||||||
|  | 				Table: clause.Table{Name: db.Statement.Table}, | ||||||
|  | 			}) | ||||||
|  | 			db.Statement.AddClause(ConvertToCreateValues(db.Statement)) | ||||||
|  | 
 | ||||||
|  | 			db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { | ||||||
|  | 			db.Statement.WriteString(" RETURNING ") | ||||||
|  | 
 | ||||||
|  | 			var ( | ||||||
|  | 				idx    int | ||||||
|  | 				fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) | ||||||
|  | 				values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) | ||||||
|  | 			) | ||||||
|  | 
 | ||||||
|  | 			for dbName, field := range sch.FieldsWithDefaultDBValue { | ||||||
|  | 				if idx != 0 { | ||||||
|  | 					db.Statement.WriteByte(',') | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				fields[idx] = field | ||||||
|  | 				db.Statement.WriteQuoted(dbName) | ||||||
|  | 				idx++ | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
|  | 
 | ||||||
|  | 			if err == nil { | ||||||
|  | 				defer rows.Close() | ||||||
|  | 
 | ||||||
|  | 				switch db.Statement.ReflectValue.Kind() { | ||||||
|  | 				case reflect.Slice, reflect.Array: | ||||||
|  | 					for rows.Next() { | ||||||
|  | 						for idx, field := range fields { | ||||||
|  | 							values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() | ||||||
|  | 						} | ||||||
|  | 						if err := rows.Scan(values...); err != nil { | ||||||
|  | 							db.AddError(err) | ||||||
|  | 						} | ||||||
|  | 						db.RowsAffected++ | ||||||
|  | 					} | ||||||
|  | 				case reflect.Struct: | ||||||
|  | 					for idx, field := range fields { | ||||||
|  | 						values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
|  | 					if rows.Next() { | ||||||
|  | 						db.RowsAffected++ | ||||||
|  | 						err = rows.Scan(values...) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if err != nil { | ||||||
|  | 				db.AddError(err) | ||||||
|  | 			} | ||||||
|  | 		} else { | ||||||
|  | 			if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { | ||||||
| 				db.RowsAffected, _ = result.RowsAffected() | 				db.RowsAffected, _ = result.RowsAffected() | ||||||
| 			} else { | 			} else { | ||||||
| 				db.AddError(err) | 				db.AddError(err) | ||||||
| @ -96,96 +176,22 @@ func Create(config *Config) func(db *gorm.DB) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func CreateWithReturning(db *gorm.DB) { |  | ||||||
| 	if db.Statement.Schema != nil && !db.Statement.Unscoped { |  | ||||||
| 		for _, c := range db.Statement.Schema.CreateClauses { |  | ||||||
| 			db.Statement.AddClause(c) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if db.Statement.SQL.String() == "" { |  | ||||||
| 		db.Statement.AddClauseIfNotExists(clause.Insert{ |  | ||||||
| 			Table: clause.Table{Name: db.Statement.Table}, |  | ||||||
| 		}) |  | ||||||
| 		db.Statement.AddClause(ConvertToCreateValues(db.Statement)) |  | ||||||
| 
 |  | ||||||
| 		db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { |  | ||||||
| 		db.Statement.WriteString(" RETURNING ") |  | ||||||
| 
 |  | ||||||
| 		var ( |  | ||||||
| 			idx    int |  | ||||||
| 			fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) |  | ||||||
| 			values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) |  | ||||||
| 		) |  | ||||||
| 
 |  | ||||||
| 		for dbName, field := range sch.FieldsWithDefaultDBValue { |  | ||||||
| 			if idx != 0 { |  | ||||||
| 				db.Statement.WriteByte(',') |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			fields[idx] = field |  | ||||||
| 			db.Statement.WriteQuoted(dbName) |  | ||||||
| 			idx++ |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) |  | ||||||
| 
 |  | ||||||
| 		if err == nil { |  | ||||||
| 			defer rows.Close() |  | ||||||
| 
 |  | ||||||
| 			switch db.Statement.ReflectValue.Kind() { |  | ||||||
| 			case reflect.Slice, reflect.Array: |  | ||||||
| 				for rows.Next() { |  | ||||||
| 					for idx, field := range fields { |  | ||||||
| 						values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() |  | ||||||
| 					} |  | ||||||
| 					if err := rows.Scan(values...); err != nil { |  | ||||||
| 						db.AddError(err) |  | ||||||
| 					} |  | ||||||
| 					db.RowsAffected++ |  | ||||||
| 				} |  | ||||||
| 			case reflect.Struct: |  | ||||||
| 				for idx, field := range fields { |  | ||||||
| 					values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				if rows.Next() { |  | ||||||
| 					db.RowsAffected++ |  | ||||||
| 					err = rows.Scan(values...) |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if err != nil { |  | ||||||
| 			db.AddError(err) |  | ||||||
| 		} |  | ||||||
| 	} else { |  | ||||||
| 		if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { |  | ||||||
| 			db.RowsAffected, _ = result.RowsAffected() |  | ||||||
| 		} else { |  | ||||||
| 			db.AddError(err) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func AfterCreate(db *gorm.DB) { | func AfterCreate(db *gorm.DB) { | ||||||
| 	if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { | 	if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { | ||||||
|  | 		tx := db.Session(&gorm.Session{}) | ||||||
| 		callMethod := func(value interface{}) bool { | 		callMethod := func(value interface{}) bool { | ||||||
| 			var ok bool | 			var ok bool | ||||||
| 			if db.Statement.Schema.AfterSave { | 			if db.Statement.Schema.AfterSave { | ||||||
| 				if i, ok := value.(gorm.AfterSaveInterface); ok { | 				if i, ok := value.(gorm.AfterSaveInterface); ok { | ||||||
| 					ok = true | 					ok = true | ||||||
| 					i.AfterSave(db) | 					db.AddError(i.AfterSave(tx)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			if db.Statement.Schema.AfterCreate { | 			if db.Statement.Schema.AfterCreate { | ||||||
| 				if i, ok := value.(gorm.AfterCreateInterface); ok { | 				if i, ok := value.(gorm.AfterCreateInterface); ok { | ||||||
| 					ok = true | 					ok = true | ||||||
| 					i.AfterCreate(db) | 					db.AddError(i.AfterCreate(tx)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 			return ok | 			return ok | ||||||
| @ -194,7 +200,7 @@ func AfterCreate(db *gorm.DB) { | |||||||
| 		if ok := callMethod(db.Statement.Dest); !ok { | 		if ok := callMethod(db.Statement.Dest); !ok { | ||||||
| 			switch db.Statement.ReflectValue.Kind() { | 			switch db.Statement.ReflectValue.Kind() { | ||||||
| 			case reflect.Slice, reflect.Array: | 			case reflect.Slice, reflect.Array: | ||||||
| 				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { | 				for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||||
| 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | ||||||
| 				} | 				} | ||||||
| 			case reflect.Struct: | 			case reflect.Struct: | ||||||
|  | |||||||
| @ -9,11 +9,12 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func BeforeDelete(db *gorm.DB) { | func BeforeDelete(db *gorm.DB) { | ||||||
| 	if db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { | 	if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { | ||||||
|  | 		tx := db.Session(&gorm.Session{}) | ||||||
| 		callMethod := func(value interface{}) bool { | 		callMethod := func(value interface{}) bool { | ||||||
| 			if db.Statement.Schema.BeforeDelete { | 			if db.Statement.Schema.BeforeDelete { | ||||||
| 				if i, ok := value.(gorm.BeforeDeleteInterface); ok { | 				if i, ok := value.(gorm.BeforeDeleteInterface); ok { | ||||||
| 					i.BeforeDelete(db) | 					db.AddError(i.BeforeDelete(tx)) | ||||||
| 					return true | 					return true | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| @ -23,7 +24,7 @@ func BeforeDelete(db *gorm.DB) { | |||||||
| 		if ok := callMethod(db.Statement.Dest); !ok { | 		if ok := callMethod(db.Statement.Dest); !ok { | ||||||
| 			switch db.Statement.ReflectValue.Kind() { | 			switch db.Statement.ReflectValue.Kind() { | ||||||
| 			case reflect.Slice, reflect.Array: | 			case reflect.Slice, reflect.Array: | ||||||
| 				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { | 				for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||||
| 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | ||||||
| 				} | 				} | ||||||
| 			case reflect.Struct: | 			case reflect.Struct: | ||||||
| @ -34,57 +35,60 @@ func BeforeDelete(db *gorm.DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Delete(db *gorm.DB) { | func Delete(db *gorm.DB) { | ||||||
| 	if db.Statement.Schema != nil && !db.Statement.Unscoped { | 	if db.Error == nil { | ||||||
| 		for _, c := range db.Statement.Schema.DeleteClauses { | 		if db.Statement.Schema != nil && !db.Statement.Unscoped { | ||||||
| 			db.Statement.AddClause(c) | 			for _, c := range db.Statement.Schema.DeleteClauses { | ||||||
| 		} | 				db.Statement.AddClause(c) | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if db.Statement.SQL.String() == "" { |  | ||||||
| 		db.Statement.AddClauseIfNotExists(clause.Delete{}) |  | ||||||
| 
 |  | ||||||
| 		if db.Statement.Schema != nil { |  | ||||||
| 			_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) |  | ||||||
| 			column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) |  | ||||||
| 
 |  | ||||||
| 			if len(values) > 0 { |  | ||||||
| 				db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) |  | ||||||
| 			} | 			} | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 			if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { | 		if db.Statement.SQL.String() == "" { | ||||||
| 				_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) | 			db.Statement.AddClauseIfNotExists(clause.Delete{}) | ||||||
| 				column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) | 
 | ||||||
|  | 			if db.Statement.Schema != nil { | ||||||
|  | 				_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) | ||||||
|  | 				column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) | ||||||
| 
 | 
 | ||||||
| 				if len(values) > 0 { | 				if len(values) > 0 { | ||||||
| 					db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) | 					db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) | ||||||
| 				} | 				} | ||||||
|  | 
 | ||||||
|  | 				if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { | ||||||
|  | 					_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) | ||||||
|  | 					column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) | ||||||
|  | 
 | ||||||
|  | 					if len(values) > 0 { | ||||||
|  | 						db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
|  | 
 | ||||||
|  | 			if _, ok := db.Statement.Clauses["WHERE"]; !ok { | ||||||
|  | 				db.AddError(gorm.ErrMissingWhereClause) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			db.Statement.AddClauseIfNotExists(clause.From{}) | ||||||
|  | 			db.Statement.Build("DELETE", "FROM", "WHERE") | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if _, ok := db.Statement.Clauses["WHERE"]; !ok { | 		result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
| 			db.AddError(gorm.ErrMissingWhereClause) | 
 | ||||||
| 			return | 		if err == nil { | ||||||
|  | 			db.RowsAffected, _ = result.RowsAffected() | ||||||
|  | 		} else { | ||||||
|  | 			db.AddError(err) | ||||||
| 		} | 		} | ||||||
| 
 |  | ||||||
| 		db.Statement.AddClauseIfNotExists(clause.From{}) |  | ||||||
| 		db.Statement.Build("DELETE", "FROM", "WHERE") |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) |  | ||||||
| 
 |  | ||||||
| 	if err == nil { |  | ||||||
| 		db.RowsAffected, _ = result.RowsAffected() |  | ||||||
| 	} else { |  | ||||||
| 		db.AddError(err) |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func AfterDelete(db *gorm.DB) { | func AfterDelete(db *gorm.DB) { | ||||||
| 	if db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { | 	if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { | ||||||
|  | 		tx := db.Session(&gorm.Session{}) | ||||||
| 		callMethod := func(value interface{}) bool { | 		callMethod := func(value interface{}) bool { | ||||||
| 			if db.Statement.Schema.AfterDelete { | 			if db.Statement.Schema.AfterDelete { | ||||||
| 				if i, ok := value.(gorm.AfterDeleteInterface); ok { | 				if i, ok := value.(gorm.AfterDeleteInterface); ok { | ||||||
| 					i.AfterDelete(db) | 					db.AddError(i.AfterDelete(tx)) | ||||||
| 					return true | 					return true | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| @ -94,7 +98,7 @@ func AfterDelete(db *gorm.DB) { | |||||||
| 		if ok := callMethod(db.Statement.Dest); !ok { | 		if ok := callMethod(db.Statement.Dest); !ok { | ||||||
| 			switch db.Statement.ReflectValue.Kind() { | 			switch db.Statement.ReflectValue.Kind() { | ||||||
| 			case reflect.Slice, reflect.Array: | 			case reflect.Slice, reflect.Array: | ||||||
| 				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { | 				for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||||
| 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | ||||||
| 				} | 				} | ||||||
| 			case reflect.Struct: | 			case reflect.Struct: | ||||||
|  | |||||||
| @ -12,24 +12,26 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func Query(db *gorm.DB) { | func Query(db *gorm.DB) { | ||||||
| 	if db.Statement.Schema != nil && !db.Statement.Unscoped { | 	if db.Error == nil { | ||||||
| 		for _, c := range db.Statement.Schema.QueryClauses { | 		if db.Statement.Schema != nil && !db.Statement.Unscoped { | ||||||
| 			db.Statement.AddClause(c) | 			for _, c := range db.Statement.Schema.QueryClauses { | ||||||
|  | 				db.Statement.AddClause(c) | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	if db.Statement.SQL.String() == "" { | 		if db.Statement.SQL.String() == "" { | ||||||
| 		BuildQuerySQL(db) | 			BuildQuerySQL(db) | ||||||
| 	} | 		} | ||||||
| 
 | 
 | ||||||
| 	rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 		rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
| 	if err != nil { | 		if err != nil { | ||||||
| 		db.AddError(err) | 			db.AddError(err) | ||||||
| 		return | 			return | ||||||
| 	} | 		} | ||||||
| 	defer rows.Close() | 		defer rows.Close() | ||||||
| 
 | 
 | ||||||
| 	gorm.Scan(rows, db, false) | 		gorm.Scan(rows, db, false) | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func BuildQuerySQL(db *gorm.DB) { | func BuildQuerySQL(db *gorm.DB) { | ||||||
| @ -129,50 +131,53 @@ func BuildQuerySQL(db *gorm.DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Preload(db *gorm.DB) { | func Preload(db *gorm.DB) { | ||||||
| 	if len(db.Statement.Preloads) > 0 { | 	if db.Error == nil { | ||||||
| 		preloadMap := map[string][]string{} | 		if len(db.Statement.Preloads) > 0 { | ||||||
| 		for name := range db.Statement.Preloads { | 			preloadMap := map[string][]string{} | ||||||
| 			preloadFields := strings.Split(name, ".") | 			for name := range db.Statement.Preloads { | ||||||
| 			for idx := range preloadFields { | 				preloadFields := strings.Split(name, ".") | ||||||
| 				preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] | 				for idx := range preloadFields { | ||||||
| 			} | 					preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		preloadNames := make([]string, len(preloadMap)) |  | ||||||
| 		idx := 0 |  | ||||||
| 		for key := range preloadMap { |  | ||||||
| 			preloadNames[idx] = key |  | ||||||
| 			idx++ |  | ||||||
| 		} |  | ||||||
| 		sort.Strings(preloadNames) |  | ||||||
| 
 |  | ||||||
| 		for _, name := range preloadNames { |  | ||||||
| 			var ( |  | ||||||
| 				curSchema     = db.Statement.Schema |  | ||||||
| 				preloadFields = preloadMap[name] |  | ||||||
| 				rels          = make([]*schema.Relationship, len(preloadFields)) |  | ||||||
| 			) |  | ||||||
| 
 |  | ||||||
| 			for idx, preloadField := range preloadFields { |  | ||||||
| 				if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { |  | ||||||
| 					rels[idx] = rel |  | ||||||
| 					curSchema = rel.FieldSchema |  | ||||||
| 				} else { |  | ||||||
| 					db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) |  | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			preload(db, rels, db.Statement.Preloads[name]) | 			preloadNames := make([]string, len(preloadMap)) | ||||||
|  | 			idx := 0 | ||||||
|  | 			for key := range preloadMap { | ||||||
|  | 				preloadNames[idx] = key | ||||||
|  | 				idx++ | ||||||
|  | 			} | ||||||
|  | 			sort.Strings(preloadNames) | ||||||
|  | 
 | ||||||
|  | 			for _, name := range preloadNames { | ||||||
|  | 				var ( | ||||||
|  | 					curSchema     = db.Statement.Schema | ||||||
|  | 					preloadFields = preloadMap[name] | ||||||
|  | 					rels          = make([]*schema.Relationship, len(preloadFields)) | ||||||
|  | 				) | ||||||
|  | 
 | ||||||
|  | 				for idx, preloadField := range preloadFields { | ||||||
|  | 					if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { | ||||||
|  | 						rels[idx] = rel | ||||||
|  | 						curSchema = rel.FieldSchema | ||||||
|  | 					} else { | ||||||
|  | 						db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				preload(db, rels, db.Statement.Preloads[name]) | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func AfterQuery(db *gorm.DB) { | func AfterQuery(db *gorm.DB) { | ||||||
| 	if db.Statement.Schema != nil && db.Statement.Schema.AfterFind { | 	if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind { | ||||||
|  | 		tx := db.Session(&gorm.Session{}) | ||||||
| 		callMethod := func(value interface{}) bool { | 		callMethod := func(value interface{}) bool { | ||||||
| 			if db.Statement.Schema.AfterFind { | 			if db.Statement.Schema.AfterFind { | ||||||
| 				if i, ok := value.(gorm.AfterFindInterface); ok { | 				if i, ok := value.(gorm.AfterFindInterface); ok { | ||||||
| 					i.AfterFind(db) | 					db.AddError(i.AfterFind(tx)) | ||||||
| 					return true | 					return true | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| @ -182,7 +187,7 @@ func AfterQuery(db *gorm.DB) { | |||||||
| 		if ok := callMethod(db.Statement.Dest); !ok { | 		if ok := callMethod(db.Statement.Dest); !ok { | ||||||
| 			switch db.Statement.ReflectValue.Kind() { | 			switch db.Statement.ReflectValue.Kind() { | ||||||
| 			case reflect.Slice, reflect.Array: | 			case reflect.Slice, reflect.Array: | ||||||
| 				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { | 				for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||||
| 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | ||||||
| 				} | 				} | ||||||
| 			case reflect.Struct: | 			case reflect.Struct: | ||||||
|  | |||||||
| @ -5,10 +5,12 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func RawExec(db *gorm.DB) { | func RawExec(db *gorm.DB) { | ||||||
| 	result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 	if db.Error == nil { | ||||||
| 	if err != nil { | 		result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
| 		db.AddError(err) | 		if err != nil { | ||||||
| 	} else { | 			db.AddError(err) | ||||||
| 		db.RowsAffected, _ = result.RowsAffected() | 		} else { | ||||||
|  | 			db.RowsAffected, _ = result.RowsAffected() | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -5,13 +5,15 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func RowQuery(db *gorm.DB) { | func RowQuery(db *gorm.DB) { | ||||||
| 	if db.Statement.SQL.String() == "" { | 	if db.Error == nil { | ||||||
| 		BuildQuerySQL(db) | 		if db.Statement.SQL.String() == "" { | ||||||
| 	} | 			BuildQuerySQL(db) | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 	if _, ok := db.Get("rows"); ok { | 		if _, ok := db.Get("rows"); ok { | ||||||
| 		db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 			db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
| 	} else { | 		} else { | ||||||
| 		db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 			db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -1,9 +1,25 @@ | |||||||
| package callbacks | package callbacks | ||||||
| 
 | 
 | ||||||
| import "github.com/jinzhu/gorm" | import ( | ||||||
|  | 	"github.com/jinzhu/gorm" | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| func BeginTransaction(db *gorm.DB) { | func BeginTransaction(db *gorm.DB) { | ||||||
|  | 	if tx := db.Begin(); tx.Error == nil { | ||||||
|  | 		db.Statement.ConnPool = tx.Statement.ConnPool | ||||||
|  | 		tx.InstanceSet("gorm:started_transaction", true) | ||||||
|  | 	} else { | ||||||
|  | 		tx.Error = nil | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func CommitOrRollbackTransaction(db *gorm.DB) { | func CommitOrRollbackTransaction(db *gorm.DB) { | ||||||
|  | 	if _, ok := db.InstanceGet("gorm:started_transaction"); ok { | ||||||
|  | 		if db.Error == nil { | ||||||
|  | 			db.Commit() | ||||||
|  | 		} else { | ||||||
|  | 			db.Rollback() | ||||||
|  | 		} | ||||||
|  | 		db.Statement.ConnPool = db.ConnPool | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -10,20 +10,21 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func BeforeUpdate(db *gorm.DB) { | func BeforeUpdate(db *gorm.DB) { | ||||||
| 	if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { | 	if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { | ||||||
|  | 		tx := db.Session(&gorm.Session{}) | ||||||
| 		callMethod := func(value interface{}) bool { | 		callMethod := func(value interface{}) bool { | ||||||
| 			var ok bool | 			var ok bool | ||||||
| 			if db.Statement.Schema.BeforeSave { | 			if db.Statement.Schema.BeforeSave { | ||||||
| 				if i, ok := value.(gorm.BeforeSaveInterface); ok { | 				if i, ok := value.(gorm.BeforeSaveInterface); ok { | ||||||
| 					ok = true | 					ok = true | ||||||
| 					i.BeforeSave(db) | 					db.AddError(i.BeforeSave(tx)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			if db.Statement.Schema.BeforeUpdate { | 			if db.Statement.Schema.BeforeUpdate { | ||||||
| 				if i, ok := value.(gorm.BeforeUpdateInterface); ok { | 				if i, ok := value.(gorm.BeforeUpdateInterface); ok { | ||||||
| 					ok = true | 					ok = true | ||||||
| 					i.BeforeUpdate(db) | 					db.AddError(i.BeforeUpdate(tx)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 			return ok | 			return ok | ||||||
| @ -32,7 +33,7 @@ func BeforeUpdate(db *gorm.DB) { | |||||||
| 		if ok := callMethod(db.Statement.Dest); !ok { | 		if ok := callMethod(db.Statement.Dest); !ok { | ||||||
| 			switch db.Statement.ReflectValue.Kind() { | 			switch db.Statement.ReflectValue.Kind() { | ||||||
| 			case reflect.Slice, reflect.Array: | 			case reflect.Slice, reflect.Array: | ||||||
| 				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { | 				for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||||
| 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | ||||||
| 				} | 				} | ||||||
| 			case reflect.Struct: | 			case reflect.Struct: | ||||||
| @ -43,51 +44,54 @@ func BeforeUpdate(db *gorm.DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Update(db *gorm.DB) { | func Update(db *gorm.DB) { | ||||||
| 	if db.Statement.Schema != nil && !db.Statement.Unscoped { | 	if db.Error == nil { | ||||||
| 		for _, c := range db.Statement.Schema.UpdateClauses { | 		if db.Statement.Schema != nil && !db.Statement.Unscoped { | ||||||
| 			db.Statement.AddClause(c) | 			for _, c := range db.Statement.Schema.UpdateClauses { | ||||||
|  | 				db.Statement.AddClause(c) | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	if db.Statement.SQL.String() == "" { | 		if db.Statement.SQL.String() == "" { | ||||||
| 		db.Statement.AddClauseIfNotExists(clause.Update{}) | 			db.Statement.AddClauseIfNotExists(clause.Update{}) | ||||||
| 		if set := ConvertToAssignments(db.Statement); len(set) != 0 { | 			if set := ConvertToAssignments(db.Statement); len(set) != 0 { | ||||||
| 			db.Statement.AddClause(set) | 				db.Statement.AddClause(set) | ||||||
| 		} else { | 			} else { | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			db.Statement.Build("UPDATE", "SET", "WHERE") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if _, ok := db.Statement.Clauses["WHERE"]; !ok { | ||||||
|  | 			db.AddError(gorm.ErrMissingWhereClause) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		db.Statement.Build("UPDATE", "SET", "WHERE") |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	if _, ok := db.Statement.Clauses["WHERE"]; !ok { | 		result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
| 		db.AddError(gorm.ErrMissingWhereClause) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 		if err == nil { | ||||||
| 
 | 			db.RowsAffected, _ = result.RowsAffected() | ||||||
| 	if err == nil { | 		} else { | ||||||
| 		db.RowsAffected, _ = result.RowsAffected() | 			db.AddError(err) | ||||||
| 	} else { | 		} | ||||||
| 		db.AddError(err) |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func AfterUpdate(db *gorm.DB) { | func AfterUpdate(db *gorm.DB) { | ||||||
| 	if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { | 	if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { | ||||||
|  | 		tx := db.Session(&gorm.Session{}) | ||||||
| 		callMethod := func(value interface{}) bool { | 		callMethod := func(value interface{}) bool { | ||||||
| 			var ok bool | 			var ok bool | ||||||
| 			if db.Statement.Schema.AfterSave { | 			if db.Statement.Schema.AfterSave { | ||||||
| 				if i, ok := value.(gorm.AfterSaveInterface); ok { | 				if i, ok := value.(gorm.AfterSaveInterface); ok { | ||||||
| 					ok = true | 					ok = true | ||||||
| 					i.AfterSave(db) | 					db.AddError(i.AfterSave(tx)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			if db.Statement.Schema.AfterUpdate { | 			if db.Statement.Schema.AfterUpdate { | ||||||
| 				if i, ok := value.(gorm.AfterUpdateInterface); ok { | 				if i, ok := value.(gorm.AfterUpdateInterface); ok { | ||||||
| 					ok = true | 					ok = true | ||||||
| 					i.AfterUpdate(db) | 					db.AddError(i.AfterUpdate(tx)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 			return ok | 			return ok | ||||||
| @ -96,7 +100,7 @@ func AfterUpdate(db *gorm.DB) { | |||||||
| 		if ok := callMethod(db.Statement.Dest); !ok { | 		if ok := callMethod(db.Statement.Dest); !ok { | ||||||
| 			switch db.Statement.ReflectValue.Kind() { | 			switch db.Statement.ReflectValue.Kind() { | ||||||
| 			case reflect.Slice, reflect.Array: | 			case reflect.Slice, reflect.Array: | ||||||
| 				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { | 				for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||||
| 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | ||||||
| 				} | 				} | ||||||
| 			case reflect.Struct: | 			case reflect.Struct: | ||||||
|  | |||||||
| @ -16,7 +16,7 @@ var ( | |||||||
| 	// ErrNotImplemented not implemented
 | 	// ErrNotImplemented not implemented
 | ||||||
| 	ErrNotImplemented = errors.New("not implemented") | 	ErrNotImplemented = errors.New("not implemented") | ||||||
| 	// ErrMissingWhereClause missing where clause
 | 	// ErrMissingWhereClause missing where clause
 | ||||||
| 	ErrMissingWhereClause = errors.New("missing WHERE clause while deleting") | 	ErrMissingWhereClause = errors.New("WHERE conditions required") | ||||||
| 	// ErrUnsupportedRelation unsupported relations
 | 	// ErrUnsupportedRelation unsupported relations
 | ||||||
| 	ErrUnsupportedRelation = errors.New("unsupported relations") | 	ErrUnsupportedRelation = errors.New("unsupported relations") | ||||||
| 	// ErrPtrStructSupported only ptr of struct supported
 | 	// ErrPtrStructSupported only ptr of struct supported
 | ||||||
|  | |||||||
							
								
								
									
										108
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										108
									
								
								gorm.go
									
									
									
									
									
								
							| @ -40,14 +40,15 @@ type DB struct { | |||||||
| 	Error        error | 	Error        error | ||||||
| 	RowsAffected int64 | 	RowsAffected int64 | ||||||
| 	Statement    *Statement | 	Statement    *Statement | ||||||
| 	clone        bool | 	clone        int | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Session session config when create session with Session() method
 | // Session session config when create session with Session() method
 | ||||||
| type Session struct { | type Session struct { | ||||||
| 	Context context.Context | 	WithConditions bool | ||||||
| 	Logger  logger.Interface | 	Context        context.Context | ||||||
| 	NowFunc func() time.Time | 	Logger         logger.Interface | ||||||
|  | 	NowFunc        func() time.Time | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Open initialize db session based on dialector
 | // Open initialize db session based on dialector
 | ||||||
| @ -76,10 +77,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { | |||||||
| 		config.cacheStore = &sync.Map{} | 		config.cacheStore = &sync.Map{} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	db = &DB{ | 	db = &DB{Config: config, clone: 1} | ||||||
| 		Config: config, |  | ||||||
| 		clone:  true, |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	db.callbacks = initializeCallbacks(db) | 	db.callbacks = initializeCallbacks(db) | ||||||
| 
 | 
 | ||||||
| @ -96,38 +94,54 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { | |||||||
| // Session create new db session
 | // Session create new db session
 | ||||||
| func (db *DB) Session(config *Session) *DB { | func (db *DB) Session(config *Session) *DB { | ||||||
| 	var ( | 	var ( | ||||||
| 		tx       = db.getInstance() | 		txConfig = *db.Config | ||||||
| 		stmt     = tx.Statement.clone() | 		tx       = &DB{ | ||||||
| 		txConfig = *tx.Config | 			Config:    &txConfig, | ||||||
|  | 			Statement: db.Statement, | ||||||
|  | 			clone:     1, | ||||||
|  | 		} | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	if config.Context != nil { | 	if config.Context != nil { | ||||||
| 		stmt.Context = config.Context | 		if tx.Statement != nil { | ||||||
|  | 			tx.Statement = tx.Statement.clone() | ||||||
|  | 		} else { | ||||||
|  | 			tx.Statement = &Statement{ | ||||||
|  | 				DB:       tx, | ||||||
|  | 				Clauses:  map[string]clause.Clause{}, | ||||||
|  | 				ConnPool: tx.ConnPool, | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		tx.Statement.Context = config.Context | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if config.WithConditions { | ||||||
|  | 		tx.clone = 3 | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if config.Logger != nil { | 	if config.Logger != nil { | ||||||
| 		txConfig.Logger = config.Logger | 		tx.Config.Logger = config.Logger | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if config.NowFunc != nil { | 	if config.NowFunc != nil { | ||||||
| 		txConfig.NowFunc = config.NowFunc | 		tx.Config.NowFunc = config.NowFunc | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &DB{ | 	return tx | ||||||
| 		Config:    &txConfig, |  | ||||||
| 		Statement: stmt, |  | ||||||
| 		clone:     true, |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // WithContext change current instance db's context to ctx
 | // WithContext change current instance db's context to ctx
 | ||||||
| func (db *DB) WithContext(ctx context.Context) *DB { | func (db *DB) WithContext(ctx context.Context) *DB { | ||||||
| 	return db.Session(&Session{Context: ctx}) | 	return db.Session(&Session{WithConditions: true, Context: ctx}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Debug start debug mode
 | // Debug start debug mode
 | ||||||
| func (db *DB) Debug() (tx *DB) { | func (db *DB) Debug() (tx *DB) { | ||||||
| 	return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)}) | 	return db.Session(&Session{ | ||||||
|  | 		WithConditions: true, | ||||||
|  | 		Logger:         db.Logger.LogMode(logger.Info), | ||||||
|  | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Set store value with key into current db instance's context
 | // Set store value with key into current db instance's context
 | ||||||
| @ -145,6 +159,21 @@ func (db *DB) Get(key string) (interface{}, bool) { | |||||||
| 	return nil, false | 	return nil, false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // InstanceSet store value with key into current db instance's context
 | ||||||
|  | func (db *DB) InstanceSet(key string, value interface{}) *DB { | ||||||
|  | 	tx := db.getInstance() | ||||||
|  | 	tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value) | ||||||
|  | 	return tx | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // InstanceGet get value with key from current db instance's context
 | ||||||
|  | func (db *DB) InstanceGet(key string) (interface{}, bool) { | ||||||
|  | 	if db.Statement != nil { | ||||||
|  | 		return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) | ||||||
|  | 	} | ||||||
|  | 	return nil, false | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // Callback returns callback manager
 | // Callback returns callback manager
 | ||||||
| func (db *DB) Callback() *callbacks { | func (db *DB) Callback() *callbacks { | ||||||
| 	return db.callbacks | 	return db.callbacks | ||||||
| @ -166,18 +195,37 @@ func (db *DB) AddError(err error) error { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db *DB) getInstance() *DB { | func (db *DB) getInstance() *DB { | ||||||
| 	if db.clone { | 	if db.clone > 0 { | ||||||
| 		stmt := &Statement{DB: db, Clauses: map[string]clause.Clause{}} | 		tx := &DB{Config: db.Config} | ||||||
| 
 | 
 | ||||||
| 		if db.Statement != nil { | 		switch db.clone { | ||||||
| 			stmt.Context = db.Statement.Context | 		case 1: // clone with new statement
 | ||||||
| 			stmt.ConnPool = db.Statement.ConnPool | 		case 2: // with old statement, generate new statement for future call, used to pass to callbacks
 | ||||||
| 		} else { | 			db.clone = 1 | ||||||
| 			stmt.Context = context.Background() | 			tx.Statement = db.Statement | ||||||
| 			stmt.ConnPool = db.ConnPool | 		case 3: // with clone statement
 | ||||||
|  | 			if db.Statement != nil { | ||||||
|  | 				tx.Statement = db.Statement.clone() | ||||||
|  | 				tx.Statement.DB = tx | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		return &DB{Config: db.Config, Statement: stmt} | 		if tx.Statement == nil { | ||||||
|  | 			tx.Statement = &Statement{ | ||||||
|  | 				DB:      tx, | ||||||
|  | 				Clauses: map[string]clause.Clause{}, | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if db.Statement != nil { | ||||||
|  | 			tx.Statement.Context = db.Statement.Context | ||||||
|  | 			tx.Statement.ConnPool = db.Statement.ConnPool | ||||||
|  | 		} else { | ||||||
|  | 			tx.Statement.Context = context.Background() | ||||||
|  | 			tx.Statement.ConnPool = db.ConnPool | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return tx | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return db | 	return db | ||||||
|  | |||||||
| @ -36,37 +36,37 @@ type TxCommiter interface { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type BeforeCreateInterface interface { | type BeforeCreateInterface interface { | ||||||
| 	BeforeCreate(*DB) | 	BeforeCreate(*DB) error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type AfterCreateInterface interface { | type AfterCreateInterface interface { | ||||||
| 	AfterCreate(*DB) | 	AfterCreate(*DB) error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type BeforeUpdateInterface interface { | type BeforeUpdateInterface interface { | ||||||
| 	BeforeUpdate(*DB) | 	BeforeUpdate(*DB) error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type AfterUpdateInterface interface { | type AfterUpdateInterface interface { | ||||||
| 	AfterUpdate(*DB) | 	AfterUpdate(*DB) error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type BeforeSaveInterface interface { | type BeforeSaveInterface interface { | ||||||
| 	BeforeSave(*DB) | 	BeforeSave(*DB) error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type AfterSaveInterface interface { | type AfterSaveInterface interface { | ||||||
| 	AfterSave(*DB) | 	AfterSave(*DB) error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type BeforeDeleteInterface interface { | type BeforeDeleteInterface interface { | ||||||
| 	BeforeDelete(*DB) | 	BeforeDelete(*DB) error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type AfterDeleteInterface interface { | type AfterDeleteInterface interface { | ||||||
| 	AfterDelete(*DB) | 	AfterDelete(*DB) error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type AfterFindInterface interface { | type AfterFindInterface interface { | ||||||
| 	AfterFind(*DB) | 	AfterFind(*DB) error | ||||||
| } | } | ||||||
|  | |||||||
| @ -12,10 +12,12 @@ import ( | |||||||
| type UserWithCallback struct { | type UserWithCallback struct { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (UserWithCallback) BeforeSave(*gorm.DB) { | func (UserWithCallback) BeforeSave(*gorm.DB) error { | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (UserWithCallback) AfterCreate(*gorm.DB) { | func (UserWithCallback) AfterCreate(*gorm.DB) error { | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestCallback(t *testing.T) { | func TestCallback(t *testing.T) { | ||||||
|  | |||||||
| @ -200,12 +200,12 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	reflectValue := reflect.Indirect(reflect.New(modelType)) | 	reflectValue := reflect.New(modelType) | ||||||
| 	callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} | 	callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} | ||||||
| 	for _, name := range callbacks { | 	for _, name := range callbacks { | ||||||
| 		if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() { | 		if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() { | ||||||
| 			switch methodValue.Type().String() { | 			switch methodValue.Type().String() { | ||||||
| 			case "func(*gorm.DB)": // TODO hack
 | 			case "func(*gorm.DB) error": // TODO hack
 | ||||||
| 				reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) | 				reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) | ||||||
| 			default: | 			default: | ||||||
| 				logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name) | 				logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name) | ||||||
|  | |||||||
							
								
								
									
										201
									
								
								tests/hooks_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								tests/hooks_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,201 @@ | |||||||
|  | package tests_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"github.com/jinzhu/gorm" | ||||||
|  | 	. "github.com/jinzhu/gorm/tests" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Product struct { | ||||||
|  | 	gorm.Model | ||||||
|  | 	Name                  string | ||||||
|  | 	Code                  string | ||||||
|  | 	Price                 float64 | ||||||
|  | 	AfterFindCallTimes    int64 | ||||||
|  | 	BeforeCreateCallTimes int64 | ||||||
|  | 	AfterCreateCallTimes  int64 | ||||||
|  | 	BeforeUpdateCallTimes int64 | ||||||
|  | 	AfterUpdateCallTimes  int64 | ||||||
|  | 	BeforeSaveCallTimes   int64 | ||||||
|  | 	AfterSaveCallTimes    int64 | ||||||
|  | 	BeforeDeleteCallTimes int64 | ||||||
|  | 	AfterDeleteCallTimes  int64 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Product) BeforeCreate(tx *gorm.DB) (err error) { | ||||||
|  | 	if s.Code == "Invalid" { | ||||||
|  | 		err = errors.New("invalid product") | ||||||
|  | 	} | ||||||
|  | 	s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Product) BeforeUpdate(tx *gorm.DB) (err error) { | ||||||
|  | 	if s.Code == "dont_update" { | ||||||
|  | 		err = errors.New("can't update") | ||||||
|  | 	} | ||||||
|  | 	s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Product) BeforeSave(tx *gorm.DB) (err error) { | ||||||
|  | 	if s.Code == "dont_save" { | ||||||
|  | 		err = errors.New("can't save") | ||||||
|  | 	} | ||||||
|  | 	s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Product) AfterFind(tx *gorm.DB) (err error) { | ||||||
|  | 	s.AfterFindCallTimes = s.AfterFindCallTimes + 1 | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Product) AfterCreate(tx *gorm.DB) (err error) { | ||||||
|  | 	return tx.Model(s).UpdateColumn("AfterCreateCallTimes", s.AfterCreateCallTimes+1).Error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Product) AfterUpdate(tx *gorm.DB) (err error) { | ||||||
|  | 	s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Product) AfterSave(tx *gorm.DB) (err error) { | ||||||
|  | 	if s.Code == "after_save_error" { | ||||||
|  | 		err = errors.New("can't save") | ||||||
|  | 	} | ||||||
|  | 	s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Product) BeforeDelete(tx *gorm.DB) (err error) { | ||||||
|  | 	if s.Code == "dont_delete" { | ||||||
|  | 		err = errors.New("can't delete") | ||||||
|  | 	} | ||||||
|  | 	s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Product) AfterDelete(tx *gorm.DB) (err error) { | ||||||
|  | 	if s.Code == "after_delete_error" { | ||||||
|  | 		err = errors.New("can't delete") | ||||||
|  | 	} | ||||||
|  | 	s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Product) GetCallTimes() []int64 { | ||||||
|  | 	return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestRunCallbacks(t *testing.T) { | ||||||
|  | 	DB.Migrator().DropTable(&Product{}) | ||||||
|  | 	DB.AutoMigrate(&Product{}) | ||||||
|  | 
 | ||||||
|  | 	p := Product{Code: "unique_code", Price: 100} | ||||||
|  | 	DB.Save(&p) | ||||||
|  | 
 | ||||||
|  | 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { | ||||||
|  | 		t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Where("Code = ?", "unique_code").First(&p) | ||||||
|  | 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { | ||||||
|  | 		t.Fatalf("After callbacks values are not saved, %v", p.GetCallTimes()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	p.Price = 200 | ||||||
|  | 	DB.Save(&p) | ||||||
|  | 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { | ||||||
|  | 		t.Fatalf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var products []Product | ||||||
|  | 	DB.Find(&products, "code = ?", "unique_code") | ||||||
|  | 	if products[0].AfterFindCallTimes != 1 { | ||||||
|  | 		t.Fatalf("AfterFind callbacks should work with slice, called %v", products[0].AfterFindCallTimes) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Where("Code = ?", "unique_code").First(&p) | ||||||
|  | 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { | ||||||
|  | 		t.Fatalf("After update callbacks values are not saved, %v", p.GetCallTimes()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Delete(&p) | ||||||
|  | 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { | ||||||
|  | 		t.Fatalf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { | ||||||
|  | 		t.Fatalf("Can't find a deleted record") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestCallbacksWithErrors(t *testing.T) { | ||||||
|  | 	DB.Migrator().DropTable(&Product{}) | ||||||
|  | 	DB.AutoMigrate(&Product{}) | ||||||
|  | 
 | ||||||
|  | 	p := Product{Code: "Invalid", Price: 100} | ||||||
|  | 	if DB.Save(&p).Error == nil { | ||||||
|  | 		t.Fatalf("An error from before create callbacks happened when create with invalid value") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { | ||||||
|  | 		t.Fatalf("Should not save record that have errors") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { | ||||||
|  | 		t.Fatalf("An error from after create callbacks happened when create with invalid value") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	p2 := Product{Code: "update_callback", Price: 100} | ||||||
|  | 	DB.Save(&p2) | ||||||
|  | 
 | ||||||
|  | 	p2.Code = "dont_update" | ||||||
|  | 	if DB.Save(&p2).Error == nil { | ||||||
|  | 		t.Fatalf("An error from before update callbacks happened when update with invalid value") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { | ||||||
|  | 		t.Fatalf("Record Should not be updated due to errors happened in before update callback") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { | ||||||
|  | 		t.Fatalf("Record Should not be updated due to errors happened in before update callback") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	p2.Code = "dont_save" | ||||||
|  | 	if DB.Save(&p2).Error == nil { | ||||||
|  | 		t.Fatalf("An error from before save callbacks happened when update with invalid value") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	p3 := Product{Code: "dont_delete", Price: 100} | ||||||
|  | 	DB.Save(&p3) | ||||||
|  | 	if DB.Delete(&p3).Error == nil { | ||||||
|  | 		t.Fatalf("An error from before delete callbacks happened when delete") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { | ||||||
|  | 		t.Fatalf("An error from before delete callbacks happened") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	p4 := Product{Code: "after_save_error", Price: 100} | ||||||
|  | 	DB.Save(&p4) | ||||||
|  | 	if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { | ||||||
|  | 		t.Fatalf("Record should be reverted if get an error in after save callback") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	p5 := Product{Code: "after_delete_error", Price: 100} | ||||||
|  | 	DB.Save(&p5) | ||||||
|  | 	if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { | ||||||
|  | 		t.Fatalf("Record should be found") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Delete(&p5) | ||||||
|  | 	if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { | ||||||
|  | 		t.Fatalf("Record shouldn't be deleted because of an error happened in after delete callback") | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @ -59,9 +59,9 @@ func OpenTestConnection() (db *gorm.DB, err error) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if debug := os.Getenv("DEBUG"); debug == "true" { | 	if debug := os.Getenv("DEBUG"); debug == "true" { | ||||||
| 		db.Logger.LogMode(logger.Info) | 		db.Logger = db.Logger.LogMode(logger.Info) | ||||||
| 	} else if debug == "false" { | 	} else if debug == "false" { | ||||||
| 		db.Logger.LogMode(logger.Silent) | 		db.Logger = db.Logger.LogMode(logger.Silent) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return | 	return | ||||||
|  | |||||||
| @ -14,37 +14,37 @@ func TestTransaction(t *testing.T) { | |||||||
| 	user := *GetUser("transcation", Config{}) | 	user := *GetUser("transcation", Config{}) | ||||||
| 
 | 
 | ||||||
| 	if err := tx.Save(&user).Error; err != nil { | 	if err := tx.Save(&user).Error; err != nil { | ||||||
| 		t.Errorf("No error should raise, but got %v", err) | 		t.Fatalf("No error should raise, but got %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { | 	if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { | ||||||
| 		t.Errorf("Should find saved record, but got %v", err) | 		t.Fatalf("Should find saved record, but got %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil { | 	if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil { | ||||||
| 		t.Errorf("Should return the underlying sql.Tx") | 		t.Fatalf("Should return the underlying sql.Tx") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	tx.Rollback() | 	tx.Rollback() | ||||||
| 
 | 
 | ||||||
| 	if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { | 	if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { | ||||||
| 		t.Errorf("Should not find record after rollback, but got %v", err) | 		t.Fatalf("Should not find record after rollback, but got %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	tx2 := DB.Begin() | 	tx2 := DB.Begin() | ||||||
| 	user2 := *GetUser("transcation-2", Config{}) | 	user2 := *GetUser("transcation-2", Config{}) | ||||||
| 	if err := tx2.Save(&user2).Error; err != nil { | 	if err := tx2.Save(&user2).Error; err != nil { | ||||||
| 		t.Errorf("No error should raise, but got %v", err) | 		t.Fatalf("No error should raise, but got %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { | 	if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { | ||||||
| 		t.Errorf("Should find saved record, but got %v", err) | 		t.Fatalf("Should find saved record, but got %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	tx2.Commit() | 	tx2.Commit() | ||||||
| 
 | 
 | ||||||
| 	if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { | 	if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { | ||||||
| 		t.Errorf("Should be able to find committed record, but got %v", err) | 		t.Fatalf("Should be able to find committed record, but got %v", err) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -52,7 +52,7 @@ func TestTransactionWithBlock(t *testing.T) { | |||||||
| 	assertPanic := func(f func()) { | 	assertPanic := func(f func()) { | ||||||
| 		defer func() { | 		defer func() { | ||||||
| 			if r := recover(); r == nil { | 			if r := recover(); r == nil { | ||||||
| 				t.Errorf("The code did not panic") | 				t.Fatalf("The code did not panic") | ||||||
| 			} | 			} | ||||||
| 		}() | 		}() | ||||||
| 		f() | 		f() | ||||||
| @ -62,39 +62,39 @@ func TestTransactionWithBlock(t *testing.T) { | |||||||
| 	err := DB.Transaction(func(tx *gorm.DB) error { | 	err := DB.Transaction(func(tx *gorm.DB) error { | ||||||
| 		user := *GetUser("transcation-block", Config{}) | 		user := *GetUser("transcation-block", Config{}) | ||||||
| 		if err := tx.Save(&user).Error; err != nil { | 		if err := tx.Save(&user).Error; err != nil { | ||||||
| 			t.Errorf("No error should raise") | 			t.Fatalf("No error should raise") | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { | 		if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { | ||||||
| 			t.Errorf("Should find saved record") | 			t.Fatalf("Should find saved record") | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		return errors.New("the error message") | 		return errors.New("the error message") | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	if err.Error() != "the error message" { | 	if err.Error() != "the error message" { | ||||||
| 		t.Errorf("Transaction return error will equal the block returns error") | 		t.Fatalf("Transaction return error will equal the block returns error") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := DB.First(&User{}, "name = ?", "transcation-block").Error; err == nil { | 	if err := DB.First(&User{}, "name = ?", "transcation-block").Error; err == nil { | ||||||
| 		t.Errorf("Should not find record after rollback") | 		t.Fatalf("Should not find record after rollback") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// commit
 | 	// commit
 | ||||||
| 	DB.Transaction(func(tx *gorm.DB) error { | 	DB.Transaction(func(tx *gorm.DB) error { | ||||||
| 		user := *GetUser("transcation-block-2", Config{}) | 		user := *GetUser("transcation-block-2", Config{}) | ||||||
| 		if err := tx.Save(&user).Error; err != nil { | 		if err := tx.Save(&user).Error; err != nil { | ||||||
| 			t.Errorf("No error should raise") | 			t.Fatalf("No error should raise") | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { | 		if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { | ||||||
| 			t.Errorf("Should find saved record") | 			t.Fatalf("Should find saved record") | ||||||
| 		} | 		} | ||||||
| 		return nil | 		return nil | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	if err := DB.First(&User{}, "name = ?", "transcation-block-2").Error; err != nil { | 	if err := DB.First(&User{}, "name = ?", "transcation-block-2").Error; err != nil { | ||||||
| 		t.Errorf("Should be able to find committed record") | 		t.Fatalf("Should be able to find committed record") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// panic will rollback
 | 	// panic will rollback
 | ||||||
| @ -102,11 +102,11 @@ func TestTransactionWithBlock(t *testing.T) { | |||||||
| 		DB.Transaction(func(tx *gorm.DB) error { | 		DB.Transaction(func(tx *gorm.DB) error { | ||||||
| 			user := *GetUser("transcation-block-3", Config{}) | 			user := *GetUser("transcation-block-3", Config{}) | ||||||
| 			if err := tx.Save(&user).Error; err != nil { | 			if err := tx.Save(&user).Error; err != nil { | ||||||
| 				t.Errorf("No error should raise") | 				t.Fatalf("No error should raise") | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { | 			if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { | ||||||
| 				t.Errorf("Should find saved record") | 				t.Fatalf("Should find saved record") | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			panic("force panic") | 			panic("force panic") | ||||||
| @ -114,7 +114,7 @@ func TestTransactionWithBlock(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	if err := DB.First(&User{}, "name = ?", "transcation-block-3").Error; err == nil { | 	if err := DB.First(&User{}, "name = ?", "transcation-block-3").Error; err == nil { | ||||||
| 		t.Errorf("Should not find record after panic rollback") | 		t.Fatalf("Should not find record after panic rollback") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -122,14 +122,14 @@ func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) { | |||||||
| 	tx := DB.Begin() | 	tx := DB.Begin() | ||||||
| 	user := User{Name: "transcation"} | 	user := User{Name: "transcation"} | ||||||
| 	if err := tx.Save(&user).Error; err != nil { | 	if err := tx.Save(&user).Error; err != nil { | ||||||
| 		t.Errorf("No error should raise") | 		t.Fatalf("No error should raise") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := tx.Commit().Error; err != nil { | 	if err := tx.Commit().Error; err != nil { | ||||||
| 		t.Errorf("Commit should not raise error") | 		t.Fatalf("Commit should not raise error") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := tx.Rollback().Error; err == nil { | 	if err := tx.Rollback().Error; err == nil { | ||||||
| 		t.Errorf("Rollback after commit should raise error") | 		t.Fatalf("Rollback after commit should raise error") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu