From b5d069300cf1514832ac692f9ce2cbc94cdf8e1a Mon Sep 17 00:00:00 2001 From: KEHyeon Date: Sun, 13 Apr 2025 18:57:57 +0900 Subject: [PATCH 1/3] fix: prevent interface type array from causing runtime errors --- callbacks.go | 21 +++++++++++++++++++++ schema/schema.go | 8 ++++++++ 2 files changed, 29 insertions(+) diff --git a/callbacks.go b/callbacks.go index 50b5b0e9..ac10d4af 100644 --- a/callbacks.go +++ b/callbacks.go @@ -121,6 +121,27 @@ func (p *processor) Execute(db *DB) *DB { stmt.ReflectValue = stmt.ReflectValue.Elem() } + if (stmt.ReflectValue.Kind() == reflect.Slice || stmt.ReflectValue.Kind() == reflect.Array) && + (stmt.ReflectValue.Len() > 0 || stmt.ReflectValue.Index(0).Kind() == reflect.Interface) { + len := stmt.ReflectValue.Len() + firstElem := stmt.ReflectValue.Index(0) + for firstElem.Kind() == reflect.Interface || firstElem.Kind() == reflect.Ptr { + firstElem = firstElem.Elem() + } + elemType := firstElem.Type() + sliceType := reflect.SliceOf(elemType) + structArrayReflectValue := reflect.MakeSlice(sliceType, 0, len) + + for i := 0; i < len; i++ { + elem := stmt.ReflectValue.Index(i) + for elem.Kind() == reflect.Interface || elem.Kind() == reflect.Ptr { + elem = elem.Elem() + } + structArrayReflectValue = reflect.Append(structArrayReflectValue, elem) + } + stmt.ReflectValue = structArrayReflectValue + fmt.Println(stmt.ReflectValue.Type()) + } if !stmt.ReflectValue.IsValid() { db.AddError(ErrInvalidValue) } diff --git a/schema/schema.go b/schema/schema.go index db236797..997a7009 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -138,6 +138,14 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam modelType = modelType.Elem() } + if modelType.Kind() == reflect.Interface { + if value.Len() > 0 { + modelType = reflect.Indirect(value.Index(0)).Elem().Type() + } + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + } if modelType.Kind() != reflect.Struct { if modelType.PkgPath() == "" { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) From f5ac40e3955473586506ba692ab7aa550201a8b4 Mon Sep 17 00:00:00 2001 From: KEHyeon Date: Sun, 13 Apr 2025 18:59:32 +0900 Subject: [PATCH 2/3] Add test about interface type array --- tests/create_test.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/create_test.go b/tests/create_test.go index abb82472..e90d2142 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -791,3 +791,15 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Errorf("failed to create data from map with table, @id != id") } } + +func TestCreateWithInterfaceArrayType(t *testing.T) { + user := *GetUser("create", Config{}) + type UserInterface interface{} + var userInterface UserInterface = &user + + if results := DB.Create([]UserInterface{userInterface}); results.Error != nil { + t.Fatalf("errors happened when create: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } +} From 8492af206f1e7a4eee7c38b28fd56999f6ea68ac Mon Sep 17 00:00:00 2001 From: KEHyeon Date: Sun, 13 Apr 2025 19:05:14 +0900 Subject: [PATCH 3/3] fix out of range error --- callbacks.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/callbacks.go b/callbacks.go index ac10d4af..c099b44e 100644 --- a/callbacks.go +++ b/callbacks.go @@ -122,7 +122,7 @@ func (p *processor) Execute(db *DB) *DB { stmt.ReflectValue = stmt.ReflectValue.Elem() } if (stmt.ReflectValue.Kind() == reflect.Slice || stmt.ReflectValue.Kind() == reflect.Array) && - (stmt.ReflectValue.Len() > 0 || stmt.ReflectValue.Index(0).Kind() == reflect.Interface) { + (stmt.ReflectValue.Len() > 0 && stmt.ReflectValue.Index(0).Kind() == reflect.Interface) { len := stmt.ReflectValue.Len() firstElem := stmt.ReflectValue.Index(0) for firstElem.Kind() == reflect.Interface || firstElem.Kind() == reflect.Ptr { @@ -140,7 +140,6 @@ func (p *processor) Execute(db *DB) *DB { structArrayReflectValue = reflect.Append(structArrayReflectValue, elem) } stmt.ReflectValue = structArrayReflectValue - fmt.Println(stmt.ReflectValue.Type()) } if !stmt.ReflectValue.IsValid() { db.AddError(ErrInvalidValue)