diff --git a/common_dialect.go b/common_dialect.go index 9360cd26..281df8a7 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -9,19 +9,19 @@ import ( type commonDialect struct{} -func (s *commonDialect) BinVar(i int) string { - return "?" +func (commonDialect) BinVar(i int) string { + return "$$" // ? } -func (s *commonDialect) SupportLastInsertId() bool { +func (commonDialect) SupportLastInsertId() bool { return true } -func (s *commonDialect) HasTop() bool { +func (commonDialect) HasTop() bool { return false } -func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "BOOLEAN" @@ -57,19 +57,19 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String())) } -func (s *commonDialect) ReturningStr(tableName, key string) string { +func (commonDialect) ReturningStr(tableName, key string) string { return "" } -func (s *commonDialect) SelectFromDummyTable() string { +func (commonDialect) SelectFromDummyTable() string { return "" } -func (s *commonDialect) Quote(key string) string { - return fmt.Sprintf("`%s`", key) +func (commonDialect) Quote(key string) string { + return fmt.Sprintf(`"%s"`, key) } -func (s *commonDialect) databaseName(scope *Scope) string { +func (commonDialect) databaseName(scope *Scope) string { from := strings.Index(scope.db.parent.source, "/") + 1 to := strings.Index(scope.db.parent.source, "?") if to == -1 { @@ -78,24 +78,24 @@ func (s *commonDialect) databaseName(scope *Scope) string { return scope.db.parent.source[from:to] } -func (s *commonDialect) HasTable(scope *Scope, tableName string) bool { +func (c commonDialect) HasTable(scope *Scope, tableName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, c.databaseName(scope)).Row().Scan(&count) return count > 0 } -func (s *commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { +func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", c.databaseName(scope), tableName, columnName).Row().Scan(&count) return count > 0 } -func (s *commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count) return count > 0 } -func (s *commonDialect) RemoveIndex(scope *Scope, indexName string) { +func (commonDialect) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) } diff --git a/mssql.go b/mssql.go index dc8e2917..c44541c7 100644 --- a/mssql.go +++ b/mssql.go @@ -7,21 +7,15 @@ import ( "time" ) -type mssql struct{} - -func (s *mssql) BinVar(i int) string { - return "$$" // ? +type mssql struct { + commonDialect } -func (s *mssql) SupportLastInsertId() bool { +func (mssql) HasTop() bool { return true } -func (s *mssql) HasTop() bool { - return true -} - -func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "bit" @@ -57,19 +51,7 @@ func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String())) } -func (s *mssql) ReturningStr(tableName, key string) string { - return "" -} - -func (s *mssql) SelectFromDummyTable() string { - return "" -} - -func (s *mssql) Quote(key string) string { - return fmt.Sprintf(" \"%s\"", key) -} - -func (s *mssql) databaseName(scope *Scope) string { +func (mssql) databaseName(scope *Scope) string { dbStr := strings.Split(scope.db.parent.source, ";") for _, value := range dbStr { s := strings.Split(value, "=") @@ -80,24 +62,20 @@ func (s *mssql) databaseName(scope *Scope) string { return "" } -func (s *mssql) HasTable(scope *Scope, tableName string) bool { +func (s mssql) HasTable(scope *Scope, tableName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.databaseName(scope)).Row().Scan(&count) return count > 0 } -func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { +func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) return count > 0 } -func (s *mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count) return count > 0 } - -func (s *mssql) RemoveIndex(scope *Scope, indexName string) { - scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) -} diff --git a/mysql.go b/mysql.go index d2eb08a5..e37a23e0 100644 --- a/mysql.go +++ b/mysql.go @@ -3,25 +3,14 @@ package gorm import ( "fmt" "reflect" - "strings" "time" ) -type mysql struct{} - -func (s *mysql) BinVar(i int) string { - return "$$" // ? +type mysql struct { + commonDialect } -func (s *mysql) SupportLastInsertId() bool { - return true -} - -func (s *mysql) HasTop() bool { - return false -} - -func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "boolean" @@ -57,45 +46,10 @@ func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) } -func (s *mysql) ReturningStr(tableName, key string) string { - return "" -} - -func (s *mysql) SelectFromDummyTable() string { - return "FROM DUAL" -} - -func (s *mysql) Quote(key string) string { +func (mysql) Quote(key string) string { return fmt.Sprintf("`%s`", key) } -func (s *mysql) databaseName(scope *Scope) string { - from := strings.Index(scope.db.parent.source, "/") + 1 - to := strings.Index(scope.db.parent.source, "?") - if to == -1 { - to = len(scope.db.parent.source) - } - return scope.db.parent.source[from:to] -} - -func (s *mysql) HasTable(scope *Scope, tableName string) bool { - var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count) - return count > 0 -} - -func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) bool { - var count int - scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) - return count > 0 -} - -func (s *mysql) HasIndex(scope *Scope, tableName string, indexName string) bool { - var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count) - return count > 0 -} - -func (s *mysql) RemoveIndex(scope *Scope, indexName string) { - scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) +func (mysql) SelectFromDummyTable() string { + return "FROM DUAL" } diff --git a/postgres.go b/postgres.go index 83c37e1f..4218e1ba 100644 --- a/postgres.go +++ b/postgres.go @@ -11,21 +11,18 @@ import ( ) type postgres struct { + commonDialect } -func (s *postgres) BinVar(i int) string { +func (postgres) BinVar(i int) string { return fmt.Sprintf("$%v", i) } -func (s *postgres) SupportLastInsertId() bool { +func (postgres) SupportLastInsertId() bool { return false } -func (s *postgres) HasTop() bool { - return false -} - -func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "boolean" @@ -62,35 +59,27 @@ func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) stri panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String())) } -func (s *postgres) ReturningStr(tableName, key string) string { +func (s postgres) ReturningStr(tableName, key string) string { return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key) } -func (s *postgres) SelectFromDummyTable() string { - return "" -} - -func (s *postgres) Quote(key string) string { - return fmt.Sprintf("\"%s\"", key) -} - -func (s *postgres) HasTable(scope *Scope, tableName string) bool { +func (postgres) HasTable(scope *Scope, tableName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName).Row().Scan(&count) return count > 0 } -func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string) bool { +func (postgres) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count) return count > 0 } -func (s *postgres) RemoveIndex(scope *Scope, indexName string) { +func (postgres) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) } -func (s *postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName).Row().Scan(&count) return count > 0 diff --git a/sqlite3.go b/sqlite3.go index ce71ee08..afe70e3a 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -6,21 +6,11 @@ import ( "time" ) -type sqlite3 struct{} - -func (s *sqlite3) BinVar(i int) string { - return "$$" // ? +type sqlite3 struct { + commonDialect } -func (s *sqlite3) SupportLastInsertId() bool { - return true -} - -func (s *sqlite3) HasTop() bool { - return false -} - -func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "bool" @@ -50,36 +40,24 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) strin panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String())) } -func (s *sqlite3) ReturningStr(tableName, key string) string { - return "" -} - -func (s *sqlite3) SelectFromDummyTable() string { - return "" -} - -func (s *sqlite3) Quote(key string) string { - return fmt.Sprintf("\"%s\"", key) -} - -func (s *sqlite3) HasTable(scope *Scope, tableName string) bool { +func (sqlite3) HasTable(scope *Scope, tableName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Row().Scan(&count) return count > 0 } -func (s *sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool { +func (sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", columnName, columnName, columnName, columnName), tableName).Row().Scan(&count) return count > 0 } -func (s *sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { var count int scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Row().Scan(&count) return count > 0 } -func (s *sqlite3) RemoveIndex(scope *Scope, indexName string) { +func (sqlite3) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) }