diff --git a/callbacks.go b/callbacks.go index 50b5b0e9..c099b44e 100644 --- a/callbacks.go +++ b/callbacks.go @@ -121,6 +121,26 @@ func (p *processor) Execute(db *DB) *DB { stmt.ReflectValue = stmt.ReflectValue.Elem() } + if (stmt.ReflectValue.Kind() == reflect.Slice || stmt.ReflectValue.Kind() == reflect.Array) && + (stmt.ReflectValue.Len() > 0 && stmt.ReflectValue.Index(0).Kind() == reflect.Interface) { + len := stmt.ReflectValue.Len() + firstElem := stmt.ReflectValue.Index(0) + for firstElem.Kind() == reflect.Interface || firstElem.Kind() == reflect.Ptr { + firstElem = firstElem.Elem() + } + elemType := firstElem.Type() + sliceType := reflect.SliceOf(elemType) + structArrayReflectValue := reflect.MakeSlice(sliceType, 0, len) + + for i := 0; i < len; i++ { + elem := stmt.ReflectValue.Index(i) + for elem.Kind() == reflect.Interface || elem.Kind() == reflect.Ptr { + elem = elem.Elem() + } + structArrayReflectValue = reflect.Append(structArrayReflectValue, elem) + } + stmt.ReflectValue = structArrayReflectValue + } if !stmt.ReflectValue.IsValid() { db.AddError(ErrInvalidValue) } diff --git a/schema/schema.go b/schema/schema.go index 9419846b..fb5c0bf3 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -135,6 +135,14 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam modelType = modelType.Elem() } + if modelType.Kind() == reflect.Interface { + if value.Len() > 0 { + modelType = reflect.Indirect(value.Index(0)).Elem().Type() + } + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + } if modelType.Kind() != reflect.Struct { if modelType.Kind() == reflect.Interface { modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() diff --git a/tests/create_test.go b/tests/create_test.go index 5200427e..5c1f31df 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -808,3 +808,15 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Errorf("failed to create data from map with table, @id != id") } } + +func TestCreateWithInterfaceArrayType(t *testing.T) { + user := *GetUser("create", Config{}) + type UserInterface interface{} + var userInterface UserInterface = &user + + if results := DB.Create([]UserInterface{userInterface}); results.Error != nil { + t.Fatalf("errors happened when create: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } +}