Test CreateWithNoGORMPrimayKey
This commit is contained in:
		
							parent
							
								
									b3b19a5577
								
							
						
					
					
						commit
						1546f8a4a1
					
				| @ -64,7 +64,7 @@ func Create(config *Config) func(db *gorm.DB) { | |||||||
| 			result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 			result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
| 
 | 
 | ||||||
| 			if err == nil { | 			if err == nil { | ||||||
| 				if db.Statement.Schema != nil { | 				if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { | ||||||
| 					if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { | 					if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { | ||||||
| 						if insertID, err := result.LastInsertId(); err == nil { | 						if insertID, err := result.LastInsertId(); err == nil { | ||||||
| 							switch db.Statement.ReflectValue.Kind() { | 							switch db.Statement.ReflectValue.Kind() { | ||||||
|  | |||||||
| @ -68,26 +68,30 @@ func Create(db *gorm.DB) { | |||||||
| 
 | 
 | ||||||
| 		switch db.Statement.ReflectValue.Kind() { | 		switch db.Statement.ReflectValue.Kind() { | ||||||
| 		case reflect.Slice, reflect.Array: | 		case reflect.Slice, reflect.Array: | ||||||
| 			for rows.Next() { | 			if len(db.Statement.Schema.PrimaryFields) > 0 { | ||||||
| 				// for idx, field := range fields {
 | 				values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) | ||||||
| 				// 	values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
 |  | ||||||
| 				// }
 |  | ||||||
| 
 | 
 | ||||||
| 				values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() | 				for rows.Next() { | ||||||
| 				if err := rows.Scan(values); err != nil { | 					for idx, field := range db.Statement.Schema.PrimaryFields { | ||||||
| 					db.AddError(err) | 						values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
|  | 					db.RowsAffected++ | ||||||
|  | 					db.AddError(rows.Scan(values...)) | ||||||
| 				} | 				} | ||||||
| 				db.RowsAffected++ |  | ||||||
| 			} | 			} | ||||||
| 		case reflect.Struct: | 		case reflect.Struct: | ||||||
| 			// for idx, field := range fields {
 | 			if len(db.Statement.Schema.PrimaryFields) > 0 { | ||||||
| 			// 	values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
 | 				values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) | ||||||
| 			// }
 |  | ||||||
| 			values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() |  | ||||||
| 
 | 
 | ||||||
| 			if rows.Next() { | 				for idx, field := range db.Statement.Schema.PrimaryFields { | ||||||
| 				db.RowsAffected++ | 					values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() | ||||||
| 				db.AddError(rows.Scan(values)) | 				} | ||||||
|  | 
 | ||||||
|  | 				if rows.Next() { | ||||||
|  | 					db.RowsAffected++ | ||||||
|  | 					db.AddError(rows.Scan(values...)) | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| @ -177,8 +181,14 @@ func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func outputInserted(db *gorm.DB) { | func outputInserted(db *gorm.DB) { | ||||||
| 	if db.Statement.Schema.PrioritizedPrimaryField != nil { | 	if len(db.Statement.Schema.PrimaryFields) > 0 { | ||||||
| 		db.Statement.WriteString(" OUTPUT INSERTED.") | 		db.Statement.WriteString(" OUTPUT ") | ||||||
| 		db.Statement.WriteQuoted(db.Statement.Schema.PrioritizedPrimaryField.DBName) | 		for idx, field := range db.Statement.Schema.PrimaryFields { | ||||||
|  | 			if idx > 0 { | ||||||
|  | 				db.Statement.WriteString(",") | ||||||
|  | 			} | ||||||
|  | 			db.Statement.WriteString(" INSERTED.") | ||||||
|  | 			db.Statement.AddVar(db.Statement, clause.Column{Name: field.DBName}) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -149,7 +149,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { | |||||||
| 				createTableSQL += "," | 				createTableSQL += "," | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			if !hasPrimaryKeyInDataType { | 			if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { | ||||||
| 				createTableSQL += "PRIMARY KEY ?," | 				createTableSQL += "PRIMARY KEY ?," | ||||||
| 				primaryKeys := []interface{}{} | 				primaryKeys := []interface{}{} | ||||||
| 				for _, field := range stmt.Schema.PrimaryFields { | 				for _, field := range stmt.Schema.PrimaryFields { | ||||||
|  | |||||||
| @ -245,3 +245,21 @@ func TestCreateWithNowFuncOverride(t *testing.T) { | |||||||
| 	AssertEqual(t, newUser.CreatedAt, curTime) | 	AssertEqual(t, newUser.CreatedAt, curTime) | ||||||
| 	AssertEqual(t, newUser.UpdatedAt, curTime) | 	AssertEqual(t, newUser.UpdatedAt, curTime) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestCreateWithNoGORMPrimayKey(t *testing.T) { | ||||||
|  | 	type JoinTable struct { | ||||||
|  | 		UserID   uint | ||||||
|  | 		FriendID uint | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Migrator().DropTable(&JoinTable{}) | ||||||
|  | 	if err := DB.AutoMigrate(&JoinTable{}); err != nil { | ||||||
|  | 		t.Errorf("no error should happen when auto migrate, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	jt := JoinTable{UserID: 1, FriendID: 2} | ||||||
|  | 	err := DB.Create(&jt).Error | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
| @ -34,6 +34,7 @@ func TestScannerValuer(t *testing.T) { | |||||||
| 			{"name1", "value1"}, | 			{"name1", "value1"}, | ||||||
| 			{"name2", "value2"}, | 			{"name2", "value2"}, | ||||||
| 		}, | 		}, | ||||||
|  | 		Role: Role{Name: "admin"}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := DB.Create(&data).Error; err != nil { | 	if err := DB.Create(&data).Error; err != nil { | ||||||
| @ -91,6 +92,7 @@ type ScannerValuerStruct struct { | |||||||
| 	Num      Num | 	Num      Num | ||||||
| 	Strings  StringsSlice | 	Strings  StringsSlice | ||||||
| 	Structs  StructsSlice | 	Structs  StructsSlice | ||||||
|  | 	Role     Role | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type EncryptedData []byte | type EncryptedData []byte | ||||||
| @ -176,3 +178,24 @@ func (l *StructsSlice) Scan(input interface{}) error { | |||||||
| 		return errors.New("not supported") | 		return errors.New("not supported") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | type Role struct { | ||||||
|  | 	Name string `gorm:"size:256"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (role *Role) Scan(value interface{}) error { | ||||||
|  | 	if b, ok := value.([]uint8); ok { | ||||||
|  | 		role.Name = string(b) | ||||||
|  | 	} else { | ||||||
|  | 		role.Name = value.(string) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (role Role) Value() (driver.Value, error) { | ||||||
|  | 	return role.Name, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (role Role) IsAdmin() bool { | ||||||
|  | 	return role.Name == "admin" | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu