From c1c13129f1960b8290037c012eaa0ecce5019c4b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 14 Nov 2013 18:59:11 +0800 Subject: [PATCH] Finish dialects --- README.md | 4 +- chain.go | 4 - dialect/dialect.go | 21 +++--- dialect/mysql.go | 61 +++++++++++++++ dialect/postgres.go | 57 ++++++++++++++ dialect/sqlite3.go | 47 ++++++++++++ do.go | 14 ++-- main.go | 2 +- model.go | 60 +++++++-------- sql_type.go | 176 -------------------------------------------- utils.go | 64 ++++++++++++++++ 11 files changed, 275 insertions(+), 235 deletions(-) delete mode 100644 sql_type.go diff --git a/README.md b/README.md index 0552bce6..9d5d82bd 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ type User struct { // TableName: `users`, gorm will pluralize struct's n UpdatedAt time.Time // Time of record is updated, will be updated automatically DeletedAt time.Time // Time of record is deleted, refer `Soft Delete` for more - Email []Email // Embedded structs + Emails []Email // Embedded structs BillingAddress Address // Embedded struct BillingAddressId sql.NullInt64 // Embedded struct BillingAddress's foreign key ShippingAddress Address // Embedded struct @@ -125,7 +125,7 @@ user := User{ Name: "jinzhu", BillingAddress: Address{Address1: "Billing Address - Address 1"}, ShippingAddress: Address{Address1: "Shipping Address - Address 1"}, - Email: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}}, + Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}}, } db.Save(&user) diff --git a/chain.go b/chain.go index fdc7767a..2c71ae5c 100644 --- a/chain.go +++ b/chain.go @@ -28,10 +28,6 @@ type Chain struct { unscoped bool } -func (s *Chain) driver() string { - return s.d.driver -} - func (s *Chain) err(err error) error { if err != nil { s.Errors = append(s.Errors, err) diff --git a/dialect/dialect.go b/dialect/dialect.go index 50773aa4..8aa3ae31 100644 --- a/dialect/dialect.go +++ b/dialect/dialect.go @@ -1,23 +1,22 @@ package dialect type Dialect interface { + BinVar(i int) string + SupportLastInsertId() bool + SqlTag(column interface{}, size int) string + PrimaryKeyTag(column interface{}, size int) string + ReturningStr(key string) string } -func NewDialect(driver string) *Dialect { +func NewDialect(driver string) Dialect { var d Dialect switch driver { case "postgres": - d = postgres{} + d = &postgres{} case "mysql": - d = mysql{} + d = &mysql{} case "sqlite3": - d = sqlite3{} + d = &sqlite3{} } - return &d + return d } - -type mysql struct{} - -type postgres struct{} - -type sqlite3 struct{} diff --git a/dialect/mysql.go b/dialect/mysql.go index 2ac98ecf..0e71612b 100644 --- a/dialect/mysql.go +++ b/dialect/mysql.go @@ -1 +1,62 @@ package dialect + +import ( + "database/sql" + "fmt" + "time" +) + +type mysql struct{} + +func (s *mysql) BinVar(i int) string { + return "?" +} + +func (s *mysql) SupportLastInsertId() bool { + return true +} + +func (d *mysql) SqlTag(column interface{}, size int) string { + switch column.(type) { + case time.Time: + return "timestamp" + case bool, sql.NullBool: + return "boolean" + case int, int8, int16, int32, uint, uint8, uint16, uint32: + return "int" + case int64, uint64, sql.NullInt64: + return "bigint" + case float32, float64, sql.NullFloat64: + return "double" + case []byte: + if size > 0 && size < 65532 { + return fmt.Sprintf("varbinary(%d)", size) + } else { + return "longblob" + } + case string, sql.NullString: + if size > 0 && size < 65532 { + return fmt.Sprintf("varchar(%d)", size) + } else { + return "longtext" + } + default: + panic("Invalid sql type for mysql") + } +} + +func (s *mysql) PrimaryKeyTag(column interface{}, size int) string { + suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY" + switch column.(type) { + case int, int8, int16, int32, uint, uint8, uint16, uint32: + return "int" + suffix_str + case int64, uint64: + return "bigint" + suffix_str + default: + panic("Invalid primary key type") + } +} + +func (s *mysql) ReturningStr(key string) (str string) { + return +} diff --git a/dialect/postgres.go b/dialect/postgres.go index 2ac98ecf..13a0afa0 100644 --- a/dialect/postgres.go +++ b/dialect/postgres.go @@ -1 +1,58 @@ package dialect + +import ( + "database/sql" + "fmt" + "time" +) + +type postgres struct { +} + +func (s *postgres) BinVar(i int) string { + return fmt.Sprintf("$%v", i) +} + +func (s *postgres) SupportLastInsertId() bool { + return false +} + +func (d *postgres) SqlTag(column interface{}, size int) string { + switch column.(type) { + case time.Time: + return "timestamp with time zone" + case bool, sql.NullBool: + return "boolean" + case int, int8, int16, int32, uint, uint8, uint16, uint32: + return "integer" + case int64, uint64, sql.NullInt64: + return "bigint" + case float32, float64, sql.NullFloat64: + return "double precision" + case []byte: + return "bytea" + case string, sql.NullString: + if size > 0 && size < 65532 { + return fmt.Sprintf("varchar(%d)", size) + } else { + return "text" + } + default: + panic("Invalid sql type for postgres") + } +} + +func (s *postgres) PrimaryKeyTag(column interface{}, size int) string { + switch column.(type) { + case int, int8, int16, int32, uint, uint8, uint16, uint32: + return "sehrial" + case int64, uint64: + return "bigserial" + default: + panic("Invalid primary key type") + } +} + +func (s *postgres) ReturningStr(key string) (str string) { + return fmt.Sprintf("RETURNING \"%v\"", key) +} diff --git a/dialect/sqlite3.go b/dialect/sqlite3.go index 2ac98ecf..3dec02ef 100644 --- a/dialect/sqlite3.go +++ b/dialect/sqlite3.go @@ -1 +1,48 @@ package dialect + +import ( + "database/sql" + "fmt" + "time" +) + +type sqlite3 struct{} + +func (s *sqlite3) BinVar(i int) string { + return "?" +} + +func (s *sqlite3) SupportLastInsertId() bool { + return true +} + +func (s *sqlite3) SqlTag(column interface{}, size int) string { + switch column.(type) { + case time.Time: + return "datetime" + case bool, sql.NullBool: + return "bool" + case int, int8, int16, int32, uint, uint8, uint16, uint32: + return "integer" + case int64, uint64, sql.NullInt64: + return "bigint" + case float32, float64, sql.NullFloat64: + return "real" + case string, sql.NullString: + if size > 0 && size < 65532 { + return fmt.Sprintf("varchar(%d)", size) + } else { + return "text" + } + default: + panic("Invalid sql type for sqlite3") + } +} + +func (s *sqlite3) PrimaryKeyTag(column interface{}, size int) string { + return "INTEGER PRIMARY KEY" +} + +func (s *sqlite3) ReturningStr(key string) (str string) { + return +} diff --git a/do.go b/do.go index 9a170da1..1bcc0d17 100644 --- a/do.go +++ b/do.go @@ -60,11 +60,7 @@ func (s *Do) setModel(value interface{}) *Do { func (s *Do) addToVars(value interface{}) string { s.sqlVars = append(s.sqlVars, value) - if s.chain.driver() == "postgres" { - return fmt.Sprintf("$%d", len(s.sqlVars)) - } else { - return "?" - } + return s.chain.d.dialect.BinVar(len(s.sqlVars)) } func (s *Do) exec(sqls ...string) (err error) { @@ -102,7 +98,7 @@ func (s *Do) prepareCreateSql() { s.tableName(), strings.Join(columns, ","), strings.Join(sqls, ","), - s.model.returningStr(), + s.chain.d.dialect.ReturningStr(s.model.primaryKeyDb()), ) return } @@ -178,13 +174,13 @@ func (s *Do) create() (i interface{}) { var id interface{} now := time.Now() - if s.chain.driver() == "postgres" { - s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id)) - } else { + if s.chain.d.dialect.SupportLastInsertId() { if sql_result, err := s.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil { id, err = sql_result.LastInsertId() s.err(err) } + } else { + s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id)) } s.chain.slog(s.sql, now, s.sqlVars...) diff --git a/main.go b/main.go index c0f97864..c835b231 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,7 @@ func init() { type DB struct { db sql_common - dialect *dialect.Dialect + dialect dialect.Dialect logger Logger log_mode bool } diff --git a/model.go b/model.go index c3e4b06e..011d1bb8 100644 --- a/model.go +++ b/model.go @@ -3,11 +3,11 @@ package gorm import ( "database/sql" "errors" - "fmt" + "go/ast" "reflect" "regexp" - "strconv" + "time" ) @@ -115,7 +115,6 @@ func (m *Model) fields(operation string) (fields []Field) { } } - tag_value := p.Tag.Get(tagIdentifier) if is_time { field.AutoCreateTime = "created_at" == field.DbName field.AutoUpdateTime = "updated_at" == field.DbName @@ -130,10 +129,14 @@ func (m *Model) fields(operation string) (fields []Field) { value.Set(reflect.ValueOf(time.Now())) } } + } - field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value) + field.Value = value.Interface() + + if is_time { + field.SqlType = m.getSqlTag(field, p) } else if field.IsPrimaryKey { - field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, tag_value) + field.SqlType = m.getSqlTag(field, p) } else { field_value := reflect.Indirect(value) @@ -148,7 +151,7 @@ func (m *Model) fields(operation string) (fields []Field) { _, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner) if is_scanner { - field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value) + field.SqlType = m.getSqlTag(field, p) } else { if indirect_value.FieldByName(p.Name + "Id").IsValid() { field.foreignKey = p.Name + "Id" @@ -162,11 +165,10 @@ func (m *Model) fields(operation string) (fields []Field) { } } default: - field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value) + field.SqlType = m.getSqlTag(field, p) } } - field.Value = value.Interface() fields = append(fields, field) } } @@ -313,13 +315,6 @@ func (m *Model) callMethod(method string) { return } -func (m *Model) returningStr() (str string) { - if m.do.chain.driver() == "postgres" { - str = fmt.Sprintf("RETURNING \"%v\"", m.primaryKeyDb()) - } - return -} - func (m *Model) setValueByColumn(name string, value interface{}, out interface{}) { data := reflect.Indirect(reflect.ValueOf(out)) setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value) @@ -343,23 +338,24 @@ func (m *Model) afterAssociations() (fields []Field) { return } -func setFieldValue(field reflect.Value, value interface{}) bool { - if field.IsValid() && field.CanAddr() { - switch field.Kind() { - case reflect.Int, reflect.Int32, reflect.Int64: - if str, ok := value.(string); ok { - value, _ = strconv.Atoi(str) - } - field.SetInt(reflect.ValueOf(value).Int()) - default: - if scanner, ok := field.Addr().Interface().(sql.Scanner); ok { - scanner.Scan(value) - } else { - field.Set(reflect.ValueOf(value)) - } - } - return true +func (m *Model) getSqlTag(field Field, struct_field reflect.StructField) string { + column := getInterfaceValue(field.Value) + typ, addational_typ, size := parseSqlTag(struct_field.Tag.Get(tagIdentifier)) + + if typ == "-" { + return "" } - return false + if len(typ) == 0 { + if field.IsPrimaryKey { + typ = m.do.chain.d.dialect.PrimaryKeyTag(column, size) + } else { + typ = m.do.chain.d.dialect.SqlTag(column, size) + } + } + + if len(addational_typ) > 0 { + typ = typ + " " + addational_typ + } + return typ } diff --git a/sql_type.go b/sql_type.go deleted file mode 100644 index 56c5be5f..00000000 --- a/sql_type.go +++ /dev/null @@ -1,176 +0,0 @@ -package gorm - -import ( - "database/sql" - "database/sql/driver" - "fmt" - "reflect" - "strconv" - "strings" - "time" -) - -func parseTag(str string) (typ string, addational_typ string, size int) { - if str == "-" { - typ = str - } else if str != "" { - tags := strings.Split(str, ";") - m := make(map[string]string) - for _, value := range tags { - v := strings.Split(value, ":") - k := strings.Trim(strings.ToUpper(v[0]), " ") - if len(v) == 2 { - m[k] = v[1] - } else { - m[k] = k - } - } - - if len(m["SIZE"]) > 0 { - size, _ = strconv.Atoi(m["SIZE"]) - } - - if len(m["TYPE"]) > 0 { - typ = m["TYPE"] - } - - addational_typ = m["NOT NULL"] + " " + m["UNIQUE"] - } - return -} - -func formatColumnValue(column interface{}) interface{} { - if v, ok := column.(reflect.Value); ok { - column = v.Interface() - } - - if valuer, ok := interface{}(column).(driver.Valuer); ok { - column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface() - } - return column -} - -func getPrimaryKeySqlType(adaptor string, column interface{}, tag string) string { - column = formatColumnValue(column) - typ, addational_typ, _ := parseTag(tag) - - if len(typ) != 0 { - return typ + addational_typ - } - - switch adaptor { - case "sqlite3": - return "INTEGER PRIMARY KEY" - case "mysql": - suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY" - switch column.(type) { - case int, int8, int16, int32, uint, uint8, uint16, uint32: - typ = "int" + suffix_str - case int64, uint64: - typ = "bigint" + suffix_str - } - case "postgres": - switch column.(type) { - case int, int8, int16, int32, uint, uint8, uint16, uint32: - typ = "serial" - case int64, uint64: - typ = "bigserial" - } - default: - panic("unsupported sql adaptor, please submit an issue in github") - } - return typ -} - -func getSqlType(adaptor string, column interface{}, tag string) string { - column = formatColumnValue(column) - typ, addational_typ, size := parseTag(tag) - - if typ == "-" { - return "" - } - - if len(typ) == 0 { - switch adaptor { - case "sqlite3": - switch column.(type) { - case time.Time: - typ = "datetime" - case bool, sql.NullBool: - typ = "bool" - case int, int8, int16, int32, uint, uint8, uint16, uint32: - typ = "integer" - case int64, uint64, sql.NullInt64: - typ = "bigint" - case float32, float64, sql.NullFloat64: - typ = "real" - case string, sql.NullString: - if size > 0 && size < 65532 { - typ = fmt.Sprintf("varchar(%d)", size) - } else { - typ = "text" - } - default: - panic("invalid sql type") - } - case "mysql": - switch column.(type) { - case time.Time: - typ = "timestamp" - case bool, sql.NullBool: - typ = "boolean" - case int, int8, int16, int32, uint, uint8, uint16, uint32: - typ = "int" - case int64, uint64, sql.NullInt64: - typ = "bigint" - case float32, float64, sql.NullFloat64: - typ = "double" - case []byte: - if size > 0 && size < 65532 { - typ = fmt.Sprintf("varbinary(%d)", size) - } else { - typ = "longblob" - } - case string, sql.NullString: - if size > 0 && size < 65532 { - typ = fmt.Sprintf("varchar(%d)", size) - } else { - typ = "longtext" - } - default: - panic("invalid sql type") - } - - case "postgres": - switch column.(type) { - case time.Time: - typ = "timestamp with time zone" - case bool, sql.NullBool: - typ = "boolean" - case int, int8, int16, int32, uint, uint8, uint16, uint32: - typ = "integer" - case int64, uint64, sql.NullInt64: - typ = "bigint" - case float32, float64, sql.NullFloat64: - typ = "double precision" - case []byte: - typ = "bytea" - case string, sql.NullString: - if size > 0 && size < 65532 { - typ = fmt.Sprintf("varchar(%d)", size) - } else { - typ = "text" - } - default: - panic("invalid sql type") - } - default: - panic("unsupported sql adaptor, please submit an issue in github") - } - } - - if len(addational_typ) > 0 { - typ = typ + " " + addational_typ - } - return typ -} diff --git a/utils.go b/utils.go index 7627e459..4db884e6 100644 --- a/utils.go +++ b/utils.go @@ -2,7 +2,10 @@ package gorm import ( "bytes" + "database/sql" + "database/sql/driver" "errors" + "reflect" "strconv" "fmt" @@ -63,3 +66,64 @@ func getInterfaceAsString(value interface{}) (str string, err error) { } return } + +func parseSqlTag(str string) (typ string, addational_typ string, size int) { + if str == "-" { + typ = str + } else if str != "" { + tags := strings.Split(str, ";") + m := make(map[string]string) + for _, value := range tags { + v := strings.Split(value, ":") + k := strings.Trim(strings.ToUpper(v[0]), " ") + if len(v) == 2 { + m[k] = v[1] + } else { + m[k] = k + } + } + + if len(m["SIZE"]) > 0 { + size, _ = strconv.Atoi(m["SIZE"]) + } + + if len(m["TYPE"]) > 0 { + typ = m["TYPE"] + } + + addational_typ = m["NOT NULL"] + " " + m["UNIQUE"] + } + return +} + +func getInterfaceValue(column interface{}) interface{} { + if v, ok := column.(reflect.Value); ok { + column = v.Interface() + } + + if valuer, ok := interface{}(column).(driver.Valuer); ok { + column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface() + } + return column +} + +func setFieldValue(field reflect.Value, value interface{}) bool { + if field.IsValid() && field.CanAddr() { + switch field.Kind() { + case reflect.Int, reflect.Int32, reflect.Int64: + if str, ok := value.(string); ok { + value, _ = strconv.Atoi(str) + } + field.SetInt(reflect.ValueOf(value).Int()) + default: + if scanner, ok := field.Addr().Interface().(sql.Scanner); ok { + scanner.Scan(value) + } else { + field.Set(reflect.ValueOf(value)) + } + } + return true + } + + return false +}