diff --git a/config_option.go b/config_option.go new file mode 100644 index 00000000..5331fc77 --- /dev/null +++ b/config_option.go @@ -0,0 +1,147 @@ +package gorm + +import ( + "sync" + "time" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" +) + +// ConfigOption for Config,use functional option. +type ConfigOption func(c *Config) + +// WithSkipDefaultTransaction enable SkipDefaultTransaction. +func WithSkipDefaultTransaction() ConfigOption { + return func(c *Config) { + c.SkipDefaultTransaction = true + } +} + +// WithNameStrategy set shema namer. +func WithNameStrategy(namer schema.Namer) ConfigOption { + return func(c *Config) { + c.NamingStrategy = namer + } +} + +// WithFullSaveAssociations set FullSaveAssociations = true. +func WithFullSaveAssociations() ConfigOption { + return func(c *Config) { + c.FullSaveAssociations = true + } +} + +// WithLogger set logger. +func WithLogger(logger logger.Interface) ConfigOption { + return func(c *Config) { + c.Logger = logger + } +} + +// WithNowFunc set now func. +func WithNowFunc(fn func() time.Time) ConfigOption { + return func(c *Config) { + c.NowFunc = fn + } +} + +// WithEnableDryRun enable dry run. +func WithEnableDryRun() ConfigOption { + return func(c *Config) { + c.DryRun = true + } +} + +// WithPrepareStmt enable PrepareStmt. +func WithPrepareStmt() ConfigOption { + return func(c *Config) { + c.PrepareStmt = true + } +} + +// WithEnableAutomaticPing enable ping. +func WithEnableAutomaticPing() ConfigOption { + return func(c *Config) { + c.DisableAutomaticPing = true + } +} + +// WithEnableForeignKeyConstraintWhenMigrating +// ForeignKeyConstraintWhenMigrating config. +func WithEnableForeignKeyConstraintWhenMigrating() ConfigOption { + return func(c *Config) { + c.DisableForeignKeyConstraintWhenMigrating = true + } +} + +// WithEnableNestedTransaction enable NestedTransaction. +func WithEnableNestedTransaction() ConfigOption { + return func(c *Config) { + c.DisableNestedTransaction = true + } +} + +// WithAllowGlobalUpdate allow global update. +func WithAllowGlobalUpdate() ConfigOption { + return func(c *Config) { + c.AllowGlobalUpdate = true + } +} + +// WithEnableQueryFields open QueryFields. +func WithEnableQueryFields() ConfigOption { + return func(c *Config) { + c.QueryFields = true + } +} + +// WithCreateBatchSize set batch size. +func WithCreateBatchSize(size int) ConfigOption { + return func(c *Config) { + c.CreateBatchSize = size + } +} + +// WithClauseBuilders set clause builder. +func WithClauseBuilders(m map[string]clause.ClauseBuilder) ConfigOption { + return func(c *Config) { + c.ClauseBuilders = m + } +} + +// WithConnPool set conn pool. +func WithConnPool(connPool ConnPool) ConfigOption { + return func(c *Config) { + c.ConnPool = connPool + } +} + +// WithDialector set dialector. +func WithDialector(dialector Dialector) ConfigOption { + return func(c *Config) { + c.Dialector = dialector + } +} + +// WithConfigPlugins set config plugins. +func WithConfigPlugins(m map[string]Plugin) ConfigOption { + return func(c *Config) { + c.Plugins = m + } +} + +// WithConfigCallbacks set cb for Config entry. +func WithConfigCallbacks(cb *callbacks) ConfigOption { + return func(c *Config) { + c.callbacks = cb + } +} + +// WithCacheStore set cacheStore for Config entry. +func WithCacheStore(s *sync.Map) ConfigOption { + return func(c *Config) { + c.cacheStore = s + } +} diff --git a/gorm.go b/gorm.go index 1109e8cd..ec7d04d0 100644 --- a/gorm.go +++ b/gorm.go @@ -15,7 +15,8 @@ import ( // Config GORM config type Config struct { - // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity + // GORM perform single create, update, delete operations in transactions + // by default to ensure database data integrity // You can disable it by setting `SkipDefaultTransaction` to true SkipDefaultTransaction bool // NamingStrategy tables, columns naming strategy @@ -85,7 +86,7 @@ type Session struct { // Open initialize db session based on dialector func Open(dialector Dialector, config *Config) (db *DB, err error) { if config == nil { - config = &Config{} + config = DefaultConfig() } if config.NamingStrategy == nil { @@ -104,7 +105,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { config.Dialector = dialector } - if config.Plugins == nil { + if len(config.Plugins) == 0 { config.Plugins = map[string]Plugin{} } @@ -116,7 +117,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { db.callbacks = initializeCallbacks(db) - if config.ClauseBuilders == nil { + if len(config.ClauseBuilders) == 0 { config.ClauseBuilders = map[string]clause.ClauseBuilder{} } @@ -150,7 +151,82 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { } if err != nil { - config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err) + config.Logger.Error(context.Background(), + "failed to initialize database, got error %v", err) + } + + return +} + +// DefaultConfig default config. +func DefaultConfig() *Config { + c := &Config{ + NamingStrategy: schema.NamingStrategy{}, + Logger: logger.Default, + NowFunc: func() time.Time { return time.Now().Local() }, + ClauseBuilders: map[string]clause.ClauseBuilder{}, + Plugins: map[string]Plugin{}, + cacheStore: &sync.Map{}, + } + + return c +} + +// Apply apply option for Config. +func (config *Config) Apply(opts ...ConfigOption) { + for _, o := range opts { + o(config) + } +} + +// OpenWithOption create *db.DB with ConfigOption +// return *db.DB,error. +func OpenWithOption(dialector Dialector, + opts ...ConfigOption) (db *DB, err error) { + config := DefaultConfig() + + // apply config option + config.Apply(opts...) + + if dialector != nil { + config.Dialector = dialector + } + + // create DB entry + db = &DB{Config: config, clone: 1} + db.callbacks = initializeCallbacks(db) + if config.Dialector != nil { + err = config.Dialector.Initialize(db) + } + + preparedStmt := &PreparedStmtDB{ + ConnPool: db.ConnPool, + Stmts: map[string]Stmt{}, + Mux: &sync.RWMutex{}, + PreparedSQL: make([]string, 0, 100), + } + db.cacheStore.Store("preparedStmt", preparedStmt) + + if config.PrepareStmt { + db.ConnPool = preparedStmt + } + + db.Statement = &Statement{ + DB: db, + ConnPool: db.ConnPool, + Context: context.Background(), + Clauses: map[string]clause.Clause{}, + } + + if err == nil && !config.DisableAutomaticPing { + if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok { + err = pinger.Ping() + } + } + + if err != nil { + config.Logger.Error(context.Background(), + "failed to initialize database, got error %v", err) } return