Merge remote-tracking branch 'gorm/master'

This commit is contained in:
Jason Chavannes 2021-04-29 15:16:21 -07:00
commit cd42f9571b
201 changed files with 23356 additions and 12102 deletions

View File

@ -1,11 +0,0 @@
---
engines:
gofmt:
enabled: true
govet:
enabled: true
golint:
enabled: true
ratings:
paths:
- "**.go"

5
.github/FUNDING.yml vendored Normal file
View File

@ -0,0 +1,5 @@
# These are supported funding model platforms
github: [jinzhu]
patreon: jinzhu
open_collective: gorm

View File

@ -1,54 +0,0 @@
Before posting a bug report about a problem, please try to verify that it is a bug and that it has not been reported already, please apply corresponding GitHub labels to the issue, for feature requests, please apply `type:feature`.
DON'T post usage related questions, ask in https://gitter.im/jinzhu/gorm or http://stackoverflow.com/questions/tagged/go-gorm,
Please answer these questions before submitting your issue. Thanks!
### What version of Go are you using (`go version`)?
### Which database and its version are you using?
### What did you do?
Please provide a complete runnable program to reproduce your issue.
```go
package main
import (
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mssql"
_ "github.com/jinzhu/gorm/dialects/mysql"
_ "github.com/jinzhu/gorm/dialects/postgres"
_ "github.com/jinzhu/gorm/dialects/sqlite"
)
var db *gorm.DB
func init() {
var err error
db, err = gorm.Open("sqlite3", "test.db")
// Please use below username, password as your database's account for the script.
// db, err = gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
// db, err = gorm.Open("mysql", "gorm:gorm@/dbname?charset=utf8&parseTime=True")
// db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm")
if err != nil {
panic(err)
}
db.LogMode(true)
}
func main() {
// your code here
if /* failure condition */ {
fmt.Println("failed")
} else {
fmt.Println("success")
}
}
```

View File

@ -1,14 +0,0 @@
Make sure these boxes checked before submitting your pull request.
- [] Do only one thing
- [] No API-breaking changes
- [] New code/logic commented & tested
- [] Write good commit message, try to squash your commits into a single one
- [] Run `./build.sh` in `gh-pages` branch for document changes
For significant changes like big bug fixes, new features, please open an issue to make a agreement on an implementation design/plan first before starting it.
Thank you.
### What did this pull request do?

166
.github/labels.json vendored Normal file
View File

@ -0,0 +1,166 @@
{
"labels": {
"critical": {
"name": "type:critical",
"colour": "#E84137",
"description": "critical questions"
},
"question": {
"name": "type:question",
"colour": "#EDEDED",
"description": "general questions"
},
"feature": {
"name": "type:feature_request",
"colour": "#43952A",
"description": "feature request"
},
"invalid_question": {
"name": "type:invalid question",
"colour": "#CF2E1F",
"description": "invalid question (not related to GORM or described in document or not enough information provided)"
},
"with_playground": {
"name": "type:with reproduction steps",
"colour": "#00ff00",
"description": "with reproduction steps"
},
"without_playground": {
"name": "type:missing reproduction steps",
"colour": "#CF2E1F",
"description": "missing reproduction steps"
},
"has_pr": {
"name": "type:has pull request",
"colour": "#43952A",
"description": "has pull request"
},
"not_tested": {
"name": "type:not tested",
"colour": "#CF2E1F",
"description": "not tested"
},
"tested": {
"name": "type:tested",
"colour": "#00ff00",
"description": "tested"
},
"breaking_change": {
"name": "type:breaking change",
"colour": "#CF2E1F",
"description": "breaking change"
}
},
"issue": {
"with_playground": {
"requires": 1,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/github.com\/go-gorm\/playground\/pull\/\\d\\d+/s"
}
]
},
"critical": {
"requires": 1,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/(critical|urgent)/i"
},
{
"type": "titleMatches",
"pattern": "/(critical|urgent)/i"
}
]
},
"question": {
"requires": 1,
"conditions": [
{
"type": "titleMatches",
"pattern": "/question/i"
},
{
"type": "descriptionMatches",
"pattern": "/question/i"
}
]
},
"feature": {
"requires": 1,
"conditions": [
{
"type": "titleMatches",
"pattern": "/feature/i"
},
{
"type": "descriptionMatches",
"pattern": "/Describe the feature/i"
}
]
},
"without_playground": {
"requires": 6,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/^((?!github.com\/go-gorm\/playground\/pull\/\\d\\d+).)*$/s"
},
{
"type": "titleMatches",
"pattern": "/^((?!question).)*$/s"
},
{
"type": "descriptionMatches",
"pattern": "/^((?!question).)*$/is"
},
{
"type": "descriptionMatches",
"pattern": "/^((?!Describe the feature).)*$/is"
},
{
"type": "titleMatches",
"pattern": "/^((?!critical|urgent).)*$/s"
},
{
"type": "descriptionMatches",
"pattern": "/^((?!critical|urgent).)*$/s"
}
]
}
},
"pr": {
"critical": {
"requires": 1,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/(critical|urgent)/i"
},
{
"type": "titleMatches",
"pattern": "/(critical|urgent)/i"
}
]
},
"not_tested": {
"requires": 1,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/\\[\\] Tested/"
}
]
},
"breaking_change": {
"requires": 1,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/\\[\\] Non breaking API changes/"
}
]
}
}
}

22
.github/workflows/invalid_question.yml vendored Normal file
View File

@ -0,0 +1,22 @@
name: "Close invalid questions issues"
on:
schedule:
- cron: "*/10 * * * *"
jobs:
stale:
runs-on: ubuntu-latest
env:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v3.0.7
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale"
days-before-stale: 0
days-before-close: 2
remove-stale-when-updated: true
only-labels: "type:invalid question"

19
.github/workflows/labeler.yml vendored Normal file
View File

@ -0,0 +1,19 @@
name: "Issue Labeler"
on:
issues:
types: [opened, edited, reopened]
pull_request:
types: [opened, edited, reopened]
jobs:
triage:
runs-on: ubuntu-latest
name: Label issues and pull requests
steps:
- name: check out
uses: actions/checkout@v2
- name: labeler
uses: jinzhu/super-labeler-action@develop
with:
GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}"

View File

@ -0,0 +1,21 @@
name: "Close Missing Playground issues"
on:
schedule:
- cron: "*/10 * * * *"
jobs:
stale:
runs-on: ubuntu-latest
env:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v3.0.7
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale"
days-before-stale: 0
days-before-close: 2
remove-stale-when-updated: true
only-labels: "type:missing reproduction steps"

11
.github/workflows/reviewdog.yml vendored Normal file
View File

@ -0,0 +1,11 @@
name: reviewdog
on: [pull_request]
jobs:
golangci-lint:
name: runner / golangci-lint
runs-on: ubuntu-latest
steps:
- name: Check out code into the Go module directory
uses: actions/checkout@v1
- name: golangci-lint
uses: reviewdog/action-golangci-lint@v1

22
.github/workflows/stale.yml vendored Normal file
View File

@ -0,0 +1,22 @@
name: "Stale"
on:
schedule:
- cron: "0 2 * * *"
jobs:
stale:
runs-on: ubuntu-latest
env:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v3.0.7
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been automatically marked as stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 30 days"
days-before-stale: 60
days-before-close: 30
stale-issue-label: "status:stale"
exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request'
stale-pr-label: 'status:stale'
exempt-pr-labels: 'type:feature,type:with reproduction steps,type:has pull request'

166
.github/workflows/tests.yml vendored Normal file
View File

@ -0,0 +1,166 @@
name: tests
on:
push:
branches-ignore:
- 'gh-pages'
pull_request:
branches-ignore:
- 'gh-pages'
jobs:
# Label of the container job
sqlite:
strategy:
matrix:
go: ['1.16', '1.15', '1.14']
platform: [ubuntu-latest] # can not run in windows OS
runs-on: ${{ matrix.platform }}
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- name: go mod package cache
uses: actions/cache@v2
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GORM_DIALECT=sqlite ./tests/tests_all.sh
mysql:
strategy:
matrix:
dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest']
go: ['1.16', '1.15', '1.14']
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
services:
mysql:
image: ${{ matrix.dbversion }}
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
ports:
- 9910:3306
options: >-
--health-cmd "mysqladmin ping -ugorm -pgorm"
--health-interval 10s
--health-start-period 10s
--health-timeout 5s
--health-retries 10
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- name: go mod package cache
uses: actions/cache@v2
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
postgres:
strategy:
matrix:
dbversion: ['postgres:latest', 'postgres:11', 'postgres:10']
go: ['1.16', '1.15', '1.14']
platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }}
services:
postgres:
image: ${{ matrix.dbversion }}
env:
POSTGRES_PASSWORD: gorm
POSTGRES_USER: gorm
POSTGRES_DB: gorm
TZ: Asia/Shanghai
ports:
- 9920:5432
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- name: go mod package cache
uses: actions/cache@v2
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh
sqlserver:
strategy:
matrix:
go: ['1.16', '1.15', '1.14']
platform: [ubuntu-latest] # can not run test in macOS and windows
runs-on: ${{ matrix.platform }}
services:
mssql:
image: mcmoe/mssqldocker:latest
env:
ACCEPT_EULA: Y
SA_PASSWORD: LoremIpsum86
MSSQL_DB: gorm
MSSQL_USER: gorm
MSSQL_PASSWORD: LoremIpsum86
ports:
- 9930:1433
options: >-
--health-cmd="/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P LoremIpsum86 -l 30 -Q \"SELECT 1\" || exit 1"
--health-start-period 10s
--health-interval 10s
--health-timeout 5s
--health-retries 10
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- name: go mod package cache
uses: actions/cache@v2
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh

3
.gitignore vendored
View File

@ -1,2 +1,5 @@
TODO*
documents documents
coverage.txt
_book _book
.idea

View File

@ -2,49 +2,41 @@
The fantastic ORM library for Golang, aims to be developer friendly. The fantastic ORM library for Golang, aims to be developer friendly.
[![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm)
[![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921) [![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions)
[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) [![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm)
[![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm)
[![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT)
[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc)
## Overview ## Overview
* Full-Featured ORM (almost) * Full-Featured ORM
* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism) * Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism, Single-table inheritance)
* Callbacks (Before/After Create/Save/Update/Delete/Find) * Hooks (Before/After Create/Save/Update/Delete/Find)
* Preloading (eager loading) * Eager loading with `Preload`, `Joins`
* Transactions * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
* Context, Prepared Statement Mode, DryRun Mode
* Batch Insert, FindInBatches, Find To Map
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr
* Composite Primary Key * Composite Primary Key
* SQL Builder
* Auto Migrations * Auto Migrations
* Logger * Logger
* Extendable, write Plugins based on GORM callbacks * Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus…
* Every feature comes with tests * Every feature comes with tests
* Developer Friendly * Developer Friendly
## Getting Started ## Getting Started
* GORM Guides [jinzhu.github.com/gorm](http://jinzhu.github.io/gorm) * GORM Guides [https://gorm.io](https://gorm.io)
## Upgrading To V1.0 ## Contributing
* [CHANGELOG](http://jinzhu.github.io/gorm/changelog.html) [You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
## Supporting the project
[![http://patreon.com/jinzhu](http://patreon_public_assets.s3.amazonaws.com/sized/becomeAPatronBanner.png)](http://patreon.com/jinzhu)
## Author
**jinzhu**
* <http://github.com/jinzhu>
* <wosmvp@gmail.com>
* <http://twitter.com/zhangjinzhu>
## Contributors
https://github.com/jinzhu/gorm/graphs/contributors
## License ## License
Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License). © Jinzhu, 2013~time.Now
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License)

View File

@ -1,375 +1,513 @@
package gorm package gorm
import ( import (
"errors"
"fmt" "fmt"
"reflect" "reflect"
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
// Association Mode contains some helper methods to handle relationship things easily. // Association Mode contains some helper methods to handle relationship things easily.
type Association struct { type Association struct {
DB *DB
Relationship *schema.Relationship
Error error Error error
scope *Scope
column string
field *Field
} }
// Find find out all related associations func (db *DB) Association(column string) *Association {
func (association *Association) Find(value interface{}) *Association { association := &Association{DB: db}
association.scope.related(value, association.column) table := db.Statement.Table
return association.setErr(association.scope.db.Error)
if err := db.Statement.Parse(db.Statement.Model); err == nil {
db.Statement.Table = table
association.Relationship = db.Statement.Schema.Relationships.Relations[column]
if association.Relationship == nil {
association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column)
} }
// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
func (association *Association) Append(values ...interface{}) *Association { for db.Statement.ReflectValue.Kind() == reflect.Ptr {
if association.Error != nil { db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
return association
}
if relationship := association.field.Relationship; relationship.Kind == "has_one" {
return association.Replace(values...)
}
return association.saveAssociations(values...)
}
// Replace replace current associations with new one
func (association *Association) Replace(values ...interface{}) *Association {
if association.Error != nil {
return association
}
var (
relationship = association.field.Relationship
scope = association.scope
field = association.field.Field
newDB = scope.NewDB()
)
// Append new values
association.field.Set(reflect.Zero(association.field.Field.Type()))
association.saveAssociations(values...)
// Belongs To
if relationship.Kind == "belongs_to" {
// Set foreign key to be null when clearing value (length equals 0)
if len(values) == 0 {
// Set foreign key to be nil
var foreignKeyMap = map[string]interface{}{}
for _, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil
}
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
} }
} else { } else {
// Polymorphic Relations association.Error = err
if relationship.PolymorphicDBName != "" {
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue)
} }
// Delete Relations except new created
if len(values) > 0 {
var associationForeignFieldNames, associationForeignDBNames []string
if relationship.Kind == "many_to_many" {
// if many to many relations, get association fields name from association foreign keys
associationScope := scope.New(reflect.New(field.Type()).Interface())
for idx, dbName := range relationship.AssociationForeignFieldNames {
if field, ok := associationScope.FieldByName(dbName); ok {
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx])
}
}
} else {
// If has one/many relations, use primary keys
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
associationForeignDBNames = append(associationForeignDBNames, field.DBName)
}
}
newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface())
if len(newPrimaryKeys) > 0 {
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys))
newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...)
}
}
if relationship.Kind == "many_to_many" {
// if many to many relations, delete related relations from join table
var sourceForeignFieldNames []string
for _, dbName := range relationship.ForeignFieldNames {
if field, ok := scope.FieldByName(dbName); ok {
sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name)
}
}
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
}
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
var foreignKeyMap = map[string]interface{}{}
for idx, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil
if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok {
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
}
}
fieldValue := reflect.New(association.field.Field.Type()).Interface()
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
}
}
return association return association
} }
// Delete remove relationship between source & passed arguments, but won't delete those arguments func (association *Association) Find(out interface{}, conds ...interface{}) error {
func (association *Association) Delete(values ...interface{}) *Association {
if association.Error != nil {
return association
}
var (
relationship = association.field.Relationship
scope = association.scope
field = association.field.Field
newDB = scope.NewDB()
)
if len(values) == 0 {
return association
}
var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
}
deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...)
if relationship.Kind == "many_to_many" {
// source value's foreign keys
for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
}
}
// get association's foreign fields name
var associationScope = scope.New(reflect.New(field.Type()).Interface())
var associationForeignFieldNames []string
for _, associationDBName := range relationship.AssociationForeignFieldNames {
if field, ok := associationScope.FieldByName(associationDBName); ok {
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
}
}
// association value's foreign keys
deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...)
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
} else {
var foreignKeyMap = map[string]interface{}{}
for _, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil
}
if relationship.Kind == "belongs_to" {
// find with deleting relation's foreign keys
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...)
newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)...,
)
// set foreign key to be null if there are some records affected
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
if results.RowsAffected > 0 {
scope.updatedAttrsWithValues(foreignKeyMap)
}
} else {
association.setErr(results.Error)
}
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
// find all relations
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)...,
)
// only include those deleting relations
newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)),
toQueryValues(deletingPrimaryKeys)...,
)
// set matched relation's foreign key to be null
fieldValue := reflect.New(association.field.Field.Type()).Interface()
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
}
}
// Remove deleted records from source's field
if association.Error == nil { if association.Error == nil {
if field.Kind() == reflect.Slice { association.Error = association.buildCondition().Find(out, conds...).Error
leftValues := reflect.Zero(field.Type()) }
return association.Error
}
for i := 0; i < field.Len(); i++ { func (association *Association) Append(values ...interface{}) error {
reflectValue := field.Index(i) if association.Error == nil {
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] switch association.Relationship.Type {
var isDeleted = false case schema.HasOne, schema.BelongsTo:
for _, pk := range deletingPrimaryKeys { if len(values) > 0 {
if equalAsString(primaryKey, pk) { association.Error = association.Replace(values...)
isDeleted = true }
default:
association.saveAssociation( /*clear*/ false, values...)
}
}
return association.Error
}
func (association *Association) Replace(values ...interface{}) error {
if association.Error == nil {
// save associations
if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
return association.Error
}
// set old associations's foreign key to null
reflectValue := association.DB.Statement.ReflectValue
rel := association.Relationship
switch rel.Type {
case schema.BelongsTo:
if len(values) == 0 {
updateMap := map[string]interface{}{}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
}
case reflect.Struct:
association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
}
for _, ref := range rel.References {
updateMap[ref.ForeignKey.DBName] = nil
}
association.Error = association.DB.UpdateColumns(updateMap).Error
}
case schema.HasOne, schema.HasMany:
var (
primaryFields []*schema.Field
foreignKeys []string
updateMap = map[string]interface{}{}
relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel})
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
tx.Not(clause.IN{Column: column, Values: values})
}
}
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
primaryFields = append(primaryFields, ref.PrimaryKey)
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
updateMap[ref.ForeignKey.DBName] = nil
} else if ref.PrimaryValue != "" {
tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
}
}
if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
}
case schema.Many2Many:
var (
primaryFields, relPrimaryFields []*schema.Field
joinPrimaryKeys, joinRelPrimaryKeys []string
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
for _, ref := range rel.References {
if ref.PrimaryValue == "" {
if ref.OwnPrimaryKey {
primaryFields = append(primaryFields, ref.PrimaryKey)
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
} else {
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
}
} else {
tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
}
}
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
tx.Where(clause.IN{Column: column, Values: values})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
}
association.Error = tx.Delete(modelValue).Error
}
}
return association.Error
}
func (association *Association) Delete(values ...interface{}) error {
if association.Error == nil {
var (
reflectValue = association.DB.Statement.ReflectValue
rel = association.Relationship
primaryFields []*schema.Field
foreignKeys []string
updateAttrs = map[string]interface{}{}
conds []clause.Expression
)
for _, ref := range rel.References {
if ref.PrimaryValue == "" {
primaryFields = append(primaryFields, ref.PrimaryKey)
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
updateAttrs[ref.ForeignKey.DBName] = nil
} else {
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
}
}
switch rel.Type {
case schema.BelongsTo:
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields)
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
case schema.HasOne, schema.HasMany:
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
case schema.Many2Many:
var (
primaryFields, relPrimaryFields []*schema.Field
joinPrimaryKeys, joinRelPrimaryKeys []string
joinValue = reflect.New(rel.JoinTable.ModelType).Interface()
)
for _, ref := range rel.References {
if ref.PrimaryValue == "" {
if ref.OwnPrimaryKey {
primaryFields = append(primaryFields, ref.PrimaryKey)
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
} else {
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
}
} else {
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
}
}
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
}
if association.Error == nil {
// clean up deleted values's foreign key
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
cleanUpDeletedRelations := func(data reflect.Value) {
if _, zero := rel.Field.ValueOf(data); !zero {
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data))
primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
switch fieldValue.Kind() {
case reflect.Slice, reflect.Array:
validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
for i := 0; i < fieldValue.Len(); i++ {
for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i))
}
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i))
}
}
association.Error = rel.Field.Set(data, validFieldValues.Interface())
case reflect.Struct:
for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(fieldValue)
}
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
break break
} }
}
if !isDeleted {
leftValues = reflect.Append(leftValues, reflectValue)
}
}
association.field.Set(leftValues) if rel.JoinTable == nil {
} else if field.Kind() == reflect.Struct { for _, ref := range rel.References {
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
for _, pk := range deletingPrimaryKeys { association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
if equalAsString(primaryKey, pk) { } else {
association.field.Set(reflect.Zero(field.Type())) association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
break }
}
}
} }
} }
} }
} }
return association switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i)))
}
case reflect.Struct:
cleanUpDeletedRelations(reflectValue)
}
}
} }
// Clear remove relationship between source & current associations, won't delete those associations return association.Error
func (association *Association) Clear() *Association { }
func (association *Association) Clear() error {
return association.Replace() return association.Replace()
} }
// Count return the count of current associations func (association *Association) Count() (count int64) {
func (association *Association) Count() int { if association.Error == nil {
association.Error = association.buildCondition().Count(&count).Error
}
return
}
type assignBack struct {
Source reflect.Value
Index int
Dest reflect.Value
}
func (association *Association) saveAssociation(clear bool, values ...interface{}) {
var ( var (
count = 0 reflectValue = association.DB.Statement.ReflectValue
relationship = association.field.Relationship assignBacks []assignBack // assign association values back to arguments after save
scope = association.scope
fieldValue = association.field.Field.Interface()
query = scope.DB()
) )
if relationship.Kind == "many_to_many" { appendToRelations := func(source, rv reflect.Value, clear bool) {
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) switch association.Relationship.Type {
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { case schema.HasOne, schema.BelongsTo:
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) switch rv.Kind() {
query = query.Where( case reflect.Slice, reflect.Array:
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), if rv.Len() > 0 {
toQueryValues(primaryKeys)..., association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface())
)
} else if relationship.Kind == "belongs_to" { if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
query = query.Where( }
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), }
toQueryValues(primaryKeys)..., case reflect.Struct:
) association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
}
}
case schema.HasMany, schema.Many2Many:
elemType := association.Relationship.Field.IndirectFieldType.Elem()
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source))
if clear {
fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem()
} }
if relationship.PolymorphicType != "" { appendToFieldValues := func(ev reflect.Value) {
query = query.Where( if ev.Type().AssignableTo(elemType) {
fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), fieldValue = reflect.Append(fieldValue, ev)
relationship.PolymorphicValue, } else if ev.Type().Elem().AssignableTo(elemType) {
) fieldValue = reflect.Append(fieldValue, ev.Elem())
} else {
association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name)
} }
if err := query.Model(fieldValue).Count(&count).Error; err != nil { if elemType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()})
}
}
switch rv.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
}
case reflect.Struct:
appendToFieldValues(rv.Addr())
}
if association.Error == nil {
association.Error = association.Relationship.Field.Set(source, fieldValue.Interface())
}
}
}
selectedSaveColumns := []string{association.Relationship.Name}
omitColumns := []string{}
selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false)
for name, ok := range selectColumns {
columnName := ""
if strings.HasPrefix(name, association.Relationship.Name) {
if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" {
columnName = name
}
} else if strings.HasPrefix(name, clause.Associations) {
columnName = name
}
if columnName != "" {
if ok {
selectedSaveColumns = append(selectedSaveColumns, columnName)
} else {
omitColumns = append(omitColumns, columnName)
}
}
}
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey {
selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name)
}
}
associationDB := association.DB.Session(&Session{}).Model(nil)
if !association.DB.FullSaveAssociations {
associationDB.Select(selectedSaveColumns)
}
if len(omitColumns) > 0 {
associationDB.Omit(omitColumns...)
}
associationDB = associationDB.Session(&Session{})
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if len(values) != reflectValue.Len() {
// clear old data
if clear && len(values) == 0 {
for i := 0; i < reflectValue.Len(); i++ {
if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
association.Error = err association.Error = err
} break
return count
} }
// saveAssociations save passed values as associations if association.Relationship.JoinTable == nil {
func (association *Association) saveAssociations(values ...interface{}) *Association { for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
association.Error = err
break
}
}
}
}
}
break
}
association.Error = ErrInvalidValueOfLength
return
}
for i := 0; i < reflectValue.Len(); i++ {
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
// TODO support save slice data, sql with case?
association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
}
case reflect.Struct:
// clear old data
if clear && len(values) == 0 {
association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
if association.Relationship.JoinTable == nil && association.Error == nil {
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
}
}
}
}
for idx, value := range values {
rv := reflect.Indirect(reflect.ValueOf(value))
appendToRelations(reflectValue, rv, clear && idx == 0)
}
if len(values) > 0 {
association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error
}
}
for _, assignBack := range assignBacks {
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source))
if assignBack.Index > 0 {
reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
} else {
reflect.Indirect(assignBack.Dest).Set(fieldValue)
}
}
}
func (association *Association) buildCondition() *DB {
var ( var (
scope = association.scope queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
field = association.field modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
relationship = field.Relationship tx = association.DB.Model(modelValue)
) )
saveAssociation := func(reflectValue reflect.Value) { if association.Relationship.JoinTable != nil {
// value has to been pointer if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
if reflectValue.Kind() != reflect.Ptr { joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
reflectPtr := reflect.New(reflectValue.Type()) for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
reflectPtr.Elem().Set(reflectValue) joinStmt.AddClause(queryClause)
reflectValue = reflectPtr }
joinStmt.Build("WHERE")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
} }
// value has to been saved for many2many tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
if relationship.Kind == "many_to_many" { Table: clause.Table{Name: association.Relationship.JoinTable.Table},
if scope.New(reflectValue.Interface()).PrimaryKeyZero() { ON: clause.Where{Exprs: queryConds},
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) }}})
}
}
// Assign Fields
var fieldType = field.Field.Type()
var setFieldBackToValue, setSliceFieldBackToValue bool
if reflectValue.Type().AssignableTo(fieldType) {
field.Set(reflectValue)
} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
// if field's type is struct, then need to set value back to argument after save
setFieldBackToValue = true
field.Set(reflectValue.Elem())
} else if fieldType.Kind() == reflect.Slice {
if reflectValue.Type().AssignableTo(fieldType.Elem()) {
field.Set(reflect.Append(field.Field, reflectValue))
} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
// if field's type is slice of struct, then need to set value back to argument after save
setSliceFieldBackToValue = true
field.Set(reflect.Append(field.Field, reflectValue.Elem()))
}
}
if relationship.Kind == "many_to_many" {
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
} else { } else {
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) tx.Clauses(clause.Where{Exprs: queryConds})
if setFieldBackToValue {
reflectValue.Elem().Set(field.Field)
} else if setSliceFieldBackToValue {
reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
}
}
} }
for _, value := range values { return tx
reflectValue := reflect.ValueOf(value)
indirectReflectValue := reflect.Indirect(reflectValue)
if indirectReflectValue.Kind() == reflect.Struct {
saveAssociation(reflectValue)
} else if indirectReflectValue.Kind() == reflect.Slice {
for i := 0; i < indirectReflectValue.Len(); i++ {
saveAssociation(indirectReflectValue.Index(i))
}
} else {
association.setErr(errors.New("invalid value type"))
}
}
return association
}
func (association *Association) setErr(err error) *Association {
if err != nil {
association.Error = err
}
return association
} }

View File

@ -1,907 +0,0 @@
package gorm_test
import (
"fmt"
"os"
"reflect"
"sort"
"testing"
"github.com/jinzhu/gorm"
)
func TestBelongsTo(t *testing.T) {
post := Post{
Title: "post belongs to",
Body: "body belongs to",
Category: Category{Name: "Category 1"},
MainCategory: Category{Name: "Main Category 1"},
}
if err := DB.Save(&post).Error; err != nil {
t.Error("Got errors when save post", err)
}
if post.Category.ID == 0 || post.MainCategory.ID == 0 {
t.Errorf("Category's primary key should be updated")
}
if post.CategoryId.Int64 == 0 || post.MainCategoryId == 0 {
t.Errorf("post's foreign key should be updated")
}
// Query
var category1 Category
DB.Model(&post).Association("Category").Find(&category1)
if category1.Name != "Category 1" {
t.Errorf("Query belongs to relations with Association")
}
var mainCategory1 Category
DB.Model(&post).Association("MainCategory").Find(&mainCategory1)
if mainCategory1.Name != "Main Category 1" {
t.Errorf("Query belongs to relations with Association")
}
var category11 Category
DB.Model(&post).Related(&category11)
if category11.Name != "Category 1" {
t.Errorf("Query belongs to relations with Related")
}
if DB.Model(&post).Association("Category").Count() != 1 {
t.Errorf("Post's category count should be 1")
}
if DB.Model(&post).Association("MainCategory").Count() != 1 {
t.Errorf("Post's main category count should be 1")
}
// Append
var category2 = Category{
Name: "Category 2",
}
DB.Model(&post).Association("Category").Append(&category2)
if category2.ID == 0 {
t.Errorf("Category should has ID when created with Append")
}
var category21 Category
DB.Model(&post).Related(&category21)
if category21.Name != "Category 2" {
t.Errorf("Category should be updated with Append")
}
if DB.Model(&post).Association("Category").Count() != 1 {
t.Errorf("Post's category count should be 1")
}
// Replace
var category3 = Category{
Name: "Category 3",
}
DB.Model(&post).Association("Category").Replace(&category3)
if category3.ID == 0 {
t.Errorf("Category should has ID when created with Replace")
}
var category31 Category
DB.Model(&post).Related(&category31)
if category31.Name != "Category 3" {
t.Errorf("Category should be updated with Replace")
}
if DB.Model(&post).Association("Category").Count() != 1 {
t.Errorf("Post's category count should be 1")
}
// Delete
DB.Model(&post).Association("Category").Delete(&category2)
if DB.Model(&post).Related(&Category{}).RecordNotFound() {
t.Errorf("Should not delete any category when Delete a unrelated Category")
}
if post.Category.Name == "" {
t.Errorf("Post's category should not be reseted when Delete a unrelated Category")
}
DB.Model(&post).Association("Category").Delete(&category3)
if post.Category.Name != "" {
t.Errorf("Post's category should be reseted after Delete")
}
var category41 Category
DB.Model(&post).Related(&category41)
if category41.Name != "" {
t.Errorf("Category should be deleted with Delete")
}
if count := DB.Model(&post).Association("Category").Count(); count != 0 {
t.Errorf("Post's category count should be 0 after Delete, but got %v", count)
}
// Clear
DB.Model(&post).Association("Category").Append(&Category{
Name: "Category 2",
})
if DB.Model(&post).Related(&Category{}).RecordNotFound() {
t.Errorf("Should find category after append")
}
if post.Category.Name == "" {
t.Errorf("Post's category should has value after Append")
}
DB.Model(&post).Association("Category").Clear()
if post.Category.Name != "" {
t.Errorf("Post's category should be cleared after Clear")
}
if !DB.Model(&post).Related(&Category{}).RecordNotFound() {
t.Errorf("Should not find any category after Clear")
}
if count := DB.Model(&post).Association("Category").Count(); count != 0 {
t.Errorf("Post's category count should be 0 after Clear, but got %v", count)
}
// Check Association mode with soft delete
category6 := Category{
Name: "Category 6",
}
DB.Model(&post).Association("Category").Append(&category6)
if count := DB.Model(&post).Association("Category").Count(); count != 1 {
t.Errorf("Post's category count should be 1 after Append, but got %v", count)
}
DB.Delete(&category6)
if count := DB.Model(&post).Association("Category").Count(); count != 0 {
t.Errorf("Post's category count should be 0 after the category has been deleted, but got %v", count)
}
if err := DB.Model(&post).Association("Category").Find(&Category{}).Error; err == nil {
t.Errorf("Post's category is not findable after Delete")
}
if count := DB.Unscoped().Model(&post).Association("Category").Count(); count != 1 {
t.Errorf("Post's category count should be 1 when query with Unscoped, but got %v", count)
}
if err := DB.Unscoped().Model(&post).Association("Category").Find(&Category{}).Error; err != nil {
t.Errorf("Post's category should be findable when query with Unscoped, got %v", err)
}
}
func TestBelongsToOverrideForeignKey1(t *testing.T) {
type Profile struct {
gorm.Model
Name string
}
type User struct {
gorm.Model
Profile Profile `gorm:"ForeignKey:ProfileRefer"`
ProfileRefer int
}
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "belongs_to" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
func TestBelongsToOverrideForeignKey2(t *testing.T) {
type Profile struct {
gorm.Model
Refer string
Name string
}
type User struct {
gorm.Model
Profile Profile `gorm:"ForeignKey:ProfileID;AssociationForeignKey:Refer"`
ProfileID int
}
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "belongs_to" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
func TestHasOne(t *testing.T) {
user := User{
Name: "has one",
CreditCard: CreditCard{Number: "411111111111"},
}
if err := DB.Save(&user).Error; err != nil {
t.Error("Got errors when save user", err.Error())
}
if user.CreditCard.UserId.Int64 == 0 {
t.Errorf("CreditCard's foreign key should be updated")
}
// Query
var creditCard1 CreditCard
DB.Model(&user).Related(&creditCard1)
if creditCard1.Number != "411111111111" {
t.Errorf("Query has one relations with Related")
}
var creditCard11 CreditCard
DB.Model(&user).Association("CreditCard").Find(&creditCard11)
if creditCard11.Number != "411111111111" {
t.Errorf("Query has one relations with Related")
}
if DB.Model(&user).Association("CreditCard").Count() != 1 {
t.Errorf("User's credit card count should be 1")
}
// Append
var creditcard2 = CreditCard{
Number: "411111111112",
}
DB.Model(&user).Association("CreditCard").Append(&creditcard2)
if creditcard2.ID == 0 {
t.Errorf("Creditcard should has ID when created with Append")
}
var creditcard21 CreditCard
DB.Model(&user).Related(&creditcard21)
if creditcard21.Number != "411111111112" {
t.Errorf("CreditCard should be updated with Append")
}
if DB.Model(&user).Association("CreditCard").Count() != 1 {
t.Errorf("User's credit card count should be 1")
}
// Replace
var creditcard3 = CreditCard{
Number: "411111111113",
}
DB.Model(&user).Association("CreditCard").Replace(&creditcard3)
if creditcard3.ID == 0 {
t.Errorf("Creditcard should has ID when created with Replace")
}
var creditcard31 CreditCard
DB.Model(&user).Related(&creditcard31)
if creditcard31.Number != "411111111113" {
t.Errorf("CreditCard should be updated with Replace")
}
if DB.Model(&user).Association("CreditCard").Count() != 1 {
t.Errorf("User's credit card count should be 1")
}
// Delete
DB.Model(&user).Association("CreditCard").Delete(&creditcard2)
var creditcard4 CreditCard
DB.Model(&user).Related(&creditcard4)
if creditcard4.Number != "411111111113" {
t.Errorf("Should not delete credit card when Delete a unrelated CreditCard")
}
if DB.Model(&user).Association("CreditCard").Count() != 1 {
t.Errorf("User's credit card count should be 1")
}
DB.Model(&user).Association("CreditCard").Delete(&creditcard3)
if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() {
t.Errorf("Should delete credit card with Delete")
}
if DB.Model(&user).Association("CreditCard").Count() != 0 {
t.Errorf("User's credit card count should be 0 after Delete")
}
// Clear
var creditcard5 = CreditCard{
Number: "411111111115",
}
DB.Model(&user).Association("CreditCard").Append(&creditcard5)
if DB.Model(&user).Related(&CreditCard{}).RecordNotFound() {
t.Errorf("Should added credit card with Append")
}
if DB.Model(&user).Association("CreditCard").Count() != 1 {
t.Errorf("User's credit card count should be 1")
}
DB.Model(&user).Association("CreditCard").Clear()
if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() {
t.Errorf("Credit card should be deleted with Clear")
}
if DB.Model(&user).Association("CreditCard").Count() != 0 {
t.Errorf("User's credit card count should be 0 after Clear")
}
// Check Association mode with soft delete
var creditcard6 = CreditCard{
Number: "411111111116",
}
DB.Model(&user).Association("CreditCard").Append(&creditcard6)
if count := DB.Model(&user).Association("CreditCard").Count(); count != 1 {
t.Errorf("User's credit card count should be 1 after Append, but got %v", count)
}
DB.Delete(&creditcard6)
if count := DB.Model(&user).Association("CreditCard").Count(); count != 0 {
t.Errorf("User's credit card count should be 0 after credit card deleted, but got %v", count)
}
if err := DB.Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err == nil {
t.Errorf("User's creditcard is not findable after Delete")
}
if count := DB.Unscoped().Model(&user).Association("CreditCard").Count(); count != 1 {
t.Errorf("User's credit card count should be 1 when query with Unscoped, but got %v", count)
}
if err := DB.Unscoped().Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err != nil {
t.Errorf("User's creditcard should be findable when query with Unscoped, got %v", err)
}
}
func TestHasOneOverrideForeignKey1(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserRefer uint
}
type User struct {
gorm.Model
Profile Profile `gorm:"ForeignKey:UserRefer"`
}
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "has_one" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
func TestHasOneOverrideForeignKey2(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserID uint
}
type User struct {
gorm.Model
Refer string
Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
}
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "has_one" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
func TestHasMany(t *testing.T) {
post := Post{
Title: "post has many",
Body: "body has many",
Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}},
}
if err := DB.Save(&post).Error; err != nil {
t.Error("Got errors when save post", err)
}
for _, comment := range post.Comments {
if comment.PostId == 0 {
t.Errorf("comment's PostID should be updated")
}
}
var compareComments = func(comments []Comment, contents []string) bool {
var commentContents []string
for _, comment := range comments {
commentContents = append(commentContents, comment.Content)
}
sort.Strings(commentContents)
sort.Strings(contents)
return reflect.DeepEqual(commentContents, contents)
}
// Query
if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil {
t.Errorf("Comment 1 should be saved")
}
var comments1 []Comment
DB.Model(&post).Association("Comments").Find(&comments1)
if !compareComments(comments1, []string{"Comment 1", "Comment 2"}) {
t.Errorf("Query has many relations with Association")
}
var comments11 []Comment
DB.Model(&post).Related(&comments11)
if !compareComments(comments11, []string{"Comment 1", "Comment 2"}) {
t.Errorf("Query has many relations with Related")
}
if DB.Model(&post).Association("Comments").Count() != 2 {
t.Errorf("Post's comments count should be 2")
}
// Append
DB.Model(&post).Association("Comments").Append(&Comment{Content: "Comment 3"})
var comments2 []Comment
DB.Model(&post).Related(&comments2)
if !compareComments(comments2, []string{"Comment 1", "Comment 2", "Comment 3"}) {
t.Errorf("Append new record to has many relations")
}
if DB.Model(&post).Association("Comments").Count() != 3 {
t.Errorf("Post's comments count should be 3 after Append")
}
// Delete
DB.Model(&post).Association("Comments").Delete(comments11)
var comments3 []Comment
DB.Model(&post).Related(&comments3)
if !compareComments(comments3, []string{"Comment 3"}) {
t.Errorf("Delete an existing resource for has many relations")
}
if DB.Model(&post).Association("Comments").Count() != 1 {
t.Errorf("Post's comments count should be 1 after Delete 2")
}
// Replace
DB.Model(&Post{Id: 999}).Association("Comments").Replace()
var comments4 []Comment
DB.Model(&post).Related(&comments4)
if len(comments4) == 0 {
t.Errorf("Replace for other resource should not clear all comments")
}
DB.Model(&post).Association("Comments").Replace(&Comment{Content: "Comment 4"}, &Comment{Content: "Comment 5"})
var comments41 []Comment
DB.Model(&post).Related(&comments41)
if !compareComments(comments41, []string{"Comment 4", "Comment 5"}) {
t.Errorf("Replace has many relations")
}
// Clear
DB.Model(&Post{Id: 999}).Association("Comments").Clear()
var comments5 []Comment
DB.Model(&post).Related(&comments5)
if len(comments5) == 0 {
t.Errorf("Clear should not clear all comments")
}
DB.Model(&post).Association("Comments").Clear()
var comments51 []Comment
DB.Model(&post).Related(&comments51)
if len(comments51) != 0 {
t.Errorf("Clear has many relations")
}
// Check Association mode with soft delete
var comment6 = Comment{
Content: "comment 6",
}
DB.Model(&post).Association("Comments").Append(&comment6)
if count := DB.Model(&post).Association("Comments").Count(); count != 1 {
t.Errorf("post's comments count should be 1 after Append, but got %v", count)
}
DB.Delete(&comment6)
if count := DB.Model(&post).Association("Comments").Count(); count != 0 {
t.Errorf("post's comments count should be 0 after comment been deleted, but got %v", count)
}
var comments6 []Comment
if DB.Model(&post).Association("Comments").Find(&comments6); len(comments6) != 0 {
t.Errorf("post's comments count should be 0 when find with Find, but got %v", len(comments6))
}
if count := DB.Unscoped().Model(&post).Association("Comments").Count(); count != 1 {
t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", count)
}
var comments61 []Comment
if DB.Unscoped().Model(&post).Association("Comments").Find(&comments61); len(comments61) != 1 {
t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", len(comments61))
}
}
func TestHasManyOverrideForeignKey1(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserRefer uint
}
type User struct {
gorm.Model
Profile []Profile `gorm:"ForeignKey:UserRefer"`
}
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "has_many" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
func TestHasManyOverrideForeignKey2(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserID uint
}
type User struct {
gorm.Model
Refer string
Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
}
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "has_many" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
func TestManyToMany(t *testing.T) {
DB.Raw("delete from languages")
var languages = []Language{{Name: "ZH"}, {Name: "EN"}}
user := User{Name: "Many2Many", Languages: languages}
DB.Save(&user)
// Query
var newLanguages []Language
DB.Model(&user).Related(&newLanguages, "Languages")
if len(newLanguages) != len([]string{"ZH", "EN"}) {
t.Errorf("Query many to many relations")
}
DB.Model(&user).Association("Languages").Find(&newLanguages)
if len(newLanguages) != len([]string{"ZH", "EN"}) {
t.Errorf("Should be able to find many to many relations")
}
if DB.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) {
t.Errorf("Count should return correct result")
}
// Append
DB.Model(&user).Association("Languages").Append(&Language{Name: "DE"})
if DB.Where("name = ?", "DE").First(&Language{}).RecordNotFound() {
t.Errorf("New record should be saved when append")
}
languageA := Language{Name: "AA"}
DB.Save(&languageA)
DB.Model(&User{Id: user.Id}).Association("Languages").Append(&languageA)
languageC := Language{Name: "CC"}
DB.Save(&languageC)
DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC})
DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}})
totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"}
if DB.Model(&user).Association("Languages").Count() != len(totalLanguages) {
t.Errorf("All appended languages should be saved")
}
// Delete
user.Languages = []Language{}
DB.Model(&user).Association("Languages").Find(&user.Languages)
var language Language
DB.Where("name = ?", "EE").First(&language)
DB.Model(&user).Association("Languages").Delete(language, &language)
if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 {
t.Errorf("Relations should be deleted with Delete")
}
if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() {
t.Errorf("Language EE should not be deleted")
}
DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages)
user2 := User{Name: "Many2Many_User2", Languages: languages}
DB.Save(&user2)
DB.Model(&user).Association("Languages").Delete(languages, &languages)
if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 || len(user.Languages) != len(totalLanguages)-3 {
t.Errorf("Relations should be deleted with Delete")
}
if DB.Model(&user2).Association("Languages").Count() == 0 {
t.Errorf("Other user's relations should not be deleted")
}
// Replace
var languageB Language
DB.Where("name = ?", "BB").First(&languageB)
DB.Model(&user).Association("Languages").Replace(languageB)
if len(user.Languages) != 1 || DB.Model(&user).Association("Languages").Count() != 1 {
t.Errorf("Relations should be replaced")
}
DB.Model(&user).Association("Languages").Replace()
if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 {
t.Errorf("Relations should be replaced with empty")
}
DB.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}})
if len(user.Languages) != 2 || DB.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) {
t.Errorf("Relations should be replaced")
}
// Clear
DB.Model(&user).Association("Languages").Clear()
if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 {
t.Errorf("Relations should be cleared")
}
// Check Association mode with soft delete
var language6 = Language{
Name: "language 6",
}
DB.Model(&user).Association("Languages").Append(&language6)
if count := DB.Model(&user).Association("Languages").Count(); count != 1 {
t.Errorf("user's languages count should be 1 after Append, but got %v", count)
}
DB.Delete(&language6)
if count := DB.Model(&user).Association("Languages").Count(); count != 0 {
t.Errorf("user's languages count should be 0 after language been deleted, but got %v", count)
}
var languages6 []Language
if DB.Model(&user).Association("Languages").Find(&languages6); len(languages6) != 0 {
t.Errorf("user's languages count should be 0 when find with Find, but got %v", len(languages6))
}
if count := DB.Unscoped().Model(&user).Association("Languages").Count(); count != 1 {
t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", count)
}
var languages61 []Language
if DB.Unscoped().Model(&user).Association("Languages").Find(&languages61); len(languages61) != 1 {
t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", len(languages61))
}
}
func TestRelated(t *testing.T) {
user := User{
Name: "jinzhu",
BillingAddress: Address{Address1: "Billing Address - Address 1"},
ShippingAddress: Address{Address1: "Shipping Address - Address 1"},
Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}},
CreditCard: CreditCard{Number: "1234567890"},
Company: Company{Name: "company1"},
}
if err := DB.Save(&user).Error; err != nil {
t.Errorf("No error should happen when saving user")
}
if user.CreditCard.ID == 0 {
t.Errorf("After user save, credit card should have id")
}
if user.BillingAddress.ID == 0 {
t.Errorf("After user save, billing address should have id")
}
if user.Emails[0].Id == 0 {
t.Errorf("After user save, billing address should have id")
}
var emails []Email
DB.Model(&user).Related(&emails)
if len(emails) != 2 {
t.Errorf("Should have two emails")
}
var emails2 []Email
DB.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2)
if len(emails2) != 1 {
t.Errorf("Should have two emails")
}
var emails3 []*Email
DB.Model(&user).Related(&emails3)
if len(emails3) != 2 {
t.Errorf("Should have two emails")
}
var user1 User
DB.Model(&user).Related(&user1.Emails)
if len(user1.Emails) != 2 {
t.Errorf("Should have only one email match related condition")
}
var address1 Address
DB.Model(&user).Related(&address1, "BillingAddressId")
if address1.Address1 != "Billing Address - Address 1" {
t.Errorf("Should get billing address from user correctly")
}
user1 = User{}
DB.Model(&address1).Related(&user1, "BillingAddressId")
if DB.NewRecord(user1) {
t.Errorf("Should get user from address correctly")
}
var user2 User
DB.Model(&emails[0]).Related(&user2)
if user2.Id != user.Id || user2.Name != user.Name {
t.Errorf("Should get user from email correctly")
}
var creditcard CreditCard
var user3 User
DB.First(&creditcard, "number = ?", "1234567890")
DB.Model(&creditcard).Related(&user3)
if user3.Id != user.Id || user3.Name != user.Name {
t.Errorf("Should get user from credit card correctly")
}
if !DB.Model(&CreditCard{}).Related(&User{}).RecordNotFound() {
t.Errorf("RecordNotFound for Related")
}
var company Company
if DB.Model(&user).Related(&company, "Company").RecordNotFound() || company.Name != "company1" {
t.Errorf("RecordNotFound for Related")
}
}
func TestForeignKey(t *testing.T) {
for _, structField := range DB.NewScope(&User{}).GetStructFields() {
for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
for _, structField := range DB.NewScope(&Email{}).GetStructFields() {
for _, foreignKey := range []string{"UserId"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
for _, structField := range DB.NewScope(&Post{}).GetStructFields() {
for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
for _, structField := range DB.NewScope(&Comment{}).GetStructFields() {
for _, foreignKey := range []string{"PostId"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
}
func testForeignKey(t *testing.T, source interface{}, sourceFieldName string, target interface{}, targetFieldName string) {
if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" {
// sqlite does not support ADD CONSTRAINT in ALTER TABLE
return
}
targetScope := DB.NewScope(target)
targetTableName := targetScope.TableName()
modelScope := DB.NewScope(source)
modelField, ok := modelScope.FieldByName(sourceFieldName)
if !ok {
t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", sourceFieldName))
}
targetField, ok := targetScope.FieldByName(targetFieldName)
if !ok {
t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", targetFieldName))
}
dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName)
err := DB.Model(source).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error
if err != nil {
t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err))
}
}
func TestLongForeignKey(t *testing.T) {
testForeignKey(t, &NotSoLongTableName{}, "ReallyLongThingID", &ReallyLongTableNameToTestMySQLNameLengthLimit{}, "ID")
}
func TestLongForeignKeyWithShortDest(t *testing.T) {
testForeignKey(t, &ReallyLongThingThatReferencesShort{}, "ShortID", &Short{}, "ID")
}
func TestHasManyChildrenWithOneStruct(t *testing.T) {
category := Category{
Name: "main",
Categories: []Category{
{Name: "sub1"},
{Name: "sub2"},
},
}
DB.Save(&category)
}
func TestSkipSaveAssociation(t *testing.T) {
type Company struct {
gorm.Model
Name string
}
type User struct {
gorm.Model
Name string
CompanyID uint
Company Company `gorm:"save_associations:false"`
}
DB.AutoMigrate(&Company{}, &User{})
DB.Save(&User{Name: "jinzhu", Company: Company{Name: "skip_save_association"}})
if !DB.Where("name = ?", "skip_save_association").First(&Company{}).RecordNotFound() {
t.Errorf("Company skip_save_association should not been saved")
}
}

View File

@ -1,242 +0,0 @@
package gorm
import "log"
// DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{}
// Callback is a struct that contains all CRUD callbacks
// Field `creates` contains callbacks will be call when creating object
// Field `updates` contains callbacks will be call when updating object
// Field `deletes` contains callbacks will be call when deleting object
// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association...
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
type Callback struct {
creates []*func(scope *Scope)
updates []*func(scope *Scope)
deletes []*func(scope *Scope)
queries []*func(scope *Scope)
rowQueries []*func(scope *Scope)
processors []*CallbackProcessor
}
// CallbackProcessor contains callback informations
type CallbackProcessor struct {
name string // current callback's name
before string // register current callback before a callback
after string // register current callback after a callback
replace bool // replace callbacks with same name
remove bool // delete callbacks with same name
kind string // callback type: create, update, delete, query, row_query
processor *func(scope *Scope) // callback handler
parent *Callback
}
func (c *Callback) clone() *Callback {
return &Callback{
creates: c.creates,
updates: c.updates,
deletes: c.deletes,
queries: c.queries,
rowQueries: c.rowQueries,
processors: c.processors,
}
}
// Create could be used to register callbacks for creating object
// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
// // business logic
// ...
//
// // set error if some thing wrong happened, will rollback the creating
// scope.Err(errors.New("error"))
// })
func (c *Callback) Create() *CallbackProcessor {
return &CallbackProcessor{kind: "create", parent: c}
}
// Update could be used to register callbacks for updating object, refer `Create` for usage
func (c *Callback) Update() *CallbackProcessor {
return &CallbackProcessor{kind: "update", parent: c}
}
// Delete could be used to register callbacks for deleting object, refer `Create` for usage
func (c *Callback) Delete() *CallbackProcessor {
return &CallbackProcessor{kind: "delete", parent: c}
}
// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
// Refer `Create` for usage
func (c *Callback) Query() *CallbackProcessor {
return &CallbackProcessor{kind: "query", parent: c}
}
// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
func (c *Callback) RowQuery() *CallbackProcessor {
return &CallbackProcessor{kind: "row_query", parent: c}
}
// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor {
cp.after = callbackName
return cp
}
// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create`
func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
cp.before = callbackName
return cp
}
// Register a new callback, refer `Callbacks.Create`
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
if cp.kind == "row_query" {
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
cp.before = "gorm:row_query"
}
}
cp.name = callbackName
cp.processor = &callback
cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder()
}
// Remove a registered callback
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
func (cp *CallbackProcessor) Remove(callbackName string) {
log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
cp.name = callbackName
cp.remove = true
cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder()
}
// Replace a registered callback with new callback
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
// scope.SetColumn("Created", now)
// scope.SetColumn("Updated", now)
// })
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
cp.name = callbackName
cp.processor = &callback
cp.replace = true
cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder()
}
// Get registered callback
// db.Callback().Create().Get("gorm:create")
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
for _, p := range cp.parent.processors {
if p.name == callbackName && p.kind == cp.kind && !cp.remove {
return *p.processor
}
}
return nil
}
// getRIndex get right index from string slice
func getRIndex(strs []string, str string) int {
for i := len(strs) - 1; i >= 0; i-- {
if strs[i] == str {
return i
}
}
return -1
}
// sortProcessors sort callback processors based on its before, after, remove, replace
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
var (
allNames, sortedNames []string
sortCallbackProcessor func(c *CallbackProcessor)
)
for _, cp := range cps {
// show warning message the callback name already exists
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
}
allNames = append(allNames, cp.name)
}
sortCallbackProcessor = func(c *CallbackProcessor) {
if getRIndex(sortedNames, c.name) == -1 { // if not sorted
if c.before != "" { // if defined before callback
if index := getRIndex(sortedNames, c.before); index != -1 {
// if before callback already sorted, append current callback just after it
sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
} else if index := getRIndex(allNames, c.before); index != -1 {
// if before callback exists but haven't sorted, append current callback to last
sortedNames = append(sortedNames, c.name)
sortCallbackProcessor(cps[index])
}
}
if c.after != "" { // if defined after callback
if index := getRIndex(sortedNames, c.after); index != -1 {
// if after callback already sorted, append current callback just before it
sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
} else if index := getRIndex(allNames, c.after); index != -1 {
// if after callback exists but haven't sorted
cp := cps[index]
// set after callback's before callback to current callback
if cp.before == "" {
cp.before = c.name
}
sortCallbackProcessor(cp)
}
}
// if current callback haven't been sorted, append it to last
if getRIndex(sortedNames, c.name) == -1 {
sortedNames = append(sortedNames, c.name)
}
}
}
for _, cp := range cps {
sortCallbackProcessor(cp)
}
var sortedFuncs []*func(scope *Scope)
for _, name := range sortedNames {
if index := getRIndex(allNames, name); !cps[index].remove {
sortedFuncs = append(sortedFuncs, cps[index].processor)
}
}
return sortedFuncs
}
// reorder all registered processors, and reset CRUD callbacks
func (c *Callback) reorder() {
var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
for _, processor := range c.processors {
if processor.name != "" {
switch processor.kind {
case "create":
creates = append(creates, processor)
case "update":
updates = append(updates, processor)
case "delete":
deletes = append(deletes, processor)
case "query":
queries = append(queries, processor)
case "row_query":
rowQueries = append(rowQueries, processor)
}
}
}
c.creates = sortProcessors(creates)
c.updates = sortProcessors(updates)
c.deletes = sortProcessors(deletes)
c.queries = sortProcessors(queries)
c.rowQueries = sortProcessors(rowQueries)
}

View File

@ -1,163 +0,0 @@
package gorm
import (
"fmt"
"strings"
)
// Define callbacks for creating
func init() {
DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
DefaultCallback.Create().Register("gorm:create", createCallback)
DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
}
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
func beforeCreateCallback(scope *Scope) {
if !scope.HasError() {
scope.CallMethod("BeforeSave")
}
if !scope.HasError() {
scope.CallMethod("BeforeCreate")
}
}
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
func updateTimeStampForCreateCallback(scope *Scope) {
if !scope.HasError() {
now := NowFunc()
if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
if createdAtField.IsBlank {
createdAtField.Set(now)
}
}
if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok {
if updatedAtField.IsBlank {
updatedAtField.Set(now)
}
}
}
}
// createCallback the callback used to insert data into database
func createCallback(scope *Scope) {
if !scope.HasError() {
defer scope.trace(NowFunc())
var (
columns, placeholders []string
blankColumnsWithDefaultValue []string
)
for _, field := range scope.Fields() {
if scope.changeableField(field) {
if field.IsNormal {
if field.IsBlank && field.HasDefaultValue {
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
} else if !field.IsPrimaryKey || !field.IsBlank {
columns = append(columns, scope.Quote(field.DBName))
placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
}
} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
for _, foreignKey := range field.Relationship.ForeignDBNames {
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
columns = append(columns, scope.Quote(foreignField.DBName))
placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
}
}
}
}
}
var (
returningColumn = "*"
quotedTableName = scope.QuotedTableName()
primaryField = scope.PrimaryField()
extraOption string
)
if str, ok := scope.Get("gorm:insert_option"); ok {
extraOption = fmt.Sprint(str)
}
if primaryField != nil {
returningColumn = scope.Quote(primaryField.DBName)
}
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
if len(columns) == 0 {
scope.Raw(fmt.Sprintf(
"INSERT INTO %v DEFAULT VALUES%v%v",
quotedTableName,
addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
))
} else {
scope.Raw(fmt.Sprintf(
"INSERT INTO %v (%v) VALUES (%v)%v%v",
scope.QuotedTableName(),
strings.Join(columns, ","),
strings.Join(placeholders, ","),
addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
))
}
// execute create sql
if lastInsertIDReturningSuffix == "" || primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
// set primary value to primary field
if primaryField != nil && primaryField.IsBlank {
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
scope.Err(primaryField.Set(primaryValue))
}
}
}
} else {
if primaryField.Field.CanAddr() {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false
scope.db.RowsAffected = 1
}
} else {
scope.Err(ErrUnaddressable)
}
}
}
}
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
func forceReloadAfterCreateCallback(scope *Scope) {
if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
for _, field := range scope.Fields() {
if field.IsPrimaryKey && !field.IsBlank {
db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface())
}
}
db.Scan(scope.Value)
}
}
// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
func afterCreateCallback(scope *Scope) {
if !scope.HasError() {
scope.CallMethod("AfterCreate")
}
if !scope.HasError() {
scope.CallMethod("AfterSave")
}
}

View File

@ -1,63 +0,0 @@
package gorm
import (
"errors"
"fmt"
)
// Define callbacks for deleting
func init() {
DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback)
DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback)
DefaultCallback.Delete().Register("gorm:delete", deleteCallback)
DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback)
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
}
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
func beforeDeleteCallback(scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("Missing WHERE clause while deleting"))
return
}
if !scope.HasError() {
scope.CallMethod("BeforeDelete")
}
}
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
func deleteCallback(scope *Scope) {
if !scope.HasError() {
var extraOption string
if str, ok := scope.Get("gorm:delete_option"); ok {
extraOption = fmt.Sprint(str)
}
deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt")
if !scope.Search.Unscoped && hasDeletedAtField {
scope.Raw(fmt.Sprintf(
"UPDATE %v SET %v=%v%v%v",
scope.QuotedTableName(),
scope.Quote(deletedAtField.DBName),
scope.AddToVars(NowFunc()),
addExtraSpaceIfExist(scope.CombinedConditionSql()),
addExtraSpaceIfExist(extraOption),
)).Exec()
} else {
scope.Raw(fmt.Sprintf(
"DELETE FROM %v%v%v",
scope.QuotedTableName(),
addExtraSpaceIfExist(scope.CombinedConditionSql()),
addExtraSpaceIfExist(extraOption),
)).Exec()
}
}
}
// afterDeleteCallback will invoke `AfterDelete` method after deleting
func afterDeleteCallback(scope *Scope) {
if !scope.HasError() {
scope.CallMethod("AfterDelete")
}
}

View File

@ -1,95 +0,0 @@
package gorm
import (
"errors"
"fmt"
"reflect"
)
// Define callbacks for querying
func init() {
DefaultCallback.Query().Register("gorm:query", queryCallback)
DefaultCallback.Query().Register("gorm:preload", preloadCallback)
DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
}
// queryCallback used to query data from database
func queryCallback(scope *Scope) {
defer scope.trace(NowFunc())
var (
isSlice, isPtr bool
resultType reflect.Type
results = scope.IndirectValue()
)
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
if primaryField := scope.PrimaryField(); primaryField != nil {
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
}
}
if value, ok := scope.Get("gorm:query_destination"); ok {
results = indirect(reflect.ValueOf(value))
}
if kind := results.Kind(); kind == reflect.Slice {
isSlice = true
resultType = results.Type().Elem()
results.Set(reflect.MakeSlice(results.Type(), 0, 0))
if resultType.Kind() == reflect.Ptr {
isPtr = true
resultType = resultType.Elem()
}
} else if kind != reflect.Struct {
scope.Err(errors.New("unsupported destination, should be slice or struct"))
return
}
scope.prepareQuerySQL()
if !scope.HasError() {
scope.db.RowsAffected = 0
if str, ok := scope.Get("gorm:query_option"); ok {
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
}
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
defer rows.Close()
columns, _ := rows.Columns()
for rows.Next() {
scope.db.RowsAffected++
elem := results
if isSlice {
elem = reflect.New(resultType).Elem()
}
scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())
if isSlice {
if isPtr {
results.Set(reflect.Append(results, elem.Addr()))
} else {
results.Set(reflect.Append(results, elem))
}
}
}
if err := rows.Err(); err != nil {
scope.Err(err)
} else if scope.db.RowsAffected == 0 && !isSlice {
scope.Err(ErrRecordNotFound)
}
}
}
}
// afterQueryCallback will invoke `AfterFind` method after querying
func afterQueryCallback(scope *Scope) {
if !scope.HasError() {
scope.CallMethod("AfterFind")
}
}

View File

@ -1,380 +0,0 @@
package gorm
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
)
// preloadCallback used to preload associations
func preloadCallback(scope *Scope) {
if _, ok := scope.Get("gorm:auto_preload"); ok {
autoPreload(scope)
}
if scope.Search.preload == nil || scope.HasError() {
return
}
var (
preloadedMap = map[string]bool{}
fields = scope.Fields()
)
for _, preload := range scope.Search.preload {
var (
preloadFields = strings.Split(preload.schema, ".")
currentScope = scope
currentFields = fields
)
for idx, preloadField := range preloadFields {
var currentPreloadConditions []interface{}
if currentScope == nil {
continue
}
// if not preloaded
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
// assign search conditions to last preload
if idx == len(preloadFields)-1 {
currentPreloadConditions = preload.conditions
}
for _, field := range currentFields {
if field.Name != preloadField || field.Relationship == nil {
continue
}
switch field.Relationship.Kind {
case "has_one":
currentScope.handleHasOnePreload(field, currentPreloadConditions)
case "has_many":
currentScope.handleHasManyPreload(field, currentPreloadConditions)
case "belongs_to":
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
case "many_to_many":
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
default:
scope.Err(errors.New("unsupported relation"))
}
preloadedMap[preloadKey] = true
break
}
if !preloadedMap[preloadKey] {
scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
return
}
}
// preload next level
if idx < len(preloadFields)-1 {
currentScope = currentScope.getColumnAsScope(preloadField)
if currentScope != nil {
currentFields = currentScope.Fields()
}
}
}
}
}
func autoPreload(scope *Scope) {
for _, field := range scope.Fields() {
if field.Relationship == nil {
continue
}
if val, ok := field.TagSettings["PRELOAD"]; ok {
if preload, err := strconv.ParseBool(val); err != nil {
scope.Err(errors.New("invalid preload option"))
return
} else if !preload {
continue
}
}
scope.Search.Preload(field.Name)
}
}
func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
var (
preloadDB = scope.NewDB()
preloadConditions []interface{}
)
for _, condition := range conditions {
if scopes, ok := condition.(func(*DB) *DB); ok {
preloadDB = scopes(preloadDB)
} else {
preloadConditions = append(preloadConditions, condition)
}
}
return preloadDB, preloadConditions
}
// handleHasOnePreload used to preload has one associations
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
relation := field.Relationship
// get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 {
return
}
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
// find relations
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
values := toQueryValues(primaryKeys)
if relation.PolymorphicType != "" {
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
values = append(values, relation.PolymorphicValue)
}
results := makeSlice(field.Struct.Type)
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
// assign find results
var (
resultsValue = indirect(reflect.ValueOf(results))
indirectScopeValue = scope.IndirectValue()
)
if indirectScopeValue.Kind() == reflect.Slice {
for j := 0; j < indirectScopeValue.Len(); j++ {
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
indirectValue.FieldByName(field.Name).Set(result)
break
}
}
}
} else {
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
scope.Err(field.Set(result))
}
}
}
// handleHasManyPreload used to preload has many associations
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
relation := field.Relationship
// get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 {
return
}
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
// find relations
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
values := toQueryValues(primaryKeys)
if relation.PolymorphicType != "" {
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
values = append(values, relation.PolymorphicValue)
}
results := makeSlice(field.Struct.Type)
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
// assign find results
var (
resultsValue = indirect(reflect.ValueOf(results))
indirectScopeValue = scope.IndirectValue()
)
if indirectScopeValue.Kind() == reflect.Slice {
preloadMap := make(map[string][]reflect.Value)
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
}
for j := 0; j < indirectScopeValue.Len(); j++ {
object := indirect(indirectScopeValue.Index(j))
objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
f := object.FieldByName(field.Name)
if results, ok := preloadMap[toString(objectRealValue)]; ok {
f.Set(reflect.Append(f, results...))
} else {
f.Set(reflect.MakeSlice(f.Type(), 0, 0))
}
}
} else {
scope.Err(field.Set(resultsValue))
}
}
// handleBelongsToPreload used to preload belongs to associations
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
relation := field.Relationship
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
// get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 {
return
}
// find relations
results := makeSlice(field.Struct.Type)
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
// assign find results
var (
resultsValue = indirect(reflect.ValueOf(results))
indirectScopeValue = scope.IndirectValue()
)
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
if indirectScopeValue.Kind() == reflect.Slice {
value := getValueFromFields(result, relation.AssociationForeignFieldNames)
for j := 0; j < indirectScopeValue.Len(); j++ {
object := indirect(indirectScopeValue.Index(j))
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
object.FieldByName(field.Name).Set(result)
}
}
} else {
scope.Err(field.Set(result))
}
}
}
// handleManyToManyPreload used to preload many to many associations
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
var (
relation = field.Relationship
joinTableHandler = relation.JoinTableHandler
fieldType = field.Struct.Type.Elem()
foreignKeyValue interface{}
foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
linkHash = map[string][]reflect.Value{}
isPtr bool
)
if fieldType.Kind() == reflect.Ptr {
isPtr = true
fieldType = fieldType.Elem()
}
var sourceKeys = []string{}
for _, key := range joinTableHandler.SourceForeignKeys() {
sourceKeys = append(sourceKeys, key.DBName)
}
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
// generate query with join table
newScope := scope.New(reflect.New(fieldType).Interface())
preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
if len(preloadDB.search.selects) == 0 {
preloadDB = preloadDB.Select("*")
}
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
// preload inline conditions
if len(preloadConditions) > 0 {
preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
}
rows, err := preloadDB.Rows()
if scope.Err(err) != nil {
return
}
defer rows.Close()
columns, _ := rows.Columns()
for rows.Next() {
var (
elem = reflect.New(fieldType).Elem()
fields = scope.New(elem.Addr().Interface()).Fields()
)
// register foreign keys in join tables
var joinTableFields []*Field
for _, sourceKey := range sourceKeys {
joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
}
scope.scan(rows, columns, append(fields, joinTableFields...))
var foreignKeys = make([]interface{}, len(sourceKeys))
// generate hashed forkey keys in join table
for idx, joinTableField := range joinTableFields {
if !joinTableField.Field.IsNil() {
foreignKeys[idx] = joinTableField.Field.Elem().Interface()
}
}
hashedSourceKeys := toString(foreignKeys)
if isPtr {
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
} else {
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
}
}
if err := rows.Err(); err != nil {
scope.Err(err)
}
// assign find results
var (
indirectScopeValue = scope.IndirectValue()
fieldsSourceMap = map[string][]reflect.Value{}
foreignFieldNames = []string{}
)
for _, dbName := range relation.ForeignFieldNames {
if field, ok := scope.FieldByName(dbName); ok {
foreignFieldNames = append(foreignFieldNames, field.Name)
}
}
if indirectScopeValue.Kind() == reflect.Slice {
for j := 0; j < indirectScopeValue.Len(); j++ {
object := indirect(indirectScopeValue.Index(j))
key := toString(getValueFromFields(object, foreignFieldNames))
fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
}
} else if indirectScopeValue.IsValid() {
key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
}
for source, link := range linkHash {
for i, field := range fieldsSourceMap[source] {
//If not 0 this means Value is a pointer and we already added preloaded models to it
if fieldsSourceMap[source][i].Len() != 0 {
continue
}
field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
}
}
}

View File

@ -1,30 +0,0 @@
package gorm
import "database/sql"
// Define callbacks for row query
func init() {
DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback)
}
type RowQueryResult struct {
Row *sql.Row
}
type RowsQueryResult struct {
Rows *sql.Rows
Error error
}
// queryCallback used to query data from database
func rowQueryCallback(scope *Scope) {
if result, ok := scope.InstanceGet("row_query_result"); ok {
scope.prepareQuerySQL()
if rowResult, ok := result.(*RowQueryResult); ok {
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
} else if rowsResult, ok := result.(*RowsQueryResult); ok {
rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
}
}
}

View File

@ -1,99 +0,0 @@
package gorm
import "reflect"
func beginTransactionCallback(scope *Scope) {
scope.Begin()
}
func commitOrRollbackTransactionCallback(scope *Scope) {
scope.CommitOrRollback()
}
func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) {
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") {
if relationship := field.Relationship; relationship != nil {
return true, relationship
}
}
}
return false, nil
}
func saveBeforeAssociationsCallback(scope *Scope) {
if !scope.shouldSaveAssociations() {
return
}
for _, field := range scope.Fields() {
if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship.Kind == "belongs_to" {
fieldValue := field.Field.Addr().Interface()
scope.Err(scope.NewDB().Save(fieldValue).Error)
if len(relationship.ForeignFieldNames) != 0 {
// set value's foreign key
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok {
scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface()))
}
}
}
}
}
}
func saveAfterAssociationsCallback(scope *Scope) {
if !scope.shouldSaveAssociations() {
return
}
for _, field := range scope.Fields() {
if ok, relationship := saveFieldAsAssociation(scope, field); ok &&
(relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
value := field.Field
switch value.Kind() {
case reflect.Slice:
for i := 0; i < value.Len(); i++ {
newDB := scope.NewDB()
elem := value.Index(i).Addr().Interface()
newScope := newDB.NewScope(elem)
if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if f, ok := scope.FieldByName(associationForeignName); ok {
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
}
}
}
if relationship.PolymorphicType != "" {
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
}
scope.Err(newDB.Save(elem).Error)
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
}
}
default:
elem := value.Addr().Interface()
newScope := scope.New(elem)
if len(relationship.ForeignFieldNames) != 0 {
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if f, ok := scope.FieldByName(associationForeignName); ok {
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
}
}
}
if relationship.PolymorphicType != "" {
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
}
scope.Err(scope.NewDB().Save(elem).Error)
}
}
}
}

View File

@ -1,112 +0,0 @@
package gorm
import (
"reflect"
"runtime"
"strings"
"testing"
)
func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
var names []string
for _, f := range funcs {
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
names = append(names, fnames[len(fnames)-1])
}
return reflect.DeepEqual(names, fnames)
}
func create(s *Scope) {}
func beforeCreate1(s *Scope) {}
func beforeCreate2(s *Scope) {}
func afterCreate1(s *Scope) {}
func afterCreate2(s *Scope) {}
func TestRegisterCallback(t *testing.T) {
var callback = &Callback{}
callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("before_create2", beforeCreate2)
callback.Create().Register("create", create)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Register("after_create2", afterCreate2)
if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
t.Errorf("register callback")
}
}
func TestRegisterCallbackWithOrder(t *testing.T) {
var callback1 = &Callback{}
callback1.Create().Register("before_create1", beforeCreate1)
callback1.Create().Register("create", create)
callback1.Create().Register("after_create1", afterCreate1)
callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
t.Errorf("register callback with order")
}
var callback2 = &Callback{}
callback2.Update().Register("create", create)
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
callback2.Update().Register("after_create2", afterCreate2)
if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
t.Errorf("register callback with order")
}
}
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
var callback1 = &Callback{}
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
callback1.Query().Register("before_create1", beforeCreate1)
callback1.Query().Register("after_create1", afterCreate1)
if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
t.Errorf("register callback with order")
}
var callback2 = &Callback{}
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
callback2.Delete().Register("after_create1", afterCreate1)
callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
t.Errorf("register callback with order")
}
}
func replaceCreate(s *Scope) {}
func TestReplaceCallback(t *testing.T) {
var callback = &Callback{}
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Replace("create", replaceCreate)
if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
t.Errorf("replace callback")
}
}
func TestRemoveCallback(t *testing.T) {
var callback = &Callback{}
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Remove("create")
if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
t.Errorf("remove callback")
}
}

View File

@ -1,117 +0,0 @@
package gorm
import (
"errors"
"fmt"
"strings"
)
// Define callbacks for updating
func init() {
DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
DefaultCallback.Update().Register("gorm:update", updateCallback)
DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
}
// assignUpdatingAttributesCallback assign updating attributes to model
func assignUpdatingAttributesCallback(scope *Scope) {
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
scope.InstanceSet("gorm:update_attrs", updateMaps)
} else {
scope.SkipLeft()
}
}
}
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
func beforeUpdateCallback(scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("Missing WHERE clause while updating"))
return
}
if _, ok := scope.Get("gorm:update_column"); !ok {
if !scope.HasError() {
scope.CallMethod("BeforeSave")
}
if !scope.HasError() {
scope.CallMethod("BeforeUpdate")
}
}
}
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
func updateTimeStampForUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.SetColumn("UpdatedAt", NowFunc())
}
}
// updateCallback the callback used to update data to database
func updateCallback(scope *Scope) {
if !scope.HasError() {
var sqls []string
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
for column, value := range updateAttrs.(map[string]interface{}) {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
}
} else {
for _, field := range scope.Fields() {
if scope.changeableField(field) {
if !field.IsPrimaryKey && field.IsNormal {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
for _, foreignKey := range relationship.ForeignDBNames {
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
sqls = append(sqls,
fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface())))
}
}
}
}
}
}
var extraOption string
if str, ok := scope.Get("gorm:update_option"); ok {
extraOption = fmt.Sprint(str)
}
if len(sqls) > 0 {
joinSQL := scope.joinsSQL()
whereSQL := scope.whereSQL()
if scope.Search.raw {
whereSQL = strings.TrimSuffix(strings.TrimPrefix(whereSQL, "WHERE ("), ")")
}
combinedSql := whereSQL + scope.groupSQL() +
scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL()
scope.Raw(fmt.Sprintf(
"UPDATE %v %v SET %v%v%v",
scope.QuotedTableName(),
addExtraSpaceIfExist(joinSQL),
strings.Join(sqls, ", "),
addExtraSpaceIfExist(combinedSql),
addExtraSpaceIfExist(extraOption),
)).Exec()
}
}
}
// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
func afterUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
if !scope.HasError() {
scope.CallMethod("AfterUpdate")
}
if !scope.HasError() {
scope.CallMethod("AfterSave")
}
}
}

327
callbacks.go Normal file
View File

@ -0,0 +1,327 @@
package gorm
import (
"context"
"errors"
"fmt"
"reflect"
"sort"
"time"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func initializeCallbacks(db *DB) *callbacks {
return &callbacks{
processors: map[string]*processor{
"create": {db: db},
"query": {db: db},
"update": {db: db},
"delete": {db: db},
"row": {db: db},
"raw": {db: db},
},
}
}
// callbacks gorm callbacks manager
type callbacks struct {
processors map[string]*processor
}
type processor struct {
db *DB
Clauses []string
fns []func(*DB)
callbacks []*callback
}
type callback struct {
name string
before string
after string
remove bool
replace bool
match func(*DB) bool
handler func(*DB)
processor *processor
}
func (cs *callbacks) Create() *processor {
return cs.processors["create"]
}
func (cs *callbacks) Query() *processor {
return cs.processors["query"]
}
func (cs *callbacks) Update() *processor {
return cs.processors["update"]
}
func (cs *callbacks) Delete() *processor {
return cs.processors["delete"]
}
func (cs *callbacks) Row() *processor {
return cs.processors["row"]
}
func (cs *callbacks) Raw() *processor {
return cs.processors["raw"]
}
func (p *processor) Execute(db *DB) {
// call scopes
for len(db.Statement.scopes) > 0 {
scopes := db.Statement.scopes
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
}
}
var (
curTime = time.Now()
stmt = db.Statement
resetBuildClauses bool
)
if len(stmt.BuildClauses) == 0 {
stmt.BuildClauses = p.Clauses
resetBuildClauses = true
}
// assign model values
if stmt.Model == nil {
stmt.Model = stmt.Dest
} else if stmt.Dest == nil {
stmt.Dest = stmt.Model
}
// parse model values
if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
} else {
db.AddError(err)
}
}
}
// assign stmt.ReflectValue
if stmt.Dest != nil {
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
for stmt.ReflectValue.Kind() == reflect.Ptr {
if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() {
stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem()))
}
stmt.ReflectValue = stmt.ReflectValue.Elem()
}
if !stmt.ReflectValue.IsValid() {
db.AddError(ErrInvalidValue)
}
}
for _, f := range p.fns {
f(db)
}
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
}, db.Error)
if !stmt.DB.DryRun {
stmt.SQL.Reset()
stmt.Vars = nil
}
if resetBuildClauses {
stmt.BuildClauses = nil
}
}
func (p *processor) Get(name string) func(*DB) {
for i := len(p.callbacks) - 1; i >= 0; i-- {
if v := p.callbacks[i]; v.name == name && !v.remove {
return v.handler
}
}
return nil
}
func (p *processor) Before(name string) *callback {
return &callback{before: name, processor: p}
}
func (p *processor) After(name string) *callback {
return &callback{after: name, processor: p}
}
func (p *processor) Match(fc func(*DB) bool) *callback {
return &callback{match: fc, processor: p}
}
func (p *processor) Register(name string, fn func(*DB)) error {
return (&callback{processor: p}).Register(name, fn)
}
func (p *processor) Remove(name string) error {
return (&callback{processor: p}).Remove(name)
}
func (p *processor) Replace(name string, fn func(*DB)) error {
return (&callback{processor: p}).Replace(name, fn)
}
func (p *processor) compile() (err error) {
var callbacks []*callback
for _, callback := range p.callbacks {
if callback.match == nil || callback.match(p.db) {
callbacks = append(callbacks, callback)
}
}
p.callbacks = callbacks
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
}
return
}
func (c *callback) Before(name string) *callback {
c.before = name
return c
}
func (c *callback) After(name string) *callback {
c.after = name
return c
}
func (c *callback) Register(name string, fn func(*DB)) error {
c.name = name
c.handler = fn
c.processor.callbacks = append(c.processor.callbacks, c)
return c.processor.compile()
}
func (c *callback) Remove(name string) error {
c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.name = name
c.remove = true
c.processor.callbacks = append(c.processor.callbacks, c)
return c.processor.compile()
}
func (c *callback) Replace(name string, fn func(*DB)) error {
c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.name = name
c.handler = fn
c.replace = true
c.processor.callbacks = append(c.processor.callbacks, c)
return c.processor.compile()
}
// getRIndex get right index from string slice
func getRIndex(strs []string, str string) int {
for i := len(strs) - 1; i >= 0; i-- {
if strs[i] == str {
return i
}
}
return -1
}
func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
var (
names, sorted []string
sortCallback func(*callback) error
)
sort.Slice(cs, func(i, j int) bool {
return cs[j].before == "*" || cs[j].after == "*"
})
for _, c := range cs {
// show warning message the callback name already exists
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
}
names = append(names, c.name)
}
sortCallback = func(c *callback) error {
if c.before != "" { // if defined before callback
if c.before == "*" && len(sorted) > 0 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
sorted = append([]string{c.name}, sorted...)
}
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// if before callback already sorted, append current callback just after it
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
} else if curIdx > sortedIdx {
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before)
}
} else if idx := getRIndex(names, c.before); idx != -1 {
// if before callback exists
cs[idx].after = c.name
}
}
if c.after != "" { // if defined after callback
if c.after == "*" && len(sorted) > 0 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
sorted = append(sorted, c.name)
}
} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// if after callback sorted, append current callback to last
sorted = append(sorted, c.name)
} else if curIdx < sortedIdx {
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after)
}
} else if idx := getRIndex(names, c.after); idx != -1 {
// if after callback exists but haven't sorted
// set after callback's before callback to current callback
after := cs[idx]
if after.before == "" {
after.before = c.name
}
if err := sortCallback(after); err != nil {
return err
}
if err := sortCallback(c); err != nil {
return err
}
}
}
// if current callback haven't been sorted, append it to last
if getRIndex(sorted, c.name) == -1 {
sorted = append(sorted, c.name)
}
return nil
}
for _, c := range cs {
if err = sortCallback(c); err != nil {
return
}
}
for _, name := range sorted {
if idx := getRIndex(names, name); !cs[idx].remove {
fns = append(fns, cs[idx].handler)
}
}
return
}

390
callbacks/associations.go Normal file
View File

@ -0,0 +1,390 @@
package callbacks
import (
"reflect"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
return func(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
// Save Belongs To associations
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
setupReferences := func(obj reflect.Value, elem reflect.Value) {
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(elem)
db.AddError(ref.ForeignKey.Set(obj, pv))
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
dest[ref.ForeignKey.DBName] = pv
if _, ok := dest[rel.Name]; ok {
dest[rel.Name] = elem.Interface()
}
}
}
}
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
objs = make([]reflect.Value, 0, db.Statement.ReflectValue.Len())
fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr
)
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
objs = append(objs, obj)
if isPtr {
elems = reflect.Append(elems, rv)
} else {
elems = reflect.Append(elems, rv.Addr())
}
}
} else {
break
}
}
if elems.Len() > 0 {
if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil {
for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i))
}
}
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil {
setupReferences(db.Statement.ReflectValue, rv)
}
}
}
}
}
}
}
func SaveAfterAssociations(create bool) func(db *gorm.DB) {
return func(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
// Save Has One associations
for _, rel := range db.Statement.Schema.Relationships.HasOne {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr
)
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(obj); !zero {
rv := rel.Field.ReflectValueOf(obj)
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj)
db.AddError(ref.ForeignKey.Set(rv, fv))
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue))
}
}
elems = reflect.Append(elems, rv)
}
}
}
if elems.Len() > 0 {
assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if f.Kind() != reflect.Ptr {
f = f.Addr()
}
assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue)
ref.ForeignKey.Set(f, fv)
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(f, ref.PrimaryValue)
}
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns)
}
}
}
// Save Has Many associations
for _, rel := range db.Statement.Schema.Relationships.HasMany {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(v)
ref.ForeignKey.Set(elem, pv)
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(elem, ref.PrimaryValue)
}
}
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
}
}
}
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj)
}
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
}
if elems.Len() > 0 {
assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
}
}
// Save Many2Many associations
for _, rel := range db.Statement.Schema.Relationships.Many2Many {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
objs := []reflect.Value{}
appendToJoins := func(obj reflect.Value, elem reflect.Value) {
joinValue := reflect.New(rel.JoinTable.ModelType)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj)
ref.ForeignKey.Set(joinValue, fv)
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(joinValue, ref.PrimaryValue)
} else {
fv, _ := ref.PrimaryKey.ValueOf(elem)
ref.ForeignKey.Set(joinValue, fv)
}
}
joins = reflect.Append(joins, joinValue)
}
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
objs = append(objs, v)
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
}
}
}
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj)
}
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
}
// optimize elems of reflect value length
if elemLen := elems.Len(); elemLen > 0 {
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil)
}
for i := 0; i < elemLen; i++ {
appendToJoins(objs[i], elems.Index(i))
}
}
if joins.Len() > 0 {
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{
SkipHooks: db.Statement.SkipHooks,
DisableNestedTransaction: true,
}).Create(joins.Interface()).Error)
}
}
}
}
}
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) clause.OnConflict {
if stmt.DB.FullSaveAssociations {
defaultUpdatingColumns = make([]string, 0, len(s.DBNames))
for _, dbName := range s.DBNames {
if v, ok := selectColumns[dbName]; (ok && !v) || (!ok && restricted) {
continue
}
if !s.LookUpField(dbName).PrimaryKey {
defaultUpdatingColumns = append(defaultUpdatingColumns, dbName)
}
}
}
if len(defaultUpdatingColumns) > 0 {
columns := make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
for _, dbName := range s.PrimaryFieldDBNames {
columns = append(columns, clause.Column{Name: dbName})
}
return clause.OnConflict{
Columns: columns,
DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns),
}
}
return clause.OnConflict{DoNothing: true}
}
func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
var (
selects, omits []string
onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns)
refName = rel.Name + "."
)
for name, ok := range selectColumns {
columnName := ""
if strings.HasPrefix(name, refName) {
columnName = strings.TrimPrefix(name, refName)
}
if columnName != "" {
if ok {
selects = append(selects, columnName)
} else {
omits = append(omits, columnName)
}
}
}
tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{
FullSaveAssociations: db.FullSaveAssociations,
SkipHooks: db.Statement.SkipHooks,
DisableNestedTransaction: true,
})
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if tx.Statement.FullSaveAssociations {
tx = tx.InstanceSet("gorm:update_track_time", true)
}
if len(selects) > 0 {
tx = tx.Select(selects)
} else if restricted && len(omits) == 0 {
tx = tx.Omit(clause.Associations)
}
if len(omits) > 0 {
tx = tx.Omit(omits...)
}
return db.AddError(tx.Create(values).Error)
}

83
callbacks/callbacks.go Normal file
View File

@ -0,0 +1,83 @@
package callbacks
import (
"gorm.io/gorm"
)
var (
createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
updateClauses = []string{"UPDATE", "SET", "WHERE"}
deleteClauses = []string{"DELETE", "FROM", "WHERE"}
)
type Config struct {
LastInsertIDReversed bool
WithReturning bool
CreateClauses []string
QueryClauses []string
UpdateClauses []string
DeleteClauses []string
}
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
enableTransaction := func(db *gorm.DB) bool {
return !db.SkipDefaultTransaction
}
createCallback := db.Callback().Create()
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
createCallback.Register("gorm:before_create", BeforeCreate)
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true))
createCallback.Register("gorm:create", Create(config))
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.CreateClauses) == 0 {
config.CreateClauses = createClauses
}
createCallback.Clauses = config.CreateClauses
queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery)
if len(config.QueryClauses) == 0 {
config.QueryClauses = queryClauses
}
queryCallback.Clauses = config.QueryClauses
deleteCallback := db.Callback().Delete()
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
deleteCallback.Register("gorm:before_delete", BeforeDelete)
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
deleteCallback.Register("gorm:delete", Delete)
deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.DeleteClauses) == 0 {
config.DeleteClauses = deleteClauses
}
deleteCallback.Clauses = config.DeleteClauses
updateCallback := db.Callback().Update()
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
updateCallback.Register("gorm:before_update", BeforeUpdate)
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
updateCallback.Register("gorm:update", Update)
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
updateCallback.Register("gorm:after_update", AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.UpdateClauses) == 0 {
config.UpdateClauses = updateClauses
}
updateCallback.Clauses = config.UpdateClauses
rowCallback := db.Callback().Row()
rowCallback.Register("gorm:row", RowQuery)
rowCallback.Clauses = config.QueryClauses
rawCallback := db.Callback().Raw()
rawCallback.Register("gorm:raw", RawExec)
rawCallback.Clauses = config.QueryClauses
}

23
callbacks/callmethod.go Normal file
View File

@ -0,0 +1,23 @@
package callbacks
import (
"reflect"
"gorm.io/gorm"
)
func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
tx := db.Session(&gorm.Session{NewDB: true})
if called := fc(db.Statement.ReflectValue.Interface(), tx); !called {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
db.Statement.CurDestIndex = 0
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx)
db.Statement.CurDestIndex++
}
case reflect.Struct:
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
}
}
}

370
callbacks/create.go Normal file
View File

@ -0,0 +1,370 @@
package callbacks
import (
"fmt"
"reflect"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
func BeforeCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.BeforeSave {
if i, ok := value.(BeforeSaveInterface); ok {
called = true
db.AddError(i.BeforeSave(tx))
}
}
if db.Statement.Schema.BeforeCreate {
if i, ok := value.(BeforeCreateInterface); ok {
called = true
db.AddError(i.BeforeCreate(tx))
}
}
return called
})
}
}
func Create(config *Config) func(db *gorm.DB) {
if config.WithReturning {
return CreateWithReturning
} else {
return func(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Insert{})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build(db.Statement.BuildClauses...)
}
if !db.DryRun && db.Error == nil {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
db.RowsAffected, _ = result.RowsAffected()
if db.RowsAffected > 0 {
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if config.LastInsertIDReversed {
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
if isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
} else {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
}
case reflect.Struct:
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
}
}
} else {
db.AddError(err)
}
}
}
} else {
db.AddError(err)
}
}
}
}
}
}
func CreateWithReturning(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Insert{})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build(db.Statement.BuildClauses...)
}
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
db.Statement.WriteString(" RETURNING ")
var (
fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue))
values = make([]interface{}, len(sch.FieldsWithDefaultDBValue))
)
for idx, field := range sch.FieldsWithDefaultDBValue {
if idx > 0 {
db.Statement.WriteByte(',')
}
fields[idx] = field
db.Statement.WriteQuoted(field.DBName)
}
if !db.DryRun && db.Error == nil {
db.RowsAffected = 0
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
defer rows.Close()
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
c := db.Statement.Clauses["ON CONFLICT"]
onConflict, _ := c.Expression.(clause.OnConflict)
for rows.Next() {
BEGIN:
reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected))
if reflect.Indirect(reflectValue).Kind() != reflect.Struct {
break
}
for idx, field := range fields {
fieldValue := field.ReflectValueOf(reflectValue)
if onConflict.DoNothing && !fieldValue.IsZero() {
db.RowsAffected++
if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() {
return
}
goto BEGIN
}
values[idx] = fieldValue.Addr().Interface()
}
db.RowsAffected++
if err := rows.Scan(values...); err != nil {
db.AddError(err)
}
}
case reflect.Struct:
for idx, field := range fields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
}
if rows.Next() {
db.RowsAffected++
db.AddError(rows.Scan(values...))
}
}
} else {
db.AddError(err)
}
}
} else if !db.DryRun && db.Error == nil {
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
}
}
}
func AfterCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
called = true
db.AddError(i.AfterSave(tx))
}
}
if db.Statement.Schema.AfterCreate {
if i, ok := value.(AfterCreateInterface); ok {
called = true
db.AddError(i.AfterCreate(tx))
}
}
return called
})
}
}
// ConvertToCreateValues convert to create values
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
switch value := stmt.Dest.(type) {
case map[string]interface{}:
values = ConvertMapToValuesForCreate(stmt, value)
case *map[string]interface{}:
values = ConvertMapToValuesForCreate(stmt, *value)
case []map[string]interface{}:
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
case *[]map[string]interface{}:
values = ConvertSliceOfMapToValuesForCreate(stmt, *value)
default:
var (
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
curTime = stmt.DB.NowFunc()
isZero bool
)
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
for _, db := range stmt.Schema.DBNames {
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) {
values.Columns = append(values.Columns, clause.Column{Name: db})
}
}
}
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
stmt.SQL.Grow(stmt.ReflectValue.Len() * 18)
values.Values = make([][]interface{}, stmt.ReflectValue.Len())
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
if stmt.ReflectValue.Len() == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
for i := 0; i < stmt.ReflectValue.Len(); i++ {
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
if !rv.IsValid() {
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
return
}
values.Values[i] = make([]interface{}, len(values.Columns))
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[i][idx], isZero = field.ValueOf(rv); isZero {
if field.DefaultValueInterface != nil {
values.Values[i][idx] = field.DefaultValueInterface
field.Set(rv, field.DefaultValueInterface)
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(rv, curTime)
values.Values[i][idx], _ = field.ValueOf(rv)
}
} else if field.AutoUpdateTime > 0 {
if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok {
field.Set(rv, curTime)
values.Values[i][idx], _ = field.ValueOf(rv)
}
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, isZero := field.ValueOf(rv); !isZero {
if len(defaultValueFieldsHavingValue[field]) == 0 {
defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len())
}
defaultValueFieldsHavingValue[field][i] = v
}
}
}
}
for field, vs := range defaultValueFieldsHavingValue {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
for idx := range values.Values {
if vs[idx] == nil {
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
} else {
values.Values[idx] = append(values.Values[idx], vs[idx])
}
}
}
case reflect.Struct:
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero {
if field.DefaultValueInterface != nil {
values.Values[0][idx] = field.DefaultValueInterface
field.Set(stmt.ReflectValue, field.DefaultValueInterface)
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(stmt.ReflectValue, curTime)
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
}
} else if field.AutoUpdateTime > 0 {
if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok {
field.Set(stmt.ReflectValue, curTime)
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
}
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], v)
}
}
}
default:
stmt.AddError(gorm.ErrInvalidData)
}
}
if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
if stmt.Schema != nil && len(values.Columns) > 1 {
columns := make([]string, 0, len(values.Columns)-1)
for _, column := range values.Columns {
if field := stmt.Schema.LookUpField(column.Name); field != nil {
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
columns = append(columns, column.Name)
}
}
}
onConflict.DoUpdates = clause.AssignmentColumns(columns)
// use primary fields as default OnConflict columns
if len(onConflict.Columns) == 0 {
for _, field := range stmt.Schema.PrimaryFields {
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName})
}
}
stmt.AddClause(onConflict)
}
}
}
return values
}

168
callbacks/delete.go Normal file
View File

@ -0,0 +1,168 @@
package callbacks
import (
"reflect"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
func BeforeDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(BeforeDeleteInterface); ok {
db.AddError(i.BeforeDelete(tx))
return true
}
return false
})
}
}
func DeleteBeforeAssociations(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
if restricted {
for column, v := range selectColumns {
if v {
if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok {
switch rel.Type {
case schema.HasOne, schema.HasMany:
queryConds := rel.ToQueryConditions(db.Statement.ReflectValue)
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
withoutConditions := false
if db.Statement.Unscoped {
tx = tx.Unscoped()
}
if len(db.Statement.Selects) > 0 {
selects := make([]string, 0, len(db.Statement.Selects))
for _, s := range db.Statement.Selects {
if s == clause.Associations {
selects = append(selects, s)
} else if strings.HasPrefix(s, column+".") {
selects = append(selects, strings.TrimPrefix(s, column+"."))
}
}
if len(selects) > 0 {
tx = tx.Select(selects)
}
}
for _, cond := range queryConds {
if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
withoutConditions = true
break
}
}
if !withoutConditions {
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
}
case schema.Many2Many:
var (
queryConds = make([]clause.Expression, 0, len(rel.References))
foreignFields = make([]*schema.Field, 0, len(rel.References))
relForeignKeys = make([]string, 0, len(rel.References))
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
table = rel.JoinTable.Table
tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table)
)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey)
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
} else if ref.PrimaryValue != "" {
queryConds = append(queryConds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
}
}
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields)
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
}
}
}
}
}
}
}
func Delete(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.DeleteClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
db.Statement.SQL.Grow(100)
db.Statement.AddClauseIfNotExists(clause.Delete{})
if db.Statement.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
}
}
db.Statement.AddClauseIfNotExists(clause.From{})
db.Statement.Build(db.Statement.BuildClauses...)
}
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil {
db.AddError(gorm.ErrMissingWhereClause)
return
}
if !db.DryRun && db.Error == nil {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
}
}
}
func AfterDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(AfterDeleteInterface); ok {
db.AddError(i.AfterDelete(tx))
return true
}
return false
})
}
}

95
callbacks/helper.go Normal file
View File

@ -0,0 +1,95 @@
package callbacks
import (
"sort"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// ConvertMapToValuesForCreate convert map to values
func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) {
values.Columns = make([]clause.Column, 0, len(mapValue))
selectColumns, restricted := stmt.SelectAndOmitColumns(true, false)
var keys = make([]string, 0, len(mapValue))
for k := range mapValue {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
value := mapValue[k]
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil {
k = field.DBName
}
}
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
values.Columns = append(values.Columns, clause.Column{Name: k})
if len(values.Values) == 0 {
values.Values = [][]interface{}{{}}
}
values.Values[0] = append(values.Values[0], value)
}
}
return
}
// ConvertSliceOfMapToValuesForCreate convert slice of map to values
func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) {
var (
columns = make([]string, 0, len(mapValues))
)
// when the length of mapValues,return directly here
// no need to call stmt.SelectAndOmitColumns method
if len(mapValues) == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
var (
result = make(map[string][]interface{}, len(mapValues))
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
)
for idx, mapValue := range mapValues {
for k, v := range mapValue {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil {
k = field.DBName
}
}
if _, ok := result[k]; !ok {
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
result[k] = make([]interface{}, len(mapValues))
columns = append(columns, k)
} else {
continue
}
}
result[k][idx] = v
}
}
sort.Strings(columns)
values.Values = make([][]interface{}, len(mapValues))
values.Columns = make([]clause.Column, len(columns))
for idx, column := range columns {
values.Columns[idx] = clause.Column{Name: column}
for i, v := range result[column] {
if len(values.Values[i]) == 0 {
values.Values[i] = make([]interface{}, len(columns))
}
values.Values[i][idx] = v
}
}
return
}

39
callbacks/interfaces.go Normal file
View File

@ -0,0 +1,39 @@
package callbacks
import "gorm.io/gorm"
type BeforeCreateInterface interface {
BeforeCreate(*gorm.DB) error
}
type AfterCreateInterface interface {
AfterCreate(*gorm.DB) error
}
type BeforeUpdateInterface interface {
BeforeUpdate(*gorm.DB) error
}
type AfterUpdateInterface interface {
AfterUpdate(*gorm.DB) error
}
type BeforeSaveInterface interface {
BeforeSave(*gorm.DB) error
}
type AfterSaveInterface interface {
AfterSave(*gorm.DB) error
}
type BeforeDeleteInterface interface {
BeforeDelete(*gorm.DB) error
}
type AfterDeleteInterface interface {
AfterDelete(*gorm.DB) error
}
type AfterFindInterface interface {
AfterFind(*gorm.DB) error
}

164
callbacks/preload.go Normal file
View File

@ -0,0 +1,164 @@
package callbacks
import (
"reflect"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) {
var (
reflectValue = db.Statement.ReflectValue
tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})
relForeignKeys []string
relForeignFields []*schema.Field
foreignFields []*schema.Field
foreignValues [][]interface{}
identityMap = map[string][]reflect.Value{}
inlineConds []interface{}
)
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if rel.JoinTable != nil {
var (
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
joinRelForeignFields = make([]*schema.Field, 0, len(rel.References))
joinForeignKeys = make([]string, 0, len(rel.References))
)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName)
joinForeignFields = append(joinForeignFields, ref.ForeignKey)
foreignFields = append(foreignFields, ref.PrimaryKey)
} else if ref.PrimaryValue != "" {
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
} else {
joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey)
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
relForeignFields = append(relForeignFields, ref.PrimaryKey)
}
}
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
if len(joinForeignValues) == 0 {
return
}
joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error)
// convert join identity map to relation identity map
fieldValues := make([]interface{}, len(joinForeignFields))
joinFieldValues := make([]interface{}, len(joinRelForeignFields))
for i := 0; i < joinResults.Len(); i++ {
for idx, field := range joinForeignFields {
fieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
}
for idx, field := range joinRelForeignFields {
joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
}
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
joinKey := utils.ToStringKey(joinFieldValues...)
identityMap[joinKey] = append(identityMap[joinKey], results...)
}
}
_, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields)
} else {
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
relForeignFields = append(relForeignFields, ref.ForeignKey)
foreignFields = append(foreignFields, ref.PrimaryKey)
} else if ref.PrimaryValue != "" {
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
} else {
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
relForeignFields = append(relForeignFields, ref.PrimaryKey)
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
if len(foreignValues) == 0 {
return
}
}
// nested preload
for p, pvs := range preloads {
tx = tx.Preload(p, pvs...)
}
reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx)
} else {
inlineConds = append(inlineConds, cond)
}
}
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error)
fieldValues := make([]interface{}, len(relForeignFields))
// clean up old values before preloading
switch reflectValue.Kind() {
case reflect.Struct:
switch rel.Type {
case schema.HasMany, schema.Many2Many:
rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
default:
rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface())
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type {
case schema.HasMany, schema.Many2Many:
rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
default:
rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
}
}
}
for i := 0; i < reflectResults.Len(); i++ {
elem := reflectResults.Index(i)
for idx, field := range relForeignFields {
fieldValues[idx], _ = field.ValueOf(elem)
}
for _, data := range identityMap[utils.ToStringKey(fieldValues...)] {
reflectFieldValue := rel.Field.ReflectValueOf(data)
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
}
reflectFieldValue = reflect.Indirect(reflectFieldValue)
switch reflectFieldValue.Kind() {
case reflect.Struct:
rel.Field.Set(data, reflectResults.Index(i).Interface())
case reflect.Slice, reflect.Array:
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface())
} else {
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())
}
}
}
}
}

228
callbacks/query.go Normal file
View File

@ -0,0 +1,228 @@
package callbacks
import (
"fmt"
"reflect"
"sort"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
func Query(db *gorm.DB) {
if db.Error == nil {
BuildQuerySQL(db)
if !db.DryRun && db.Error == nil {
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
db.AddError(err)
return
}
defer rows.Close()
gorm.Scan(rows, db, false)
}
}
}
func BuildQuerySQL(db *gorm.DB) {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.QueryClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
db.Statement.SQL.Grow(100)
clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
var conds []clause.Expression
for _, primaryField := range db.Statement.Schema.PrimaryFields {
if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero {
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
}
}
if len(conds) > 0 {
db.Statement.AddClause(clause.Where{Exprs: conds})
}
}
if len(db.Statement.Selects) > 0 {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
for idx, name := range db.Statement.Selects {
if db.Statement.Schema == nil {
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
} else if f := db.Statement.Schema.LookUpField(name); f != nil {
clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
} else {
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
}
}
} else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 {
selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false)
clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames))
for _, dbName := range db.Statement.Schema.DBNames {
if v, ok := selectColumns[dbName]; (ok && v) || !ok {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName})
}
}
} else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() {
queryFields := db.QueryFields
if !queryFields {
switch db.Statement.ReflectValue.Kind() {
case reflect.Struct:
queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType
case reflect.Slice:
queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType
}
}
if queryFields {
stmt := gorm.Statement{DB: db}
// smaller struct
if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) {
clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))
for idx, dbName := range stmt.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
}
}
}
}
// inline joins
if len(db.Statement.Joins) != 0 {
if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
for idx, dbName := range db.Statement.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
}
}
joins := []clause.Join{}
if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
joins = fromClause.Joins
}
for _, join := range db.Statement.Joins {
if db.Statement.Schema == nil {
joins = append(joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
tableAliasName := relation.Name
for _, s := range relation.FieldSchema.DBNames {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName,
Name: s,
Alias: tableAliasName + "__" + s,
})
}
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
}
} else {
if ref.PrimaryValue == "" {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}
} else {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
}
}
}
}
joins = append(joins, clause.Join{
Type: clause.LeftJoin,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
})
} else {
joins = append(joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
}
db.Statement.Joins = nil
db.Statement.AddClause(clause.From{Joins: joins})
} else {
db.Statement.AddClauseIfNotExists(clause.From{})
}
db.Statement.AddClauseIfNotExists(clauseSelect)
db.Statement.Build(db.Statement.BuildClauses...)
}
}
func Preload(db *gorm.DB) {
if db.Error == nil && len(db.Statement.Preloads) > 0 {
preloadMap := map[string]map[string][]interface{}{}
for name := range db.Statement.Preloads {
preloadFields := strings.Split(name, ".")
if preloadFields[0] == clause.Associations {
for _, rel := range db.Statement.Schema.Relationships.Relations {
if rel.Schema == db.Statement.Schema {
if _, ok := preloadMap[rel.Name]; !ok {
preloadMap[rel.Name] = map[string][]interface{}{}
}
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
preloadMap[rel.Name][value] = db.Statement.Preloads[name]
}
}
}
} else {
if _, ok := preloadMap[preloadFields[0]]; !ok {
preloadMap[preloadFields[0]] = map[string][]interface{}{}
}
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
}
}
}
preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap {
preloadNames = append(preloadNames, key)
}
sort.Strings(preloadNames)
for _, name := range preloadNames {
if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil {
preload(db, rel, db.Statement.Preloads[name], preloadMap[name])
} else {
db.AddError(fmt.Errorf("%v: %w for schema %v", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
}
}
}
}
func AfterQuery(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(AfterFindInterface); ok {
db.AddError(i.AfterFind(tx))
return true
}
return false
})
}
}

16
callbacks/raw.go Normal file
View File

@ -0,0 +1,16 @@
package callbacks
import (
"gorm.io/gorm"
)
func RawExec(db *gorm.DB) {
if db.Error == nil && !db.DryRun {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
db.AddError(err)
} else {
db.RowsAffected, _ = result.RowsAffected()
}
}
}

21
callbacks/row.go Normal file
View File

@ -0,0 +1,21 @@
package callbacks
import (
"gorm.io/gorm"
)
func RowQuery(db *gorm.DB) {
if db.Error == nil {
BuildQuerySQL(db)
if !db.DryRun {
if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) {
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} else {
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
}
db.RowsAffected = -1
}
}
}

31
callbacks/transaction.go Normal file
View File

@ -0,0 +1,31 @@
package callbacks
import (
"gorm.io/gorm"
)
func BeginTransaction(db *gorm.DB) {
if !db.Config.SkipDefaultTransaction {
if tx := db.Begin(); tx.Error == nil {
db.Statement.ConnPool = tx.Statement.ConnPool
db.InstanceSet("gorm:started_transaction", true)
} else if tx.Error == gorm.ErrInvalidTransaction {
tx.Error = nil
} else {
db.Error = tx.Error
}
}
}
func CommitOrRollbackTransaction(db *gorm.DB) {
if !db.Config.SkipDefaultTransaction {
if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
if db.Error == nil {
db.Commit()
} else {
db.Rollback()
}
db.Statement.ConnPool = db.ConnPool
}
}
}

261
callbacks/update.go Normal file
View File

@ -0,0 +1,261 @@
package callbacks
import (
"reflect"
"sort"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
func SetupUpdateReflectValue(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest {
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
}
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if _, ok := dest[rel.Name]; ok {
rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name])
}
}
}
}
}
}
func BeforeUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.BeforeSave {
if i, ok := value.(BeforeSaveInterface); ok {
called = true
db.AddError(i.BeforeSave(tx))
}
}
if db.Statement.Schema.BeforeUpdate {
if i, ok := value.(BeforeUpdateInterface); ok {
called = true
db.AddError(i.BeforeUpdate(tx))
}
}
return called
})
}
}
func Update(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{})
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
db.Statement.AddClause(set)
} else {
return
}
db.Statement.Build(db.Statement.BuildClauses...)
}
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
db.AddError(gorm.ErrMissingWhereClause)
return
}
if !db.DryRun && db.Error == nil {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
}
}
}
func AfterUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
called = true
db.AddError(i.AfterSave(tx))
}
}
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(AfterUpdateInterface); ok {
called = true
db.AddError(i.AfterUpdate(tx))
}
}
return called
})
}
}
// ConvertToAssignments convert to update assignments
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
var (
selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
assignValue func(field *schema.Field, value interface{})
)
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
assignValue = func(field *schema.Field, value interface{}) {
for i := 0; i < stmt.ReflectValue.Len(); i++ {
field.Set(stmt.ReflectValue.Index(i), value)
}
}
case reflect.Struct:
assignValue = func(field *schema.Field, value interface{}) {
if stmt.ReflectValue.CanAddr() {
field.Set(stmt.ReflectValue, value)
}
}
default:
assignValue = func(field *schema.Field, value interface{}) {
}
}
updatingValue := reflect.ValueOf(stmt.Dest)
for updatingValue.Kind() == reflect.Ptr {
updatingValue = updatingValue.Elem()
}
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var primaryKeyExprs []clause.Expression
for i := 0; i < stmt.ReflectValue.Len(); i++ {
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
var notZero bool
for idx, field := range stmt.Schema.PrimaryFields {
value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
notZero = notZero || !isZero
}
if notZero {
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
}
}
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields {
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
}
}
switch value := updatingValue.Interface().(type) {
case map[string]interface{}:
set = make([]clause.Assignment, 0, len(value))
keys := make([]string, 0, len(value))
for k := range value {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
kv := value[k]
if _, ok := kv.(*gorm.DB); ok {
kv = []interface{}{kv}
}
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil {
if field.DBName != "" {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
assignValue(field, value[k])
}
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
assignValue(field, value[k])
}
continue
}
}
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
}
}
if !stmt.SkipHooks && stmt.Schema != nil {
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.LookUpField(dbName)
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || !ok {
now := stmt.DB.NowFunc()
assignValue(field, now)
if field.AutoUpdateTime == schema.UnixNanosecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
} else if field.AutoUpdateTime == schema.UnixMillisecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
} else if field.GORMDataType == schema.Time {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
} else {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
}
}
}
}
}
default:
switch updatingValue.Kind() {
case reflect.Struct:
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.LookUpField(dbName)
if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
value, isZero := field.ValueOf(updatingValue)
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano()
} else if field.AutoUpdateTime == schema.UnixMillisecond {
value = stmt.DB.NowFunc().UnixNano() / 1e6
} else if field.GORMDataType == schema.Time {
value = stmt.DB.NowFunc()
} else {
value = stmt.DB.NowFunc().Unix()
}
isZero = false
}
if ok || !isZero {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
assignValue(field, value)
}
}
} else {
if value, isZero := field.ValueOf(updatingValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
}
default:
stmt.AddError(gorm.ErrInvalidData)
}
}
return
}

View File

@ -1,177 +0,0 @@
package gorm_test
import (
"errors"
"github.com/jinzhu/gorm"
"reflect"
"testing"
)
func (s *Product) BeforeCreate() (err error) {
if s.Code == "Invalid" {
err = errors.New("invalid product")
}
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
return
}
func (s *Product) BeforeUpdate() (err error) {
if s.Code == "dont_update" {
err = errors.New("can't update")
}
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
return
}
func (s *Product) BeforeSave() (err error) {
if s.Code == "dont_save" {
err = errors.New("can't save")
}
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
return
}
func (s *Product) AfterFind() {
s.AfterFindCallTimes = s.AfterFindCallTimes + 1
}
func (s *Product) AfterCreate(tx *gorm.DB) {
tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1})
}
func (s *Product) AfterUpdate() {
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
}
func (s *Product) AfterSave() (err error) {
if s.Code == "after_save_error" {
err = errors.New("can't save")
}
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
return
}
func (s *Product) BeforeDelete() (err error) {
if s.Code == "dont_delete" {
err = errors.New("can't delete")
}
s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
return
}
func (s *Product) AfterDelete() (err error) {
if s.Code == "after_delete_error" {
err = errors.New("can't delete")
}
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
return
}
func (s *Product) GetCallTimes() []int64 {
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
}
func TestRunCallbacks(t *testing.T) {
p := Product{Code: "unique_code", Price: 100}
DB.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
}
DB.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes())
}
p.Price = 200
DB.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
}
var products []Product
DB.Find(&products, "code = ?", "unique_code")
if products[0].AfterFindCallTimes != 2 {
t.Errorf("AfterFind callbacks should work with slice")
}
DB.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes())
}
DB.Delete(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
}
if DB.Where("Code = ?", "unique_code").First(&p).Error == nil {
t.Errorf("Can't find a deleted record")
}
}
func TestCallbacksWithErrors(t *testing.T) {
p := Product{Code: "Invalid", Price: 100}
if DB.Save(&p).Error == nil {
t.Errorf("An error from before create callbacks happened when create with invalid value")
}
if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
t.Errorf("Should not save record that have errors")
}
if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
t.Errorf("An error from after create callbacks happened when create with invalid value")
}
p2 := Product{Code: "update_callback", Price: 100}
DB.Save(&p2)
p2.Code = "dont_update"
if DB.Save(&p2).Error == nil {
t.Errorf("An error from before update callbacks happened when update with invalid value")
}
if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback")
}
if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback")
}
p2.Code = "dont_save"
if DB.Save(&p2).Error == nil {
t.Errorf("An error from before save callbacks happened when update with invalid value")
}
p3 := Product{Code: "dont_delete", Price: 100}
DB.Save(&p3)
if DB.Delete(&p3).Error == nil {
t.Errorf("An error from before delete callbacks happened when delete")
}
if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
t.Errorf("An error from before delete callbacks happened")
}
p4 := Product{Code: "after_save_error", Price: 100}
DB.Save(&p4)
if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
t.Errorf("Record should be reverted if get an error in after save callback")
}
p5 := Product{Code: "after_delete_error", Price: 100}
DB.Save(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record should be found")
}
DB.Delete(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
}
}

293
chainable_api.go Normal file
View File

@ -0,0 +1,293 @@
package gorm
import (
"fmt"
"regexp"
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/utils"
)
// Model specify the model you would like to run db operations
// // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello")
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
// db.Model(&user).Update("name", "hello")
func (db *DB) Model(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Model = value
return
}
// Clauses Add clauses
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
tx = db.getInstance()
var whereConds []interface{}
for _, cond := range conds {
if c, ok := cond.(clause.Interface); ok {
tx.Statement.AddClause(c)
} else if optimizer, ok := cond.(StatementModifier); ok {
optimizer.ModifyStatement(tx.Statement)
} else {
whereConds = append(whereConds, cond)
}
}
if len(whereConds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)})
}
return
}
var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`)
// Table specify the table you would like to run db operations
func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 {
tx.Statement.Table = results[1]
return
}
} else if tables := strings.Split(name, "."); len(tables) == 2 {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = tables[1]
return
}
tx.Statement.Table = name
return
}
// Distinct specify distinct fields that you want querying
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Distinct = true
if len(args) > 0 {
tx = tx.Select(args[0], args[1:]...)
}
return
}
// Select specify fields that you want when querying, creating, updating
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
switch v := query.(type) {
case []string:
tx.Statement.Selects = v
for _, arg := range args {
switch arg := arg.(type) {
case string:
tx.Statement.Selects = append(tx.Statement.Selects, arg)
case []string:
tx.Statement.Selects = append(tx.Statement.Selects, arg...)
default:
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
return
}
}
delete(tx.Statement.Clauses, "SELECT")
case string:
if strings.Count(v, "?") >= len(args) && len(args) > 0 {
tx.Statement.AddClause(clause.Select{
Distinct: db.Statement.Distinct,
Expression: clause.Expr{SQL: v, Vars: args},
})
} else if strings.Count(v, "@") > 0 && len(args) > 0 {
tx.Statement.AddClause(clause.Select{
Distinct: db.Statement.Distinct,
Expression: clause.NamedExpr{SQL: v, Vars: args},
})
} else {
tx.Statement.Selects = []string{v}
for _, arg := range args {
switch arg := arg.(type) {
case string:
tx.Statement.Selects = append(tx.Statement.Selects, arg)
case []string:
tx.Statement.Selects = append(tx.Statement.Selects, arg...)
default:
tx.Statement.AddClause(clause.Select{
Distinct: db.Statement.Distinct,
Expression: clause.Expr{SQL: v, Vars: args},
})
return
}
}
delete(tx.Statement.Clauses, "SELECT")
}
default:
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
}
return
}
// Omit specify fields that you want to ignore when creating, updating and querying
func (db *DB) Omit(columns ...string) (tx *DB) {
tx = db.getInstance()
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
} else {
tx.Statement.Omits = columns
}
return
}
// Where add conditions
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: conds})
}
return
}
// Not add NOT conditions
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}})
}
return
}
// Or add OR conditions
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}})
}
return
}
// Joins specify Joins conditions
// db.Joins("Account").Find(&user)
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
return
}
// Group specify the group method on the find
func (db *DB) Group(name string) (tx *DB) {
tx = db.getInstance()
fields := strings.FieldsFunc(name, utils.IsValidDBNameChar)
tx.Statement.AddClause(clause.GroupBy{
Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}},
})
return
}
// Having specify HAVING conditions for GROUP BY
func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.GroupBy{
Having: tx.Statement.BuildCondition(query, args...),
})
return
}
// Order specify order when retrieve records from database
// db.Order("name DESC")
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
func (db *DB) Order(value interface{}) (tx *DB) {
tx = db.getInstance()
switch v := value.(type) {
case clause.OrderByColumn:
tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{v},
})
default:
tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{{
Column: clause.Column{Name: fmt.Sprint(value), Raw: true},
}},
})
}
return
}
// Limit specify the number of records to be retrieved
func (db *DB) Limit(limit int) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Limit: limit})
return
}
// Offset specify the number of records to skip before starting to return the records
func (db *DB) Offset(offset int) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Offset: offset})
return
}
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000)
// }
//
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
// return func (db *gorm.DB) *gorm.DB {
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
// }
// }
//
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
tx = db.getInstance()
tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
return tx
}
// Preload preload associations with given conditions
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if tx.Statement.Preloads == nil {
tx.Statement.Preloads = map[string][]interface{}{}
}
tx.Statement.Preloads[query] = args
return
}
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.attrs = attrs
return
}
func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.assigns = attrs
return
}
func (db *DB) Unscoped() (tx *DB) {
tx = db.getInstance()
tx.Statement.Unscoped = true
return
}
func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.SQL = strings.Builder{}
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
} else {
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
}
return
}

56
clause/benchmarks_test.go Normal file
View File

@ -0,0 +1,56 @@
package clause_test
import (
"sync"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
func BenchmarkSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clauses := []clause.Interface{clause.Select{}, clause.From{}, clause.Where{Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}}}
for _, clause := range clauses {
stmt.AddClause(clause)
}
stmt.Build("SELECT", "FROM", "WHERE")
_ = stmt.SQL.String()
}
}
func BenchmarkComplexSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clauses := []clause.Interface{
clause.Select{}, clause.From{},
clause.Where{Exprs: []clause.Expression{
clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.Gt{Column: "age", Value: 18},
clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}),
}},
clause.Where{Exprs: []clause.Expression{
clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}),
}},
clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}},
clause.Limit{Limit: 10, Offset: 20},
clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}},
}
for _, clause := range clauses {
stmt.AddClause(clause)
}
stmt.Build("SELECT", "FROM", "WHERE", "GROUP BY", "LIMIT", "ORDER BY")
_ = stmt.SQL.String()
}
}

88
clause/clause.go Normal file
View File

@ -0,0 +1,88 @@
package clause
// Interface clause interface
type Interface interface {
Name() string
Build(Builder)
MergeClause(*Clause)
}
// ClauseBuilder clause builder, allows to customize how to build clause
type ClauseBuilder func(Clause, Builder)
type Writer interface {
WriteByte(byte) error
WriteString(string) (int, error)
}
// Builder builder interface
type Builder interface {
Writer
WriteQuoted(field interface{})
AddVar(Writer, ...interface{})
}
// Clause
type Clause struct {
Name string // WHERE
BeforeExpression Expression
AfterNameExpression Expression
AfterExpression Expression
Expression Expression
Builder ClauseBuilder
}
// Build build clause
func (c Clause) Build(builder Builder) {
if c.Builder != nil {
c.Builder(c, builder)
} else if c.Expression != nil {
if c.BeforeExpression != nil {
c.BeforeExpression.Build(builder)
builder.WriteByte(' ')
}
if c.Name != "" {
builder.WriteString(c.Name)
builder.WriteByte(' ')
}
if c.AfterNameExpression != nil {
c.AfterNameExpression.Build(builder)
builder.WriteByte(' ')
}
c.Expression.Build(builder)
if c.AfterExpression != nil {
builder.WriteByte(' ')
c.AfterExpression.Build(builder)
}
}
}
const (
PrimaryKey string = "~~~py~~~" // primary key
CurrentTable string = "~~~ct~~~" // current table
Associations string = "~~~as~~~" // associations
)
var (
currentTable = Table{Name: CurrentTable}
PrimaryColumn = Column{Table: CurrentTable, Name: PrimaryKey}
)
// Column quote with name
type Column struct {
Table string
Name string
Alias string
Raw bool
}
// Table quote with name
type Table struct {
Name string
Alias string
Raw bool
}

43
clause/clause_test.go Normal file
View File

@ -0,0 +1,43 @@
package clause_test
import (
"reflect"
"strings"
"sync"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
var db, _ = gorm.Open(tests.DummyDialector{}, nil)
func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, vars []interface{}) {
var (
buildNames []string
buildNamesMap = map[string]bool{}
user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
)
for _, c := range clauses {
if _, ok := buildNamesMap[c.Name()]; !ok {
buildNames = append(buildNames, c.Name())
buildNamesMap[c.Name()] = true
}
stmt.AddClause(c)
}
stmt.Build(buildNames...)
if strings.TrimSpace(stmt.SQL.String()) != result {
t.Errorf("SQL expects %v got %v", result, stmt.SQL.String())
}
if !reflect.DeepEqual(stmt.Vars, vars) {
t.Errorf("Vars expects %+v got %v", stmt.Vars, vars)
}
}

23
clause/delete.go Normal file
View File

@ -0,0 +1,23 @@
package clause
type Delete struct {
Modifier string
}
func (d Delete) Name() string {
return "DELETE"
}
func (d Delete) Build(builder Builder) {
builder.WriteString("DELETE")
if d.Modifier != "" {
builder.WriteByte(' ')
builder.WriteString(d.Modifier)
}
}
func (d Delete) MergeClause(clause *Clause) {
clause.Name = ""
clause.Expression = d
}

31
clause/delete_test.go Normal file
View File

@ -0,0 +1,31 @@
package clause_test
import (
"fmt"
"testing"
"gorm.io/gorm/clause"
)
func TestDelete(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Delete{}, clause.From{}},
"DELETE FROM `users`", nil,
},
{
[]clause.Interface{clause.Delete{Modifier: "LOW_PRIORITY"}, clause.From{}},
"DELETE LOW_PRIORITY FROM `users`", nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

344
clause/expression.go Normal file
View File

@ -0,0 +1,344 @@
package clause
import (
"database/sql"
"database/sql/driver"
"go/ast"
"reflect"
)
// Expression expression interface
type Expression interface {
Build(builder Builder)
}
// NegationExpressionBuilder negation expression builder
type NegationExpressionBuilder interface {
NegationBuild(builder Builder)
}
// Expr raw expression
type Expr struct {
SQL string
Vars []interface{}
WithoutParentheses bool
}
// Build build raw expression
func (expr Expr) Build(builder Builder) {
var (
afterParenthesis bool
idx int
)
for _, v := range []byte(expr.SQL) {
if v == '?' && len(expr.Vars) > idx {
if afterParenthesis || expr.WithoutParentheses {
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
builder.AddVar(builder, expr.Vars[idx])
} else {
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
builder.AddVar(builder, nil)
} else {
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
}
default:
builder.AddVar(builder, expr.Vars[idx])
}
}
} else {
builder.AddVar(builder, expr.Vars[idx])
}
idx++
} else {
if v == '(' {
afterParenthesis = true
} else {
afterParenthesis = false
}
builder.WriteByte(v)
}
}
}
// NamedExpr raw expression for named expr
type NamedExpr struct {
SQL string
Vars []interface{}
}
// Build build raw expression
func (expr NamedExpr) Build(builder Builder) {
var (
idx int
inName bool
afterParenthesis bool
namedMap = make(map[string]interface{}, len(expr.Vars))
)
for _, v := range expr.Vars {
switch value := v.(type) {
case sql.NamedArg:
namedMap[value.Name] = value.Value
case map[string]interface{}:
for k, v := range value {
namedMap[k] = v
}
default:
var appendFieldsToMap func(reflect.Value)
appendFieldsToMap = func(reflectValue reflect.Value) {
reflectValue = reflect.Indirect(reflectValue)
switch reflectValue.Kind() {
case reflect.Struct:
modelType := reflectValue.Type()
for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface()
if fieldStruct.Anonymous {
appendFieldsToMap(reflectValue.Field(i))
}
}
}
}
}
appendFieldsToMap(reflect.ValueOf(value))
}
}
name := make([]byte, 0, 10)
for _, v := range []byte(expr.SQL) {
if v == '@' && !inName {
inName = true
name = []byte{}
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' {
if inName {
if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv)
} else {
builder.WriteByte('@')
builder.WriteString(string(name))
}
inName = false
}
afterParenthesis = false
builder.WriteByte(v)
} else if v == '?' && len(expr.Vars) > idx {
if afterParenthesis {
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
builder.AddVar(builder, expr.Vars[idx])
} else {
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
builder.AddVar(builder, nil)
} else {
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
}
default:
builder.AddVar(builder, expr.Vars[idx])
}
}
} else {
builder.AddVar(builder, expr.Vars[idx])
}
idx++
} else if inName {
name = append(name, v)
} else {
if v == '(' {
afterParenthesis = true
} else {
afterParenthesis = false
}
builder.WriteByte(v)
}
}
if inName {
builder.AddVar(builder, namedMap[string(name)])
}
}
// IN Whether a value is within a set of values
type IN struct {
Column interface{}
Values []interface{}
}
func (in IN) Build(builder Builder) {
builder.WriteQuoted(in.Column)
switch len(in.Values) {
case 0:
builder.WriteString(" IN (NULL)")
case 1:
if _, ok := in.Values[0].([]interface{}); !ok {
builder.WriteString(" = ")
builder.AddVar(builder, in.Values[0])
break
}
fallthrough
default:
builder.WriteString(" IN (")
builder.AddVar(builder, in.Values...)
builder.WriteByte(')')
}
}
func (in IN) NegationBuild(builder Builder) {
switch len(in.Values) {
case 0:
case 1:
if _, ok := in.Values[0].([]interface{}); !ok {
builder.WriteQuoted(in.Column)
builder.WriteString(" <> ")
builder.AddVar(builder, in.Values[0])
break
}
fallthrough
default:
builder.WriteQuoted(in.Column)
builder.WriteString(" NOT IN (")
builder.AddVar(builder, in.Values...)
builder.WriteByte(')')
}
}
// Eq equal to for where
type Eq struct {
Column interface{}
Value interface{}
}
func (eq Eq) Build(builder Builder) {
builder.WriteQuoted(eq.Column)
if eqNil(eq.Value) {
builder.WriteString(" IS NULL")
} else {
builder.WriteString(" = ")
builder.AddVar(builder, eq.Value)
}
}
func (eq Eq) NegationBuild(builder Builder) {
Neq(eq).Build(builder)
}
// Neq not equal to for where
type Neq Eq
func (neq Neq) Build(builder Builder) {
builder.WriteQuoted(neq.Column)
if eqNil(neq.Value) {
builder.WriteString(" IS NOT NULL")
} else {
builder.WriteString(" <> ")
builder.AddVar(builder, neq.Value)
}
}
func (neq Neq) NegationBuild(builder Builder) {
Eq(neq).Build(builder)
}
// Gt greater than for where
type Gt Eq
func (gt Gt) Build(builder Builder) {
builder.WriteQuoted(gt.Column)
builder.WriteString(" > ")
builder.AddVar(builder, gt.Value)
}
func (gt Gt) NegationBuild(builder Builder) {
Lte(gt).Build(builder)
}
// Gte greater than or equal to for where
type Gte Eq
func (gte Gte) Build(builder Builder) {
builder.WriteQuoted(gte.Column)
builder.WriteString(" >= ")
builder.AddVar(builder, gte.Value)
}
func (gte Gte) NegationBuild(builder Builder) {
Lt(gte).Build(builder)
}
// Lt less than for where
type Lt Eq
func (lt Lt) Build(builder Builder) {
builder.WriteQuoted(lt.Column)
builder.WriteString(" < ")
builder.AddVar(builder, lt.Value)
}
func (lt Lt) NegationBuild(builder Builder) {
Gte(lt).Build(builder)
}
// Lte less than or equal to for where
type Lte Eq
func (lte Lte) Build(builder Builder) {
builder.WriteQuoted(lte.Column)
builder.WriteString(" <= ")
builder.AddVar(builder, lte.Value)
}
func (lte Lte) NegationBuild(builder Builder) {
Gt(lte).Build(builder)
}
// Like whether string matches regular expression
type Like Eq
func (like Like) Build(builder Builder) {
builder.WriteQuoted(like.Column)
builder.WriteString(" LIKE ")
builder.AddVar(builder, like.Value)
}
func (like Like) NegationBuild(builder Builder) {
builder.WriteQuoted(like.Column)
builder.WriteString(" NOT LIKE ")
builder.AddVar(builder, like.Value)
}
func eqNil(value interface{}) bool {
if valuer, ok := value.(driver.Valuer); ok {
value, _ = valuer.Value()
}
return value == nil || eqNilReflect(value)
}
func eqNilReflect(value interface{}) bool {
reflectValue := reflect.ValueOf(value)
return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil()
}

153
clause/expression_test.go Normal file
View File

@ -0,0 +1,153 @@
package clause_test
import (
"database/sql"
"fmt"
"reflect"
"sync"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
func TestExpr(t *testing.T) {
results := []struct {
SQL string
Result string
Vars []interface{}
}{{
SQL: "create table ? (? ?, ? ?)",
Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}},
Result: "create table `users` (`id` int, `name` text)",
}}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
if stmt.SQL.String() != result.Result {
t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
}
})
}
}
func TestNamedExpr(t *testing.T) {
type Base struct {
Name2 string
}
type NamedArgument struct {
Name1 string
Base
}
results := []struct {
SQL string
Result string
Vars []interface{}
ExpectedVars []interface{}
}{{
SQL: "create table ? (? ?, ? ?)",
Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}},
Result: "create table `users` (`id` int, `name` text)",
}, {
SQL: "name1 = @name AND name2 = @name",
Vars: []interface{}{sql.Named("name", "jinzhu")},
Result: "name1 = ? AND name2 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1",
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
Result: "name1 = ? AND name2 = ? AND name3 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
}, {
SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1",
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu2"}},
Result: "name1 = ? AND name2 = ? AND name3 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
}, {
SQL: "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist",
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
}, {
SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @Notexist",
Vars: []interface{}{NamedArgument{Name1: "jinzhu", Base: Base{Name2: "jinzhu2"}}},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
}, {
SQL: "create table ? (? ?, ? ?)",
Vars: []interface{}{},
Result: "create table ? (? ?, ? ?)",
}}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clause.NamedExpr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
if stmt.SQL.String() != result.Result {
t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
}
if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) {
t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars)
}
})
}
}
func TestExpression(t *testing.T) {
column := "column-name"
results := []struct {
Expressions []clause.Expression
Result string
}{{
Expressions: []clause.Expression{
clause.Eq{Column: column, Value: "column-value"},
},
Result: "`column-name` = ?",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: column, Value: nil},
clause.Eq{Column: column, Value: (*string)(nil)},
clause.Eq{Column: column, Value: (*int)(nil)},
clause.Eq{Column: column, Value: (*bool)(nil)},
clause.Eq{Column: column, Value: (interface{})(nil)},
clause.Eq{Column: column, Value: sql.NullString{String: "", Valid: false}},
},
Result: "`column-name` IS NULL",
}, {
Expressions: []clause.Expression{
clause.Neq{Column: column, Value: "column-value"},
},
Result: "`column-name` <> ?",
}, {
Expressions: []clause.Expression{
clause.Neq{Column: column, Value: nil},
clause.Neq{Column: column, Value: (*string)(nil)},
clause.Neq{Column: column, Value: (*int)(nil)},
clause.Neq{Column: column, Value: (*bool)(nil)},
clause.Neq{Column: column, Value: (interface{})(nil)},
},
Result: "`column-name` IS NOT NULL",
}}
for idx, result := range results {
for idy, expression := range result.Expressions {
t.Run(fmt.Sprintf("case #%v.%v", idx, idy), func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
expression.Build(stmt)
if stmt.SQL.String() != result.Result {
t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
}
})
}
}
}

37
clause/from.go Normal file
View File

@ -0,0 +1,37 @@
package clause
// From from clause
type From struct {
Tables []Table
Joins []Join
}
// Name from clause name
func (from From) Name() string {
return "FROM"
}
// Build build from clause
func (from From) Build(builder Builder) {
if len(from.Tables) > 0 {
for idx, table := range from.Tables {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(table)
}
} else {
builder.WriteQuoted(currentTable)
}
for _, join := range from.Joins {
builder.WriteByte(' ')
join.Build(builder)
}
}
// MergeClause merge from clause
func (from From) MergeClause(clause *Clause) {
clause.Expression = from
}

75
clause/from_test.go Normal file
View File

@ -0,0 +1,75 @@
package clause_test
import (
"fmt"
"testing"
"gorm.io/gorm/clause"
)
func TestFrom(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}},
"SELECT * FROM `users`", nil,
},
{
[]clause.Interface{
clause.Select{}, clause.From{
Tables: []clause.Table{{Name: "users"}},
Joins: []clause.Join{
{
Type: clause.InnerJoin,
Table: clause.Table{Name: "articles"},
ON: clause.Where{
[]clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}},
},
},
},
},
},
"SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id`", nil,
},
{
[]clause.Interface{
clause.Select{}, clause.From{
Tables: []clause.Table{{Name: "users"}},
Joins: []clause.Join{
{
Type: clause.RightJoin,
Table: clause.Table{Name: "profiles"},
ON: clause.Where{
[]clause.Expression{clause.Eq{clause.Column{Table: "profiles", Name: "email"}, clause.Column{Table: clause.CurrentTable, Name: "email"}}},
},
},
},
}, clause.From{
Joins: []clause.Join{
{
Type: clause.InnerJoin,
Table: clause.Table{Name: "articles"},
ON: clause.Where{
[]clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}},
},
}, {
Type: clause.LeftJoin,
Table: clause.Table{Name: "companies"},
Using: []string{"company_name"},
},
},
},
},
"SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id` LEFT JOIN `companies` USING (`company_name`)", nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

48
clause/group_by.go Normal file
View File

@ -0,0 +1,48 @@
package clause
// GroupBy group by clause
type GroupBy struct {
Columns []Column
Having []Expression
}
// Name from clause name
func (groupBy GroupBy) Name() string {
return "GROUP BY"
}
// Build build group by clause
func (groupBy GroupBy) Build(builder Builder) {
for idx, column := range groupBy.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
if len(groupBy.Having) > 0 {
builder.WriteString(" HAVING ")
Where{Exprs: groupBy.Having}.Build(builder)
}
}
// MergeClause merge group by clause
func (groupBy GroupBy) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(GroupBy); ok {
copiedColumns := make([]Column, len(v.Columns))
copy(copiedColumns, v.Columns)
groupBy.Columns = append(copiedColumns, groupBy.Columns...)
copiedHaving := make([]Expression, len(v.Having))
copy(copiedHaving, v.Having)
groupBy.Having = append(copiedHaving, groupBy.Having...)
}
clause.Expression = groupBy
if len(groupBy.Columns) == 0 {
clause.Name = ""
} else {
clause.Name = groupBy.Name()
}
}

40
clause/group_by_test.go Normal file
View File

@ -0,0 +1,40 @@
package clause_test
import (
"fmt"
"testing"
"gorm.io/gorm/clause"
)
func TestGroupBy(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{
Columns: []clause.Column{{Name: "role"}},
Having: []clause.Expression{clause.Eq{"role", "admin"}},
}},
"SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", []interface{}{"admin"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{
Columns: []clause.Column{{Name: "role"}},
Having: []clause.Expression{clause.Eq{"role", "admin"}},
}, clause.GroupBy{
Columns: []clause.Column{{Name: "gender"}},
Having: []clause.Expression{clause.Neq{"gender", "U"}},
}},
"SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", []interface{}{"admin", "U"},
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

39
clause/insert.go Normal file
View File

@ -0,0 +1,39 @@
package clause
type Insert struct {
Table Table
Modifier string
}
// Name insert clause name
func (insert Insert) Name() string {
return "INSERT"
}
// Build build insert clause
func (insert Insert) Build(builder Builder) {
if insert.Modifier != "" {
builder.WriteString(insert.Modifier)
builder.WriteByte(' ')
}
builder.WriteString("INTO ")
if insert.Table.Name == "" {
builder.WriteQuoted(currentTable)
} else {
builder.WriteQuoted(insert.Table)
}
}
// MergeClause merge insert clause
func (insert Insert) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Insert); ok {
if insert.Modifier == "" {
insert.Modifier = v.Modifier
}
if insert.Table.Name == "" {
insert.Table = v.Table
}
}
clause.Expression = insert
}

35
clause/insert_test.go Normal file
View File

@ -0,0 +1,35 @@
package clause_test
import (
"fmt"
"testing"
"gorm.io/gorm/clause"
)
func TestInsert(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Insert{}},
"INSERT INTO `users`", nil,
},
{
[]clause.Interface{clause.Insert{Modifier: "LOW_PRIORITY"}},
"INSERT LOW_PRIORITY INTO `users`", nil,
},
{
[]clause.Interface{clause.Insert{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}},
"INSERT LOW_PRIORITY INTO `products`", nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

47
clause/joins.go Normal file
View File

@ -0,0 +1,47 @@
package clause
type JoinType string
const (
CrossJoin JoinType = "CROSS"
InnerJoin JoinType = "INNER"
LeftJoin JoinType = "LEFT"
RightJoin JoinType = "RIGHT"
)
// Join join clause for from
type Join struct {
Type JoinType
Table Table
ON Where
Using []string
Expression Expression
}
func (join Join) Build(builder Builder) {
if join.Expression != nil {
join.Expression.Build(builder)
} else {
if join.Type != "" {
builder.WriteString(string(join.Type))
builder.WriteByte(' ')
}
builder.WriteString("JOIN ")
builder.WriteQuoted(join.Table)
if len(join.ON.Exprs) > 0 {
builder.WriteString(" ON ")
join.ON.Build(builder)
} else if len(join.Using) > 0 {
builder.WriteString(" USING (")
for idx, c := range join.Using {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(c)
}
builder.WriteByte(')')
}
}
}

48
clause/limit.go Normal file
View File

@ -0,0 +1,48 @@
package clause
import "strconv"
// Limit limit clause
type Limit struct {
Limit int
Offset int
}
// Name where clause name
func (limit Limit) Name() string {
return "LIMIT"
}
// Build build where clause
func (limit Limit) Build(builder Builder) {
if limit.Limit > 0 {
builder.WriteString("LIMIT ")
builder.WriteString(strconv.Itoa(limit.Limit))
}
if limit.Offset > 0 {
if limit.Limit > 0 {
builder.WriteString(" ")
}
builder.WriteString("OFFSET ")
builder.WriteString(strconv.Itoa(limit.Offset))
}
}
// MergeClause merge order by clauses
func (limit Limit) MergeClause(clause *Clause) {
clause.Name = ""
if v, ok := clause.Expression.(Limit); ok {
if limit.Limit == 0 && v.Limit != 0 {
limit.Limit = v.Limit
}
if limit.Offset == 0 && v.Offset > 0 {
limit.Offset = v.Offset
} else if limit.Offset < 0 {
limit.Offset = 0
}
}
clause.Expression = limit
}

58
clause/limit_test.go Normal file
View File

@ -0,0 +1,58 @@
package clause_test
import (
"fmt"
"testing"
"gorm.io/gorm/clause"
)
func TestLimit(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{
Limit: 10,
Offset: 20,
}},
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
"SELECT * FROM `users` OFFSET 20", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` OFFSET 30", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}},
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` LIMIT 10 OFFSET 30", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
"SELECT * FROM `users` LIMIT 10", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}},
"SELECT * FROM `users` OFFSET 30", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}},
"SELECT * FROM `users` LIMIT 50 OFFSET 30", nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

31
clause/locking.go Normal file
View File

@ -0,0 +1,31 @@
package clause
type Locking struct {
Strength string
Table Table
Options string
}
// Name where clause name
func (locking Locking) Name() string {
return "FOR"
}
// Build build where clause
func (locking Locking) Build(builder Builder) {
builder.WriteString(locking.Strength)
if locking.Table.Name != "" {
builder.WriteString(" OF ")
builder.WriteQuoted(locking.Table)
}
if locking.Options != "" {
builder.WriteByte(' ')
builder.WriteString(locking.Options)
}
}
// MergeClause merge order by clauses
func (locking Locking) MergeClause(clause *Clause) {
clause.Expression = locking
}

35
clause/locking_test.go Normal file
View File

@ -0,0 +1,35 @@
package clause_test
import (
"fmt"
"testing"
"gorm.io/gorm/clause"
)
func TestLocking(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}},
"SELECT * FROM `users` FOR UPDATE", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}},
"SELECT * FROM `users` FOR SHARE OF `users`", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}, clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}},
"SELECT * FROM `users` FOR UPDATE NOWAIT", nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

52
clause/on_conflict.go Normal file
View File

@ -0,0 +1,52 @@
package clause
type OnConflict struct {
Columns []Column
Where Where
OnConstraint string
DoNothing bool
DoUpdates Set
UpdateAll bool
}
func (OnConflict) Name() string {
return "ON CONFLICT"
}
// Build build onConflict clause
func (onConflict OnConflict) Build(builder Builder) {
if len(onConflict.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range onConflict.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
builder.WriteString(`) `)
}
if onConflict.OnConstraint != "" {
builder.WriteString("ON CONSTRAINT ")
builder.WriteString(onConflict.OnConstraint)
builder.WriteByte(' ')
}
if onConflict.DoNothing {
builder.WriteString("DO NOTHING")
} else {
builder.WriteString("DO UPDATE SET ")
onConflict.DoUpdates.Build(builder)
}
if len(onConflict.Where.Exprs) > 0 {
builder.WriteString(" WHERE ")
onConflict.Where.Build(builder)
builder.WriteByte(' ')
}
}
// MergeClause merge onConflict clauses
func (onConflict OnConflict) MergeClause(clause *Clause) {
clause.Expression = onConflict
}

54
clause/order_by.go Normal file
View File

@ -0,0 +1,54 @@
package clause
type OrderByColumn struct {
Column Column
Desc bool
Reorder bool
}
type OrderBy struct {
Columns []OrderByColumn
Expression Expression
}
// Name where clause name
func (orderBy OrderBy) Name() string {
return "ORDER BY"
}
// Build build where clause
func (orderBy OrderBy) Build(builder Builder) {
if orderBy.Expression != nil {
orderBy.Expression.Build(builder)
} else {
for idx, column := range orderBy.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column.Column)
if column.Desc {
builder.WriteString(" DESC")
}
}
}
}
// MergeClause merge order by clauses
func (orderBy OrderBy) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(OrderBy); ok {
for i := len(orderBy.Columns) - 1; i >= 0; i-- {
if orderBy.Columns[i].Reorder {
orderBy.Columns = orderBy.Columns[i:]
clause.Expression = orderBy
return
}
}
copiedColumns := make([]OrderByColumn, len(v.Columns))
copy(copiedColumns, v.Columns)
orderBy.Columns = append(copiedColumns, orderBy.Columns...)
}
clause.Expression = orderBy
}

57
clause/order_by_test.go Normal file
View File

@ -0,0 +1,57 @@
package clause_test
import (
"fmt"
"testing"
"gorm.io/gorm/clause"
)
func TestOrderBy(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
}},
"SELECT * FROM `users` ORDER BY `users`.`id` DESC", nil,
},
{
[]clause.Interface{
clause.Select{}, clause.From{}, clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
}, clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}}},
},
},
"SELECT * FROM `users` ORDER BY `users`.`id` DESC,`name`", nil,
},
{
[]clause.Interface{
clause.Select{}, clause.From{}, clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
}, clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}, Reorder: true}},
},
},
"SELECT * FROM `users` ORDER BY `name`", nil,
},
{
[]clause.Interface{
clause.Select{}, clause.From{}, clause.OrderBy{
Expression: clause.Expr{SQL: "FIELD(id, ?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true},
},
},
"SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)", []interface{}{1, 2, 3},
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

30
clause/returning.go Normal file
View File

@ -0,0 +1,30 @@
package clause
type Returning struct {
Columns []Column
}
// Name where clause name
func (returning Returning) Name() string {
return "RETURNING"
}
// Build build where clause
func (returning Returning) Build(builder Builder) {
for idx, column := range returning.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
}
// MergeClause merge order by clauses
func (returning Returning) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Returning); ok {
returning.Columns = append(v.Columns, returning.Columns...)
}
clause.Expression = returning
}

36
clause/returning_test.go Normal file
View File

@ -0,0 +1,36 @@
package clause_test
import (
"fmt"
"testing"
"gorm.io/gorm/clause"
)
func TestReturning(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
[]clause.Column{clause.PrimaryColumn},
}},
"SELECT * FROM `users` RETURNING `users`.`id`", nil,
}, {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
[]clause.Column{clause.PrimaryColumn},
}, clause.Returning{
[]clause.Column{{Name: "name"}, {Name: "age"}},
}},
"SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

45
clause/select.go Normal file
View File

@ -0,0 +1,45 @@
package clause
// Select select attrs when querying, updating, creating
type Select struct {
Distinct bool
Columns []Column
Expression Expression
}
func (s Select) Name() string {
return "SELECT"
}
func (s Select) Build(builder Builder) {
if len(s.Columns) > 0 {
if s.Distinct {
builder.WriteString("DISTINCT ")
}
for idx, column := range s.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
} else {
builder.WriteByte('*')
}
}
func (s Select) MergeClause(clause *Clause) {
if s.Expression != nil {
if s.Distinct {
if expr, ok := s.Expression.(Expr); ok {
expr.SQL = "DISTINCT " + expr.SQL
clause.Expression = expr
return
}
}
clause.Expression = s.Expression
} else {
clause.Expression = s
}
}

41
clause/select_test.go Normal file
View File

@ -0,0 +1,41 @@
package clause_test
import (
"fmt"
"testing"
"gorm.io/gorm/clause"
)
func TestSelect(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}},
"SELECT * FROM `users`", nil,
},
{
[]clause.Interface{clause.Select{
Columns: []clause.Column{clause.PrimaryColumn},
}, clause.From{}},
"SELECT `users`.`id` FROM `users`", nil,
},
{
[]clause.Interface{clause.Select{
Columns: []clause.Column{clause.PrimaryColumn},
}, clause.Select{
Columns: []clause.Column{{Name: "name"}},
}, clause.From{}},
"SELECT `name` FROM `users`", nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

60
clause/set.go Normal file
View File

@ -0,0 +1,60 @@
package clause
import "sort"
type Set []Assignment
type Assignment struct {
Column Column
Value interface{}
}
func (set Set) Name() string {
return "SET"
}
func (set Set) Build(builder Builder) {
if len(set) > 0 {
for idx, assignment := range set {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(assignment.Column)
builder.WriteByte('=')
builder.AddVar(builder, assignment.Value)
}
} else {
builder.WriteQuoted(PrimaryColumn)
builder.WriteByte('=')
builder.WriteQuoted(PrimaryColumn)
}
}
// MergeClause merge assignments clauses
func (set Set) MergeClause(clause *Clause) {
copiedAssignments := make([]Assignment, len(set))
copy(copiedAssignments, set)
clause.Expression = Set(copiedAssignments)
}
func Assignments(values map[string]interface{}) Set {
keys := make([]string, 0, len(values))
for key := range values {
keys = append(keys, key)
}
sort.Strings(keys)
assignments := make([]Assignment, len(keys))
for idx, key := range keys {
assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]}
}
return assignments
}
func AssignmentColumns(values []string) Set {
assignments := make([]Assignment, len(values))
for idx, value := range values {
assignments[idx] = Assignment{Column: Column{Name: value}, Value: Column{Table: "excluded", Name: value}}
}
return assignments
}

57
clause/set_test.go Normal file
View File

@ -0,0 +1,57 @@
package clause_test
import (
"fmt"
"sort"
"strings"
"testing"
"gorm.io/gorm/clause"
)
func TestSet(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{
clause.Update{},
clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}),
},
"UPDATE `users` SET `users`.`id`=?", []interface{}{1},
},
{
[]clause.Interface{
clause.Update{},
clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}),
clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}),
},
"UPDATE `users` SET `name`=?", []interface{}{"jinzhu"},
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}
func TestAssignments(t *testing.T) {
set := clause.Assignments(map[string]interface{}{
"name": "jinzhu",
"age": 18,
})
assignments := []clause.Assignment(set)
sort.Slice(assignments, func(i, j int) bool {
return strings.Compare(assignments[i].Column.Name, assignments[j].Column.Name) > 0
})
if len(assignments) != 2 || assignments[0].Column.Name != "name" || assignments[0].Value.(string) != "jinzhu" || assignments[1].Column.Name != "age" || assignments[1].Value.(int) != 18 {
t.Errorf("invalid assignments, got %v", assignments)
}
}

38
clause/update.go Normal file
View File

@ -0,0 +1,38 @@
package clause
type Update struct {
Modifier string
Table Table
}
// Name update clause name
func (update Update) Name() string {
return "UPDATE"
}
// Build build update clause
func (update Update) Build(builder Builder) {
if update.Modifier != "" {
builder.WriteString(update.Modifier)
builder.WriteByte(' ')
}
if update.Table.Name == "" {
builder.WriteQuoted(currentTable)
} else {
builder.WriteQuoted(update.Table)
}
}
// MergeClause merge update clause
func (update Update) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Update); ok {
if update.Modifier == "" {
update.Modifier = v.Modifier
}
if update.Table.Name == "" {
update.Table = v.Table
}
}
clause.Expression = update
}

35
clause/update_test.go Normal file
View File

@ -0,0 +1,35 @@
package clause_test
import (
"fmt"
"testing"
"gorm.io/gorm/clause"
)
func TestUpdate(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Update{}},
"UPDATE `users`", nil,
},
{
[]clause.Interface{clause.Update{Modifier: "LOW_PRIORITY"}},
"UPDATE LOW_PRIORITY `users`", nil,
},
{
[]clause.Interface{clause.Update{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}},
"UPDATE LOW_PRIORITY `products`", nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

45
clause/values.go Normal file
View File

@ -0,0 +1,45 @@
package clause
type Values struct {
Columns []Column
Values [][]interface{}
}
// Name from clause name
func (Values) Name() string {
return "VALUES"
}
// Build build from clause
func (values Values) Build(builder Builder) {
if len(values.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range values.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
builder.WriteByte(')')
builder.WriteString(" VALUES ")
for idx, value := range values.Values {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteByte('(')
builder.AddVar(builder, value...)
builder.WriteByte(')')
}
} else {
builder.WriteString("DEFAULT VALUES")
}
}
// MergeClause merge values clauses
func (values Values) MergeClause(clause *Clause) {
clause.Name = ""
clause.Expression = values
}

33
clause/values_test.go Normal file
View File

@ -0,0 +1,33 @@
package clause_test
import (
"fmt"
"testing"
"gorm.io/gorm/clause"
)
func TestValues(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{
clause.Insert{},
clause.Values{
Columns: []clause.Column{{Name: "name"}, {Name: "age"}},
Values: [][]interface{}{{"jinzhu", 18}, {"josh", 1}},
},
},
"INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", []interface{}{"jinzhu", 18, "josh", 1},
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

177
clause/where.go Normal file
View File

@ -0,0 +1,177 @@
package clause
import (
"strings"
)
// Where where clause
type Where struct {
Exprs []Expression
}
// Name where clause name
func (where Where) Name() string {
return "WHERE"
}
// Build build where clause
func (where Where) Build(builder Builder) {
// Switch position if the first query expression is a single Or condition
for idx, expr := range where.Exprs {
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
if idx != 0 {
where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0]
}
break
}
}
buildExprs(where.Exprs, builder, " AND ")
}
func buildExprs(exprs []Expression, builder Builder, joinCond string) {
wrapInParentheses := false
for idx, expr := range exprs {
if idx > 0 {
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
builder.WriteString(" OR ")
} else {
builder.WriteString(joinCond)
}
}
if len(exprs) > 1 {
switch v := expr.(type) {
case OrConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToLower(e.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
}
}
case AndConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToLower(e.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
}
}
case Expr:
sql := strings.ToLower(v.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
}
}
if wrapInParentheses {
builder.WriteString(`(`)
expr.Build(builder)
builder.WriteString(`)`)
wrapInParentheses = false
} else {
expr.Build(builder)
}
}
}
// MergeClause merge where clauses
func (where Where) MergeClause(clause *Clause) {
if w, ok := clause.Expression.(Where); ok {
exprs := make([]Expression, len(w.Exprs)+len(where.Exprs))
copy(exprs, w.Exprs)
copy(exprs[len(w.Exprs):], where.Exprs)
where.Exprs = exprs
}
clause.Expression = where
}
func And(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
} else if len(exprs) == 1 {
return exprs[0]
}
return AndConditions{Exprs: exprs}
}
type AndConditions struct {
Exprs []Expression
}
func (and AndConditions) Build(builder Builder) {
if len(and.Exprs) > 1 {
builder.WriteByte('(')
buildExprs(and.Exprs, builder, " AND ")
builder.WriteByte(')')
} else {
buildExprs(and.Exprs, builder, " AND ")
}
}
func Or(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
}
return OrConditions{Exprs: exprs}
}
type OrConditions struct {
Exprs []Expression
}
func (or OrConditions) Build(builder Builder) {
if len(or.Exprs) > 1 {
builder.WriteByte('(')
buildExprs(or.Exprs, builder, " OR ")
builder.WriteByte(')')
} else {
buildExprs(or.Exprs, builder, " OR ")
}
}
func Not(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
}
return NotConditions{Exprs: exprs}
}
type NotConditions struct {
Exprs []Expression
}
func (not NotConditions) Build(builder Builder) {
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}
for idx, c := range not.Exprs {
if idx > 0 {
builder.WriteString(" AND ")
}
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder)
} else {
builder.WriteString("NOT ")
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToLower(e.SQL)
if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses {
builder.WriteByte('(')
}
}
c.Build(builder)
if wrapInParentheses {
builder.WriteByte(')')
}
}
}
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
}

69
clause/where_test.go Normal file
View File

@ -0,0 +1,69 @@
package clause_test
import (
"fmt"
"testing"
"gorm.io/gorm/clause"
)
func TestWhere(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?", []interface{}{"1", 18, "jinzhu"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?", []interface{}{"1", "jinzhu"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"})},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Not(clause.Gt{Column: "score", Value: 100}), clause.Like{Column: "name", Value: "%linus%"})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))},
}},
"SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", []interface{}{18, "jinzhu"},
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

4
clause/with.go Normal file
View File

@ -0,0 +1,4 @@
package clause
type With struct {
}

View File

@ -1,218 +0,0 @@
package gorm_test
import (
"os"
"reflect"
"testing"
"time"
"github.com/jinzhu/now"
)
func TestCreate(t *testing.T) {
float := 35.03554004971999
now := time.Now()
user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float}
if !DB.NewRecord(user) || !DB.NewRecord(&user) {
t.Error("User should be new record before create")
}
if count := DB.Save(&user).RowsAffected; count != 1 {
t.Error("There should be one record be affected when create record")
}
if DB.NewRecord(user) || DB.NewRecord(&user) {
t.Error("User should not new record after save")
}
var newUser User
DB.First(&newUser, user.Id)
if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) {
t.Errorf("User's PasswordHash should be saved ([]byte)")
}
if newUser.Age != 18 {
t.Errorf("User's Age should be saved (int)")
}
if newUser.UserNum != Num(111) {
t.Errorf("User's UserNum should be saved (custom type)")
}
if newUser.Latitude != float {
t.Errorf("Float64 should not be changed after save")
}
if user.CreatedAt.IsZero() {
t.Errorf("Should have created_at after create")
}
if newUser.CreatedAt.IsZero() {
t.Errorf("Should have created_at after create")
}
DB.Model(user).Update("name", "create_user_new_name")
DB.First(&user, user.Id)
if user.CreatedAt.Format(time.RFC3339Nano) != newUser.CreatedAt.Format(time.RFC3339Nano) {
t.Errorf("CreatedAt should not be changed after update")
}
}
func TestCreateWithExistingTimestamp(t *testing.T) {
user := User{Name: "CreateUserExistingTimestamp"}
timeA := now.MustParse("2016-01-01")
user.CreatedAt = timeA
user.UpdatedAt = timeA
DB.Save(&user)
if user.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
t.Errorf("CreatedAt should not be changed")
}
if user.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
t.Errorf("UpdatedAt should not be changed")
}
var newUser User
DB.First(&newUser, user.Id)
if newUser.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
t.Errorf("CreatedAt should not be changed")
}
if newUser.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
t.Errorf("UpdatedAt should not be changed")
}
}
type AutoIncrementUser struct {
User
Sequence uint `gorm:"AUTO_INCREMENT"`
}
func TestCreateWithAutoIncrement(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
t.Skip("Skipping this because only postgres properly support auto_increment on a non-primary_key column")
}
DB.AutoMigrate(&AutoIncrementUser{})
user1 := AutoIncrementUser{}
user2 := AutoIncrementUser{}
DB.Create(&user1)
DB.Create(&user2)
if user2.Sequence-user1.Sequence != 1 {
t.Errorf("Auto increment should apply on Sequence")
}
}
func TestCreateWithNoGORMPrimayKey(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" {
t.Skip("Skipping this because MSSQL will return identity only if the table has an Id column")
}
jt := JoinTable{From: 1, To: 2}
err := DB.Create(&jt).Error
if err != nil {
t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err)
}
}
func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
animal := Animal{Name: "Ferdinand"}
if DB.Save(&animal).Error != nil {
t.Errorf("No error should happen when create a record without std primary key")
}
if animal.Counter == 0 {
t.Errorf("No std primary key should be filled value after create")
}
if animal.Name != "Ferdinand" {
t.Errorf("Default value should be overrided")
}
// Test create with default value not overrided
an := Animal{From: "nerdz"}
if DB.Save(&an).Error != nil {
t.Errorf("No error should happen when create an record without std primary key")
}
// We must fetch the value again, to have the default fields updated
// (We can't do this in the update statements, since sql default can be expressions
// And be different from the fields' type (eg. a time.Time fields has a default value of "now()"
DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an)
if an.Name != "galeone" {
t.Errorf("Default value should fill the field. But got %v", an.Name)
}
}
func TestAnonymousScanner(t *testing.T) {
user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}}
DB.Save(&user)
var user2 User
DB.First(&user2, "name = ?", "anonymous_scanner")
if user2.Role.Name != "admin" {
t.Errorf("Should be able to get anonymous scanner")
}
if !user2.Role.IsAdmin() {
t.Errorf("Should be able to get anonymous scanner")
}
}
func TestAnonymousField(t *testing.T) {
user := User{Name: "anonymous_field", Company: Company{Name: "company"}}
DB.Save(&user)
var user2 User
DB.First(&user2, "name = ?", "anonymous_field")
DB.Model(&user2).Related(&user2.Company)
if user2.Company.Name != "company" {
t.Errorf("Should be able to get anonymous field")
}
}
func TestSelectWithCreate(t *testing.T) {
user := getPreparedUser("select_user", "select_with_create")
DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user)
var queryuser User
DB.Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id)
if queryuser.Name != user.Name || queryuser.Age == user.Age {
t.Errorf("Should only create users with name column")
}
if queryuser.BillingAddressID.Int64 == 0 || queryuser.ShippingAddressId != 0 ||
queryuser.CreditCard.ID == 0 || len(queryuser.Emails) == 0 {
t.Errorf("Should only create selected relationships")
}
}
func TestOmitWithCreate(t *testing.T) {
user := getPreparedUser("omit_user", "omit_with_create")
DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user)
var queryuser User
DB.Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id)
if queryuser.Name == user.Name || queryuser.Age != user.Age {
t.Errorf("Should only create users with age column")
}
if queryuser.BillingAddressID.Int64 != 0 || queryuser.ShippingAddressId == 0 ||
queryuser.CreditCard.ID != 0 || len(queryuser.Emails) != 0 {
t.Errorf("Should not create omitted relationships")
}
}

View File

@ -1,303 +0,0 @@
package gorm_test
import (
"testing"
"time"
"github.com/jinzhu/gorm"
)
type CustomizeColumn struct {
ID int64 `gorm:"column:mapped_id; primary_key:yes"`
Name string `gorm:"column:mapped_name"`
Date *time.Time `gorm:"column:mapped_time"`
}
// Make sure an ignored field does not interfere with another field's custom
// column name that matches the ignored field.
type CustomColumnAndIgnoredFieldClash struct {
Body string `sql:"-"`
RawBody string `gorm:"column:body"`
}
func TestCustomizeColumn(t *testing.T) {
col := "mapped_name"
DB.DropTable(&CustomizeColumn{})
DB.AutoMigrate(&CustomizeColumn{})
scope := DB.NewScope(&CustomizeColumn{})
if !scope.Dialect().HasColumn(scope.TableName(), col) {
t.Errorf("CustomizeColumn should have column %s", col)
}
col = "mapped_id"
if scope.PrimaryKey() != col {
t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey())
}
expected := "foo"
now := time.Now()
cc := CustomizeColumn{ID: 666, Name: expected, Date: &now}
if count := DB.Create(&cc).RowsAffected; count != 1 {
t.Error("There should be one record be affected when create record")
}
var cc1 CustomizeColumn
DB.First(&cc1, 666)
if cc1.Name != expected {
t.Errorf("Failed to query CustomizeColumn")
}
cc.Name = "bar"
DB.Save(&cc)
var cc2 CustomizeColumn
DB.First(&cc2, 666)
if cc2.Name != "bar" {
t.Errorf("Failed to query CustomizeColumn")
}
}
func TestCustomColumnAndIgnoredFieldClash(t *testing.T) {
DB.DropTable(&CustomColumnAndIgnoredFieldClash{})
if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil {
t.Errorf("Should not raise error: %s", err)
}
}
type CustomizePerson struct {
IdPerson string `gorm:"column:idPerson;primary_key:true"`
Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"`
}
type CustomizeAccount struct {
IdAccount string `gorm:"column:idAccount;primary_key:true"`
Name string
}
func TestManyToManyWithCustomizedColumn(t *testing.T) {
DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount")
DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{})
account := CustomizeAccount{IdAccount: "account", Name: "id1"}
person := CustomizePerson{
IdPerson: "person",
Accounts: []CustomizeAccount{account},
}
if err := DB.Create(&account).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if err := DB.Create(&person).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
var person1 CustomizePerson
scope := DB.NewScope(nil)
if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil {
t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err)
}
if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" {
t.Errorf("should preload correct accounts")
}
}
type CustomizeUser struct {
gorm.Model
Email string `sql:"column:email_address"`
}
type CustomizeInvitation struct {
gorm.Model
Address string `sql:"column:invitation"`
Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"`
}
func TestOneToOneWithCustomizedColumn(t *testing.T) {
DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{})
DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{})
user := CustomizeUser{
Email: "hello@example.com",
}
invitation := CustomizeInvitation{
Address: "hello@example.com",
}
DB.Create(&user)
DB.Create(&invitation)
var invitation2 CustomizeInvitation
if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if invitation2.Person.Email != user.Email {
t.Errorf("Should preload one to one relation with customize foreign keys")
}
}
type PromotionDiscount struct {
gorm.Model
Name string
Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"`
Rule *PromotionRule `gorm:"ForeignKey:discount_id"`
Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"`
}
type PromotionBenefit struct {
gorm.Model
Name string
PromotionID uint
Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"`
}
type PromotionCoupon struct {
gorm.Model
Code string
DiscountID uint
Discount PromotionDiscount
}
type PromotionRule struct {
gorm.Model
Name string
Begin *time.Time
End *time.Time
DiscountID uint
Discount *PromotionDiscount
}
func TestOneToManyWithCustomizedColumn(t *testing.T) {
DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{})
DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{})
discount := PromotionDiscount{
Name: "Happy New Year",
Coupons: []*PromotionCoupon{
{Code: "newyear1"},
{Code: "newyear2"},
},
}
if err := DB.Create(&discount).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
var discount1 PromotionDiscount
if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
if len(discount.Coupons) != 2 {
t.Errorf("should find two coupons")
}
var coupon PromotionCoupon
if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
if coupon.Discount.Name != "Happy New Year" {
t.Errorf("should preload discount from coupon")
}
}
func TestHasOneWithPartialCustomizedColumn(t *testing.T) {
DB.DropTable(&PromotionDiscount{}, &PromotionRule{})
DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{})
var begin = time.Now()
var end = time.Now().Add(24 * time.Hour)
discount := PromotionDiscount{
Name: "Happy New Year 2",
Rule: &PromotionRule{
Name: "time_limited",
Begin: &begin,
End: &end,
},
}
if err := DB.Create(&discount).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
var discount1 PromotionDiscount
if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) {
t.Errorf("Should be able to preload Rule")
}
var rule PromotionRule
if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
if rule.Discount.Name != "Happy New Year 2" {
t.Errorf("should preload discount from rule")
}
}
func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{})
DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{})
discount := PromotionDiscount{
Name: "Happy New Year 3",
Benefits: []PromotionBenefit{
{Name: "free cod"},
{Name: "free shipping"},
},
}
if err := DB.Create(&discount).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
var discount1 PromotionDiscount
if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
if len(discount.Benefits) != 2 {
t.Errorf("should find two benefits")
}
var benefit PromotionBenefit
if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
if benefit.Discount.Name != "Happy New Year 3" {
t.Errorf("should preload discount from coupon")
}
}
type SelfReferencingUser struct {
gorm.Model
Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;AssociationForeignKey:ID=friend_id"`
}
func TestSelfReferencingMany2ManyColumn(t *testing.T) {
DB.DropTable(&SelfReferencingUser{}, "UserFriends")
DB.AutoMigrate(&SelfReferencingUser{})
friend := SelfReferencingUser{}
if err := DB.Create(&friend).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
user := SelfReferencingUser{
Friends: []*SelfReferencingUser{&friend},
}
if err := DB.Create(&user).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
}

View File

@ -1,91 +0,0 @@
package gorm_test
import (
"testing"
"time"
)
func TestDelete(t *testing.T) {
user1, user2 := User{Name: "delete1"}, User{Name: "delete2"}
DB.Save(&user1)
DB.Save(&user2)
if err := DB.Delete(&user1).Error; err != nil {
t.Errorf("No error should happen when delete a record, err=%s", err)
}
if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
t.Errorf("User can't be found after delete")
}
if DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() {
t.Errorf("Other users that not deleted should be found-able")
}
}
func TestInlineDelete(t *testing.T) {
user1, user2 := User{Name: "inline_delete1"}, User{Name: "inline_delete2"}
DB.Save(&user1)
DB.Save(&user2)
if DB.Delete(&User{}, user1.Id).Error != nil {
t.Errorf("No error should happen when delete a record")
} else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
t.Errorf("User can't be found after delete")
}
if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Errorf("No error should happen when delete a record, err=%s", err)
} else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() {
t.Errorf("User can't be found after delete")
}
}
func TestSoftDelete(t *testing.T) {
type User struct {
Id int64
Name string
DeletedAt *time.Time
}
DB.AutoMigrate(&User{})
user := User{Name: "soft_delete"}
DB.Save(&user)
DB.Delete(&user)
if DB.First(&User{}, "name = ?", user.Name).Error == nil {
t.Errorf("Can't find a soft deleted record")
}
if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err)
}
DB.Unscoped().Delete(&user)
if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() {
t.Errorf("Can't find permanently deleted record")
}
}
func TestSoftDeleteWithCustomizedDeletedAtColumnName(t *testing.T) {
creditCard := CreditCard{Number: "411111111234567"}
DB.Save(&creditCard)
DB.Delete(&creditCard)
if deletedAtField, ok := DB.NewScope(&CreditCard{}).FieldByName("DeletedAt"); !ok || deletedAtField.DBName != "deleted_time" {
t.Errorf("CreditCard's DeletedAt's column name should be `deleted_time`")
}
if DB.First(&CreditCard{}, "number = ?", creditCard.Number).Error == nil {
t.Errorf("Can't find a soft deleted record")
}
if err := DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).Error; err != nil {
t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err)
}
DB.Unscoped().Delete(&creditCard)
if !DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).RecordNotFound() {
t.Errorf("Can't find permanently deleted record")
}
}

View File

@ -1,116 +0,0 @@
package gorm
import (
"database/sql"
"fmt"
"reflect"
"strconv"
"strings"
)
// Dialect interface contains behaviors that differ across SQL database
type Dialect interface {
// GetName get dialect's name
GetName() string
// SetDB set db for dialect
SetDB(db SQLCommon)
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
BindVar(i int) string
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
Quote(key string) string
// DataTypeOf return data's sql type
DataTypeOf(field *StructField) string
// HasIndex check has index or not
HasIndex(tableName string, indexName string) bool
// HasForeignKey check has foreign key or not
HasForeignKey(tableName string, foreignKeyName string) bool
// RemoveIndex remove index
RemoveIndex(tableName string, indexName string) error
// HasTable check has table or not
HasTable(tableName string) bool
// HasColumn check has column or not
HasColumn(tableName string, columnName string) bool
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
LimitAndOffsetSQL(limit, offset interface{}) string
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
SelectFromDummyTable() string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string
// BuildForeignKeyName returns a foreign key name for the given table, field and reference
BuildForeignKeyName(tableName, field, dest string) string
// CurrentDatabase return current database name
CurrentDatabase() string
}
var dialectsMap = map[string]Dialect{}
func newDialect(name string, db SQLCommon) Dialect {
if value, ok := dialectsMap[name]; ok {
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
dialect.SetDB(db)
return dialect
}
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
commontDialect := &commonDialect{}
commontDialect.SetDB(db)
return commontDialect
}
// RegisterDialect register new dialect
func RegisterDialect(name string, dialect Dialect) {
dialectsMap[name] = dialect
}
// ParseFieldStructForDialect get field's sql data type
var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
// Get redirected field type
var (
reflectType = field.Struct.Type
dataType = field.TagSettings["TYPE"]
)
for reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
}
// Get redirected field value
fieldValue = reflect.Indirect(reflect.New(reflectType))
if gormDataType, ok := fieldValue.Interface().(interface {
GormDataType(Dialect) string
}); ok {
dataType = gormDataType.GormDataType(dialect)
}
// Get scanner's real value
var getScannerValue func(reflect.Value)
getScannerValue = func(value reflect.Value) {
fieldValue = value
if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct {
getScannerValue(fieldValue.Field(0))
}
}
getScannerValue(fieldValue)
// Default Size
if num, ok := field.TagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
} else {
size = 255
}
// Default type from tag setting
additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
if value, ok := field.TagSettings["DEFAULT"]; ok {
additionalType = additionalType + " DEFAULT " + value
}
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
}

View File

@ -1,156 +0,0 @@
package gorm
import (
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
)
// DefaultForeignKeyNamer contains the default foreign key name generator method
type DefaultForeignKeyNamer struct {
}
type commonDialect struct {
db SQLCommon
DefaultForeignKeyNamer
}
func init() {
RegisterDialect("common", &commonDialect{})
}
func (commonDialect) GetName() string {
return "common"
}
func (s *commonDialect) SetDB(db SQLCommon) {
s.db = db
}
func (commonDialect) BindVar(i int) string {
return "$$$" // ?
}
func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (s *commonDialect) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
sqlType = "INTEGER AUTO_INCREMENT"
} else {
sqlType = "INTEGER"
}
case reflect.Int64, reflect.Uint64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
sqlType = "BIGINT AUTO_INCREMENT"
} else {
sqlType = "BIGINT"
}
case reflect.Float32, reflect.Float64:
sqlType = "FLOAT"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("VARCHAR(%d)", size)
} else {
sqlType = "VARCHAR(65532)"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "TIMESTAMP"
}
default:
if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("BINARY(%d)", size)
} else {
sqlType = "BINARY(65532)"
}
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count)
return count > 0
}
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
return err
}
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
return false
}
func (s commonDialect) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count)
return count > 0
}
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
return count > 0
}
func (s commonDialect) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}
func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
}
}
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
}
}
return
}
func (commonDialect) SelectFromDummyTable() string {
return ""
}
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}
func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string {
keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest)
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
return keyName
}
// IsByteArrayOrSlice returns true of the reflected value is an array or slice
func IsByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
}

View File

@ -1,176 +0,0 @@
package gorm
import (
"crypto/sha1"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"unicode/utf8"
)
type mysql struct {
commonDialect
}
func init() {
RegisterDialect("mysql", &mysql{})
}
func (mysql) GetName() string {
return "mysql"
}
func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
}
// Get Data Type for MySQL Dialect
func (s *mysql) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
// MySQL allows only one auto increment column per table, and it must
// be a KEY column.
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey {
delete(field.TagSettings, "AUTO_INCREMENT")
}
}
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "boolean"
case reflect.Int8:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "tinyint AUTO_INCREMENT"
} else {
sqlType = "tinyint"
}
case reflect.Int, reflect.Int16, reflect.Int32:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "int AUTO_INCREMENT"
} else {
sqlType = "int"
}
case reflect.Uint8:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "tinyint unsigned AUTO_INCREMENT"
} else {
sqlType = "tinyint unsigned"
}
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "int unsigned AUTO_INCREMENT"
} else {
sqlType = "int unsigned"
}
case reflect.Int64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "bigint AUTO_INCREMENT"
} else {
sqlType = "bigint"
}
case reflect.Uint64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "bigint unsigned AUTO_INCREMENT"
} else {
sqlType = "bigint unsigned"
}
case reflect.Float32, reflect.Float64:
sqlType = "double"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varchar(%d)", size)
} else {
sqlType = "longtext"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
if _, ok := field.TagSettings["NOT NULL"]; ok {
sqlType = "timestamp"
} else {
sqlType = "timestamp NULL"
}
}
default:
if IsByteArrayOrSlice(dataValue) {
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varbinary(%d)", size)
} else {
sqlType = "longblob"
}
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s mysql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
}
}
}
}
return
}
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
return count > 0
}
func (s mysql) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}
func (mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}
func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest)
if utf8.RuneCountInString(keyName) <= 64 {
return keyName
}
h := sha1.New()
h.Write([]byte(keyName))
bs := h.Sum(nil)
// sha1 is 40 digits, keep first 24 characters of destination
destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_"))
if len(destRunes) > 24 {
destRunes = destRunes[:24]
}
return fmt.Sprintf("%s%x", string(destRunes), bs)
}

View File

@ -1,132 +0,0 @@
package gorm
import (
"fmt"
"reflect"
"strings"
"time"
)
type postgres struct {
commonDialect
}
func init() {
RegisterDialect("postgres", &postgres{})
RegisterDialect("cloudsqlpostgres", &postgres{})
}
func (postgres) GetName() string {
return "postgres"
}
func (postgres) BindVar(i int) string {
return fmt.Sprintf("$%v", i)
}
func (s *postgres) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "serial"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint32, reflect.Uint64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "bigserial"
} else {
sqlType = "bigint"
}
case reflect.Float32, reflect.Float64:
sqlType = "numeric"
case reflect.String:
if _, ok := field.TagSettings["SIZE"]; !ok {
size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
}
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varchar(%d)", size)
} else {
sqlType = "text"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "timestamp with time zone"
}
case reflect.Map:
if dataValue.Type().Name() == "Hstore" {
sqlType = "hstore"
}
default:
if IsByteArrayOrSlice(dataValue) {
sqlType = "bytea"
if isUUID(dataValue) {
sqlType = "uuid"
}
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s postgres) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
return count > 0
}
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
return count > 0
}
func (s postgres) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
return count > 0
}
func (s postgres) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
return count > 0
}
func (s postgres) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
return
}
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
}
func (postgres) SupportLastInsertID() bool {
return false
}
func isUUID(value reflect.Value) bool {
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
return false
}
typename := value.Type().Name()
lower := strings.ToLower(typename)
return "uuid" == lower || "guid" == lower
}

View File

@ -1,107 +0,0 @@
package gorm
import (
"fmt"
"reflect"
"strings"
"time"
)
type sqlite3 struct {
commonDialect
}
func init() {
RegisterDialect("sqlite3", &sqlite3{})
}
func (sqlite3) GetName() string {
return "sqlite3"
}
// Get Data Type for Sqlite Dialect
func (s *sqlite3) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if field.IsPrimaryKey {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "integer primary key autoincrement"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint64:
if field.IsPrimaryKey {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "integer primary key autoincrement"
} else {
sqlType = "bigint"
}
case reflect.Float32, reflect.Float64:
sqlType = "real"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varchar(%d)", size)
} else {
sqlType = "text"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "datetime"
}
default:
if IsByteArrayOrSlice(dataValue) {
sqlType = "blob"
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
return count > 0
}
func (s sqlite3) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
return count > 0
}
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
return count > 0
}
func (s sqlite3) CurrentDatabase() (name string) {
var (
ifaces = make([]interface{}, 3)
pointers = make([]*string, 3)
i int
)
for i = 0; i < 3; i++ {
ifaces[i] = &pointers[i]
}
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
return
}
if pointers[1] != nil {
name = *pointers[1]
}
return
}

View File

@ -1,170 +0,0 @@
package mssql
import (
"fmt"
"reflect"
"strconv"
"strings"
"time"
_ "github.com/denisenkom/go-mssqldb"
"github.com/jinzhu/gorm"
)
func setIdentityInsert(scope *gorm.Scope) {
if scope.Dialect().GetName() == "mssql" {
for _, field := range scope.PrimaryFields() {
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank {
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
scope.InstanceSet("mssql:identity_insert_on", true)
}
}
}
}
func turnOffIdentityInsert(scope *gorm.Scope) {
if scope.Dialect().GetName() == "mssql" {
if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok {
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName()))
}
}
}
func init() {
gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert)
gorm.DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert)
gorm.RegisterDialect("mssql", &mssql{})
}
type mssql struct {
db gorm.SQLCommon
gorm.DefaultForeignKeyNamer
}
func (mssql) GetName() string {
return "mssql"
}
func (s *mssql) SetDB(db gorm.SQLCommon) {
s.db = db
}
func (mssql) BindVar(i int) string {
return "$$$" // ?
}
func (mssql) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (s *mssql) DataTypeOf(field *gorm.StructField) string {
var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "bit"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "int IDENTITY(1,1)"
} else {
sqlType = "int"
}
case reflect.Int64, reflect.Uint64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "bigint IDENTITY(1,1)"
} else {
sqlType = "bigint"
}
case reflect.Float32, reflect.Float64:
sqlType = "float"
case reflect.String:
if size > 0 && size < 8000 {
sqlType = fmt.Sprintf("nvarchar(%d)", size)
} else {
sqlType = "nvarchar(max)"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "datetimeoffset"
}
default:
if gorm.IsByteArrayOrSlice(dataValue) {
if size > 0 && size < 8000 {
sqlType = fmt.Sprintf("varbinary(%d)", size)
} else {
sqlType = "varbinary(max)"
}
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s mssql) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
return count > 0
}
func (s mssql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
return false
}
func (s mssql) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count)
return count > 0
}
func (s mssql) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
return count > 0
}
func (s mssql) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
return
}
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset)
}
}
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
if sql == "" {
// add default zero offset
sql += " OFFSET 0 ROWS"
}
sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit)
}
}
return
}
func (mssql) SelectFromDummyTable() string {
return ""
}
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}

View File

@ -1,3 +0,0 @@
package mysql
import _ "github.com/go-sql-driver/mysql"

View File

@ -1,77 +0,0 @@
package postgres
import (
"database/sql"
"database/sql/driver"
_ "github.com/lib/pq"
"github.com/lib/pq/hstore"
"encoding/json"
"errors"
"fmt"
)
type Hstore map[string]*string
// Value get value of Hstore
func (h Hstore) Value() (driver.Value, error) {
hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
if len(h) == 0 {
return nil, nil
}
for key, value := range h {
var s sql.NullString
if value != nil {
s.String = *value
s.Valid = true
}
hstore.Map[key] = s
}
return hstore.Value()
}
// Scan scan value into Hstore
func (h *Hstore) Scan(value interface{}) error {
hstore := hstore.Hstore{}
if err := hstore.Scan(value); err != nil {
return err
}
if len(hstore.Map) == 0 {
return nil
}
*h = Hstore{}
for k := range hstore.Map {
if hstore.Map[k].Valid {
s := hstore.Map[k].String
(*h)[k] = &s
} else {
(*h)[k] = nil
}
}
return nil
}
// Jsonb Postgresql's JSONB data type
type Jsonb struct {
json.RawMessage
}
// Value get value of Jsonb
func (j Jsonb) Value() (driver.Value, error) {
return j.MarshalJSON()
}
// Scan scan value into Jsonb
func (j *Jsonb) Scan(value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
}
return json.Unmarshal(bytes, j)
}

View File

@ -1,3 +0,0 @@
package sqlite
import _ "github.com/mattn/go-sqlite3"

View File

@ -1,73 +0,0 @@
package gorm_test
import "testing"
type BasePost struct {
Id int64
Title string
URL string
}
type Author struct {
ID string
Name string
Email string
}
type HNPost struct {
BasePost
Author `gorm:"embedded_prefix:user_"` // Embedded struct
Upvotes int32
}
type EngadgetPost struct {
BasePost BasePost `gorm:"embedded"`
Author Author `gorm:"embedded;embedded_prefix:author_"` // Embedded struct
ImageUrl string
}
func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
dialect := DB.NewScope(&EngadgetPost{}).Dialect()
engadgetPostScope := DB.NewScope(&EngadgetPost{})
if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") {
t.Errorf("should has prefix for embedded columns")
}
if len(engadgetPostScope.PrimaryFields()) != 1 {
t.Errorf("should have only one primary field with embedded struct, but got %v", len(engadgetPostScope.PrimaryFields()))
}
hnScope := DB.NewScope(&HNPost{})
if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") {
t.Errorf("should has prefix for embedded columns")
}
}
func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
DB.Save(&HNPost{BasePost: BasePost{Title: "news"}})
DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
var news HNPost
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if news.Title != "hn_news" {
t.Errorf("embedded struct's value should be scanned correctly")
}
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
var egNews EngadgetPost
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if egNews.BasePost.Title != "engadget_news" {
t.Errorf("embedded struct's value should be scanned correctly")
}
if DB.NewScope(&HNPost{}).PrimaryField() == nil {
t.Errorf("primary key with embedded struct should works")
}
for _, field := range DB.NewScope(&HNPost{}).Fields() {
if field.Name == "BasePost" {
t.Errorf("scope Fields should not contain embedded struct")
}
}
}

View File

@ -2,59 +2,41 @@ package gorm
import ( import (
"errors" "errors"
"strings"
"gorm.io/gorm/logger"
) )
var ( var (
// ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct // ErrRecordNotFound record not found error
ErrRecordNotFound = errors.New("record not found") ErrRecordNotFound = logger.ErrRecordNotFound
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
ErrInvalidSQL = errors.New("invalid SQL")
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
ErrInvalidTransaction = errors.New("no valid transaction") ErrInvalidTransaction = errors.New("no valid transaction")
// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` // ErrNotImplemented not implemented
ErrCantStartTransaction = errors.New("can't start transaction") ErrNotImplemented = errors.New("not implemented")
// ErrUnaddressable unaddressable value // ErrMissingWhereClause missing where clause
ErrUnaddressable = errors.New("using unaddressable value") ErrMissingWhereClause = errors.New("WHERE conditions required")
// ErrUnsupportedRelation unsupported relations
ErrUnsupportedRelation = errors.New("unsupported relations")
// ErrPrimaryKeyRequired primary keys required
ErrPrimaryKeyRequired = errors.New("primary key required")
// ErrModelValueRequired model value required
ErrModelValueRequired = errors.New("model value required")
// ErrInvalidData unsupported data
ErrInvalidData = errors.New("unsupported data")
// ErrUnsupportedDriver unsupported driver
ErrUnsupportedDriver = errors.New("unsupported driver")
// ErrRegistered registered
ErrRegistered = errors.New("registered")
// ErrInvalidField invalid field
ErrInvalidField = errors.New("invalid field")
// ErrEmptySlice empty slice found
ErrEmptySlice = errors.New("empty slice found")
// ErrDryRunModeUnsupported dry run mode unsupported
ErrDryRunModeUnsupported = errors.New("dry run mode unsupported")
// ErrInvalidDB invalid db
ErrInvalidDB = errors.New("invalid db")
// ErrInvalidValue invalid value
ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice")
// ErrInvalidValueOfLength invalid values do not match length
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
) )
// Errors contains all happened errors
type Errors []error
// GetErrors gets all happened errors
func (errs Errors) GetErrors() []error {
return errs
}
// Add adds an error
func (errs Errors) Add(newErrors ...error) Errors {
for _, err := range newErrors {
if err == nil {
continue
}
if errors, ok := err.(Errors); ok {
errs = errs.Add(errors...)
} else {
ok = true
for _, e := range errs {
if err == e {
ok = false
}
}
if ok {
errs = append(errs, err)
}
}
}
return errs
}
// Error format happened errors
func (errs Errors) Error() string {
var errors = []string{}
for _, e := range errs {
errors = append(errors, e.Error())
}
return strings.Join(errors, "; ")
}

View File

@ -1,20 +0,0 @@
package gorm_test
import (
"errors"
"testing"
"github.com/jinzhu/gorm"
)
func TestErrorsCanBeUsedOutsideGorm(t *testing.T) {
errs := []error{errors.New("First"), errors.New("Second")}
gErrs := gorm.Errors(errs)
gErrs = gErrs.Add(errors.New("Third"))
gErrs = gErrs.Add(gErrs)
if gErrs.Error() != "First; Second; Third" {
t.Fatalf("Gave wrong error, got %s", gErrs.Error())
}
}

View File

@ -1,58 +0,0 @@
package gorm
import (
"database/sql"
"errors"
"fmt"
"reflect"
)
// Field model field definition
type Field struct {
*StructField
IsBlank bool
Field reflect.Value
}
// Set set a value to the field
func (field *Field) Set(value interface{}) (err error) {
if !field.Field.IsValid() {
return errors.New("field value not valid")
}
if !field.Field.CanAddr() {
return ErrUnaddressable
}
reflectValue, ok := value.(reflect.Value)
if !ok {
reflectValue = reflect.ValueOf(value)
}
fieldValue := field.Field
if reflectValue.IsValid() {
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
} else {
if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.Struct.Type.Elem()))
}
fieldValue = fieldValue.Elem()
}
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
err = scanner.Scan(reflectValue.Interface())
} else {
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
}
}
} else {
field.Field.Set(reflect.Zero(field.Field.Type()))
}
field.IsBlank = isBlank(field.Field)
return err
}

View File

@ -1,49 +0,0 @@
package gorm_test
import (
"testing"
"github.com/jinzhu/gorm"
)
type CalculateField struct {
gorm.Model
Name string
Children []CalculateFieldChild
Category CalculateFieldCategory
EmbeddedField
}
type EmbeddedField struct {
EmbeddedName string `sql:"NOT NULL;DEFAULT:'hello'"`
}
type CalculateFieldChild struct {
gorm.Model
CalculateFieldID uint
Name string
}
type CalculateFieldCategory struct {
gorm.Model
CalculateFieldID uint
Name string
}
func TestCalculateField(t *testing.T) {
var field CalculateField
var scope = DB.NewScope(&field)
if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil {
t.Errorf("Should calculate fields correctly for the first time")
}
if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil {
t.Errorf("Should calculate fields correctly for the first time")
}
if field, ok := scope.FieldByName("embedded_name"); !ok {
t.Errorf("should find embedded field")
} else if _, ok := field.TagSettings["NOT NULL"]; !ok {
t.Errorf("should find embedded field's tag settings")
}
}

649
finisher_api.go Normal file
View File

@ -0,0 +1,649 @@
package gorm
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
// Create insert the value into database
func (db *DB) Create(value interface{}) (tx *DB) {
if db.CreateBatchSize > 0 {
return db.CreateInBatches(value, db.CreateBatchSize)
}
tx = db.getInstance()
tx.Statement.Dest = value
tx.callbacks.Create().Execute(tx)
return
}
// CreateInBatches insert the value in batches into database
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
var rowsAffected int64
tx = db.getInstance()
callFc := func(tx *DB) error {
// the reflection length judgment of the optimized value
reflectLen := reflectValue.Len()
for i := 0; i < reflectLen; i += batchSize {
ends := i + batchSize
if ends > reflectLen {
ends = reflectLen
}
subtx := tx.getInstance()
subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface()
subtx.callbacks.Create().Execute(subtx)
if subtx.Error != nil {
return subtx.Error
}
rowsAffected += subtx.RowsAffected
}
return nil
}
if tx.SkipDefaultTransaction {
tx.AddError(callFc(tx.Session(&Session{})))
} else {
tx.AddError(tx.Transaction(callFc))
}
tx.RowsAffected = rowsAffected
default:
tx = db.getInstance()
tx.Statement.Dest = value
tx.callbacks.Create().Execute(tx)
}
return
}
// Save update value in database, if the value doesn't have primary key, will insert it
func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = value
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
}
tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true))
case reflect.Struct:
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
for _, pf := range tx.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(reflectValue); isZero {
tx.callbacks.Create().Execute(tx)
return
}
}
}
fallthrough
default:
selectedUpdate := len(tx.Statement.Selects) != 0
// when updating, use all fields including those zero-value fields
if !selectedUpdate {
tx.Statement.Selects = append(tx.Statement.Selects, "*")
}
tx.callbacks.Update().Execute(tx)
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
result := reflect.New(tx.Statement.Schema.ModelType).Interface()
if err := tx.Session(&Session{}).First(result).Error; errors.Is(err, ErrRecordNotFound) {
return tx.Create(value)
}
}
}
return
}
// First find first record that match given conditions, order by primary key
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx)
return
}
// Take return a record that match given conditions, the order will depend on the database implementation
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1)
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx)
return
}
// Last find last record that match given conditions, order by primary key
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
Desc: true,
})
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx)
return
}
// Find find records that match given conditions
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance()
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx)
return
}
// FindInBatches find records in batches
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
var (
tx = db.Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
}).Session(&Session{})
queryDB = tx
rowsAffected int64
batch int
)
for {
result := queryDB.Limit(batchSize).Find(dest)
rowsAffected += result.RowsAffected
batch++
if result.Error == nil && result.RowsAffected != 0 {
tx.AddError(fc(result, batch))
} else if result.Error != nil {
tx.AddError(result.Error)
}
if tx.Error != nil || int(result.RowsAffected) < batchSize {
break
} else {
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
if result.Statement.Schema.PrioritizedPrimaryField == nil {
tx.AddError(ErrPrimaryKeyRequired)
break
} else {
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
}
}
}
tx.RowsAffected = rowsAffected
return tx
}
func (tx *DB) assignInterfacesToValue(values ...interface{}) {
for _, value := range values {
switch v := value.(type) {
case []clause.Expression:
for _, expr := range v {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
if field := tx.Statement.Schema.LookUpField(column); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
}
case clause.Column:
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
}
}
} else if andCond, ok := expr.(clause.AndConditions); ok {
tx.assignInterfacesToValue(andCond.Exprs)
}
}
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
if exprs := tx.Statement.BuildCondition(value); len(exprs) > 0 {
tx.assignInterfacesToValue(exprs)
}
default:
if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Struct:
for _, f := range s.Fields {
if f.Readable {
if v, isZero := f.ValueOf(reflectValue); !isZero {
if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, v))
}
}
}
}
}
} else if len(values) > 0 {
if exprs := tx.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
tx.assignInterfacesToValue(exprs)
}
return
}
}
}
}
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs)
}
}
// initialize with attrs, conds
if len(tx.Statement.attrs) > 0 {
tx.assignInterfacesToValue(tx.Statement.attrs...)
}
}
// initialize with attrs, conds
if len(tx.Statement.assigns) > 0 {
tx.assignInterfacesToValue(tx.Statement.assigns...)
}
return
}
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs)
}
}
// initialize with attrs, conds
if len(tx.Statement.attrs) > 0 {
tx.assignInterfacesToValue(tx.Statement.attrs...)
}
// initialize with attrs, conds
if len(tx.Statement.assigns) > 0 {
tx.assignInterfacesToValue(tx.Statement.assigns...)
}
return tx.Create(dest)
} else if len(db.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
assigns := map[string]interface{}{}
for _, expr := range exprs {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
assigns[column] = eq.Value
case clause.Column:
assigns[column.Name] = eq.Value
default:
}
}
}
return tx.Model(dest).Updates(assigns)
}
return db
}
// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Update(column string, value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value}
tx.callbacks.Update().Execute(tx)
return
}
// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Updates(values interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = values
tx.callbacks.Update().Execute(tx)
return
}
func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value}
tx.Statement.SkipHooks = true
tx.callbacks.Update().Execute(tx)
return
}
func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = values
tx.Statement.SkipHooks = true
tx.callbacks.Update().Execute(tx)
return
}
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance()
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
tx.Statement.Dest = value
tx.callbacks.Delete().Execute(tx)
return
}
func (db *DB) Count(count *int64) (tx *DB) {
tx = db.getInstance()
if tx.Statement.Model == nil {
tx.Statement.Model = tx.Statement.Dest
defer func() {
tx.Statement.Model = nil
}()
}
if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
defer func() {
db.Statement.Clauses["SELECT"] = selectClause
}()
} else {
defer delete(tx.Statement.Clauses, "SELECT")
}
if len(tx.Statement.Selects) == 0 {
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
} else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") {
expr := clause.Expr{SQL: "count(1)"}
if len(tx.Statement.Selects) == 1 {
dbName := tx.Statement.Selects[0]
fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
if tx.Statement.Parse(tx.Statement.Model) == nil {
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
dbName = f.DBName
}
}
if tx.Statement.Distinct {
expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
} else {
expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
}
}
}
tx.Statement.AddClause(clause.Select{Expression: expr})
}
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
delete(db.Statement.Clauses, "ORDER BY")
defer func() {
db.Statement.Clauses["ORDER BY"] = orderByClause
}()
}
}
tx.Statement.Dest = count
tx.callbacks.Query().Execute(tx)
if tx.RowsAffected != 1 {
*count = tx.RowsAffected
}
return
}
func (db *DB) Row() *sql.Row {
tx := db.getInstance().InstanceSet("rows", false)
tx.callbacks.Row().Execute(tx)
row, ok := tx.Statement.Dest.(*sql.Row)
if !ok && tx.DryRun {
db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
}
return row
}
func (db *DB) Rows() (*sql.Rows, error) {
tx := db.getInstance().InstanceSet("rows", true)
tx.callbacks.Row().Execute(tx)
rows, ok := tx.Statement.Dest.(*sql.Rows)
if !ok && tx.DryRun && tx.Error == nil {
tx.Error = ErrDryRunModeUnsupported
}
return rows, tx.Error
}
// Scan scan value to a struct
func (db *DB) Scan(dest interface{}) (tx *DB) {
config := *db.Config
currentLogger, newLogger := config.Logger, logger.Recorder.New()
config.Logger = newLogger
tx = db.getInstance()
tx.Config = &config
if rows, err := tx.Rows(); err != nil {
tx.AddError(err)
} else {
defer rows.Close()
if rows.Next() {
tx.ScanRows(rows, dest)
} else {
tx.RowsAffected = 0
}
}
currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
return newLogger.SQL, tx.RowsAffected
}, tx.Error)
tx.Logger = currentLogger
return
}
// Pluck used to query single column from a model as a map
// var ages []int64
// db.Find(&users).Pluck("age", &ages)
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
tx = db.getInstance()
if tx.Statement.Model != nil {
if tx.Statement.Parse(tx.Statement.Model) == nil {
if f := tx.Statement.Schema.LookUpField(column); f != nil {
column = f.DBName
}
}
} else if tx.Statement.Table == "" {
tx.AddError(ErrModelValueRequired)
}
if len(tx.Statement.Selects) != 1 {
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
tx.Statement.AddClauseIfNotExists(clause.Select{
Distinct: tx.Statement.Distinct,
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
})
}
tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx)
return
}
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
tx := db.getInstance()
if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
tx.AddError(err)
}
tx.Statement.Dest = dest
tx.Statement.ReflectValue = reflect.ValueOf(dest)
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem()
}
Scan(rows, tx, true)
return tx.Error
}
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// nested transaction
if !db.DisableNestedTransaction {
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
db.RollbackTo(fmt.Sprintf("sp%p", fc))
}
}()
}
if err == nil {
err = fc(db.Session(&Session{}))
}
} else {
tx := db.Begin(opts...)
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
tx.Rollback()
}
}()
if err = tx.Error; err == nil {
err = fc(tx)
}
if err == nil {
err = tx.Commit().Error
}
}
panicked = false
return
}
// Begin begins a transaction
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
var (
// clone statement
tx = db.getInstance().Session(&Session{Context: db.Statement.Context})
opt *sql.TxOptions
err error
)
if len(opts) > 0 {
opt = opts[0]
}
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else {
err = ErrInvalidTransaction
}
if err != nil {
tx.AddError(err)
}
return tx
}
// Commit commit a transaction
func (db *DB) Commit() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Commit())
} else {
db.AddError(ErrInvalidTransaction)
}
return db
}
// Rollback rollback a transaction
func (db *DB) Rollback() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
if !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Rollback())
}
} else {
db.AddError(ErrInvalidTransaction)
}
return db
}
func (db *DB) SavePoint(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
db.AddError(savePointer.SavePoint(db, name))
} else {
db.AddError(ErrUnsupportedDriver)
}
return db
}
func (db *DB) RollbackTo(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
db.AddError(savePointer.RollbackTo(db, name))
} else {
db.AddError(ErrUnsupportedDriver)
}
return db
}
// Exec execute raw sql
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.SQL = strings.Builder{}
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
} else {
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
}
tx.callbacks.Raw().Execute(tx)
return
}

8
go.mod Normal file
View File

@ -0,0 +1,8 @@
module gorm.io/gorm
go 1.14
require (
github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.1.2
)

4
go.sum Normal file
View File

@ -0,0 +1,4 @@
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.2 h1:eVKgfIdy9b6zbWBMgFpfDPoAMifwSZagU9HmEU6zgiI=
github.com/jinzhu/now v1.1.2/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=

441
gorm.go Normal file
View File

@ -0,0 +1,441 @@
package gorm
import (
"context"
"database/sql"
"fmt"
"sort"
"sync"
"time"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
)
// for Config.cacheStore store PreparedStmtDB key
const preparedStmtDBKey = "preparedStmt"
// Config GORM config
type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool
// NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer
// FullSaveAssociations full save associations
FullSaveAssociations bool
// Logger
Logger logger.Interface
// NowFunc the function to be used when creating a new timestamp
NowFunc func() time.Time
// DryRun generate sql without execute
DryRun bool
// PrepareStmt executes the given query in cached statement
PrepareStmt bool
// DisableAutomaticPing
DisableAutomaticPing bool
// DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool
// DisableNestedTransaction disable nested transaction
DisableNestedTransaction bool
// AllowGlobalUpdate allow global update
AllowGlobalUpdate bool
// QueryFields executes the SQL query with all fields of the table
QueryFields bool
// CreateBatchSize default create batch size
CreateBatchSize int
// ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
// ConnPool db conn pool
ConnPool ConnPool
// Dialector database dialector
Dialector
// Plugins registered plugins
Plugins map[string]Plugin
callbacks *callbacks
cacheStore *sync.Map
}
func (c *Config) Apply(config *Config) error {
if config != c {
*config = *c
}
return nil
}
func (c *Config) AfterInitialize(db *DB) error {
if db != nil {
for _, plugin := range c.Plugins {
if err := plugin.Initialize(db); err != nil {
return err
}
}
}
return nil
}
type Option interface {
Apply(*Config) error
AfterInitialize(*DB) error
}
// DB GORM DB definition
type DB struct {
*Config
Error error
RowsAffected int64
Statement *Statement
clone int
}
// Session session config when create session with Session() method
type Session struct {
DryRun bool
PrepareStmt bool
NewDB bool
SkipHooks bool
SkipDefaultTransaction bool
DisableNestedTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
QueryFields bool
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
CreateBatchSize int
}
// Open initialize db session based on dialector
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
config := &Config{}
sort.Slice(opts, func(i, j int) bool {
_, isConfig := opts[i].(*Config)
_, isConfig2 := opts[j].(*Config)
return isConfig && !isConfig2
})
for _, opt := range opts {
if opt != nil {
if err := opt.Apply(config); err != nil {
return nil, err
}
defer func(opt Option) {
if errr := opt.AfterInitialize(db); errr != nil {
err = errr
}
}(opt)
}
}
if d, ok := dialector.(interface{ Apply(*Config) error }); ok {
if err = d.Apply(config); err != nil {
return
}
}
if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{}
}
if config.Logger == nil {
config.Logger = logger.Default
}
if config.NowFunc == nil {
config.NowFunc = func() time.Time { return time.Now().Local() }
}
if dialector != nil {
config.Dialector = dialector
}
if config.Plugins == nil {
config.Plugins = map[string]Plugin{}
}
if config.cacheStore == nil {
config.cacheStore = &sync.Map{}
}
db = &DB{Config: config, clone: 1}
db.callbacks = initializeCallbacks(db)
if config.ClauseBuilders == nil {
config.ClauseBuilders = map[string]clause.ClauseBuilder{}
}
if config.Dialector != nil {
err = config.Dialector.Initialize(db)
}
preparedStmt := &PreparedStmtDB{
ConnPool: db.ConnPool,
Stmts: map[string]Stmt{},
Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100),
}
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
if config.PrepareStmt {
db.ConnPool = preparedStmt
}
db.Statement = &Statement{
DB: db,
ConnPool: db.ConnPool,
Context: context.Background(),
Clauses: map[string]clause.Clause{},
}
if err == nil && !config.DisableAutomaticPing {
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
err = pinger.Ping()
}
}
if err != nil {
config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
}
return
}
// Session create new db session
func (db *DB) Session(config *Session) *DB {
var (
txConfig = *db.Config
tx = &DB{
Config: &txConfig,
Statement: db.Statement,
Error: db.Error,
clone: 1,
}
)
if config.CreateBatchSize > 0 {
tx.Config.CreateBatchSize = config.CreateBatchSize
}
if config.SkipDefaultTransaction {
tx.Config.SkipDefaultTransaction = true
}
if config.AllowGlobalUpdate {
txConfig.AllowGlobalUpdate = true
}
if config.FullSaveAssociations {
txConfig.FullSaveAssociations = true
}
if config.Context != nil || config.PrepareStmt || config.SkipHooks {
tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx
}
if config.Context != nil {
tx.Statement.Context = config.Context
}
if config.PrepareStmt {
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
preparedStmt := v.(*PreparedStmtDB)
tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool,
Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts,
}
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
}
}
if config.SkipHooks {
tx.Statement.SkipHooks = true
}
if config.DisableNestedTransaction {
txConfig.DisableNestedTransaction = true
}
if !config.NewDB {
tx.clone = 2
}
if config.DryRun {
tx.Config.DryRun = true
}
if config.QueryFields {
tx.Config.QueryFields = true
}
if config.Logger != nil {
tx.Config.Logger = config.Logger
}
if config.NowFunc != nil {
tx.Config.NowFunc = config.NowFunc
}
return tx
}
// WithContext change current instance db's context to ctx
func (db *DB) WithContext(ctx context.Context) *DB {
return db.Session(&Session{Context: ctx})
}
// Debug start debug mode
func (db *DB) Debug() (tx *DB) {
return db.Session(&Session{
Logger: db.Logger.LogMode(logger.Info),
})
}
// Set store value with key into current db instance's context
func (db *DB) Set(key string, value interface{}) *DB {
tx := db.getInstance()
tx.Statement.Settings.Store(key, value)
return tx
}
// Get get value with key from current db instance's context
func (db *DB) Get(key string) (interface{}, bool) {
return db.Statement.Settings.Load(key)
}
// InstanceSet store value with key into current db instance's context
func (db *DB) InstanceSet(key string, value interface{}) *DB {
tx := db.getInstance()
tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value)
return tx
}
// InstanceGet get value with key from current db instance's context
func (db *DB) InstanceGet(key string) (interface{}, bool) {
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
}
// Callback returns callback manager
func (db *DB) Callback() *callbacks {
return db.callbacks
}
// AddError add error to db
func (db *DB) AddError(err error) error {
if db.Error == nil {
db.Error = err
} else if err != nil {
db.Error = fmt.Errorf("%v; %w", db.Error, err)
}
return db.Error
}
// DB returns `*sql.DB`
func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}
if sqldb, ok := connPool.(*sql.DB); ok {
return sqldb, nil
}
return nil, ErrInvalidDB
}
func (db *DB) getInstance() *DB {
if db.clone > 0 {
tx := &DB{Config: db.Config, Error: db.Error}
if db.clone == 1 {
// clone with new statement
tx.Statement = &Statement{
DB: tx,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
Vars: make([]interface{}, 0, 8),
}
} else {
// with clone statement
tx.Statement = db.Statement.clone()
tx.Statement.DB = tx
}
return tx
}
return db
}
func Expr(expr string, args ...interface{}) clause.Expr {
return clause.Expr{SQL: expr, Vars: args}
}
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
var (
tx = db.getInstance()
stmt = tx.Statement
modelSchema, joinSchema *schema.Schema
)
if err := stmt.Parse(model); err == nil {
modelSchema = stmt.Schema
} else {
return err
}
if err := stmt.Parse(joinTable); err == nil {
joinSchema = stmt.Schema
} else {
return err
}
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
for _, ref := range relation.References {
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
f.DataType = ref.ForeignKey.DataType
f.GORMDataType = ref.ForeignKey.GORMDataType
if f.Size == 0 {
f.Size = ref.ForeignKey.Size
}
ref.ForeignKey = f
} else {
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
}
}
for name, rel := range relation.JoinTable.Relationships.Relations {
if _, ok := joinSchema.Relationships.Relations[name]; !ok {
rel.Schema = joinSchema
joinSchema.Relationships.Relations[name] = rel
}
}
relation.JoinTable = joinSchema
} else {
return fmt.Errorf("failed to found relation: %v", field)
}
return nil
}
func (db *DB) Use(plugin Plugin) error {
name := plugin.Name()
if _, ok := db.Plugins[name]; ok {
return ErrRegistered
}
if err := plugin.Initialize(db); err != nil {
return err
}
db.Plugins[name] = plugin
return nil
}

View File

@ -1,20 +0,0 @@
package gorm
import "database/sql"
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
type SQLCommon interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type sqlDb interface {
Begin() (*sql.Tx, error)
}
type sqlTx interface {
Commit() error
Rollback() error
}

63
interfaces.go Normal file
View File

@ -0,0 +1,63 @@
package gorm
import (
"context"
"database/sql"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
// Dialector GORM database dialector
type Dialector interface {
Name() string
Initialize(*DB) error
Migrator(db *DB) Migrator
DataTypeOf(*schema.Field) string
DefaultValueOf(*schema.Field) clause.Expression
BindVarTo(writer clause.Writer, stmt *Statement, v interface{})
QuoteTo(clause.Writer, string)
Explain(sql string, vars ...interface{}) string
}
// Plugin GORM plugin interface
type Plugin interface {
Name() string
Initialize(*DB) error
}
// ConnPool db conns pool interface
type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
// SavePointerDialectorInterface save pointer interface
type SavePointerDialectorInterface interface {
SavePoint(tx *DB, name string) error
RollbackTo(tx *DB, name string) error
}
type TxBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
type ConnPoolBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
}
type TxCommitter interface {
Commit() error
Rollback() error
}
// Valuer gorm valuer interface
type Valuer interface {
GormValue(context.Context, *DB) clause.Expr
}
type GetDBConnector interface {
GetDBConn() (*sql.DB, error)
}

View File

@ -1,223 +0,0 @@
package gorm
import (
"errors"
"fmt"
"reflect"
"strings"
)
// JoinTableHandlerInterface is an interface for how to handle many2many relations
type JoinTableHandlerInterface interface {
// initialize join table handler
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
// Table return join table's table name
Table(db *DB) string
// Add create relationship in join table for source and destination
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
// Delete delete relationship in join table for sources
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
// JoinWith query with `Join` conditions
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
// SourceForeignKeys return source foreign keys
SourceForeignKeys() []JoinTableForeignKey
// DestinationForeignKeys return destination foreign keys
DestinationForeignKeys() []JoinTableForeignKey
}
// JoinTableForeignKey join table foreign key struct
type JoinTableForeignKey struct {
DBName string
AssociationDBName string
}
// JoinTableSource is a struct that contains model type and foreign keys
type JoinTableSource struct {
ModelType reflect.Type
ForeignKeys []JoinTableForeignKey
}
// JoinTableHandler default join table handler
type JoinTableHandler struct {
TableName string `sql:"-"`
Source JoinTableSource `sql:"-"`
Destination JoinTableSource `sql:"-"`
}
// SourceForeignKeys return source foreign keys
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
return s.Source.ForeignKeys
}
// DestinationForeignKeys return destination foreign keys
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
return s.Destination.ForeignKeys
}
// Setup initialize a default join table handler
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
s.TableName = tableName
s.Source = JoinTableSource{ModelType: source}
s.Source.ForeignKeys = []JoinTableForeignKey{}
for idx, dbName := range relationship.ForeignFieldNames {
s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
DBName: relationship.ForeignDBNames[idx],
AssociationDBName: dbName,
})
}
s.Destination = JoinTableSource{ModelType: destination}
s.Destination.ForeignKeys = []JoinTableForeignKey{}
for idx, dbName := range relationship.AssociationForeignFieldNames {
s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
DBName: relationship.AssociationForeignDBNames[idx],
AssociationDBName: dbName,
})
}
}
// Table return join table's table name
func (s JoinTableHandler) Table(db *DB) string {
return s.TableName
}
func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
values := map[string]interface{}{}
for _, source := range sources {
scope := db.NewScope(source)
modelType := scope.GetModelStruct().ModelType
if s.Source.ModelType == modelType {
for _, foreignKey := range s.Source.ForeignKeys {
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
values[foreignKey.DBName] = field.Field.Interface()
}
}
} else if s.Destination.ModelType == modelType {
for _, foreignKey := range s.Destination.ForeignKeys {
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
values[foreignKey.DBName] = field.Field.Interface()
}
}
}
}
return values
}
// Add create relationship in join table for source and destination
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
scope := db.NewScope("")
searchMap := map[string]interface{}{}
// getSearchMap() cannot be used here since the source and destination
// model types may be identical
sourceScope := db.NewScope(source)
for _, foreignKey := range s.Source.ForeignKeys {
if field, ok := sourceScope.FieldByName(foreignKey.AssociationDBName); ok {
searchMap[foreignKey.DBName] = field.Field.Interface()
}
}
destinationScope := db.NewScope(destination)
for _, foreignKey := range s.Destination.ForeignKeys {
if field, ok := destinationScope.FieldByName(foreignKey.AssociationDBName); ok {
searchMap[foreignKey.DBName] = field.Field.Interface()
}
}
var assignColumns, binVars, conditions []string
var values []interface{}
for key, value := range searchMap {
assignColumns = append(assignColumns, scope.Quote(key))
binVars = append(binVars, `?`)
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
values = append(values, value)
}
for _, value := range values {
values = append(values, value)
}
quotedTable := scope.Quote(handler.Table(db))
sql := fmt.Sprintf(
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
quotedTable,
strings.Join(assignColumns, ","),
strings.Join(binVars, ","),
scope.Dialect().SelectFromDummyTable(),
quotedTable,
strings.Join(conditions, " AND "),
)
return db.Exec(sql, values...).Error
}
// Delete delete relationship in join table for sources
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
var (
scope = db.NewScope(nil)
conditions []string
values []interface{}
)
for key, value := range s.getSearchMap(db, sources...) {
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
values = append(values, value)
}
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
}
// JoinWith query with `Join` conditions
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
var (
scope = db.NewScope(source)
tableName = handler.Table(db)
quotedTableName = scope.Quote(tableName)
joinConditions []string
values []interface{}
)
if s.Source.ModelType == scope.GetModelStruct().ModelType {
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
for _, foreignKey := range s.Destination.ForeignKeys {
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
}
var foreignDBNames []string
var foreignFieldNames []string
for _, foreignKey := range s.Source.ForeignKeys {
foreignDBNames = append(foreignDBNames, foreignKey.DBName)
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
foreignFieldNames = append(foreignFieldNames, field.Name)
}
}
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
var condString string
if len(foreignFieldValues) > 0 {
var quotedForeignDBNames []string
for _, dbName := range foreignDBNames {
quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName)
}
condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
values = append(values, toQueryValues(keys))
} else {
condString = fmt.Sprintf("1 <> 1")
}
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
Where(condString, toQueryValues(foreignFieldValues)...)
}
db.Error = errors.New("wrong source type for join table handler")
return db
}

View File

@ -1,117 +0,0 @@
package gorm_test
import (
"fmt"
"strconv"
"testing"
"time"
"github.com/jinzhu/gorm"
)
type Person struct {
Id int
Name string
Addresses []*Address `gorm:"many2many:person_addresses;"`
}
type PersonAddress struct {
gorm.JoinTableHandler
PersonID int
AddressID int
DeletedAt *time.Time
CreatedAt time.Time
}
func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error {
foreignPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(foreignValue).PrimaryKeyValue()))
associationPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(associationValue).PrimaryKeyValue()))
if result := db.Unscoped().Model(&PersonAddress{}).Where(map[string]interface{}{
"person_id": foreignPrimaryKey,
"address_id": associationPrimaryKey,
}).Update(map[string]interface{}{
"person_id": foreignPrimaryKey,
"address_id": associationPrimaryKey,
"deleted_at": gorm.Expr("NULL"),
}).RowsAffected; result == 0 {
return db.Create(&PersonAddress{
PersonID: foreignPrimaryKey,
AddressID: associationPrimaryKey,
}).Error
}
return nil
}
func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error {
return db.Delete(&PersonAddress{}).Error
}
func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB {
table := pa.Table(db)
return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
}
func TestJoinTable(t *testing.T) {
DB.Exec("drop table person_addresses;")
DB.AutoMigrate(&Person{})
DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{})
address1 := &Address{Address1: "address 1"}
address2 := &Address{Address1: "address 2"}
person := &Person{Name: "person", Addresses: []*Address{address1, address2}}
DB.Save(person)
DB.Model(person).Association("Addresses").Delete(address1)
if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 1 {
t.Errorf("Should found one address")
}
if DB.Model(person).Association("Addresses").Count() != 1 {
t.Errorf("Should found one address")
}
if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 2 {
t.Errorf("Found two addresses with Unscoped")
}
if DB.Model(person).Association("Addresses").Clear(); DB.Model(person).Association("Addresses").Count() != 0 {
t.Errorf("Should deleted all addresses")
}
}
func TestEmbeddedMany2ManyRelationship(t *testing.T) {
type EmbeddedPerson struct {
ID int
Name string
Addresses []*Address `gorm:"many2many:person_addresses;"`
}
type NewPerson struct {
EmbeddedPerson
ExternalID uint
}
DB.Exec("drop table person_addresses;")
DB.AutoMigrate(&NewPerson{})
address1 := &Address{Address1: "address 1"}
address2 := &Address{Address1: "address 2"}
person := &NewPerson{ExternalID: 100, EmbeddedPerson: EmbeddedPerson{Name: "person", Addresses: []*Address{address1, address2}}}
if err := DB.Save(person).Error; err != nil {
t.Errorf("no error should return when save embedded many2many relationship, but got %v", err)
}
if err := DB.Model(person).Association("Addresses").Delete(address1).Error; err != nil {
t.Errorf("no error should return when delete embedded many2many relationship, but got %v", err)
}
association := DB.Model(person).Association("Addresses")
if count := association.Count(); count != 1 || association.Error != nil {
t.Errorf("Should found one address, but got %v, error is %v", count, association.Error)
}
if association.Clear(); association.Count() != 0 {
t.Errorf("Should deleted all addresses")
}
}

119
logger.go
View File

@ -1,119 +0,0 @@
package gorm
import (
"database/sql/driver"
"fmt"
"log"
"os"
"reflect"
"regexp"
"strconv"
"time"
"unicode"
)
var (
defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
sqlRegexp = regexp.MustCompile(`\?`)
numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`)
)
func isPrintable(s string) bool {
for _, r := range s {
if !unicode.IsPrint(r) {
return false
}
}
return true
}
var LogFormatter = func(values ...interface{}) (messages []interface{}) {
if len(values) > 1 {
var (
sql string
formattedValues []string
level = values[0]
currentTime = "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m"
source = fmt.Sprintf("\033[35m(%v)\033[0m", values[1])
)
messages = []interface{}{source, currentTime}
if level == "sql" {
// duration
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
// sql
for _, value := range values[4].([]interface{}) {
indirectValue := reflect.Indirect(reflect.ValueOf(value))
if indirectValue.IsValid() {
value = indirectValue.Interface()
if t, ok := value.(time.Time); ok {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05")))
} else if b, ok := value.([]byte); ok {
if str := string(b); isPrintable(str) {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
} else {
formattedValues = append(formattedValues, "'<binary>'")
}
} else if r, ok := value.(driver.Valuer); ok {
if value, err := r.Value(); err == nil && value != nil {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
} else {
formattedValues = append(formattedValues, "NULL")
}
} else {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
}
} else {
formattedValues = append(formattedValues, "NULL")
}
}
// differentiate between $n placeholders or else treat like ?
if numericPlaceHolderRegexp.MatchString(values[3].(string)) {
sql = values[3].(string)
for index, value := range formattedValues {
placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1)
sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1")
}
} else {
formattedValuesLength := len(formattedValues)
for index, value := range sqlRegexp.Split(values[3].(string), -1) {
sql += value
if index < formattedValuesLength {
sql += formattedValues[index]
}
}
}
messages = append(messages, sql)
messages = append(messages, fmt.Sprintf(" \n\033[36;31m[%v]\033[0m ", strconv.FormatInt(values[5].(int64), 10)+" rows affected or returned "))
} else {
messages = append(messages, "\033[31;1m")
messages = append(messages, values[2:]...)
messages = append(messages, "\033[0m")
}
}
return
}
type logger interface {
Print(v ...interface{})
}
// LogWriter log writer interface
type LogWriter interface {
Println(v ...interface{})
}
// Logger default logger
type Logger struct {
LogWriter
}
// Print format & print log
func (logger Logger) Print(values ...interface{}) {
logger.Println(LogFormatter(values...)...)
}

Some files were not shown because too many files have changed in this diff Show More