From a549b6bd4964d51d37e2ba3f45f4f5bf909deb87 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 11 Mar 2015 17:05:58 +0800 Subject: [PATCH] Refactor SQL Tag --- README.md | 1 + common_dialect.go | 20 +++++++------------- dialect.go | 3 +-- model_struct.go | 19 +++++++------------ mssql.go | 20 +++++++------------- mysql.go | 20 +++++++------------- postgres.go | 19 +++++++------------ scope_private.go | 28 ++++++++++++++++++++-------- sqlite3.go | 14 ++++---------- 9 files changed, 61 insertions(+), 83 deletions(-) diff --git a/README.md b/README.md index 7085ef46..c99a3d0c 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ type User struct { Birthday time.Time Age int Name string `sql:"size:255"` // Default size for string is 255, you could reset it with this tag + Num int `sql:"AUTO_INCREMENT"` CreatedAt time.Time UpdatedAt time.Time DeletedAt time.Time diff --git a/common_dialect.go b/common_dialect.go index 87be97bf..9360cd26 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -21,13 +21,19 @@ func (s *commonDialect) HasTop() bool { return false } -func (s *commonDialect) SqlTag(value reflect.Value, size int) string { +func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "BOOLEAN" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if autoIncrease { + return "INTEGER AUTO_INCREMENT" + } return "INTEGER" case reflect.Int64, reflect.Uint64: + if autoIncrease { + return "BIGINT AUTO_INCREMENT" + } return "BIGINT" case reflect.Float32, reflect.Float64: return "FLOAT" @@ -51,18 +57,6 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int) string { panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String())) } -func (s *commonDialect) PrimaryKeyTag(value reflect.Value, size int) string { - suffix := " NOT NULL PRIMARY KEY" - switch value.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - return "INTEGER" + suffix - case reflect.Int64, reflect.Uint64: - return "BIGINT" + suffix - default: - panic("Invalid primary key type") - } -} - func (s *commonDialect) ReturningStr(tableName, key string) string { return "" } diff --git a/dialect.go b/dialect.go index 0c58d61a..2e64cca5 100644 --- a/dialect.go +++ b/dialect.go @@ -9,8 +9,7 @@ type Dialect interface { BinVar(i int) string SupportLastInsertId() bool HasTop() bool - SqlTag(value reflect.Value, size int) string - PrimaryKeyTag(value reflect.Value, size int) string + SqlTag(value reflect.Value, size int, autoIncrease bool) string ReturningStr(tableName, key string) string SelectFromDummyTable() string Quote(key string) string diff --git a/model_struct.go b/model_struct.go index 17605e50..ffd0a522 100644 --- a/model_struct.go +++ b/model_struct.go @@ -27,7 +27,6 @@ type StructField struct { IsIgnored bool IsScanner bool HasDefaultValue bool - SqlTag string Tag reflect.StructTag Struct reflect.StructField IsForeignKey bool @@ -44,7 +43,6 @@ func (structField *StructField) clone() *StructField { IsIgnored: structField.IsIgnored, IsScanner: structField.IsScanner, HasDefaultValue: structField.HasDefaultValue, - SqlTag: structField.SqlTag, Tag: structField.Tag, Struct: structField.Struct, IsForeignKey: structField.IsForeignKey, @@ -281,10 +279,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct { field.IsPrimaryKey = true modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) } - - if scope.db != nil { - scope.generateSqlTag(field) - } } } modelStruct.StructFields = append(modelStruct.StructFields, field) @@ -301,7 +295,7 @@ func (scope *Scope) GetStructFields() (fields []*StructField) { return scope.GetModelStruct().StructFields } -func (scope *Scope) generateSqlTag(field *StructField) { +func (scope *Scope) generateSqlTag(field *StructField) string { var sqlType string structType := field.Struct.Type if structType.Kind() == reflect.Ptr { @@ -337,17 +331,18 @@ func (scope *Scope) generateSqlTag(field *StructField) { size, _ = strconv.Atoi(value) } + _, autoIncrease := sqlSettings["AUTO_INCREMENT"] if field.IsPrimaryKey { - sqlType = scope.Dialect().PrimaryKeyTag(reflectValue, size) - } else { - sqlType = scope.Dialect().SqlTag(reflectValue, size) + autoIncrease = true } + + sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease) } if strings.TrimSpace(additionalType) == "" { - field.SqlTag = sqlType + return sqlType } else { - field.SqlTag = fmt.Sprintf("%v %v", sqlType, additionalType) + return fmt.Sprintf("%v %v", sqlType, additionalType) } } diff --git a/mssql.go b/mssql.go index 3323874c..dc8e2917 100644 --- a/mssql.go +++ b/mssql.go @@ -21,13 +21,19 @@ func (s *mssql) HasTop() bool { return true } -func (s *mssql) SqlTag(value reflect.Value, size int) string { +func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "bit" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if autoIncrease { + return "int IDENTITY(1,1)" + } return "int" case reflect.Int64, reflect.Uint64: + if autoIncrease { + return "bigint IDENTITY(1,1)" + } return "bigint" case reflect.Float32, reflect.Float64: return "float" @@ -51,18 +57,6 @@ func (s *mssql) SqlTag(value reflect.Value, size int) string { panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String())) } -func (s *mssql) PrimaryKeyTag(value reflect.Value, size int) string { - suffix := " IDENTITY(1,1) PRIMARY KEY" - switch value.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - return "int" + suffix - case reflect.Int64, reflect.Uint64: - return "bigint" + suffix - default: - panic("Invalid primary key type") - } -} - func (s *mssql) ReturningStr(tableName, key string) string { return "" } diff --git a/mysql.go b/mysql.go index e608619d..d2eb08a5 100644 --- a/mysql.go +++ b/mysql.go @@ -21,13 +21,19 @@ func (s *mysql) HasTop() bool { return false } -func (s *mysql) SqlTag(value reflect.Value, size int) string { +func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "boolean" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if autoIncrease { + return "int AUTO_INCREMENT" + } return "int" case reflect.Int64, reflect.Uint64: + if autoIncrease { + return "bigint AUTO_INCREMENT" + } return "bigint" case reflect.Float32, reflect.Float64: return "double" @@ -51,18 +57,6 @@ func (s *mysql) SqlTag(value reflect.Value, size int) string { panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) } -func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string { - suffix := " NOT NULL AUTO_INCREMENT PRIMARY KEY" - switch value.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - return "int" + suffix - case reflect.Int64, reflect.Uint64: - return "bigint" + suffix - default: - panic("Invalid primary key type") - } -} - func (s *mysql) ReturningStr(tableName, key string) string { return "" } diff --git a/postgres.go b/postgres.go index 98068536..83c37e1f 100644 --- a/postgres.go +++ b/postgres.go @@ -25,13 +25,19 @@ func (s *postgres) HasTop() bool { return false } -func (s *postgres) SqlTag(value reflect.Value, size int) string { +func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "boolean" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if autoIncrease { + return "serial" + } return "integer" case reflect.Int64, reflect.Uint64: + if autoIncrease { + return "bigserial" + } return "bigint" case reflect.Float32, reflect.Float64: return "numeric" @@ -56,17 +62,6 @@ func (s *postgres) SqlTag(value reflect.Value, size int) string { panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String())) } -func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string { - switch value.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - return "serial PRIMARY KEY" - case reflect.Int64, reflect.Uint64: - return "bigserial PRIMARY KEY" - default: - panic("Invalid primary key type") - } -} - func (s *postgres) ReturningStr(tableName, key string) string { return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key) } diff --git a/scope_private.go b/scope_private.go index e4262d64..b5b5bb9c 100644 --- a/scope_private.go +++ b/scope_private.go @@ -447,7 +447,7 @@ func (scope *Scope) createJoinTable(field *StructField) { joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable) joinTable := joinTableHandler.Table(scope.db, relationship) if !scope.Dialect().HasTable(scope, joinTable) { - primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255) + primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false) scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", scope.Quote(joinTable), strings.Join([]string{ @@ -460,14 +460,25 @@ func (scope *Scope) createJoinTable(field *StructField) { } func (scope *Scope) createTable() *Scope { - var sqls []string - for _, structField := range scope.GetStructFields() { - if structField.IsNormal { - sqls = append(sqls, scope.Quote(structField.DBName)+" "+structField.SqlTag) + var tags []string + var primaryKeys []string + for _, field := range scope.GetStructFields() { + if field.IsNormal { + sqlTag := scope.generateSqlTag(field) + tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag) } - scope.createJoinTable(structField) + + if field.IsPrimaryKey { + primaryKeys = append(primaryKeys, field.DBName) + } + scope.createJoinTable(field) } - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.QuotedTableName(), strings.Join(sqls, ","))).Exec() + + var primaryKeyStr string + if len(primaryKeys) > 0 { + primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) + } + scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr)).Exec() return scope } @@ -530,7 +541,8 @@ func (scope *Scope) autoMigrate() *Scope { for _, field := range scope.GetStructFields() { if !scope.Dialect().HasColumn(scope, tableName, field.DBName) { if field.IsNormal { - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, field.SqlTag)).Exec() + sqlTag := scope.generateSqlTag(field) + scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, sqlTag)).Exec() } } scope.createJoinTable(field) diff --git a/sqlite3.go b/sqlite3.go index e24d2410..ce71ee08 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -20,13 +20,16 @@ func (s *sqlite3) HasTop() bool { return false } -func (s *sqlite3) SqlTag(value reflect.Value, size int) string { +func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "bool" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: return "integer" case reflect.Int64, reflect.Uint64: + if autoIncrease { + return "integer" + } return "bigint" case reflect.Float32, reflect.Float64: return "real" @@ -47,15 +50,6 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int) string { panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String())) } -func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string { - switch value.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr, reflect.Int64, reflect.Uint64: - return "INTEGER PRIMARY KEY" - default: - panic("Invalid primary key type") - } -} - func (s *sqlite3) ReturningStr(tableName, key string) string { return "" }