Fix config and add ut
This commit is contained in:
		
							parent
							
								
									40b4998893
								
							
						
					
					
						commit
						dac14f0b78
					
				| @ -1,7 +1,6 @@ | |||||||
| package callbacks | package callbacks | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"database/sql/driver" |  | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"strings" | 	"strings" | ||||||
| @ -114,7 +113,7 @@ func Create(config *Config) func(db *gorm.DB) { | |||||||
| 			db.Statement.Result.RowsAffected = db.RowsAffected | 			db.Statement.Result.RowsAffected = db.RowsAffected | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if db.RowsAffected == 0 { | 		if db.RowsAffected == 0 || db.DisableLastInsertID { | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| @ -127,22 +126,6 @@ func Create(config *Config) func(db *gorm.DB) { | |||||||
| 		insertOk := err == nil && insertID > 0 | 		insertOk := err == nil && insertID > 0 | ||||||
| 
 | 
 | ||||||
| 		if !insertOk { | 		if !insertOk { | ||||||
| 			if db.IgnoreLastInsertIDWhenNotSupport { |  | ||||||
| 				_, rowsAffectedErr := driver.RowsAffected(0).LastInsertId() |  | ||||||
| 				if strings.Compare(err.Error(), rowsAffectedErr.Error()) == 0 { |  | ||||||
| 					return |  | ||||||
| 				} |  | ||||||
| 				_, resultNoRowsErr := driver.ResultNoRows.LastInsertId() |  | ||||||
| 				if strings.Compare(err.Error(), resultNoRowsErr.Error()) == 0 { |  | ||||||
| 					return |  | ||||||
| 				} |  | ||||||
| 				if db.IsNotSupportLastInsertIDErr != nil && db.IsNotSupportLastInsertIDErr(err) { |  | ||||||
| 					return |  | ||||||
| 				} |  | ||||||
| 				if db.Logger != nil { |  | ||||||
| 					db.Logger.Warn(db.Statement.Context, "Failed to get last insert ID, err: %v", err) |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 			if !supportReturning { | 			if !supportReturning { | ||||||
| 				db.AddError(err) | 				db.AddError(err) | ||||||
| 			} | 			} | ||||||
|  | |||||||
							
								
								
									
										7
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								gorm.go
									
									
									
									
									
								
							| @ -24,11 +24,8 @@ type Config struct { | |||||||
| 	SkipDefaultTransaction    bool | 	SkipDefaultTransaction    bool | ||||||
| 	DefaultTransactionTimeout time.Duration | 	DefaultTransactionTimeout time.Duration | ||||||
| 
 | 
 | ||||||
| 	// Not all database support LastInsertId, you can set `IgnoreLastInsertIDWhenNotSupport` to true in those cases
 | 	// Not all database support LastInsertId, you can set `DisableLastInsertID` to true in those cases
 | ||||||
| 	IgnoreLastInsertIDWhenNotSupport bool | 	DisableLastInsertID bool | ||||||
| 	// When 'IgnoreLastInsertIDWhenNotSupport' is true, you can set `IsNotSupportLastInsertIDErr` to check if the error is 'NotSupportLastInsertID'
 |  | ||||||
| 	// Gorm only asserts the type of returned sql/driver.Result if 'IsNotSupportLastInsertIDErr' is not set.
 |  | ||||||
| 	IsNotSupportLastInsertIDErr func(error) bool |  | ||||||
| 
 | 
 | ||||||
| 	// NamingStrategy tables, columns naming strategy
 | 	// NamingStrategy tables, columns naming strategy
 | ||||||
| 	NamingStrategy schema.Namer | 	NamingStrategy schema.Namer | ||||||
|  | |||||||
| @ -1,6 +1,9 @@ | |||||||
| package tests_test | package tests_test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"database/sql" | ||||||
|  | 	"database/sql/driver" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"regexp" | 	"regexp" | ||||||
| @ -535,6 +538,40 @@ func TestCreateNilPointer(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | type ConnPoolLastInsertIDMock struct { | ||||||
|  | 	gorm.ConnPool | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m ConnPoolLastInsertIDMock) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { | ||||||
|  | 	rst, err := m.ConnPool.ExecContext(ctx, query, args...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	affected, _ := rst.RowsAffected() | ||||||
|  | 	return driver.RowsAffected(affected), nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestCreateWithDisableLastInsertID(t *testing.T) { | ||||||
|  | 	rawPool := DB.ConnPool | ||||||
|  | 	DB.ConnPool = ConnPoolLastInsertIDMock{rawPool} | ||||||
|  | 	defer func() { | ||||||
|  | 		DB.ConnPool = rawPool | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	user := &User{Name: "TestCreateWithDisableLastInsertID"} | ||||||
|  | 	err := DB.Create(&user).Error | ||||||
|  | 	if err == nil { | ||||||
|  | 		t.Fatalf("it should be error") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.DisableLastInsertID = true | ||||||
|  | 	err = DB.Create(&user).Error | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("it should be nil") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestFirstOrCreateRowsAffected(t *testing.T) { | func TestFirstOrCreateRowsAffected(t *testing.T) { | ||||||
| 	user := User{Name: "TestFirstOrCreateRowsAffected"} | 	user := User{Name: "TestFirstOrCreateRowsAffected"} | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Krisdiano
						Krisdiano