diff --git a/dialect_common.go b/dialect_common.go index ef351f9e..9ccff6e9 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -92,7 +92,8 @@ func (s *commonDialect) DataTypeOf(field *StructField) string { func (s commonDialect) HasIndex(tableName string, indexName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count) + currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) return count > 0 } @@ -107,13 +108,25 @@ func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bo func (s commonDialect) HasTable(tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count) + currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) return count > 0 } +func (s commonDialect) currentDatabaseAndTable(tableName string) (string, string) { + currentDatabase := s.CurrentDatabase() + if currentDatabase == "" && strings.Contains(tableName, ".") { + splitStrings := strings.SplitN(tableName, ".", 2) + currentDatabase = splitStrings[0] + tableName = splitStrings[1] + } + return currentDatabase, tableName +} + func (s commonDialect) HasColumn(tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) + currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 }