Compare commits

...

1046 Commits

Author SHA1 Message Date
贾一饼
49eaeacb89
optimize: field.ReflectValueOf (#7530)
Some checks failed
tests / mysql (mysql:5.7, 1.24, ubuntu-latest) (push) Failing after 59s
tests / mysql (mysql:8, 1.23, ubuntu-latest) (push) Failing after 59s
tests / mysql (mysql:8, 1.24, ubuntu-latest) (push) Failing after 5s
tests / sqlite (1.23, ubuntu-latest) (push) Failing after 1m26s
tests / mysql (mysql:9, 1.23, ubuntu-latest) (push) Failing after 36s
tests / mysql (mysql:9, 1.24, ubuntu-latest) (push) Failing after 32s
tests / mariadb (mariadb:latest, 1.24, ubuntu-latest) (push) Failing after 9s
tests / mariadb (mariadb:latest, 1.23, ubuntu-latest) (push) Failing after 22s
tests / postgres (postgres:13, 1.23, ubuntu-latest) (push) Failing after 24s
tests / mysql (mysql:5.7, 1.23, ubuntu-latest) (push) Failing after 2m1s
tests / postgres (postgres:14, 1.23, ubuntu-latest) (push) Failing after 20s
tests / postgres (postgres:14, 1.24, ubuntu-latest) (push) Failing after 6s
tests / postgres (postgres:15, 1.24, ubuntu-latest) (push) Failing after 14s
tests / postgres (postgres:15, 1.23, ubuntu-latest) (push) Failing after 29s
tests / postgres (postgres:latest, 1.23, ubuntu-latest) (push) Failing after 26s
tests / postgres (postgres:latest, 1.24, ubuntu-latest) (push) Failing after 18s
tests / postgres (postgres:13, 1.24, ubuntu-latest) (push) Failing after 2m55s
tests / sqlite (1.24, ubuntu-latest) (push) Failing after 4m53s
tests / tidb (v6.5.0, 1.24, ubuntu-latest) (push) Failing after 2m6s
golangci-lint / lint (push) Failing after 6m41s
tests / gaussdb (opengauss/opengauss:7.0.0-RC1.B023, 1.23, ubuntu-latest) (push) Failing after 3m31s
tests / gaussdb (opengauss/opengauss:7.0.0-RC1.B023, 1.24, ubuntu-latest) (push) Failing after 3m38s
tests / sqlserver (1.23, ubuntu-latest) (push) Failing after 7m36s
tests / tidb (v6.5.0, 1.23, ubuntu-latest) (push) Failing after 12m41s
tests / sqlserver (1.24, ubuntu-latest) (push) Failing after 14m12s
Stale / stale (push) Successful in 13s
Close Missing Playground issues / stale (push) Successful in 6s
Close invalid questions issues / stale (push) Successful in 6s
2025-07-23 13:02:49 +08:00
贾一饼
52b4744410
optimize: performance optimization (#7526) 2025-07-22 18:32:28 +08:00
Jinzhu
9af6d510b5 Fix query when map keys include table-qualified column names, close #7507 2025-07-22 14:21:04 +08:00
Jinzhu
c63374f5d1 Don't request LastInsertID from database if not necessary, close #7469 2025-07-21 17:55:20 +08:00
jc
b9c7e562b0
fix(schema): check the hook function parameter type (#7468)
* fix(schema): Check the callback function parameter type

* fix log

* fix
2025-07-21 17:06:16 +08:00
Riseif
985940f0d8
should check inner condition length (#7512) 2025-07-21 11:57:12 +08:00
moseszane168
991c2d4891
Add GaussDB Database Support (#7508)
* support gaussdb

* use github CI

* change function name

* use gorm.io/driver/gaussdb

---------

Co-authored-by: bing.ma <bing.ma@daocloud.io>
2025-07-21 10:46:58 +08:00
Jinzhu
751a6dde7a
Call after initialize for gorm.Config (#7518) 2025-07-15 12:05:03 +08:00
贾一饼
2f4925e017
A little optimization for filed.ValueOf (#7499)
Co-authored-by: 贾一饼 <Boyang.Liang@apulis.com>
2025-07-07 11:15:10 +08:00
Eshan-Jogwar
1e8baf5459
fixes #7486 (#7492)
* fixes #7486

* Added A test case for subset changes of model

* completed the test file for check_subset_model_change_test.go
2025-06-25 11:11:08 +08:00
Salent Olivick
842ee527eb
fix decimal migrate error.(#7450) (#7450)
Signed-off-by: Chise1 <chise123@live.com>
2025-06-06 10:35:01 +08:00
enomotodev
23c0d7cf05
test: update MySQL test matrix to use official images and add 9.0, 8.4 versions (#7476)
* test: update MySQL test matrix to use official images and add 9.0, 8.4 versions

* test: use major version tags for MySQL test matrix
2025-06-06 10:10:23 +08:00
Jinzhu
718eae4fdd fix tests for mysql 9.0 2025-06-05 19:34:13 +08:00
Jinzhu
49b01a3e93 Fix Generics Scan, close https://github.com/go-gorm/playground/pull/803 2025-05-29 14:23:57 +08:00
Jinzhu
c44405a25b
Implement Generics API (#7424)
* Implement Generics API

* Add more generics tests

* Add more tests and Take method

* use delayed‑ops pipeline for generics API

* fix generics tests for mysql

* Support SubQuery for Generics

* Add clause.JoinTable helper method

* Fix golangci-lint error

* Complete the design and implementation of generic version Join

* improve generics version Joins support

* allow configuring select/omit columns for joins via subqueries

* finish generic version Preload

* handle error of generics Joins/Preload

* fix tests

* Add LimitPerRecord for generic version Preload

* fix tests for mysql 5.7

* test for nested generic version Join/Preload

* Add WithResult support for generics API

* test reuse generics db conditions

* fix data race

* remove ExampleLRU test

* Add default transaction timeout support

* fix test
2025-05-25 15:40:40 +08:00
Name
751c1d6b45
perf(schema): avoid redundant strings.ToLower call (#7464)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2025-05-25 09:27:21 +08:00
codingplz
8e7ab46c1b
fix: return init dialector error (#7379)
* fix: return init dialector error

* mock defer

* fix: skip AfterInitialize

---------

Co-authored-by: wenyazhou.13 <wenyazhou.13@bytedance.com>
2025-05-22 10:53:47 +08:00
Name
e3037e4ef0
perf: break early on match failure in ParseConstraint (#7402)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2025-05-22 10:49:19 +08:00
pipipipipip
1204330419
feat: error message show field name (#7452)
* feat: error message show field name

* feat: failed to parse field

* feat: failed to parse field

---------

Co-authored-by: zyp <zyp>
2025-05-21 11:13:31 +08:00
Name
9703eb775f
perf: use strings.IndexByte to replace strings.Index (#7454)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2025-05-21 10:35:56 +08:00
Name
1c966e0d25
perf: use strings.Cut to replace strings.SplitN (#7455)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2025-05-21 10:35:23 +08:00
Jinzhu
e5b867e785 remove unnecessary session-level configuration for prepared statements 2025-05-07 14:56:49 +08:00
iTanken
8c4e8e2d2a
fix: int type variable defaultMaxSize overflows in 32-bit environment (#7439)
Refs: #7435
2025-04-27 14:05:16 +08:00
Zhaodong Xie
a827495be1
Preparestmt use LRU Map instead default map (#7435)
* 支持lru淘汰preparestmt cache

* 支持lru淘汰preparestmt cache

* 支持lru淘汰preparestmt cache

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* change const export

* Add stmt_store

* refact prepare stmt store

* Rename lru store

* change const export

* ADD UT

* format code and add session level prepare stmt config

* code format according to golinter ci

* ADD UT

---------

Co-authored-by: xiezhaodong <xiezhaodong@bytedance.com>
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2025-04-25 16:22:26 +08:00
Jinzhu
489a563293 only check new issues for golangci linter 2025-04-17 15:30:17 +08:00
Jinzhu
42bd4f603c
use golangci replace reviewdog (#7426)
* use golangci replace reviewdog

* Update golangci config
2025-04-17 11:55:13 +08:00
a631807682
a9d27293de test: mssql ci 2025-03-11 15:56:46 +08:00
a631807682
3876ffe4bb test: mssql ci 2025-03-11 15:56:46 +08:00
Jinzhu
ee3b549d7d Update tests script 2025-03-09 15:35:17 +08:00
Aman
9f273777f5
fix deprecated reflect.PtrTo reflect.PointerTo usage (#7366)
* fix deprecated reflect.PtrTo reflect.PointerTo usage

* replace all deprecated reflect.PtrTo reflect.PointerTo usage
2025-02-13 14:16:26 +08:00
Vladimir Avtsenov
9ca84b3dde
fix concurrent map writes (#7298) 2025-01-12 19:49:06 +08:00
Max Katz
86b1d22911
chore: update copyright year (#7332) 2025-01-12 18:32:39 +08:00
Evyatar Yaffe
fed49230cb
Enhance db.Scan with ParamsFilter - Issue 7336 - Suggestion (#7337) 2025-01-12 18:19:28 +08:00
aviyam181199
8503287ca4
Fixed Empty Returning Clause Merge Bug (#7339) 2025-01-12 18:18:04 +08:00
nowindexman
4ef3af10ed
feat:Capitalize the priority field of IndexOption so that other systems can access this field from outside the package. (#7342) 2025-01-12 18:16:48 +08:00
Bennett Amodio
f482f25c71
fix: deterministic index ordering when migrating (#7208)
Issue: We observed that, when creating a database based on the same
gORM schema multiple times, indexes could appear in different orders,
hurting determinism for use-cases like schema comparison.

In order to fix this, it's simple to switch the ParseIndexes function
to return a list of indices rather than a map, so the callers will
iterate in deterministic order.
2024-12-06 10:27:44 +08:00
Jinzhu
6bfccf8afa Refactor all tests script 2024-11-21 17:03:31 +08:00
Ivan Ryabov
49bbaa637f
use map look-up for indexes (#7242) 2024-11-14 17:41:43 +08:00
홍성욱
b0d70a26d1
[#6372] Fixed nullable constraint bug for columns during auto migration (#7269)
* [#6372] Fixed nullable constraint bug for columns during auto migration

* [#6372] fix comment

* [#6372] Add test code

* [#6372] Add test code

* [#6372] Fix failed test case

* [#6372] Fix failed test case

* [#6372] wip

* [#6372] wip

* [#6372] wip

* [#6372] wip
2024-11-14 17:40:18 +08:00
omid fth
deceebfab8
Create CODE_OF_CONDUCT.md (#7240) 2024-10-17 14:18:13 +08:00
Yidi Sprei
52e3b353eb
refactor(workflow): update release workflow to enhance automation (#7224)
Replaced the old release workflow with a new setup using Release
Drafter. This refactor allows for more detailed release notes by
categorizing changes and automatically generating release drafts.
The new workflow triggers on semantic version tags and improves
permissions management. This change enhances the release process
by providing better documentation and automation.
2024-10-09 19:31:04 +08:00
Hansu Park
8020e8c166
refactor: improve logging for unimplemented ErrorTranslator in TranslateError config (#7225) 2024-10-09 19:29:48 +08:00
Yidi Sprei
62bd0b9331
Add GitHub Actions workflow to automate release creation on tagged pushes (#7209)
* feat(workflows): add GitHub Action to create release on new tag

This workflow automates the release creation process whenever a new
tag is pushed to the repository. It checks if a release for the tag
already exists and creates one if it doesn't, enhancing the release
management and streamlining the deployment process.

* test

* fix(workflow): improve release existence check in release-on-tag.yml

Refactor the script to improve checking for existing releases by tag.
Return an object instead of using core.setOutput to streamline the
workflow logic. Also, set result-encoding to string for better
compatibility.

* fix(workflow): correct output handling in release-on-tag.yml

Corrected the output handling in the 'release-on-tag.yml' workflow file
by changing 'result-encoding' to 'outputs'. This ensures that the step
correctly checks if a release exists before attempting to create a new
one, thereby avoiding potential errors during the release process.
Added a blank line for better readability.

* fix(workflow): correct output setting in release-on-tag workflow

Refactored how the release_exists output is set to be compatible with
GitHub Actions syntax. This ensures the workflow reliably detects if
a release already exists, improving the robustness of the release
process.

* fix(release): correct output handling in release-on-tag workflow

Improved the way outputs are managed in the 'check_release' step by
returning values instead of direct assignments. This change ensures
better handling of release existence checks and improves code
readability. Added 'result-encoding' to specify string encoding for
results.

* fix(workflow): correct release existence check and add debug output

Refactored the release existence check to return a simple boolean
string ('true'/'false') rather than an object. Added a step to print
the release existence status for debugging purposes. This ensures
correct conditional evaluation and aids in troubleshooting workflow
issues.

* fix(workflow): simplify release process by removing redundant checks

The release-on-tag workflow has been streamlined by eliminating the
redundant steps checking for existing releases. This change reduces
complexity and speeds up the release process by directly creating a
release on every tag push without prior existence verification.
2024-09-30 14:18:13 +08:00
Omkar P
c6ac54812a
Use official SQL Server docker image for tests (#7205)
* Use official SQL Server docker image for tests

* Try with tag 2019-latest instead of latest

* Use platform ubuntu-20.04 for SQL Server

* Switch to 2019-CU18-ubuntu-20.04

* Check with 2022-latest image tag and ubuntu-latest platform

* Update health-cmd

* Try sqlcmd without -N -C

* Re-include -N -C, try with ubuntu-20.04

* Try ubuntu-20.04 without -N -C (last trial)

* Finalize working config

* Remove unused env variables
2024-09-30 11:21:19 +08:00
Jinzhu
68434b76eb fix test script with latest docker release 2024-09-18 20:03:35 +08:00
Kazuki Isogai
c2515ce260
feat: remove version top-level element and rename. (#7086) 2024-09-18 19:42:52 +08:00
Leo Sjöberg
7f75b12bb2
Generate unique savepoint names for nested transactions (#7174)
* Generate unique savepoint names

* Add a test for deeply nested wrapped transactions
2024-09-14 20:58:29 +08:00
abhijeet45
0daaf1747c
fix: AfterQuery using safer right trim while clearing from clause's join added as part of https://github.com/go-gorm/gorm/pull/7027 (#7153)
Co-authored-by: Abhijeet Bhowmik <abhijeet.bhowmik@cambiumnetworks.com>
2024-08-22 19:03:42 +08:00
ivila
0dbfda5d7e
fix memory leaks in PrepareStatementDB (#7142)
* fix memory leaks in PrepareStatementDB

* Fix CR:
1) Fix potential Segmentation Fault in Reset function
2) Setting db.Stmts to nil map when Close to avoid further using

* Add Test:
1) TestPreparedStmtConcurrentReset
2) TestPreparedStmtConcurrentClose

* Fix test, create new connection to keep away from other tests

---------

Co-authored-by: Zehui Chen <zehui@ssc-hn.com>
2024-08-22 19:02:05 +08:00
enomotodev
4a50b36f63
ci: Add PostgreSQL 14 and 15 to GitHub Actions matrix (#7081)
* ci: Add PostgreSQL 14 and 15 to GitHub Actions matrix

* ci: Remove older PostgreSQL versions from test matrix
2024-06-25 10:36:06 +08:00
molon
11c4331058
feat: add MapColumns method (#6901)
* add MapColumns method

* fix MapColumns desc

* add TestMapColumns
2024-06-24 17:42:59 +08:00
Jinzhu
8a0af58cc5 fix map fields with clickhouse driver 2024-06-20 20:19:31 +08:00
Jinzhu
4f6291154b Allow to support other field types 2024-06-20 16:45:38 +08:00
Sérgio Prata Almeida
109f239fae
add DB level propagation for the Unscoped flag (#7007)
* adds PropagateUnscoped to db Config

* adds PropagateUnscoped test

* adds PropagateUnscoped to Session and sets it accordingly
2024-06-17 11:59:06 +08:00
Jinzhu
79bf7f92ed fix CI for sqlserver 2024-06-17 11:58:13 +08:00
Jinzhu
3d09f7947f only listen local port 2024-06-13 18:45:02 +08:00
Waleed Masoom
73a988ceb2
fix(scan): update Scan function to reset structs to zero values for each scan (#7061)
Co-authored-by: waleed.masoom <waleed.masoom@wheniwork.com>
2024-06-12 18:57:36 +08:00
Emilien
05167fd591
fix: use reflect.Append when preloading nested associations (#7014)
Co-authored-by: Emilien Kofman <emilien.kofman@miimosa.com>
2024-06-12 18:52:33 +08:00
Sergei Sadov
78c6dfd712
Fix association replace non-addressable panic (#7012)
* Fix association replace non-addressable panic

* Fix tests

* Add has one panic test

---------

Co-authored-by: sgsv <->
2024-06-12 18:49:45 +08:00
Nico Schäfer
3fe7fcf356
fix: unsupported data on nested joins with preloads (#6957)
* fix: `unsupported data` on nested joins with preloads

* Add test case for pointer join with nested prelaods

* Fix tests
2024-06-12 18:00:47 +08:00
贾一饼
9c4070ed19
fix: AfterQuery should clear FROM Clause's Joins rather than the Statement (#7027) 2024-06-12 17:51:44 +08:00
supergem3000
49d524aaea
feat: chainable order support clause.OrderBy (#7054)
* feat: chainable order support clause.OrderBy

* indent
2024-06-12 17:47:34 +08:00
Jinzhu
49d94c173c upgrade github action tests template 2024-06-12 17:24:34 +08:00
Ryuji Kokubu
0f105ec163
fix: strings.Title -> cases.Title bcs strings.Title library is deprecated (#6999)
* add: add cases library

* fix: strings.Title -> cases.Title

* run goimports to solve the error
2024-06-12 11:46:59 +08:00
hakusai22
5e599a07ec
fix: typo (#7003)
* fix: typo

* fix: covered
2024-05-08 12:07:58 +08:00
Wenli Looi
9d370bcb3e
Fix handling of unknown column types (#6540) 2024-04-26 17:53:11 +08:00
PiexlMax(奇淼
78920199f0
Fix panic bug in migrator due to lack of nil check for stmt.Schema (#6932) 2024-04-26 15:15:49 +08:00
Anıl Şenay
ac59252327
Add new error for "Violation Check Constraint" (#6992) 2024-04-26 10:53:17 +08:00
Cr
207f1ac68f
fix: not clause with or condition (#6984) 2024-04-25 20:22:53 +08:00
Cr
85299bfca7
perf: merge nested preload query when using join (#6990)
* pref: merge nest preload query

* fix: preload test
2024-04-25 20:21:03 +08:00
Jinzhu
5553ff3dcb downgrade mssql driver 2024-04-25 20:12:15 +08:00
kkocdko
bc49365de2
Faster utils.FileWithLineNum (#6981)
* faster FileWithLineNum

* tweak caller skip count
2024-04-22 14:43:02 +08:00
Ivan Chavez
d0b4ceb726
Added comment describing Unscoped() method (#6969) 2024-04-17 11:38:55 +08:00
yetone
9a61ef2af8
fix: duplicated preload (#6948) 2024-04-15 11:20:20 +08:00
Jinzhu
1e13fd7543 Fix duplicated columns in INSERT SQL for some fields with default value 2024-04-08 11:29:55 +08:00
hjwblog.com
1b48aa072d
feat: prepare_stmt support ping (#6924)
* feat: prepare_stmt support ping

* feat: prepare_stmt tx support ping
2024-03-28 16:47:39 +08:00
snackmgmg
26195e6d16
fix: remove callback from callbacks if Remove() called (#6916)
* fix: remove callback from callbacks if Remove() called

* reduce number of loops

* remove unnecessary blank line
2024-03-26 11:33:36 +08:00
givemeafish
956f7ce843
fix: 'type XXXX int' will print wrong sql to terminal (#6917)
Co-authored-by: 王泽平 <zeping.wang@yo-star.com>
2024-03-21 16:00:02 +08:00
Jinzhu
0d6c5345f3 Don't close prepared stmt for normal db error 2024-03-21 15:55:43 +08:00
Jinzhu
57603882ea Only close bad conn prepared stmt 2024-03-20 19:47:20 +08:00
Jinzhu
81536f823c Fix insert id into map results, fix #6812 2024-03-19 11:50:28 +08:00
Jinzhu
1b0aa802df Fix AutoMigrate for bool fields with default value 2024-03-18 19:24:16 +08:00
Jinzhu
e0c3be03fb Fix tests in local 2024-03-18 16:28:46 +08:00
jessetang
303de6e7c8
chore: optimize regEnLetterAndMidline regular (#6908)
* chore: optimize regular

* fix
2024-03-18 15:33:54 +08:00
Jinghao Lu
f7ebf049da
fix(create): fix insert column order (#6855)
* fix(create): fix insert column order

* chore: add ConvertToCreateValues ut for Slice case

* fix: remvoe testify dependency

---------

Co-authored-by: lujinghao <lujinghao@bytedance.com>
2024-03-18 13:48:42 +08:00
jessetang
ab89d54d87
chore: UnixNano convert to UnixMilli (#6907) 2024-03-18 13:44:55 +08:00
Jinzhu
281f3e369a Fix constraint name regexp 2024-03-18 11:32:30 +08:00
jessetang
7b1fb0bd73
fix(scan): array element is set to a zero value (#6890)
* fix(scan): array element is set to a zero value

* add test

* fix test

* optimization
2024-03-15 14:14:48 +08:00
black-06
e4e23d26d2
fix: nested preload with join panic when find (#6877) 2024-03-09 21:27:19 +08:00
jessetang
c4c9aa45e3
fix(scan.go): reflect.MakeSlice passes in the reflect.Array type (#6880) 2024-03-09 17:39:01 +08:00
Cr
9efae659cb
test: namer identifier lenght (#6872) 2024-03-09 17:31:28 +08:00
hishope
f17a75242e Signed-off-by: hishope <csqiye@126.com>
fix some typos in tests

Signed-off-by: hishope <csqiye@126.com>
2024-03-07 16:19:17 +08:00
tsuba3
3e2c4fc446
Fix regression in db.Not introduced in v1.25.6. (#6844)
* Fix regression in db.Not introduced in 940358e.

* Fix
2024-03-05 10:23:51 +08:00
Chef
f118e55db5
Add unittest test helper function ConvertSliceOfMapToValuesForCreate (#6854) 2024-03-05 10:22:57 +08:00
Chef
52404cddbb
CHORE add unittest test function ConvertMapToValueForCreate (#6846)
* CHORE add unittest test function  ConvertMapToValueForCreate

* CHORE move the test cases located in the files convert_map_test.go and visit_map_test.go into the file helper_test.go.
2024-02-27 10:48:04 +08:00
M Dmitry
d81ae6f701
Fixed: panic on nullable value with multiple foreign key usage (#6839)
See: https://github.com/go-gorm/playground/pull/537
2024-02-19 11:42:25 +08:00
black-06
8fb9a31775
refactor: part 2 of distinguish between Unique and UniqueIndex (#6822) 2024-02-06 19:48:40 +08:00
jasonchuan
9514d5f9e6
let limit and offset use bind parameter (#6806)
* let limit and offset use bind parameter

* format

* format limt_test

* try again

* fix test case fro connpool

* adding driverName for postgres  ,if not to do so, the stmt vars will be added  a wrong  one called pgx.QueryExecModeSimpleProtocol  ,  causing the SQL with limit  problem  need 1 parameter ,but given two.

* delete trunk files

* restore the test_test.go

* restore test_test.go

* driver/postgres->v1.5.5

* change postgres version rollback to 1.5.4

---------

Co-authored-by: chenchuan <chenchuan@360.cn>
Co-authored-by: jason_chuan <jason_chuan@126.com>
2024-02-06 10:54:40 +08:00
black-06
46816ad31d
refactor: distinguish between Unique and UniqueIndex (#6386)
* refactor: distinguish between UniqueIndex and Index

* add test

* add ParseIndex test

* modify unique to constraint

* modify unique to constraint

* fix MigrateColumnUnique

* fix test

* fix unit test

* update test mod

* add MigrateColumnUnique to Migrator interface

* fix format lint

* add comment

* go mod tidy

* revert: revert MigrateColumn

* resolve conflicts
2024-02-04 15:49:19 +08:00
black-06
418ee3fc19
fix: preload shouldn't overwrite the value of join (#6771)
* fix: preload shouldn't overwrite the value of join

* fix lint

* fix: join may automatically add nested query
2024-01-29 11:34:57 +08:00
dependabot[bot]
e043924fe7
chore(deps): bump actions/cache from 3 to 4 (#6802)
Bumps [actions/cache](https://github.com/actions/cache) from 3 to 4.
- [Release notes](https://github.com/actions/cache/releases)
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md)
- [Commits](https://github.com/actions/cache/compare/v3...v4)

---
updated-dependencies:
- dependency-name: actions/cache
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-01-29 10:34:20 +08:00
Jacky
0123dd4509
fix: ignore .gen.go suffix in logger to get the real caller when using gen #6697 (#6785) 2024-01-12 17:09:22 +08:00
Jinzhu
940358e0dd Fix tests doesn't follow https://gorm.io/docs/method_chaining.html convention 2024-01-12 16:42:21 +08:00
iTanken
87decced23
fix: ExplainSQL using consecutive pairs of escaper in SQL string represents an escaper (#6766)
Preventing it from being interpreted as the string terminator. This is a widely used escape mechanism in SQL standards and is applicable in most relational databases.
2023-12-28 19:53:36 +08:00
Stephano George
436cca753c
fix: join and select mytable.* not working (#6761)
* fix: select mytable.* not working

* fix: select mytable.*: will not match `mytable."*"`.
feat: increase readability of code matching table name column name
2023-12-23 21:19:41 +08:00
Alexis Viscogliosi
a2cac75218
feature: bring custom type and id column name to polymorphism (#6716)
* feature: bring custom type and id column name to polymorphism

* relationship: better returns for hasPolymorphicRelation

* fix: tests
2023-12-15 16:36:08 +08:00
Maciej Laskowski
b9ebdb13c7
Making locking parameters more intuitive (#6719)
* Making locking parameters more intuitive

* remove dedicated type
2023-12-15 16:32:56 +08:00
BugKillerPro
2fb4928aa8
refactor: Resolve implicit memory aliasing in for loop (#6730) 2023-12-15 16:31:23 +08:00
Franco Liberali
f0af94cd16 add test to show that update from works 2023-12-04 11:50:58 +08:00
FangSqing
3207ad6033
map insert support return increment id (#6662) 2023-11-15 21:32:56 +08:00
Jinzhu
c1e911f6ed Update tests/go.mod 2023-11-09 18:46:39 +08:00
Kijima Daigo
40f4afe8c2
docs: fix broken link (#6673) 2023-11-07 10:20:06 +08:00
Flc゛
d2fb7a942b
chore(logger): optimize (#6675)
* chore(logger): optimize

* chore(logger): optimize
2023-11-07 10:19:41 +08:00
black-06
9fea15ae75
feat: add MigrateColumnUnique (#6640)
* feat: add MigrateColumnUnique

* feat: define new methods

* delete debug in test
2023-10-30 17:15:49 +08:00
Cr
5adc0ce5f6
test: fix TestEmbeddedRelations (#6639) 2023-10-26 11:58:13 +08:00
gleb
78e905919f
tests/sqilte: enable FOREIGN_KEYS inside OpenTestConnection (#6641) 2023-10-26 11:54:15 +08:00
Franco Liberali
6bef318891
add support for returning in sqlserver (#6585) 2023-10-10 15:03:34 +08:00
dependabot[bot]
1b24081010
chore(deps): bump actions/checkout from 3 to 4 (#6586)
Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v3...v4)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-10 14:50:45 +08:00
Jeremy Quirke
8c18714462
Don't call MethodByName with a variable arg (#6602)
Go 1.22 goes somewhat toward addressing the issue using reflect
MethodByName disabling linker deadcode elimination (DCE) and the
resultant large increase in binary size because the linker cannot
prune unused code because it might be reached via reflection.

Go Issue golang/go#62257 reduces the number of incidences of this
problem by leveraging a compiler assist to avoid marking functions
containing calls to MethodByName as ReflectMethods as long as the
arguments are constants.

An analysis of Uber Technologies code base however shows that a number
of transitive imports still contain calls to MethodByName with a
variable argument, including GORM.

In the case of GORM, the solution we are proposing is because the
number of possible methods is finite, we will "unroll" this. This
demonstrably shows that GORM is not longer a problem for DCE.

Before
```
% go version
go version devel go1.22-2f3458a8ce Sat Sep 16 16:26:48 2023 -0700 darwin/arm64
% go  test ./... -ldflags=-dumpdep   2>  >(grep -i -e  '->.*<reflectmethod>')
gorm.io/gorm.(*Statement).BuildCondition -> gorm.io/gorm/schema.ParseWithSpecialTableName <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).Method <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).MethodByName <ReflectMethod>
ok  	gorm.io/gorm	(cached)
ok  	gorm.io/gorm/callbacks	(cached)
gorm.io/gorm/clause_test.BenchmarkComplexSelect -> gorm.io/gorm/schema.ParseWithSpecialTableName <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).Method <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).MethodByName <ReflectMethod>
?   	gorm.io/gorm/migrator	[no test files]
ok  	gorm.io/gorm/clause	(cached)
ok  	gorm.io/gorm/logger	(cached)
gorm.io/gorm/schema_test.TestAdvancedDataTypeValuerAndSetter -> gorm.io/gorm/schema.ParseWithSpecialTableName <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).Method <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).MethodByName <ReflectMethod>
?   	gorm.io/gorm/utils/tests	[no test files]
ok  	gorm.io/gorm/schema	(cached)
ok  	gorm.io/gorm/utils	(cached)
```

After

```
%go version
go version devel go1.22-2f3458a8ce Sat Sep 16 16:26:48 2023 -0700 darwin/arm64
%go  test ./... -ldflags=-dumpdep   2>  >(grep -i -e  '->.*<reflectmethod>')
ok  	gorm.io/gorm	(cached)
ok  	gorm.io/gorm/callbacks	(cached)
?   	gorm.io/gorm/migrator	[no test files]
?   	gorm.io/gorm/utils/tests	[no test files]
ok  	gorm.io/gorm/clause	(cached)
ok  	gorm.io/gorm/logger	(cached)
ok  	gorm.io/gorm/schema	(cached)
ok  	gorm.io/gorm/utils	(cached)
```
2023-10-10 14:50:29 +08:00
Mathias Zeller
12ba285a52
*datatypes.JSON in model causes panic on tx.Statement.Changed (#6611)
* do not panic on nil

* more explanation in comments

* get things compact
2023-10-10 14:46:32 +08:00
hjwblog.com
9d8a5bb208
feat: reuse name (#6626) 2023-10-10 14:45:48 +08:00
Samuel N Cui
2095d42b4c
fix: sqlite dialector cannot apply PRIMARY KEY AUTOINCREMENT type (#6624)
* fix: sqlite dialector cannot apply `PRIMARY KEY AUTOINCREMENT` type

fix #4760

* feat: add auto increment test

* feat: update sqlite

* feat: update tests deps sqlite to v1.5.4
2023-10-09 17:26:27 +08:00
Jinzhu
e57e5d8884 Update go.mod 2023-08-27 15:40:54 +08:00
Jinzhu
653732e1c3 Update go testing versions 2023-08-24 20:19:38 +08:00
Rataj
ac07543962
Fixed error message when dialector fails to initialize (#6509)
Let's say we have a problem with DSN which leads to dialector initialize error. However DB connection is not created and for some reason line 184 error provides <nil> even though "db" doesn't exist.

Previously, this code leads to:
panic: runtime error: invalid memory address or nil pointer dereference

This fix now doesn't attempt to close non-existant database connection and instead continues, so the proper error is shown. In my case:
[error] failed to initialize database, got error default addr for network 'localhost' unknown
2023-08-20 19:46:56 +08:00
龚一涛
7e44f73ad3
fix schema GetIdentityFieldValuesMap interface or ptr (#6417)
Co-authored-by: uptutu <yitao.gong@vzenith.com>
2023-08-19 21:35:14 +08:00
Heliner
2c2089760c
add float32 test case (#6530) 2023-08-19 21:33:57 +08:00
qqxhb
fef42941ba
feat: rm GetDBConnWithContext method (#6535)
* feat: rm contextconnpool method

* feat: nil
2023-08-19 21:33:31 +08:00
weih
bae684b363
fix(clause): when the value of clause.Eq is an empty array, the SQL should be IN (NULL) (#6503) 2023-08-10 13:34:33 +08:00
Jinzhu
15162afaf2 Support GetDBConnWithContext PreparedStmtDB 2023-08-10 13:30:57 +08:00
fayvori
3c34bc2f59
refactor: Regex description (#6507)
* Mirror cleanup

* Regex description

---------

Co-authored-by: Ignat Belousov <ignatbelousov@Ignats-MacBook-Pro.local>
2023-08-07 16:35:19 +08:00
Aayush Acharya
f473761813
fix: added SkipHooks in db getInstance() (#6484) 2023-08-04 10:35:59 +08:00
San Ye
193c454cf4
keep float precision in ExplainSQL (#6495) 2023-08-04 10:31:18 +08:00
Saeid
1fb26ac90e
test: coverage for tabletype added (#6496)
* test: coverage for tabletype added

* test: tidb exclueded

---------

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-08-04 10:30:07 +08:00
Jinzhu
a7f01bd1b2 Test Pluck with customized type 2023-07-25 10:47:19 +08:00
Saeid
c10f807d3c
test: coverage for foreign key violation err (#6403)
* test: coverage for foreign key violation err

* test: enabled foreign keys constraint for sqlite

* test: enabled mysql& mssql for ErrForeignKeyViolate

* test: disabled mysql & updated sqlserver driver version

* test: skipped tidb

---------

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-07-12 21:21:22 +08:00
Saeid
2066138684
ci: fix mariadb mysqladmin (#6401)
Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-06-11 07:42:18 +08:00
Saeid
c2d571cbc8
test: coverage for duplicated key err (#6389)
* test: ErrDuplicatedKey coverage added

* test: updated sqlserver version

* test: removed sqlserver

* test: support added for sqlserver

---------

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-06-10 21:05:19 +08:00
Johannes Riecken
7dd702d379
Fix incorrect documentation comment (has many -> has one) (#6382) 2023-06-07 15:02:30 +08:00
Nuno Cruces
7157b7e375
fix: database/sql.Scanner should not retain references (#6380) 2023-06-07 15:02:07 +08:00
Lev Zakharov
661781a3d7
feat: add *sql.DB connector that uses database context (#6366)
* feat: add SQLConnector

* rename
2023-06-05 16:25:05 +08:00
KantaHasegawa
5eaccaa624
reafactor: add nil detection when sqldb return (#6373)
* reafactor: add null detection when sqldb return

* refactor: Detecting nil in dbConnector.GetDBConn()

* refactor: Revert partial code from c1ea73036715018a1bb55cdb8690441044e13a76

* fix: fix if statement
2023-06-05 16:24:00 +08:00
Lev Zakharov
7a76c042e6
refactor: remove unnecessary prepared statement allocation (#6374) 2023-06-05 16:23:17 +08:00
black
c1ea730367 fix: avoid panic when open fails 2023-06-01 15:22:21 +08:00
东方上人
740f2be453
fix: begin transaction fail, rollback panic (#6365) 2023-05-31 19:21:51 +08:00
mohammad ali
26663ab9bf
max identifier length changed to 63 (#6337)
* max identifier length changed to 63

* default maxIdentifierLength is 64

* renamed License to LICENSE (#6336)

* Added support of "Violates Foreign Key Constraint" (#6329)

* Added support of "Violates Foreign Key Constraint"

Updated the translator and added the support of "foreign key constraint violation". For this, this error type is needed here.

* changed the description of ErrForeignKeyViolated

* refactor: error translator test (#6350)

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>

* fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements (#6220)

* test: add nested transaction and prepareStmt coexist test case

note: please test in the MySQL environment

Change-Id: I0db32adc5f74b0d443e98943d3b182236583b959
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

* fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements

1. SavetPoint SQL Statement not support in Prepared Statements
 e.g. see mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html

Change-Id: I082012db9b140e8ec69764c633724665cc802692
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

* revert(transaction_api): remove savepoint name pool,meaningless

Change-Id: I84aa9924fc54612005a81c83d66fdf8968ee56ad
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

---------

Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>
Co-authored-by: 王柳洋 <wangliuyang.520@bytedance.com>

* fix: save with hook (#6285) (#6294)

---------

Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>
Co-authored-by: Avinaba Bhattacharjee <avinababhattacharjee2002@gmail.com>
Co-authored-by: Muhammad Amir Ejaz <37077032+codingamir@users.noreply.github.com>
Co-authored-by: Saeid <sk.saeidee@yahoo.com>
Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
Co-authored-by: wangliuyang <54885906+wangliuyang520@users.noreply.github.com>
Co-authored-by: 王柳洋 <wangliuyang.520@bytedance.com>
Co-authored-by: black-06 <hello.bug@foxmail.com>
2023-05-30 10:00:48 +08:00
black-06
11fdf46a9f
fix: save with hook (#6285) (#6294) 2023-05-26 10:28:02 +08:00
wangliuyang
812bb20c34
fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements (#6220)
* test: add nested transaction and prepareStmt coexist test case

note: please test in the MySQL environment

Change-Id: I0db32adc5f74b0d443e98943d3b182236583b959
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

* fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements

1. SavetPoint SQL Statement not support in Prepared Statements
 e.g. see mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html

Change-Id: I082012db9b140e8ec69764c633724665cc802692
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

* revert(transaction_api): remove savepoint name pool,meaningless

Change-Id: I84aa9924fc54612005a81c83d66fdf8968ee56ad
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

---------

Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>
Co-authored-by: 王柳洋 <wangliuyang.520@bytedance.com>
2023-05-26 10:24:28 +08:00
Saeid
8197c00def
refactor: error translator test (#6350)
Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-05-25 11:10:00 +08:00
Muhammad Amir Ejaz
001738be49
Added support of "Violates Foreign Key Constraint" (#6329)
* Added support of "Violates Foreign Key Constraint"

Updated the translator and added the support of "foreign key constraint violation". For this, this error type is needed here.

* changed the description of ErrForeignKeyViolated
2023-05-21 21:27:22 +08:00
Avinaba Bhattacharjee
6698ba709e
renamed License to LICENSE (#6336) 2023-05-21 21:24:00 +08:00
201430098137
f5837deef3
fix:clickhouse error not capture(#6277) (#6321)
Co-authored-by: zhuangg <zhuangg@mingyuanyun.com>
2023-05-17 10:15:41 +08:00
Jinzhu
c3d7d08b9a Clear SET clause after build SQL 2023-05-15 15:43:44 +08:00
aclich
63534145fd
fix: 🐛 embedded struct test failed with custom datatypes (#6311)
* fix: 🐛 embedded struct test failed with custom datatypes

Fix the pointer embedded struct within custom datatypes and *time.time
should be nil issue.

* fix: 🐛 change test case to avoid mssql driver issue

change test cases from bytes to string to avoid mssql driver issue
2023-05-15 09:59:26 +08:00
John Mai
e61b98d696
feat: migrator support table comment (#6225)
* feat: migrator support table comment

* feat: migrator support tableType.It like ColumnTypes

* Avoid updating the go.mod file.

* Update tests_all.sh

* Update migrator.go

* remove Catalog() & Engine() methods.

* remove CatalogValue & EngineValue.

---------

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-05-05 15:58:27 +08:00
black-06
32045fdd7d
feat: unscoped association (#5899) (#6246)
* feat: unscoped association (#5899)

* modify name because mysql character is latin1

* work only on has association

* format

* Unscoped on belongs_to association
2023-05-04 19:30:45 +08:00
hykuan
67642abfff
fix: 🐛 numeric types in pointer embedded struct test failed (#6293) 2023-05-04 19:29:31 +08:00
hanwn
aeb298635b
debug: use slice Stale sort (#6263)
Co-authored-by: hanwang <hanwang.7721@bytedance.com>
2023-04-26 22:19:46 +08:00
Cr
407bedae0a
fix: nested joins alias (#6265) 2023-04-26 22:19:32 +08:00
yikakia
1f763c81cb
fix typo chainable_api.go (#6266) 2023-04-26 22:19:06 +08:00
Zhiheng Lin
32fc201554
fix: avoid coroutine leaks when the dialecter initialization fails. (#6249)
Co-authored-by: Kevin Lin <kevin.lin@shopee.com>
2023-04-21 22:17:21 +08:00
black-06
ac20d9e222
fix: unit test (#6250)
* fix: unit test

* fix create test

https://github.com/go-gorm/gorm/pull/6127#discussion_r1171214125

* style: rename to adaptorSerializerModel
2023-04-21 22:09:38 +08:00
Jinzhu
e9637024d3 Update README 2023-04-11 13:16:25 +08:00
black-06
828e22b17f
feat: support embedded preload (#6137)
* feat: support embedded preload

* fix lint and test

* fix test...
2023-04-11 13:10:38 +08:00
black-06
4b0da0e97a
fix cond in scopes (#6152)
* fix cond in scopes

* replace quote

* fix execute scopes
2023-04-11 12:01:23 +08:00
bsmith-auth0
ccc3cb758a
fix: many2many association with duplicate belongs to elem (#6206) 2023-04-11 11:06:13 +08:00
jessetang
05bb9d6106
refactor(migrator): non-standard codes (#6180) 2023-04-11 10:32:46 +08:00
dependabot[bot]
1d9f4b0f55
chore(deps): bump actions/stale from 7 to 8 (#6190)
Bumps [actions/stale](https://github.com/actions/stale) from 7 to 8.
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v7...v8)

---
updated-dependencies:
- dependency-name: actions/stale
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-11 10:27:05 +08:00
hanwn
59ca46db3c
fix: limit(0).offset(0) return all data (#6191)
Co-authored-by: hanwang <hanwang.7721@bytedance.com>
2023-04-11 10:25:47 +08:00
Cr
f0360dccbf
fix: embedded should be nil if not exists (#6219) 2023-04-11 10:13:25 +08:00
Saeid Kanishka
b444011d09
refactor: translatorError flag added for backward compatibility (#6178)
Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-03-24 10:07:05 +08:00
cyhone
5d1cdfef2e
avoid starting a transaction when performing only one insert operation in CreateInBatches function (#6174) 2023-03-23 14:02:35 +08:00
black-06
1a7ea98ac5
fix: count with group (#6157) (#6160)
* fix: count with group (#6157)

* add an easy-to-understand ut
2023-03-23 11:19:53 +08:00
black-06
0c7e575f19
save should be idempotent #6139 (#6149) 2023-03-23 11:18:57 +08:00
dependabot[bot]
d2dd0ce4a7
chore(deps): bump actions/setup-go from 3 to 4 (#6165)
Bumps [actions/setup-go](https://github.com/actions/setup-go) from 3 to 4.
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](https://github.com/actions/setup-go/compare/v3...v4)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-23 11:18:02 +08:00
Jinzhu
cc2d46e5be reuse name for savepoints from nested transaction, close #6060 2023-03-10 17:42:38 +08:00
Cr
8bf1f269cf
feat: support nested join (#6067)
* feat: support nested join

* fix: empty rel value
2023-03-10 17:21:56 +08:00
Jeffry Luqman
654b5f2006
test: pgsql alter column from smallint or string to boolean (#6107)
* test: pgsql alter column from smallint to boolean

* test: pgsql alter column from string to boolean
2023-03-10 17:11:56 +08:00
Cr
b62192456f
fix: diff schema update assign value (#6096) 2023-03-10 17:04:54 +08:00
Saeid Kanishka
707d70a542
refactor: translate error only when it is not nil (#6133)
* refactor: translate error only when it is not nil

* refactor: fix the error flow

* refactor: update the error if checks

* Update gorm.go

---------

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-03-10 16:51:27 +08:00
Truong Nguyen
ed474152b1
Fix: Composite primary key with auto-increment value returns 0 after insert (#6127)
* Fix #4930 workaround for databases that support auto-increment in composite primary key.

* Add test for composite key with auto-increment.

* schema.go: use field.AutoIncrement instead of field.TagSettings["AUTOINCREMENT"], add test to check autoincrement:false

create_test.go: remove unused code: drop table CompositeKeyProduct

---------

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-03-10 16:50:03 +08:00
Jinzhu
1643a36260 Fix possible concurrency problem for serializer 2023-03-10 16:39:57 +08:00
Cr
e9f25c73ee
fix: on confilct with default null (#6129)
* fix: on confilct with default null

* Update create.go

---------

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-03-10 16:35:26 +08:00
Saeid Kanishka
85eaf9eeda
feat: Unique Constraint Violation error translator for different drivers (#6004)
* feat: duplicated key error translator for different drivers

* test: removed the dependency

* test: fixed broken tests

* refactor: added ErrorTransltor interface

* style: applied styler

---------

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-03-06 14:03:31 +08:00
Jinzhu
f3874339ef Fix Save with stress tests 2023-03-02 17:22:51 +08:00
Jiepeng Cao
877cc9148f
Remove redundant code (#6087) 2023-02-27 15:44:35 +08:00
black-06
a80707de9e
Create and drop view (#6097)
* create view

* add comment

* fix test

* check param and add comment
2023-02-27 15:43:10 +08:00
Jiepeng Cao
391c961c7f
quotes on docker-compose.yml ports (#6089) 2023-02-27 15:39:02 +08:00
Cr
04cbd956eb
test: pgsql migrate unique index (#6028) 2023-02-18 09:21:07 +08:00
black-06
e66a059b82
fix: update panic if model is not ptr (#6037)
* fix: update panic if model is not ptr

* fix: update panic if model is not ptr

* fix: update panic if model is not ptr

* fix: raise an error if the value is not addressable

* fix: return
2023-02-18 09:20:29 +08:00
black-06
42fc75cb2c
fix: association concurrently appending (#6044)
* fix: association concurrently appending

* fix: fix unit test

* fix: fix gofumpt
2023-02-18 09:19:24 +08:00
Cr
aa89736db2
fix: miss join type (#6056) 2023-02-18 09:13:36 +08:00
Michael Anstis
532e9cf4cc
Issue 6054: Unscoped not working with PreLoad on Joins (#6058)
* Issue 6054: Unscoped not working with PreLoad on Joins

* Formatting

---------

Co-authored-by: Michael Anstis <manstis@redhat.com>
2023-02-18 09:06:43 +08:00
Cheese
02b7e26f6b
feat: add tidb integration test cases (#6014)
* feat: support tidb integration test

* feat: update the mysql driver version to test
2023-02-08 16:29:09 +08:00
Cr
878ac51e98
fix:throw model value required error (#6031)
* fix:throw model value required error

* chore:ingore typecheck

* chore:ingore errcheck

* refactor: use other error

* chore: gofumpt style
2023-02-08 13:40:41 +08:00
chyroc
e1f46eb802
fix: ignore nil query (#6021) 2023-02-02 17:54:51 +08:00
Jinzhu
4d6b70ec88 Allow modify statement from dest 2023-02-02 17:15:08 +08:00
qiankunli
cfbcedbf03
fix: support zeroValue tag on DeletedAt (#6011)
* fix: support zeroValue tag on DeletedAt

Signed-off-by: qiankunli <qiankun.li@qq.com>

* Update soft_delete_test.go

* Update tests_test.go

* Update soft_delete.go

---------

Signed-off-by: qiankunli <qiankun.li@qq.com>
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-02-01 14:40:55 +08:00
Jinzhu
d834dd60b7 Remove unnecessary code 2023-01-19 15:22:13 +08:00
Jinzhu
3d35ddba55 Fix use table.* as select/omit columns 2023-01-12 16:52:56 +08:00
Haibo
baf1afa1fc
fix(schema): field is only unique when there is one unique index (#5974) 2023-01-11 14:05:39 +08:00
Jinzhu
2bc913787b support implicit table alias, close #5840 #5940 2023-01-02 21:46:27 +08:00
Jinzhu
3d91802b1d Fix unexpected alter table in auto migration, close #5942, #5943 2023-01-02 21:06:04 +08:00
Jinzhu
b0e13d95b4 update github tests action 2023-01-01 22:27:49 +08:00
Jinzhu
4b768c8aff Upgrade tests deps 2023-01-01 22:22:08 +08:00
Haibo
16a272209a
fix(migrator): Tag default:'null' always causes field migration #5953 (#5954)
* fix(migrator): Tag default:'null' always causes field migration #5953

* Update migrate_test.go

* Update migrate_test.go

* Update migrate_test.go

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-01-01 22:14:28 +08:00
Haibo
da2b2861de
fix(migrator): ignore relationships when migrating #5913 (#5946) 2023-01-01 19:54:28 +08:00
dependabot[bot]
7da24d1d52
chore(deps): bump actions/stale from 6 to 7 (#5945)
Bumps [actions/stale](https://github.com/actions/stale) from 6 to 7.
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v6...v7)

---
updated-dependencies:
- dependency-name: actions/stale
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-12-27 08:47:17 +08:00
Jinzhu
ddd3cc2502 Add ParameterizedQueries option support for logger, close #5288 2022-12-25 11:37:23 +08:00
Cr
794edad60e
test(MigrateColumn): mock alter column to improve field compare (#5499)
* test(MigrateColumn): mock alter column to improve field compare

* Update migrate_test.go

* Update migrate_test.go

* Update migrate_test.go

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-12-24 17:42:16 +08:00
Cr
1935eb0adb
feat: support inner join (#5583)
* feat: support inner join

* test: mixed inner join and left join

* chore: code comment

* Update statement.go

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-12-24 12:27:38 +08:00
Defoo Li
775fa70af5
DryRun for migrator (#5689)
* DryRun for migrator

* Update migrator.go

* Update migrator.go

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-12-24 12:14:23 +08:00
Ning
bbd2bbe521
fix:Issue migrating field with CURRENT_TIMESTAMP (#5906)
Co-authored-by: ningfei <accelerator314@outlook.com>
2022-12-24 11:02:11 +08:00
Nate Armstrong
f3c6fc2533
Update func comments in chainable_api and FirstOr_ (#5935)
Add comments to functions in chainable_api. Depending on the method,
these comments add some additional context or details that are relevant
when reading the function, link to the actual docs at gorm.io/docs, or
provide examples of use. These comments should make GORM much more
pleasant to use with an IDE that provides hoverable comments, and are
minimal examples.

Also add in-code documentation to FirstOrInit and FirstOrCreate.

Almost all examples are directly pulled from the docs, with short
comments explaining the code. Most examples omit the `db.Model(&User{})`
for brevity, and would not actually work.

Co-authored-by: Nate Armstrong <nate.armstrong@eluv.io>
2022-12-23 16:51:01 +08:00
Edward McFarlane
4ec73c9bf4
Add test case for embedded value selects (#5901)
* Add test case for embedded value selects

* Revert recycle struct optimisation to avoid pointer overwrites
2022-12-19 11:49:05 +08:00
Cr
d9525d4da4
fix: skip append relation field to default db value (#5885)
* fix: relation field returning

* chore: gofumpt style
2022-12-01 20:26:59 +08:00
wjw1758548031
f931def33d
clear code syntax (#5889)
* clear code syntax

* clear code syntax
2022-12-01 20:25:53 +08:00
Jinzhu
f91313436a Fix group by with count logic 2022-11-21 11:10:56 +08:00
Cr
342310fba4
fix(FindInBatches): throw err if pk not exists (#5868) 2022-11-21 10:49:27 +08:00
kvii
b6836c2d3e
fix bug in windows (#5844)
* fix bug in windows

* fix file name bug

* test in unix like platform
2022-11-21 10:48:13 +08:00
jessetang
cef3de694d
cleanup(prepare_stmt.go): unnecessary map delete (#5849) 2022-11-13 11:12:09 +08:00
jessetang
1b9cd56c53
doc(README.md): add contributors (#5847) 2022-11-10 16:30:32 +08:00
kvii
871f1de6b9
fix logger path bug (#5836) 2022-11-05 11:52:08 +08:00
jessetang
fb640cf7da
test(utils): add utils unit test (#5834) 2022-11-05 08:38:14 +08:00
jessetang
5c8ecc3a2a
feat: golangci add goimports and whitespace (#5835) 2022-11-05 08:37:37 +08:00
jessetang
f82e9cfdbe
test(clause/joins): add join unit test (#5832) 2022-11-03 21:03:13 +08:00
Cr
b2f42528a4
fix(Joins): args with select and omit (#5790)
* fix(Joins): args with select and omit

* chore: gofumpt style
2022-11-02 10:28:00 +08:00
Cr
9d82aa5673
test: invalid cache plan with prepare stmt (#5778)
* test: invalid cache plan with prepare stmt

* test: more test cases

* test: drop and rename column
2022-10-20 14:10:47 +08:00
Cr
5dd2bb4827
feat(PreparedStmtDB): support reset (#5782)
* feat(PreparedStmtDB): support reset

* fix: close all stmt

* test: fix test

* fix: delete one by one
2022-10-19 14:46:59 +08:00
Jinzhu
3f20a543fa Support use clause.Interface as query params 2022-10-18 18:01:55 +08:00
viatoriche / Maxim Panfilov
62593cfad0 add test: TestAutoMigrateInt8PG: shouldn't execute ALTER COLUMN TYPE smallint, close #5762 2022-10-18 17:28:06 +08:00
Jinzhu
a0f4d3f7d2 Save as empty string for not nullable nil field serialized into json 2022-10-18 16:25:39 +08:00
Jinzhu
ab5f80a8d8 Save as NULL for nil object serialized into json 2022-10-18 15:44:56 +08:00
Cr
186e8a9e14
fix: association without pks (#5779) 2022-10-18 11:58:42 +08:00
Jinzhu
2a788fb20c Upgrade tests go.mod 2022-10-17 17:01:42 +08:00
Jinzhu
aa4312ee74 Don't display any GORM related package path as source 2022-10-17 17:01:42 +08:00
Jinzhu
08aa2f9888 Update README 2022-10-14 20:30:41 +08:00
Jinzhu
2c56954cb1 tests mariadb with returning support 2022-10-08 20:48:22 +08:00
Jinzhu
e93dc3426e Test postgres autoincrement check 2022-10-08 17:16:32 +08:00
Jinzhu
983e96f142 Add tests for alter column type 2022-10-08 16:04:57 +08:00
Jinzhu
34fbe84580 Add TableName with NamingStrategy support, close #5726 2022-10-07 21:18:37 +08:00
robhafner
e8f48b5c15
fix: limit=0 results (#5735) (#5736) 2022-10-07 20:14:14 +08:00
jesse.tang
4b22a55a75
fix: primaryFields are overwritten (#5721) 2022-10-07 18:29:28 +08:00
Wen Sun
9564b82975
Fix OnConstraint builder (#5738) 2022-10-07 13:46:20 +08:00
Cr
0b7113b618
fix: prepare deadlock (#5568)
* fix: prepare deadlock

* chore[ci skip]: code style

* chore[ci skip]: test remove unnecessary params

* fix: prepare deadlock

* fix: double check prepare

* test: more goroutines

* chore[ci skip]: improve code comments

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-09-30 18:13:36 +08:00
Stephano George
a3cc6c6088 Fix: wrong value when Find with Join with same column name, close #5723, #5711 2022-09-30 17:18:42 +08:00
jesse.tang
be440e7512
fix possible nil panic in tests (#5720)
* fix maybe nil panic

* reset code
2022-09-30 11:14:34 +08:00
dependabot[bot]
e1dd0dcbc4
chore(deps): bump actions/stale from 5 to 6 (#5717)
Bumps [actions/stale](https://github.com/actions/stale) from 5 to 6.
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/stale
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-09-30 11:13:01 +08:00
Nguyen Huu Tuan
328f301982
add some test case which related the logic (#5477) 2022-09-22 18:35:21 +08:00
kinggo
12237454ed fix: use preparestmt in trasaction will use new conn, close #5508 2022-09-22 16:47:43 +08:00
Cr
73bc53f061
feat: migrator support type aliases (#5627)
* feat: migrator support type aliases

* perf: check type
2022-09-22 15:56:32 +08:00
Cr
101a7c789f
fix: scan array (#5624)
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-09-22 15:51:47 +08:00
Jinzhu
3a72ba102e Allow shared foreign key for many2many jointable 2022-09-22 15:03:41 +08:00
jesse.tang
1f634c3937
support scan assign slice cap (#5634)
* support scan assign slice cap

* fix
2022-09-22 14:50:35 +08:00
Cr
5ed7b1a65e
fix: same embedded filed name (#5705) 2022-09-22 11:25:03 +08:00
qqxhb
490625981a
fix: update omit (#5699) 2022-09-16 15:02:44 +08:00
Googol Lee
edb00c10ad
AutoMigrate() should always migrate checks, even there is no relationship constraints. (#5644)
* fix: remove uuid autoincrement

* AutoMigrate() should always migrate checks, even there is no relationship constranits.

Co-authored-by: a631807682 <631807682@qq.com>
2022-09-14 10:26:51 +08:00
Bruce MacKenzie
f29afdd329
Rewrite of finisher_api Godocs (#5618) 2022-09-09 11:16:41 +08:00
Jiepeng Cao
b3eb1c8c51
simplified regexp (#5677) 2022-09-05 15:39:19 +08:00
jesse.tang
f78f635fae
Optimize: code logic db.scanIntoStruct() (#5633) 2022-09-05 15:34:33 +08:00
Cr
d71caef7d9
fix: remove uuid autoincrement (#5620) 2022-09-03 20:00:21 +08:00
Shunsuke Otani
8c3018b96a
Replace ioutil.Discard with io.Discard (#5603) 2022-08-15 10:50:06 +08:00
Shunsuke Otani
3f92b9b0df
Refactor: redundant type from composite literal (#5604) 2022-08-15 10:47:26 +08:00
Aoang
ba227e8939
Add Go 1.19 Support (#5608) 2022-08-15 10:46:57 +08:00
enwawerueli
573b9fa536 fix: correct grammar 2022-08-15 10:28:36 +08:00
Bruce MacKenzie
a35883590b
update Delete Godoc to describe soft delete behaviour (#5554) 2022-08-11 11:38:04 +08:00
Cr
f223279384
chore: fix gorm tag (#5577) 2022-08-10 11:03:42 +08:00
hjwblog.com
6e03b97e26
fix: empty serilizer err #5524 (#5525)
* fix: empty serilizer err #5524

* feat: fix UnixSecondSerializer return nil

* feat: split type case

Co-authored-by: huanjiawei <huanjiawei@bytedance.com>
2022-07-27 13:59:47 +08:00
MJrocker
3c6eb14c92 Fixed some typos in the code comment 2022-07-27 10:22:49 +08:00
Cr
06e174e24d
fix: embedded default value (#5540) 2022-07-25 14:10:30 +08:00
Xudong Zhang
bab3cd1724
fix bad logging performance of bulk create (#5520) (#5521) 2022-07-18 20:47:00 +08:00
Jinzhu
75720099b5 Create a new db in FindInBatches 2022-07-18 18:07:05 +08:00
Goxiaoy
2ba599e8b7
fix empty QueryClauses in association (#5502) (#5503)
* fix empty QueryClauses in association (#5502)

* test: empty QueryClauses in association (#5502)

* style: empty QueryClauses in association (#5502)

* style: empty QueryClauses in association (#5502)
2022-07-15 11:15:18 +08:00
alingse
099813bf11
Adjust ToStringKey use unpack params, fix pass []any as any in variadic function (#5500)
* fix pass []any as any in variadic function

* add .vscode to gitignore
2022-07-14 20:05:22 +08:00
Jinzhu
4d40e34734 Update select tests 2022-07-14 14:55:54 +08:00
Jinzhu
3262daf8d4 Fix select with association column 2022-07-13 18:26:35 +08:00
Jinzhu
cae30e9a50 Fix select with association column 2022-07-13 18:02:11 +08:00
Jinzhu
a7063848ef Fix select with uppercase column name 2022-07-13 17:44:14 +08:00
Jinzhu
08f6d06e47 Fix select with quoted column name 2022-07-13 17:21:19 +08:00
Jinzhu
62fdc2bb3b Fix serializer with empty string 2022-07-11 11:51:05 +08:00
Jinzhu
b13d1757fa Refactor Model with slice data 2022-07-07 15:39:29 +08:00
Jinzhu
9fd73ae4f1 Revert "use callback to handle transaction"
This reverts commit 93f28bc116526ba4decdd969a7b2b0b245ad70f1.
2022-07-07 15:06:48 +08:00
Jinzhu
fe01e1b9f4 Fix Model with slice data 2022-07-07 14:43:33 +08:00
Cr
46bce170ca
test: pg array type (#5480) 2022-07-04 16:42:27 +08:00
Jason Lee
5c4016d9a3
Merge pull request #5455 from longbridgeapp/feat-support-transaction-calllback 2022-07-01 23:34:26 +08:00
Cr
c74bc57add
fix: association many2many duplicate elem (#5473)
* fix: association many2many duplicate elem

* chore: gofumpt style
2022-07-01 15:12:15 +08:00
Joe
2cb4088456 ignore AddError return error 2022-07-01 14:37:38 +08:00
Cr
235c093bb9
fix(MigrateColumn):declared different type without length (#5465) 2022-06-29 10:07:42 +08:00
wws
3e6ab99043
fix:serializer contain field panic (#5461) 2022-06-25 16:32:47 +08:00
Joe
93f28bc116 use callback to handle transaction
- make transaction have before and after hooks, so plugin can have hack before
or after transaction
2022-06-24 10:33:39 +08:00
Jinzhu
a70af2a4c0 Fix Select with digits in column name 2022-06-20 15:35:40 +08:00
qqxhb
1305f637f8
feat: add method GetIndexes (#5436)
* feat: add method GetIndexes

* feat: add default impl for Index interface

* feat: fmt
2022-06-17 11:00:57 +08:00
Cr
8d45714628
fix: reset null value in slice (#5417)
* fix: reset null value in slice

* fix: can not set field in-place in join
2022-06-14 13:48:50 +08:00
Bexanderthebex
d01de7232b
enhancement: Avoid calling reflect.New() when passing in slice of values to Scan() (#5388)
* fix: reduce allocations when slice of values

* chore[test]: Add benchmark for scan

* chore[test]: add bench for scan slice

* chore[test]: add bench for slice pointer and improve tests

* chore[test]: make sure database is empty when doing slice tests

* fix[test]: correct sql delete statement

* enhancement: skip new if rows affected = 0
2022-06-01 11:50:57 +08:00
dependabot[bot]
f4e9904b02
chore(deps): bump gorm.io/driver/mysql from 1.3.3 to 1.3.4 in /tests (#5385)
Bumps [gorm.io/driver/mysql](https://github.com/go-gorm/mysql) from 1.3.3 to 1.3.4.
- [Release notes](https://github.com/go-gorm/mysql/releases)
- [Commits](https://github.com/go-gorm/mysql/compare/v1.3.3...v1.3.4)

---
updated-dependencies:
- dependency-name: gorm.io/driver/mysql
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-06-01 10:26:09 +08:00
Cr
93986de8e4
fix: migrate column default value (#5359)
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-05-28 23:09:13 +08:00
t-inagaki@hum_op
dc1ae394f3
fixed FirstOrCreate not handled error when table is not exists (#5367)
* fixed FirstOrCreate not handled error when table is not exists

* delete useless part
2022-05-28 22:18:43 +08:00
Cr
7e13b03bd4
fix: duplicate column scan (#5369)
* fix: duplicate column scan

* fix: dup filed in inconsistent schema and database

* chore[ci skip]: gofumpt style

* chore[ci skip]: fix typo
2022-05-28 22:18:07 +08:00
Cr
7d1a92d60e
test: test for skip prepared when auto migrate (#5350) 2022-05-22 16:12:28 +08:00
Clark McCauley
540fb49bcb
Fixed #5355 - Named variables don't work when followed by Windows CRLF line endings (#5356)
* Fixed #5355.

* Fixed unit test to test both CRLF and CR line endings
2022-05-22 15:16:01 +08:00
Cr
7496c3a56e
fix: trx in hooks clone stmt (#5338)
* fix: trx in hooks

* chore: format by gofumpt
2022-05-17 14:13:41 +08:00
black-06
f5e77aab2f
fix: quote index when creating table (#5331) 2022-05-17 10:59:53 +08:00
Cr
373bcf7aca
fix: many2many auto migrate (#5322)
* fix: many2many auto migrate

* fix: uuid ossp
2022-05-09 10:07:18 +08:00
Cr
19b8d37ae8
fix: preload with skip hooks (#5310) 2022-05-04 18:57:53 +08:00
Cr
b0104943ed
fix: callbcak sort when using multiple plugin (#5304) 2022-04-30 09:57:16 +08:00
Heliner
d3488ae6bc
fix: add judge result of auto_migrate (#5306)
Co-authored-by: fredhan <fredhan@futunn.com>
2022-04-30 09:50:53 +08:00
Cr
bd7e42ec65
fix: AutoMigrate with special table name (#5301)
* fix: AutoMigrate with special table name

* test: migrate with special table name
2022-04-27 21:13:48 +08:00
Jinzhu
6a6dfdae72 Refactor FirstOrCreate, FirstOrInit 2022-04-26 17:16:48 +08:00
Chiung-Ming Huang
0211ac91a2
index: add composite id (#5269)
* index: add composite id

* index: add test cases of composite id

* index: improve the comments for the test cases of composite id
2022-04-25 11:39:23 +08:00
Cr
a0cc631272
test: test for postgrs serial column (#5234)
* test: test for postgrs sercial column

* test: only for postgres

* chore: spelling mistake

* test: for drop sequence
2022-04-24 12:13:27 +08:00
aelmel
3643f856a3
check for pointer to pointer value (#5278)
* check for pointer to pointer value

* revert to Ptr

Co-authored-by: Alexei Melnic <alexei.melnic@meliora.xyz>
2022-04-24 09:10:36 +08:00
Cr
9b80fe9e96
fix: stmt.Changed zero value filed behavior (#5281)
* fix: stmt.Changed zero value filed behavior

* chore: rename var
2022-04-24 09:08:52 +08:00
glebarez
395606ac7c
fix missing error-check in AutoMigrate (#5283) 2022-04-22 11:19:33 +08:00
Jinzhu
88c26b62ee Support Scopes in group conditions 2022-04-20 17:21:38 +08:00
Cr
b49ae84780
fix: FindInBatches with offset limit (#5255)
* fix: FindInBatches with offset limit

* fix: break first

* fix: FindInBatches Limit zero
2022-04-17 09:58:33 +08:00
ZhangShenao
e0ed3ce400
fix spelling mistake (#5256)
Co-authored-by: Shenao Zhang <shenao.zhang@shopee.com>
2022-04-14 20:32:57 +08:00
Jinzhu
d421c67ef5 Remove ErrRecordNotFound error from log when using Save 2022-04-14 10:51:39 +08:00
dependabot[bot]
ce53ea53ee
chore(deps): bump actions/setup-go from 2 to 3 (#5243)
Bumps [actions/setup-go](https://github.com/actions/setup-go) from 2 to 3.
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](https://github.com/actions/setup-go/compare/v2...v3)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-04-13 15:53:12 +08:00
dependabot[bot]
771cbed755
chore(deps): bump actions/stale from 4 to 5 (#5244)
Bumps [actions/stale](https://github.com/actions/stale) from 4 to 5.
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v4...v5)

---
updated-dependencies:
- dependency-name: actions/stale
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-04-13 15:52:40 +08:00
Cr
a65912c588
fix: FirstOrCreate RowsAffected (#5250) 2022-04-13 15:52:07 +08:00
Filippo Del Moro
6aa6d37fc4
Fix scanIntoStruct (#5241)
* Reproduces error case

* Fix scanIntoStruct

Co-authored-by: Filippo Del Moro <filippo.delmoro@facile.it>
2022-04-13 15:47:04 +08:00
Jinzhu
74e07b049c Serializer unixtime support ptr of int 2022-04-11 22:07:56 +08:00
Jinzhu
41bef26f13 Remove shared sync pool for Scanner compatibility 2022-04-11 21:37:44 +08:00
Naveen
5c9ef9a843
Set permissions for GitHub actions (#5237)
Restrict the GitHub token permissions only to the required ones; this way, even if the attackers will succeed in compromising your workflow, they won’t be able to do much.

- Included permissions for the action. https://github.com/ossf/scorecard/blob/main/docs/checks.md#token-permissions

https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#permissions

https://docs.github.com/en/actions/using-jobs/assigning-permissions-to-jobs

[Keeping your GitHub Actions and workflows secure Part 1: Preventing pwn requests](https://securitylab.github.com/research/github-actions-preventing-pwn-requests/)

Signed-off-by: naveensrinivasan <172697+naveensrinivasan@users.noreply.github.com>
2022-04-10 09:38:43 +08:00
Jinzhu
0729261b62 Support double ptr for Save 2022-04-08 14:23:25 +08:00
Hasan
81c4024232
Offset issue resolved for scanning results back into struct (#5227) 2022-04-07 23:56:41 +08:00
huangcheng1
38a24606da fix: tables lost when joins exists in from clause, close #5218
commit 7f6a603afa26820e187489b5203f93adc513687c
Author: Jinzhu <wosmvp@gmail.com>
Date:   Sat Apr 2 17:26:48 2022 +0800

    Refactor #5218

commit 95d00e6ff2668233f3eca98aa4917291e3d869bd
Author: huangcheng1 <huangcheng1@sensetime.com>
Date:   Fri Apr 1 16:30:27 2022 +0800

    fix: tables lost when joins exists in from clause
2022-04-02 17:27:53 +08:00
Jinzhu
9144969c83 Allow to use tag to disable auto create/update time 2022-04-02 17:17:47 +08:00
ZhangShenao
f7b52bb649
unify db receiver name (#5215)
Co-authored-by: Shenao Zhang <shenao.zhang@shopee.com>
2022-04-01 08:35:16 +08:00
Goxiaoy
cd0315334b
fix: context missing in association (#5214) 2022-04-01 08:33:39 +08:00
ZhangShenao
8333844f71
fix variable shadowing (#5212)
Co-authored-by: Shenao Zhang <shenao.zhang@shopee.com>
2022-03-31 20:57:20 +08:00
Jinzhu
ea8509b777 Use defer to close rows to avoid scan panic leak rows 2022-03-29 18:48:32 +08:00
Jinzhu
9dd6ed9c65 Scan with Rows interface 2022-03-29 18:14:37 +08:00
dependabot[bot]
6c827ff2e3
chore(deps): bump actions/cache from 2 to 3 (#5196)
Bumps [actions/cache](https://github.com/actions/cache) from 2 to 3.
- [Release notes](https://github.com/actions/cache/releases)
- [Commits](https://github.com/actions/cache/compare/v2...v3)

---
updated-dependencies:
- dependency-name: actions/cache
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-03-28 19:55:05 +08:00
qqxhb
6d40a83432
Update README.md
add gorm gen
2022-03-24 16:30:14 +08:00
Cr
3d7019a7c2
fix: throw err if association model miss primary key (#5187) 2022-03-24 09:34:06 +08:00
Jin
9a4d10be64
style: fix coding typo (#5184) 2022-03-24 09:31:58 +08:00
Jinzhu
f92e6747cb Handle field set value error 2022-03-23 17:24:25 +08:00
Jinzhu
a7b3b5956f Fix hooks order, close https://github.com/go-gorm/gorm.io/pull/519 2022-03-22 22:42:36 +08:00
Jinzhu
d66f37ad32 Add Go 1.18 2022-03-21 10:50:14 +08:00
Jin
2d5cb997ed
style: fix linter check for NamingStrategy and onConflictOption (#5174) 2022-03-20 09:02:45 +08:00
Jinzhu
0097b39a77 Should ignore error when parsing default value for time, close #5176 2022-03-20 08:55:08 +08:00
Jinzhu
540b47571a Fix update select clause with before/after expressions, close #5164 2022-03-18 20:57:33 +08:00
Cr
d402765f69
test: fix utils.AssertEqual (#5172) 2022-03-18 20:11:23 +08:00
ag9920
3c00980e01 fix: serializer use default valueOf in assignInterfacesToValue, close #5168
commit 58e1b2bffbc216f2862d040fb545a8a486e473b6
Author: Jinzhu <wosmvp@gmail.com>
Date:   Fri Mar 18 17:06:43 2022 +0800

    Refactor #5168

commit fb9233011d209174e8223e970f0f732412852908
Author: ag9920 <alexgong7@outlook.com>
Date:   Thu Mar 17 21:23:28 2022 +0800

    fix: serializer use default valueOf in assignInterfacesToValue
2022-03-18 17:12:17 +08:00
Jinzhu
e6f7da0e0d Support Variable Relation 2022-03-18 14:30:30 +08:00
chenrui
5431da8caf fix: preload panic when model and dest different close #5130
commit e8307b5ef5273519a32cd8e4fd29250d1c277f6e
Author: Jinzhu <wosmvp@gmail.com>
Date:   Fri Mar 18 13:37:22 2022 +0800

    Refactor #5130

commit 40cbba49f374c9bae54f80daee16697ae45e905b
Author: chenrui <chenrui@jingdaka.com>
Date:   Sat Mar 5 17:36:56 2022 +0800

    test: fix test fail

commit 66d3f078291102a30532b6a9d97c757228a9b543
Author: chenrui <chenrui@jingdaka.com>
Date:   Sat Mar 5 17:29:09 2022 +0800

    test: drop table and auto migrate

commit 7cbf019a930019476a97ac7ac0f5fc186e8d5b42
Author: chenrui <chenrui@jingdaka.com>
Date:   Sat Mar 5 15:27:45 2022 +0800

    fix: preload panic when model and dest different
2022-03-18 13:38:46 +08:00
chenrui
c2e36ebe62 fix: soft delete for join, close #5132
commit a83023bdfc0dc6eaccc6704b64ff6436c2fe7725
Author: Jinzhu <wosmvp@gmail.com>
Date:   Fri Mar 18 01:05:25 2022 +0800

    Refactor #5132

commit 8559f51102c01be6c19913c0bc3a5771721ff1f5
Author: chenrui <chenrui@jingdaka.com>
Date:   Mon Mar 7 20:33:12 2022 +0800

    fix: should add deleted_at exprs for every joins

commit 2b7a1bdcf3eff9d23253173d21e73c1f056f9be4
Author: chenrui <chenrui@jingdaka.com>
Date:   Mon Mar 7 14:46:48 2022 +0800

    test: move debug flag

commit ce13a2a7bc50d2c23678806acf65dbd589827c77
Author: chenrui <chenrui@jingdaka.com>
Date:   Mon Mar 7 14:39:56 2022 +0800

    fix: soft delete for join.on
2022-03-18 01:09:20 +08:00
chenrui
9b9ae325bb fix: circular reference save, close #5140
commit 2ac099a37ac7bd74f0a98a6fdc42cc8527404144
Author: Jinzhu <wosmvp@gmail.com>
Date:   Thu Mar 17 23:49:21 2022 +0800

    Refactor #5140

commit 6e3ca2d1aa09943dcfb5d9a4b93bea28212f71be
Author: a631807682 <631807682@qq.com>
Date:   Sun Mar 13 12:52:08 2022 +0800

    test: add test for LoadOrStoreVisitMap

commit 9d5c68e41000fd15dea124797dd5f2656bf6b304
Author: chenrui <chenrui@jingdaka.com>
Date:   Thu Mar 10 20:33:47 2022 +0800

    chore: add more comment

commit bfffefb179c883389b72bef8f04469c0a8418043
Author: chenrui <chenrui@jingdaka.com>
Date:   Thu Mar 10 20:28:48 2022 +0800

    fix: should check values has been saved instead of rel.Name

commit e55cdfa4b3fbcf8b80baf009e8ddb2e40d471494
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:48:01 2022 +0800

    chore: go lint

commit fe4715c5bd4ac28950c97dded9848710d8becb88
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:27:24 2022 +0800

    chore: add test comment

commit 326862f3f8980482a09d7d1a7f4d1011bb8a7c59
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:22:33 2022 +0800

    fix: circular reference save
2022-03-17 23:53:31 +08:00
Mikhail Faraponov
2990790fbc
Use WriteByte for single byte operations (#5167)
Co-authored-by: Mikhail Faraponov <mikefaraponov@Mikhails-MacBook-Pro.local>
2022-03-17 22:54:30 +08:00
Hasan
f3e2da5ba3 Added offset when scanning the result back to struct, close #5143
commit 9a2058164d44c98d7b586b87bed1757f89d6fad7
Author: Jinzhu <wosmvp@gmail.com>
Date:   Thu Mar 17 22:34:19 2022 +0800

    Refactor #5143

commit c259de21768936428c9d89f7b31afb95b8acb36a
Author: Hasan <mr.k779@outlook.com>
Date:   Mon Mar 14 20:04:01 2022 +0545

    Update scan_test.go

commit 09f127b49151a52fbb8b354a03e6610d4f70262f
Author: Hasan <mr.k779@outlook.com>
Date:   Mon Mar 14 19:23:47 2022 +0545

    Added test for scanning embedded data into structs

commit aeaca493cf412def7813d36fd6a68acc832bf79f
Author: Hasan <mr.k779@outlook.com>
Date:   Tue Mar 8 04:08:16 2022 +0600

    Added offset when scanning the result back to struct
2022-03-17 22:52:40 +08:00
Jinzhu
63ac66b569 Support default tag for time.Time 2022-03-17 11:34:27 +08:00
Jinzhu
6befa0c947 Refactor preload error check 2022-03-17 11:22:25 +08:00
labulakalia
61b4c31236
fix when index name is "type", parseFieldIndexes will set index TYPE is "TYPE" (#5155)
* fix index name is type, parseFieldIndexes will set index TYPE is "TYPE"

* check TYPE empty
2022-03-14 21:47:59 +08:00
dependabot[bot]
f961bf1c14
chore(deps): bump actions/checkout from 2 to 3 (#5133)
Bumps [actions/checkout](https://github.com/actions/checkout) from 2 to 3.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v2...v3)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-03-12 22:28:18 +08:00
Jason Lee
b566ed7913
Merge pull request #5125 from CaoManhDat/master
ToSQL should enable SkipDefaultTransaction by default
2022-03-04 11:37:37 +08:00
Cao Manh Dat
29a8557384 ToSQL should enable SkipDefaultTransaction by default 2022-03-03 09:18:01 +07:00
Jinzhu
4e523499d1 Refactor Tx interface 2022-03-01 16:59:50 +08:00
lianghuan
996b96e812 Add TxConnPoolBeginner and Tx interface 2022-03-01 16:59:50 +08:00
Jinzhu
e2e802b837 Refactor Scan 2022-02-28 13:00:30 +08:00
Jinzhu
43a72b369e Refactor Scan 2022-02-28 00:16:57 +08:00
Jinzhu
530b0a12b4 Add fast path for ValueOf, ReflectValueOf 2022-02-27 22:16:31 +08:00
Jinzhu
4d14ac39ff Merge branch 'a631807682-i5091' 2022-02-27 09:12:35 +08:00
Jinzhu
68bb5379d9 Refactor scan into struct 2022-02-27 09:09:29 +08:00
Jinzhu
f2edda50e1 Refactor check missing where condition 2022-02-27 08:40:15 +08:00
chenrui
397b583b8e fix: query scanner in single column 2022-02-25 22:38:48 +08:00
Jinzhu
6a18a15c93 Refactor check missing where condition 2022-02-25 10:48:23 +08:00
jing1
3741f258d0
feat: support gob serialize (#5108) 2022-02-24 10:21:27 +08:00
Michael Nussbaum
45ef1da7e4
Fix naming longer then 64 chars with dots in table (#5045)
Ensures that foreign key relationships and indexes are given
syntactically valid names when their name length exceeds 64 characters
and they contained dot characters within the name. This is most often
relevant when a Postgres table name is fully qualified by including its schema
as part of its name
2022-02-24 10:10:20 +08:00
Jinzhu
b1201fce4e Fix update with customized time type, close #5101 2022-02-23 17:48:26 +08:00
Qt
7837fb6fa0
fix typo in TxCommitter interface comment & improve CheckTruth, chek val empty first (#5094)
* fix typo in TxCommitter interface comment

* improve CheckTruth, chek val empty first
2022-02-20 21:19:15 +08:00
codingxh
664c5fb767
strings.replace -> strings.replaceAll (#5095)
Co-authored-by: huquan<xxhh_quan_g@163.com>
2022-02-20 19:55:04 +08:00
Gilad Weiss
f3547e00cc
Inherit clone flag (NewDB) on transaction creation (#5012)
* Inherit clone flag (NewDB) on transaction creation

I find it very reassuring to know that after a finisher API, I get a clean db object for my next queries.
If you look at the example in https://gorm.io/docs i’d see many queries running one after the other.. but in reality they wouldn’t work as the they are portrayed and that’s because in default mode NewDB is false and will make all the clauses stay even after a finisher API.

My solution is just to have the value of the clone flag in the “parent” db object, be injected to its children transactions.

* Fix typo
2022-02-20 08:33:12 +08:00
sammyrnycreal
5edc78116f Fixed the use of "or" to be " OR ", to account for words that contain "or" or "and" (e.g., 'score', 'band') in a sql statement as the name of a field. 2022-02-20 08:22:21 +08:00
Jinzhu
48ced75d1d Improve support for AutoMigrate 2022-02-19 23:42:20 +08:00
Jinzhu
e0b4e0ec8f Update auto stale days 2022-02-19 17:11:23 +08:00
Jinzhu
0af95f509a Enhance migrator Columntype interface (#5088)
* Update Migrator ColumnType interface

* Update MigrateColumn Test

* Upgrade test drivers

* Fix typo
2022-02-19 17:02:53 +08:00
Jinzhu
39d84cba5f Add serializer support (#5078)
* Update context

* Update GormFieldValuer

* Add Serializer

* Add Serializer Interface

* Refactor gorm field

* Refactor setter, valuer

* Add sync.Pool

* Fix test

* Add pool manager

* Fix pool manager

* Add poolInitializer

* Add Serializer Scan support

* Add Serializer Value method

* Add serializer test

* Finish Serializer

* Fix JSONSerializer for postgres

* Fix JSONSerializer for sqlserver

* Test serializer tag

* Add unixtime serializer

* Update go.mod
2022-02-19 17:02:53 +08:00
li-jin-gou
19ac396a22
fix: isPrintable incorrect (#5076)
* fix: isPrintable incorrect

* fix: isPrintable incorrect

* style: use ReplaceAll instead of Replace
2022-02-15 20:32:03 +08:00
Jinzhu
a0aceeb33e Migrator AlterColumn with full data type 2022-02-10 10:40:48 +08:00
Jinzhu
df2365057b Remove uncessary switch case 2022-02-09 17:23:16 +08:00
Jinzhu
4eeb839cea Better support Stringer when explain SQL 2022-02-09 15:17:25 +08:00
li-jin-gou
d22215129e
fix: replace empty table name result in panic (#5048)
* fix: replace empty name result in panic

* fix: replace empty table name result in panic
2022-02-08 17:06:10 +08:00
Jinzhu
416c4d0653 Test query with Or and soft delete 2022-02-08 16:31:24 +08:00
Jason Lee
93b1a6f7ea
Merge pull request #5043 from Saurabh-Thakre/patch-2 2022-02-04 22:31:21 +08:00
Saurabh Thakre
581a879bf1
Added comments to existing methods
Added two comments to describe FirstOrInit and FirstOrCreate methods.
2022-01-31 17:26:28 +05:30
Jinzhu
f19b84d104 Fix github action 2022-01-30 22:46:41 +08:00
Jinzhu
8d293d44dd Fix docker-compose test env for Mac M1 2022-01-30 22:05:38 +08:00
Ning
8c3673286d
preoload not allowd before count (#5023)
Co-authored-by: ningfei <accelerator314@outlook.com>
2022-01-30 18:17:06 +08:00
li-jin-gou
c0bea447b9
fix: omit not work when use join (#5034) 2022-01-28 22:16:42 +08:00
Jinzhu
98c4b78e4d Add Session Initialized option 2022-01-28 19:26:10 +08:00
Jinzhu
cec0d32aec Support use clause.Expression as argument 2022-01-28 18:48:32 +08:00
dependabot[bot]
e5894ca449
chore(deps): bump gorm.io/driver/mysql from 1.2.1 to 1.2.3 in /tests (#4987)
Bumps [gorm.io/driver/mysql](https://github.com/go-gorm/mysql) from 1.2.1 to 1.2.3.
- [Release notes](https://github.com/go-gorm/mysql/releases)
- [Commits](https://github.com/go-gorm/mysql/compare/v1.2.1...v1.2.3)

---
updated-dependencies:
- dependency-name: gorm.io/driver/mysql
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-01-12 13:11:57 +08:00
piyongcai
a0d6ff1fea
time.Time, []byte type add alias support. (rebase master) (#4992)
* time.Time, []byte type add alias support

* reformat
2022-01-12 13:11:40 +08:00
Jinzhu
eae73624ad Fix return failed to begin transaction error when failed to start a transaction 2022-01-07 10:04:35 +08:00
kinggo
0df42e9afc
feat: add Connection to execute multiple commands in a single connection; (#4982) 2022-01-07 09:49:56 +08:00
halfcrazy
f757b8fdc9
fix: auto migration column order unpredictable (#4980) 2022-01-06 18:55:20 +08:00
kinggo
b47cf57f5e
ci: add gofumpt check in reviewdog (#4973) 2022-01-06 15:02:53 +08:00
kinggo
4dd2647967
Merge pull request #4964 from liweitingwt/f_test_error
improve the error handle in tests_test
2021-12-31 14:25:04 +08:00
kinggo
8dde09e0be
fix: generate sql incorrect when use soft_delete and only one OR (#4969)
* fix: generate sql incorrect when use soft_delete and only one OR
2021-12-30 11:47:14 +08:00
liweiting.wt
b9667cb747 fix: fix the error handle in tests_test 2021-12-28 18:22:17 +08:00
Emre Güllü
2c3fc2db28
Fix: Where clauses with named arguments may cause generation of unintended queries (#4937) 2021-12-21 19:50:00 +08:00
liweitingwt
24026bf1fe
modify unscoped judge (#4929)
* modify unscoped judge

* modify unscoped judge

Co-authored-by: liweiting <liweiting1995@gmail.com>
2021-12-16 10:41:34 +08:00
Jinzhu
adf8f70f06 Upgrade go.mod 2021-12-10 17:50:19 +08:00
piyongcai
380cc64ff5
fix type alias AutoMigrate bug(Add Test Case) (#4888)
* fix type alias AutoMigrate bug. eg

```go
package main

type IDer interface{ GetID() int64 }

// ID will add some method to implement some interface eg: GetID
type ID int64
func (z ID) GetID() int64 { return int64(z) }

type Test struct {
	ID
	Code string `gorm:"size:50"`
	Name string `gorm:"size:50"`
}

func main() {
	db, err := gorm.Open(postgres.New(postgres.Config{
		DSN: `dsn`,
		PreferSimpleProtocol: false,
	}), &gorm.Config{
		Logger:                 logger.Default.LogMode(logger.Info),
		SkipDefaultTransaction: true,
	})
	if err != nil {
		log.Fatal(err)
	}

	if err = db.AutoMigrate(&Test{}); err != nil {
		// invalid embedded struct for Test's field ID, should be struct, but got main.ID
		log.Fatal(err)
	}
}
```

* fix type alias AutoMigrate bug. eg

```go
package main

type IDer interface{ GetID() int64 }

// ID will add some method to implement some interface eg: GetID
type ID int64
func (z ID) GetID() int64 { return int64(z) }

type Test struct {
	ID
	Code string `gorm:"size:50"`
	Name string `gorm:"size:50"`
}

func main() {
	db, err := gorm.Open(postgres.New(postgres.Config{
		DSN:                  `dsn`,
		PreferSimpleProtocol: false,
	}), &gorm.Config{
		Logger:                 logger.Default.LogMode(logger.Info),
		SkipDefaultTransaction: true,
	})
	if err != nil {
		log.Fatal(err)
	}

	if err = db.AutoMigrate(&Test{}); err != nil {
		// invalid embedded struct for Test's field ID, should be struct, but got main.ID
		log.Fatal(err)
	}
}
```

* Add typealis test.

* try to fix golangci-lint
2021-12-10 17:45:36 +08:00
Matthieu MOREL
2a578d767f
Use Golangci configuration file (#4896) 2021-12-10 17:44:11 +08:00
kinggo
e5bdd610c3
fix: save not use soft_delete (#4897)
* fix: Save not use soft_delete

* fix: save not use soft_delete

* fix: save not use soft_delete

* fix: save not use soft_delete

Co-authored-by: kinggo <>
2021-12-08 13:58:06 +08:00
Jinzhu
300a23fc31 Check rows.Close error, close #4891 2021-12-02 10:39:24 +08:00
Jinzhu
8627634959 Fix create associations with zero primary key, close #4890 2021-12-02 10:20:16 +08:00
Jinzhu
3a3b82263a Fix auto migration always alert table, close #4198 2021-11-29 20:24:16 +08:00
kinggo
d8a710cba2
fix: count() when use group by and only find one record (#4885)
Co-authored-by: 李龙 <lilong.21@bytedance.com>
2021-11-29 20:14:23 +08:00
Jinzhu
27e2753c9d Fix create duplicated value when updating nested has many relationship, close #4796 2021-11-29 18:43:39 +08:00
Jinzhu
45e804dd3f Fix call valuer interface when using nil value 2021-11-29 16:19:11 +08:00
Jinzhu
92d5a959a0 Fix tests 2021-11-29 15:16:57 +08:00
Jinzhu
270e38c518 Fix duplicated error when Scan, close #4525 2021-11-29 14:23:10 +08:00
Jinzhu
e1b4c066a8 Fix FullSaveAssociations, close #4874 2021-11-29 11:02:44 +08:00
heige
9d5f315b6d
feat: go code style adjust and optimize code for callbacks package (#4861)
* feat: go code style adjust and optimize code for callbacks package

* Update scan.go
2021-11-29 09:33:20 +08:00
Jinzhu
b8f33a42a4
Add unused argument (#4871)
* Append unused argument to gorm statement
2021-11-23 17:11:52 +08:00
dependabot[bot]
cff7845e58
Bump gorm.io/driver/mysql from 1.1.3 to 1.2.0 in /tests (#4856)
Bumps [gorm.io/driver/mysql](https://github.com/go-gorm/mysql) from 1.1.3 to 1.2.0.
- [Release notes](https://github.com/go-gorm/mysql/releases)
- [Commits](https://github.com/go-gorm/mysql/compare/v1.1.3...v1.2.0)

---
updated-dependencies:
- dependency-name: gorm.io/driver/mysql
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-11-23 11:40:18 +08:00
dependabot[bot]
0f8e861597
Bump github.com/jinzhu/now from 1.1.2 to 1.1.3 in /tests (#4866)
Bumps [github.com/jinzhu/now](https://github.com/jinzhu/now) from 1.1.2 to 1.1.3.
- [Release notes](https://github.com/jinzhu/now/releases)
- [Commits](https://github.com/jinzhu/now/compare/v1.1.2...v1.1.3)

---
updated-dependencies:
- dependency-name: github.com/jinzhu/now
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-11-23 11:40:03 +08:00
dependabot[bot]
11d5c346ae
Bump github.com/jinzhu/now from 1.1.2 to 1.1.3 (#4865)
Bumps [github.com/jinzhu/now](https://github.com/jinzhu/now) from 1.1.2 to 1.1.3.
- [Release notes](https://github.com/jinzhu/now/releases)
- [Commits](https://github.com/jinzhu/now/compare/v1.1.2...v1.1.3)

---
updated-dependencies:
- dependency-name: github.com/jinzhu/now
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-11-23 11:39:42 +08:00
dino.ma
5e64ac7de9
feat(migrator,migrator/migrator.go,tests/migrate_test.go) : Get multiple data tables for migrator. (#4841)
* feat(migrator,migrator/migrator.go,tests/migrate_test.go) : Get multiple data tables for migrator.

* feat(migrator.go and migrator/migrator.go) : remove Table Struct replace with []string

* fix(migrator)  : Return all data tables

* Update migrator.go

* fix(migrator/migrator.go):remove var sql

* feat(migrate_test.go/go.mod):update sqlserver,sqlite,postgres,pq version and add getTables test

* fix(migrate_test.go):change GetTables Method Test,use intersection

Co-authored-by: dino.ma <mashengjie03@baidu.com>
2021-11-13 14:03:33 +08:00
riverchu
33bc56cbb5 feat(update): update when has SET clause 2021-11-09 19:55:47 +08:00
Jinzhu
5daa413f41 Stabilize schema.FieldsWithDefaultDBValue's order, close #4643 2021-11-08 20:20:55 +08:00
Jinzhu
ca7accdbf6 Fix preload all associations with inline conditions, close #4836 2021-11-08 19:47:10 +08:00
Jinzhu
b23c3b290e Don't query with primary key when using Save 2021-11-08 18:49:59 +08:00
Mayank Govilla
d9d5c4dce0
Fix self-referential belongs to constraint (#4801)
* create tests for self-ref has one migration

* add relation equality check to avoid skipping self-referential schemas

* remove drop table error check
2021-11-08 09:47:29 +08:00
heige
4c8810a848
Refactor if logic (#4683)
* adjust code for preload

* adjust code for Create
2021-11-04 13:45:44 +08:00
kinggo
c170af11e9
fix connections leak (#4826)
* fix connections leak

* fix connections leak

* fix connections leak

* fix connections leak

Co-authored-by: 李龙 <lilong.21@bytedance.com>
2021-11-03 13:39:52 +08:00
dependabot[bot]
7b927900e9
Bump gorm.io/driver/sqlserver from 1.1.2 to 1.2.0 in /tests (#4820)
Bumps [gorm.io/driver/sqlserver](https://github.com/go-gorm/sqlserver) from 1.1.2 to 1.2.0.
- [Release notes](https://github.com/go-gorm/sqlserver/releases)
- [Commits](https://github.com/go-gorm/sqlserver/compare/v1.1.2...v1.2.0)

---
updated-dependencies:
- dependency-name: gorm.io/driver/sqlserver
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-11-01 17:09:08 +08:00
Jason Lee
8de266b4a7
Add ToSQL support to generate SQL string. (#4787)
* Add db.ToSQL method for generate SQL string.

* Improve sql builder test for all dialects.

Improve assertEqualSQL test helper for ignore quotes in SQL.
2021-11-01 17:08:54 +08:00
Jinzhu
9635d25150 Fix query with uninitialized map 2021-11-01 13:00:52 +08:00
Jinzhu
9f533950a2 Add dest value if current size equal zero 2021-10-28 17:12:31 +08:00
Jinzhu
e953880d19 Add returning tests 2021-10-28 09:17:33 +08:00
Jinzhu
835d7bde59 Add returning support to delete 2021-10-28 07:56:55 +08:00
Jinzhu
af3fbdc2fc Improve returning support 2021-10-26 22:40:14 +08:00
Jason Lee
d3211908a0
Refactor ParseWithSchemaTable method and improve test. (#4789)
* Refactor ParseWithSchemaTable method and improve test.

* Fix schema.ParseWithSchemaTable method for only use schemaTable in migrator and improve test.

* Rename `schemaTable` to `specialTableName` for clearly argument.
2021-10-25 11:26:44 +08:00
Jason Lee
38e55f1117
Merge pull request #4773 from xwjdsh/master
fix: automigrate error caused by indexes while using dynamic table name
2021-10-19 10:11:11 +08:00
Wendell Sun
a3bd9c3ea2 fix: automigrate error caused by indexes while using dynamic table name 2021-10-19 09:59:57 +08:00
Jinzhu
9a5ba37604 Merge branch 'hashicorp-jimlambrt-null-without-ptrs' 2021-10-13 21:02:03 +08:00
Jinzhu
b27095e8a1 Refactor Convert SQL null values to zero values for model fields which are not pointers #4710 2021-10-13 21:01:36 +08:00
Jim
19cf645dbd feat: Convert SQL nulls to zero values (ConvertNullToZeroValues)
Makes it the default behavior to convert SQL null values to zero
values for model fields which are not pointers.
2021-10-13 08:11:22 -04:00
kinggo
696092e287
update tests' go.mod and tests_all.sh (#4774) 2021-10-13 14:41:33 +08:00
kinggo
ec58e3319f
fixed:panic when create value from nil struct pointer. (#4771)
* fixed:create nil pointer

* fixed:panic when create value from nil struct pointer.
2021-10-12 21:19:08 +08:00
kinggo
418c60c83c
fixed: clauseSelect.Columns missed when use Join And execute multiple query. (#4757) 2021-10-09 16:55:45 +08:00
Jinzhu
bfda75d099 Support specify select/omit columns with table 2021-10-09 10:42:41 +08:00
Jinzhu
6312d86c54 Support specify select/omit columns with table 2021-10-08 17:51:27 +08:00
Jinzhu
d4c838c1ce Upgrade sqlite driver 2021-10-08 17:31:58 +08:00
kinggo
b46e2afc4a
fix : update miss where's condition when primary key use "<-:create" tag (#4738)
* fix:update miss where condition

* fix:rename test case
2021-10-08 13:47:01 +08:00
heige
e3fc49a694
feat: ajust PreparedStmtDB unlock location and BuildCondition if logic (#4681) 2021-10-08 11:16:58 +08:00
heige
c13f3011f9
feat: adjust SetupJoinTable func if..else code (#4680) 2021-10-08 11:05:50 +08:00
Paras Waykole
5d91ddac8c
fixed belongs_to & has_one reversed if field same (proper fix) (#4694)
* fixed belongs_to & has_one reversed if field same

* hasmany same foreign key bug fixed and test added

* belongsToSameForeignKey fixed and reverted old fix
2021-10-08 10:59:55 +08:00
dependabot[bot]
57d927d046
Bump gorm.io/driver/postgres from 1.1.1 to 1.1.2 in /tests (#4740)
Bumps [gorm.io/driver/postgres](https://github.com/go-gorm/postgres) from 1.1.1 to 1.1.2.
- [Release notes](https://github.com/go-gorm/postgres/releases)
- [Commits](https://github.com/go-gorm/postgres/compare/v1.1.1...v1.1.2)

---
updated-dependencies:
- dependency-name: gorm.io/driver/postgres
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-10-08 10:54:50 +08:00
s-takehana
0b6bd33934
Update tests.yml (#4741) 2021-10-08 10:51:53 +08:00
River
851fea0221
fix: QuoteTo not fully support raw mode (#4735)
* fix: QuoteTo not fully support raw mode

* fix: table alias without AS

* test: clause.Column/Table quote test

* fix: revert table alias quote
2021-09-29 14:02:35 +08:00
Jinzhu
c4a2e891da Fix Join condition with DB 2021-09-28 22:37:15 +08:00
Jinzhu
002bf78ea7 Fix Join condition with DB, close #4719 2021-09-28 21:43:31 +08:00
kinggo
6864a24150
fix:remove the tableName judgment in pluck (#4731) 2021-09-27 22:11:29 +08:00
Jim
5202529ea1
fix (clause/expression): Allow sql stmt terminator (#4693)
Allow the sql stmt terminator ";" at the end of a named parameter.

Example: select * from table_name where name == @name;
2021-09-20 21:40:48 +08:00
dependabot[bot]
199c8529b6
Bump gorm.io/driver/postgres from 1.1.0 to 1.1.1 in /tests (#4699)
Bumps [gorm.io/driver/postgres](https://github.com/go-gorm/postgres) from 1.1.0 to 1.1.1.
- [Release notes](https://github.com/go-gorm/postgres/releases)
- [Commits](https://github.com/go-gorm/postgres/compare/v1.1.0...v1.1.1)

---
updated-dependencies:
- dependency-name: gorm.io/driver/postgres
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-09-20 21:33:38 +08:00
dependabot[bot]
d67120a155
Bump gorm.io/driver/sqlite from 1.1.4 to 1.1.5 in /tests (#4701)
Bumps [gorm.io/driver/sqlite](https://github.com/go-gorm/sqlite) from 1.1.4 to 1.1.5.
- [Release notes](https://github.com/go-gorm/sqlite/releases)
- [Commits](https://github.com/go-gorm/sqlite/compare/v1.1.4...v1.1.5)

---
updated-dependencies:
- dependency-name: gorm.io/driver/sqlite
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-09-20 21:25:29 +08:00
Jinzhu
ab355336cb Fix scan with interface 2021-09-17 18:35:14 +08:00
Jinzhu
da16a8aac6 Update updated_at when upserting with Create OnConflict 2021-09-17 15:29:49 +08:00
Jinzhu
12bbde89e6 Fix Scan with interface 2021-09-17 14:04:19 +08:00
Jinzhu
61b018cb94 Fix count with selected * 2021-09-16 11:17:54 +08:00
Jinzhu
d41fb3acdc Refactor dummy driver QuoteTo method 2021-09-11 16:22:35 +08:00
Jinzhu
04f049c1da Add tests for RowsAffected 2021-09-09 11:22:55 +08:00
Jinzhu
a16db07945 Refactor Join ON 2021-09-07 21:21:44 +08:00
Jinzhu
ba16b2368f
Refactor update record (#4679) 2021-09-07 20:04:54 +08:00
Jinzhu
6c94b07e98 try to fix fatal error: concurrent map read and map write 2021-09-07 15:30:14 +08:00
Jinzhu
3b6a7c8aec Update sqlserver driver 2021-09-07 12:01:19 +08:00
Adrien Carreira
d047f854e6 PR Comments 2021-09-06 20:13:20 +08:00
Adrien Carreira
c301aeb524 Refactor for readability 2021-09-06 20:13:20 +08:00
Adrien Carreira
52cc438d07 JoinsOn unit test + use all primary keys 2021-09-06 20:13:20 +08:00
Adrien Carreira
895c1178a0 Proposal, Add Specific on for Joins queries 2021-09-06 20:13:20 +08:00
riverchu
eaa63d15e7 feat: copy dest fields to model struct 2021-09-06 20:13:20 +08:00
riverchu
4581e8b590 test: update Save test 2021-09-06 20:13:20 +08:00
riverchu
c898622791 test: add testcase in TestSave 2021-09-06 20:13:20 +08:00
riverchu
1d9e563023 style: prepose error judgement 2021-09-06 20:13:20 +08:00
dependabot[bot]
a89d4d8fd5
Bump github.com/lib/pq from 1.10.2 to 1.10.3 in /tests (#4676)
Bumps [github.com/lib/pq](https://github.com/lib/pq) from 1.10.2 to 1.10.3.
- [Release notes](https://github.com/lib/pq/releases)
- [Commits](https://github.com/lib/pq/compare/v1.10.2...v1.10.3)

---
updated-dependencies:
- dependency-name: github.com/lib/pq
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-09-06 16:26:14 +08:00
dependabot[bot]
5f019f74bf
Bump gorm.io/gorm from 1.21.13 to 1.21.14 in /tests (#4655)
Bumps [gorm.io/gorm](https://github.com/go-gorm/gorm) from 1.21.13 to 1.21.14.
- [Release notes](https://github.com/go-gorm/gorm/releases)
- [Commits](https://github.com/go-gorm/gorm/compare/v1.21.13...v1.21.14)

---
updated-dependencies:
- dependency-name: gorm.io/gorm
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-09-03 17:47:50 +08:00
jxlwqq
15188cf409
Add Go 1.17 (#4666) 2021-09-03 17:47:32 +08:00
Jinzhu
3a8c250180 Refactor calc associations onConflictOption 2021-08-26 13:37:49 +08:00
zkqiang
74746211b8 Test update association with non-updatable 2021-08-26 13:37:49 +08:00
zkqiang
e81833fd11 Fix onConflict with non-updatable in associations 2021-08-26 13:37:49 +08:00
Jinzhu
f21e35f7c5 Fix table not supported error when using unexpected table name 2021-08-26 13:14:16 +08:00
dependabot[bot]
0934b10856
Bump gorm.io/driver/sqlserver from 1.0.7 to 1.0.8 in /tests (#4631)
Bumps [gorm.io/driver/sqlserver](https://github.com/go-gorm/sqlserver) from 1.0.7 to 1.0.8.
- [Release notes](https://github.com/go-gorm/sqlserver/releases)
- [Commits](https://github.com/go-gorm/sqlserver/compare/v1.0.7...v1.0.8)

---
updated-dependencies:
- dependency-name: gorm.io/driver/sqlserver
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-08-23 15:30:02 +08:00
Sec Cake
093694fbf2
Fix extra 'AND' when len(values) == 0 ON IN.NegationBuild() (#4618) 2021-08-20 18:06:48 +08:00
dependabot[bot]
7a53d8e46b
Bump gorm.io/driver/mysql from 1.1.1 to 1.1.2 in /tests (#4615)
Bumps [gorm.io/driver/mysql](https://github.com/go-gorm/mysql) from 1.1.1 to 1.1.2.
- [Release notes](https://github.com/go-gorm/mysql/releases)
- [Commits](https://github.com/go-gorm/mysql/compare/v1.1.1...v1.1.2)

---
updated-dependencies:
- dependency-name: gorm.io/driver/mysql
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-08-20 17:52:56 +08:00
dependabot[bot]
e076e9e0fb
Bump gorm.io/gorm from 1.21.12 to 1.21.13 in /tests (#4616)
Bumps [gorm.io/gorm](https://github.com/go-gorm/gorm) from 1.21.12 to 1.21.13.
- [Release notes](https://github.com/go-gorm/gorm/releases)
- [Commits](https://github.com/go-gorm/gorm/compare/v1.21.12...v1.21.13)

---
updated-dependencies:
- dependency-name: gorm.io/gorm
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-08-20 17:52:48 +08:00
River
1bb0d8732d
feat: count accpet db.table (#4626)
* feat: count accpet `db`.`table`

* fix: logic fix
2021-08-20 17:37:21 +08:00
River
25f561a742
feat: QuoteTo accept clause.Expr (#4621)
* feat: QuoteTo accept clause.Expr

* test: update Expr build test
2021-08-19 14:33:18 +08:00
Jinzhu
2b2f6e77af Add SchemaName to NamingStrategy 2021-08-11 16:20:29 +08:00
Sungyun Hur
a83d25e25e
chore(logger): explicitly set config of Default Logger (#4605) 2021-08-11 11:49:46 +08:00
张凯强
21e85b89d6
Fix create with ignore migration (#4571) 2021-08-09 13:23:44 +08:00
SmallTianTian
82fe815303
fix: table couln't be reentrant (#4556) 2021-08-09 13:20:22 +08:00
Matthieu MOREL
cbe72751ac
Update Dependencies (#4582)
* Create dependabot.yml

* Bump reviewdog/action-golangci-lint from 1 to 2 (#1)

Bumps [reviewdog/action-golangci-lint](https://github.com/reviewdog/action-golangci-lint) from 1 to 2.
- [Release notes](https://github.com/reviewdog/action-golangci-lint/releases)
- [Commits](https://github.com/reviewdog/action-golangci-lint/compare/v1...v2)

---
updated-dependencies:
- dependency-name: reviewdog/action-golangci-lint
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Bump actions/stale from 3.0.7 to 4 (#2)

Bumps [actions/stale](https://github.com/actions/stale) from 3.0.7 to 4.
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v3.0.7...v4)

---
updated-dependencies:
- dependency-name: actions/stale
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Bump gorm.io/gorm from 1.21.9 to 1.21.12 in /tests (#3)

Bumps [gorm.io/gorm](https://github.com/go-gorm/gorm) from 1.21.9 to 1.21.12.
- [Release notes](https://github.com/go-gorm/gorm/releases)
- [Commits](https://github.com/go-gorm/gorm/compare/v1.21.9...v1.21.12)

---
updated-dependencies:
- dependency-name: gorm.io/gorm
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Bump gorm.io/driver/mysql from 1.0.5 to 1.1.1 in /tests (#4)

Bumps [gorm.io/driver/mysql](https://github.com/go-gorm/mysql) from 1.0.5 to 1.1.1.
- [Release notes](https://github.com/go-gorm/mysql/releases)
- [Commits](https://github.com/go-gorm/mysql/compare/v1.0.5...v1.1.1)

---
updated-dependencies:
- dependency-name: gorm.io/driver/mysql
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Bump github.com/lib/pq from 1.6.0 to 1.10.2 in /tests (#5)

Bumps [github.com/lib/pq](https://github.com/lib/pq) from 1.6.0 to 1.10.2.
- [Release notes](https://github.com/lib/pq/releases)
- [Commits](https://github.com/lib/pq/compare/v1.6.0...v1.10.2)

---
updated-dependencies:
- dependency-name: github.com/lib/pq
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Bump github.com/google/uuid from 1.2.0 to 1.3.0 in /tests (#6)

Bumps [github.com/google/uuid](https://github.com/google/uuid) from 1.2.0 to 1.3.0.
- [Release notes](https://github.com/google/uuid/releases)
- [Commits](https://github.com/google/uuid/compare/v1.2.0...v1.3.0)

---
updated-dependencies:
- dependency-name: github.com/google/uuid
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-08-09 13:16:25 +08:00
Walter Scheper
a870486c4f
Do not emit ORDER BY for empty values (#4592)
This restores the behavior from gorm v1, where calling `DB.Order` with
an empty string, nil, or any unexpected type is a no-op.
2021-08-09 13:14:23 +08:00
heige
9e5a4e30b4
Fix migrator GuessConstraintAndTable method for return value for *schema.Check (#4527) 2021-08-03 11:40:57 +08:00
heige
413fe587c6
Optimize migrator.go MigrateColumn and ColumnTypes func. (#4532) 2021-08-02 18:44:10 +08:00
daheige
7a49629fd1 optimize Parse func for fieldValue.Interface 2021-07-28 19:00:34 +08:00
daheige
41ac73b6a1 update comment for ConvertSliceOfMapToValuesForCreate func 2021-07-28 18:56:39 +08:00
Jason Lee
3e15478534
Merge pull request #4530 from daheige/optimize-setupValuerAndSetter
optimize setupValuerAndSetter func
2021-07-28 18:51:09 +08:00
heige
5115813c50
Fix preload fmt.Errorf formatter (#4531) 2021-07-28 18:50:08 +08:00
s-takehana
2202e99cbf
Fix create index with comments in MySQL (#4521)
* Fix create index with comments in MySQL

* Fix tests
2021-07-18 11:47:44 +08:00
daheige
a70254609d optimize setupValuerAndSetter func 2021-07-14 22:03:17 +08:00
Jinzhu
74752018dc Fix hang when closing a prepared statement 2021-07-14 18:35:10 +08:00
River
ac97aec513
New Comma Expression (#4524)
* Add new comma expression

* Add comma expression unit test
2021-07-14 15:51:24 +08:00
Jinzhu
d4f3c109d6 Fix OnConflict with one column, close #4370 2021-07-13 21:29:31 +08:00
Jinzhu
83530ec659 Fix delete order by clause when counting, close #4478 2021-07-13 21:17:43 +08:00
Jinzhu
52b72d7ef2 Add error explanations when preloading assocations w/o foreign fields, close #4356 2021-07-13 21:00:13 +08:00
Jinzhu
b13732c450 Fix invalid preload SQL when no data found, close #4443 2021-07-13 20:23:05 +08:00
Jinzhu
c73fe96cfd Fix scan into decimal.Decimal, close #4457 2021-07-13 19:59:31 +08:00
Jinzhu
b616d810eb Fix scan single value to custom type, close #4501 2021-07-13 19:29:10 +08:00
Jinzhu
76cd73cb82 Fix wipes out MySQL global variables from the query, close #4515 2021-07-13 18:48:43 +08:00
Jinzhu
2ec7043818 Respect update permission for OnConflict Create 2021-07-13 18:04:42 +08:00
Burak Demirpolat
0329b800b0
slightly better callback warning (#4495) 2021-07-13 16:38:44 +08:00
wangyuehong
80497f27a6
title foreign schema for many2many to avoid panic (#4496)
Co-authored-by: yuehong.wang <yuehong.wang@dena.jp>
2021-07-13 16:36:22 +08:00
shiyu7
16579e00c6
fix: fix race issue in prepare method (#4487) 2021-07-01 06:27:12 +08:00
Jason Lee
6d64e31965
Merge pull request #4479 from wuwenchi/master
Fix document for Pluck's usage #4473
2021-06-27 11:29:20 +08:00
wuwenchi
8bd8d38fe9 Fix Pluck's usage #4473 2021-06-26 21:23:16 +08:00
Jinzhu
8e67a08774 Fix Scopes with Row, close #4465 2021-06-18 15:38:20 +08:00
Jinzhu
3226937f68 Fix calc gormSourceDir, close #4456 2021-06-13 10:32:03 +08:00
kalle (jag)
25b9f2e26a
Added return names to logger.Interface.Trace (#4450) 2021-06-11 21:51:40 +08:00
Tony
a0bddccfe1
Use count(*) instead of count(1) include NULL and non-NULL rows(SQL-92). (#4453) 2021-06-11 21:51:18 +08:00
Jinzhu
5b65b02805 Update tests go.mod 2021-06-11 16:00:26 +08:00
Jinzhu
e425ed6f6a Update tests go.mod 2021-06-10 20:26:21 +08:00
heige
50e85e14d4
Code optimize (#4415)
* optimize gormSourceDir replace

* fmt.Errorf adjust and Optimize for-break

* strings trim

* feat: avoid using the same name field and if..else optimization adjustment

* optimization callbacks/create.go Create func if...else logic

* fix: callbacks/create.go Create func

* fix FileWithLineNum func and add gormSourceDir unit test

* remove debug print and utils_filenum_test.go
2021-06-10 10:21:28 +08:00
liamrfell
00b252559f
Fix: FirstOrCreate slice out of bounds error when using 'Assigns' (#4436)
Co-authored-by: Liam Fell <liam@lot.to>
2021-06-07 10:39:24 +08:00
Vitaliy Shein
dd8bf88eb9
add Target where clause for on conflict (#4442)
Co-authored-by: Vitaliy Shein <vitaliy.shein@thebricks.com>
2021-06-07 10:39:00 +08:00
s-takehana
cf079b8b7d
Update version in tests.yml (#4432) 2021-06-02 09:58:22 +08:00
Jinzhu
810058cd55 Fix soft delete with Update 2021-06-01 18:34:38 +08:00
Jinzhu
9abac96546 Fix Eq, Neq support slice of data 2021-05-31 17:21:27 +08:00
Jinzhu
14e96080d8 Eq, Neq support slice of data 2021-05-31 15:25:38 +08:00
heyanfu
363f9b7863
golint standard (#4421) 2021-05-31 10:08:06 +08:00
Ikko Ashimine
bcf2b385a4
Fix typo in associations_test.go (#4407)
occured -> occurred
2021-05-27 17:40:28 +08:00
Brenda Wallace
ac722c16f9
Small grammar fix in error message (#4406) 2021-05-24 10:23:34 +08:00
Jinzhu
ea1bce3771 Only check struct value can address or not 2021-05-23 11:21:56 +08:00
Paras Waykole
79f427d862
fixed has_many stopped working if field names are identical (#4387)
* fixed belongs_to & has_one reversed if field same

* hasmany same foreign key bug fixed and test added
2021-05-19 16:05:29 +08:00
Atreya
cf93b16730
Fix ErrInvalidTransaction error message (#4380) 2021-05-17 15:53:48 +08:00
Jinzhu
92c3ba9dcc Fix create new db sessions in scopes 2021-05-17 15:36:07 +08:00
Chen Quan
a480bd8545
Update Optimize schema (#4364) 2021-05-10 09:51:50 +08:00
Jinzhu
6b7abc54a2 Fix tests 2021-05-06 13:06:31 +08:00
Jinzhu
2aca96d147 test ignore migration, close #4314, #4315 2021-05-05 08:26:48 +08:00
我的我的
3f359eab9b
slim trace if depth (#4346)
Co-authored-by: gogs <guzzsek@gmail.com>
2021-05-05 08:14:40 +08:00
Paras Waykole
8f7f3ad315
fixed belongs_to & has_one reversed if field same (#4343) 2021-05-05 07:57:54 +08:00
Jinzhu
70e93e73d8 Check data type if copyable before change reference field's type 2021-04-30 16:35:55 +08:00
Karolos Lykos
f0d0bbbc10
Added missing white space (#4330)
* Added missing white space

* Added missing white space

* Added missing white space
2021-04-29 07:15:37 +08:00
Jinzhu
6951be0284 Allow customize clauses 2021-04-28 17:19:30 +08:00
Jinzhu
82cb4ebfe2 Fix overwrite Statement in scopes 2021-04-22 13:12:20 +08:00
Sky34gl3
a855fe6402
Fixed naming longer than 64 characters (#4310)
Co-authored-by: Mickael MAUGER <mickael.mauger@almerys.com>
2021-04-22 13:11:19 +08:00
Jinzhu
d327926425 Check ReflectValue.CanAddr before set field value 2021-04-19 21:37:38 +08:00
Chris Faulkner
15a46bc042
Fix some typos (#4294) 2021-04-19 21:03:39 +08:00
Jinzhu
7701c88507 Assign transaction error to db 2021-04-16 19:27:23 +08:00
Jinzhu
d483ffa45c Fix Preload with nil pointer 2021-04-15 10:37:05 +08:00
heige
74e7a9ca07
Optimize reflect value length and method (#4280)
* Respect ignore migration when add column (#4276)

continue https://github.com/go-gorm/gorm/pull/4028

* feat: Optimal value type acquisition for v (#4278)

* feat: optimize relect value length and value

* feat: optimize ConvertSliceOfMapToValuesForCreate method

Co-authored-by: yrong1997 <yrong1997@gmail.com>
2021-04-14 13:00:54 +08:00
heige
5555b010dc
feat: Optimal value type acquisition for v (#4278) 2021-04-13 09:41:30 +08:00
yrong1997
d7911300f8
Respect ignore migration when add column (#4276)
continue https://github.com/go-gorm/gorm/pull/4028
2021-04-13 09:39:43 +08:00
Jinzhu
d278ca49ef sort GORM options before apply 2021-04-09 11:43:24 +08:00
Jinzhu
ad53074f1d Pass db error to new instance 2021-04-09 11:07:14 +08:00
Jinzhu
f3bdfa8261 Add IgnoreRecordNotFoundError option for logger 2021-04-09 10:21:01 +08:00
Jinzhu
673053f56a Fix context cancel error, close #4259, close #4260 2021-04-09 10:21:01 +08:00
gavwu
8cfa9d98f0
Update field.go (#4228)
seems like the `if-else` branch do the same thing, so remove it
2021-04-02 09:56:38 +08:00
Jinzhu
33601dc72f Support Having w/o Group 2021-03-30 18:28:09 +08:00
Jinzhu
73c6d3e64e Add AfterInitialize error 2021-03-29 18:36:01 +08:00
Jinzhu
0eba7a9ed1 Fix apply option 2021-03-26 14:20:42 +08:00
Jinzhu
a8b72546c1 Fix get database connection for prepared stmt, close #4214 2021-03-25 10:17:57 +08:00
Jinzhu
26e0c6fb69 skip test sqlserver due to it will raise data race for invalid sql 2021-03-24 17:12:30 +08:00
Jinzhu
88078e48d0 Remove sqlite_windows test case 2021-03-24 16:56:41 +08:00
Jinzhu
8204d0ada2 Update tests script 2021-03-24 16:44:51 +08:00
Jinzhu
704e53a774 Call scopes before parse model value, close #4209 2021-03-24 16:35:39 +08:00
Jinzhu
4d5cec8bdd Add golang 1.16 2021-03-24 14:22:36 +08:00
Genta Kamitani
26dd4c980a
Fix: FindInBatches ignores errors (#4203) 2021-03-22 14:11:07 +08:00
Jinzhu
8c92d9694a Fix to call Scopes with using Migrator 2021-03-19 16:34:51 +08:00
Jinzhu
a9fe025ef5 Add GetDBConnector interface 2021-03-19 15:55:38 +08:00
Jinzhu
220349ccf2 Fix omit associations, close #4161 2021-03-19 15:15:26 +08:00
Jinzhu
e85b73e5a5 Fix nested Scopes, close #4196 2021-03-19 13:44:25 +08:00
Jinzhu
a3d9bbfc36 build *clause.Expr 2021-03-19 13:21:43 +08:00
Jinzhu
27bb9137d3 Refactor OnConflict.UpdateALl 2021-03-18 11:44:20 +08:00
heige
07f3795f93
optimize MigrateColumn method for regexp (#4188) 2021-03-17 11:32:17 +08:00
Jinzhu
2055e29eb8 Refactor nested preload all associations 2021-03-14 10:42:58 +08:00
ruozhixian
c575a4e719 support to preload all children in multiple levels associations 2021-03-11 16:36:49 +08:00
Jinzhu
912360097a Fix Scopes with Migrator, close #4145 2021-03-11 10:36:14 +08:00
Jinzhu
9fccb17d07 Fix double pointer for where conditions, close #4159 2021-03-10 19:46:59 +08:00
Jinzhu
14b9bd163c Don't panic when using nil pointer, close #4168 2021-03-10 19:32:56 +08:00
Jinzhu
675de6fc16 Clear scopes before invoke scopes methods 2021-03-08 19:21:09 +08:00
Shubhendra Singh Chauhan
0348b1d3c1
chore: improve code quality (#4123)
* Combine multiple `append`s into a single call

* Clean up copied struct fields with type conversion

* Remove unnecessary use of slice
2021-03-08 10:46:43 +08:00
heige
02cb40531e
Optimize parse constraint (#4153)
* for Config.cacheStore store PreparedStmtDB key

* invalid db error and value and invalid value length error (#4151)

* support named params in Select API  (#4142)

* adds support for named arguments in select

* changes clause identifies and adds test

* optimize match english letters and midline

Co-authored-by: Ratan Phayade <ratanphayade@users.noreply.github.com>
2021-03-08 10:21:33 +08:00
heige
221d0a0ec1
optimize value of reflection length (#4152) 2021-03-08 10:20:04 +08:00
Ratan Phayade
a3abb5fedf
support named params in Select API (#4142)
* adds support for named arguments in select

* changes clause identifies and adds test
2021-03-07 10:59:00 +08:00
heige
bc347758e5
for Config.cacheStore store PreparedStmtDB key (#4149) 2021-03-07 10:57:22 +08:00
heige
495ec4bd87
invalid db error and value and invalid value length error (#4151) 2021-03-07 10:56:32 +08:00
Jinzhu
a948c84607 Revert "Revert "Don't override the from clauses, close #4129" close #4139"
This reverts commit d6c23586ae435a124353d3c5dfa6f504c24c5c3c.
2021-03-05 22:19:34 +08:00
Jinzhu
d6c23586ae Revert "Don't override the from clauses, close #4129" close #4139
This reverts commit 664755270ddba77cc669de814afca71ae5575fce.
2021-03-05 19:42:54 +08:00
Jinzhu
294625759c Fix after initialize db callback 2021-03-05 14:12:55 +08:00
Jinzhu
1476b2f7d4 Fix apply config 2021-03-04 20:37:44 +08:00
Sivchari
adf85d5b82
change the method of initializing slice (#4097)
* change the method of initializing slice and fixed the length to be specified as 0

* keep the association.go code in the var group

* keep the association.go code in the var group

* change to initializing in var group
2021-03-04 19:44:15 +08:00
Jinzhu
664755270d Don't override the from clauses, close #4129 2021-03-04 19:16:08 +08:00
Jinzhu
90476fea7a Fix Join with slice IN, close #4133 2021-03-04 18:40:47 +08:00
Jinzhu
42999e9809 Fix overwrite preloading associations, close #4134 2021-03-04 18:28:32 +08:00
Jinzhu
0157099576 Use functional options 2021-03-04 17:40:25 +08:00
Jinzhu
3694ef4a2c Fix get current table 2021-02-26 17:30:00 +08:00
Jinzhu
eb9a704fda Fix update UpdatedAt when full saving associations, close #4115 2021-02-26 17:11:25 +08:00
Jinzhu
189547f615 Fix new session with Begin, close #4120 2021-02-26 16:43:43 +08:00
Jinzhu
ddeb143eb9 Lazy call registered scopes 2021-02-25 22:01:59 +08:00
Jinzhu
6b7d18656d Lazy call registered scopes 2021-02-25 20:06:26 +08:00
Jinzhu
828e6b646b Lazy call registered scopes 2021-02-25 18:49:01 +08:00
Jinzhu
940da051a7 Skip nested associations when create data with Select, close #4108 2021-02-23 19:35:24 +08:00
Jinzhu
79225bfe48 Fix Omit/Select without Model value, close #4098 2021-02-18 10:53:29 +08:00
Jinzhu
73d44a4f97 Fix create duplicated constraint, close #4090 2021-02-16 08:52:56 +08:00
Jinzhu
92a2389450 Fix create duplicated constraint, close #4090 2021-02-16 08:35:19 +08:00
Jinzhu
628a0ae707 Fix foreign key & reference with same name, close #4081 2021-02-15 09:10:51 +08:00
Joel Nordell
5744e29fbd
Replacer interface for more flexible NamingStrategy (#4042)
* Change NameReplacer to an interface, allowing custom Replacers.

* Add NoLowerCase option to skip the snake_casing of names.

* Move sync.Map from global variable into member of NamingStrategy.

This maintains backward compatibility by making the smap optional - the
NamingStrategy still works if it is nil. gorm.Open activates it by
calling Init() if the given Namer is a schema.NamingStrategy.

Also, this changes the key stored in the smap to be the original name,
instead of the replaced name.

* Refactor NamingStrategy tests to add more assertions about how and when Replacers get called.

* Remove the name cache from NamingStrategy.
2021-02-14 08:16:24 +08:00
Jinzhu
a13b7a6acb Fix OnConflict where order for postgres, close #4073 2021-02-10 14:11:29 +08:00
Jinzhu
84ea3ec0cc Fix sub query argument order with multiple raw SQL 2021-02-09 19:56:45 +08:00
Jinzhu
df24821896 Fix SubQuery for raw sql 2021-02-09 17:05:50 +08:00
yrong1997
2ba612e805
Add field tag to ignore migration (#4028)
* Add field tag to ignore migration

* Fix null value with space

* refactor migration tag
2021-02-09 16:03:02 +08:00
Jinzhu
883c32e59a Support Unscoped when delete with selected associations, close #4062 2021-02-07 14:36:27 +08:00
Jinzhu
deff0594ee Save associations based on creatable/updatable permission, close #4056 2021-02-07 14:24:11 +08:00
Jinzhu
4373aa01ab Don't call AfterFind hooks if no record found, close #4048 2021-02-07 12:44:59 +08:00
Jinzhu
bb153384d1 Switch driver.Valuer, fmt.Stringer order when format SQL 2021-02-07 11:18:42 +08:00
heige
e80853e7f5
optimization check for ParseCheckConstraints (#4063) 2021-02-07 10:12:13 +08:00
heige
ef5ef18d4a
recommended to use magic const strings (#4059) 2021-02-07 10:09:32 +08:00
Jinzhu
3d3208ed60 initialize config plugins 2021-02-03 16:27:49 +08:00
Jinzhu
8f37cb0195 Make has to be a const, close #4024 2021-02-01 10:42:13 +08:00
Jinzhu
db0cc4d60b Fix too long foreign key/checker names, close #4026 2021-02-01 10:37:12 +08:00
Jinzhu
7598204dc3 Support FullSaveAssociations for association mode, close #4010 2021-01-29 16:55:26 +08:00
Jinzhu
6e3ac74b7e Fix preloading all associations together with nested associations, close #4016 2021-01-28 20:18:01 +08:00
David Harkness
4267df02af
Fix typo in README (#4012) 2021-01-28 10:21:58 +08:00
Jinzhu
8500380e60 Add name checker test, close #4007 2021-01-27 17:45:59 +08:00
Ben
cc61202fe2
retrieving gorm object support pointer (#4006) 2021-01-27 11:50:15 +08:00
Jinzhu
81aa949105 Remove the uncessary reflect.Ptr 2021-01-27 11:48:47 +08:00
rorschach
ba59065024 retrieving gorm object support pointer 2021-01-27 11:48:47 +08:00
Manyanda Chitimbo
f6308ed223
refactor: fix typo in tests.yml (#4005) 2021-01-27 11:18:39 +08:00
Jinzhu
7f198ead0e Refactor nested preloading associations, close #3970 2021-01-26 16:33:19 +08:00
Jinzhu
08678106a4 Support replace associations without the creation in association mode, close #3937 2021-01-26 14:34:21 +08:00
Jinzhu
916338a9e1 Test migrate constraints, close #3986 2021-01-26 13:39:34 +08:00
Jinzhu
59c01b7943 Make migrator works with dbresolver, close #3992 2021-01-25 10:30:57 +08:00
Jinzhu
f8bd4c4875 Don't create index if there are error exist, close #3976 2021-01-24 10:23:04 +08:00
Jinzhu
35ebfe6874 Support group conditions with single OR condition 2021-01-20 18:24:05 +08:00
Jinzhu
9790103e68 Fix Where with empty struct, close #3966 2021-01-19 16:37:49 +08:00
Jinzhu
6095dbf939 Fix parse embedded relations, close #3964, #3965 2021-01-19 15:40:04 +08:00
Jinzhu
3d87575e7e make Count compatible with Select with Count func, close #3962 2021-01-18 19:43:04 +08:00
Jinzhu
4a15540504 SkipDefaultTransaction skip CreateInBatches transaction 2021-01-18 11:43:42 +08:00
Jinzhu
59fa07953c Preload with settings, close #3945 2021-01-15 17:15:59 +08:00
Jinzhu
79628be2c2 Fix wrong RowsAffected if not data found 2021-01-14 16:01:23 +08:00
Lisa Casner
ce610a9560
title case schema name (#3940) 2021-01-13 13:05:05 +08:00
Jinzhu
de850edb4f Fix Change UpdatedAt to current time when doing OnConflict UpdateAll 2021-01-11 19:16:47 +08:00
Jinzhu
b864a5457a Allow foreign key following the default naming conventions, close #3928 2021-01-10 17:32:17 +08:00
Jinzhu
fe553a7c1a Fix prepared statement in transaction mode can't be shared in normal operations, close #3927 2021-01-10 16:46:06 +08:00
Jinzhu
7302c8a136 Fix tests and logger 2021-01-10 15:27:53 +08:00
Jinzhu
7ebb320f3e Allow customize join table's table in callback 2021-01-10 14:58:54 +08:00
Qt
f9131e309d
reduce DB's Use method complexity and make it easier to understand (#3930) 2021-01-10 10:15:48 +08:00
Jinzhu
d888c799d7 Change UpdatedAt to current time when doing OnConflict UpdateAll 2021-01-08 19:20:42 +08:00
Jinzhu
a5bfe2f39d Keep Error for new Session 2021-01-07 11:45:40 +08:00
Jinzhu
bf0fd9bef6 Fix logger check LogLevel 2021-01-06 16:07:19 +08:00
Jinzhu
5e72cd9a2b Add ErrPrimaryKeyRequired if schema has no primary key defined 2021-01-06 14:42:42 +08:00
Jinzhu
435bf70865 Add OnConflict OnConstraint support, close #3882 2021-01-05 21:31:51 +08:00
Jinzhu
6d260a86bd Fix Set/Get settings when saving associations, close #3908 2021-01-05 21:12:31 +08:00
Jinzhu
53b3ebdd1d Add invalid data error when building conditions 2021-01-05 21:01:16 +08:00
Jinzhu
00a785cd68 Don't use invalid value to build conditions, close #3912 2021-01-05 18:01:51 +08:00
Jinzhu
60b769c2c8 OnConflict UpdateAll includes fields that specified default values via tag 2021-01-04 15:13:56 +08:00
Philip Sahli
9b8d3b3a0f
fix typo (#3911) 2021-01-04 11:30:05 +08:00
Jinzhu
1b8cb07cf2 Allow Where select fields when searching with struct 2020-12-30 18:13:52 +08:00
Jinzhu
79864af9ff Allow customize auto increment increment 2020-12-30 11:16:40 +08:00
Jinzhu
6c0ee2700a Allow to use Valuer with Eq expression, #3899 2020-12-30 10:42:13 +08:00
Jinzhu
065787c54e Compatible with with foreign key with ID suffix #3890 2020-12-28 18:20:55 +08:00
Jinzhu
8bf50a5592 Fix parse relations if only specfied References, close #3890 2020-12-28 17:58:12 +08:00
Jinzhu
ade0bd6d60 Fix SELECT with sql expression in some cases, close #3889 2020-12-28 10:40:30 +08:00
Jinzhu
ad8a5c0d1a Add QueryFields mode when query many2many relations 2020-12-25 16:35:25 +08:00
Jinzhu
59730417aa Fix auto migrate field with customized field type, close https://github.com/go-gorm/mysql/issues/20 2020-12-23 17:31:47 +08:00
Jinzhu
77bf4aecc6 Create associations w/o nested transaction option 2020-12-18 13:25:52 +08:00
Jinzhu
468152d45b Add DisableNestedTransaction support 2020-12-16 19:33:35 +08:00
Jinzhu
6848ae872f Fix gorm.Expr with SubQuery, fix #3857 2020-12-15 15:50:35 +08:00
Jinzhu
0f00493c50 Continue to update tracking fields even not selected with Select, but skip them if omited with Omit, fix #3856 2020-12-15 11:18:29 +08:00
Jinzhu
14a0976dd4 populate the DeletedAt field when soft delete, fix #3855 2020-12-15 10:39:20 +08:00
Jinzhu
21c3f05aa2 Use transaction's conn when preparing statement 2020-12-14 18:31:18 +08:00
vellotis
51b5208599
Fix building of clause.Eq and clause.Neq expressions that fail to handle (*T)(nil) use cases correctly (#3848)
* Update tests to cover building `clause.Eq` and `clause.Neq` when value could be a nil pointer of a primitive

* Fix use cases for `clause.Eq` and `clause.Neq` when value is nil pointer of a primitive type
2020-12-11 14:07:23 +08:00
Jinzhu
e1952924e2 Support named Joins, close #3833 2020-12-07 10:31:06 +08:00
Jinzhu
6a0fca2195 Return error for invalid relations definition, close #3830 2020-12-06 18:07:16 +08:00
Jinzhu
1ef1f0bfe4 Fix Count with complicated Select, close #3826 2020-12-06 14:30:42 +08:00
Jinzhu
f655041908 Allow overwrite ignored field's permission, close #3829 2020-12-06 11:07:05 +08:00
Andy Bursavich
61d3a4d6ea
Fix schema initialization paths (#3825)
* Fix schema initialization paths

The initialized channel was only closed if the schema's cacheStore did not contain the embeddedCacheKey and there were no errors parsing relations.  If the key existed or an error occurred, it would not be closed. This could leave other goroutines waiting for synchronization that will never occur.

Additionally, the other code paths that wait for initialization to complete did not return the possible error.

* Unnest common schema initialization

This makes the common code path less deeply nested and the flow control easier to follow.
2020-12-04 11:28:38 +08:00
Andrei Baibaratsky
f2321ca164
Fixed creation of associated records with composite primary keys (go-gorm#3817) (#3818) 2020-12-03 15:00:26 +08:00
Jinzhu
51568ba4ab Delete select clause after Count, close #3814 2020-12-02 17:27:07 +08:00
Jinzhu
0c12a4c360 Add CreateBatchSize option 2020-12-02 14:59:50 +08:00
SmallTianTian
41e52f343a
fix: scan more base type and sql.NullXXX (#3813) 2020-12-02 14:00:16 +08:00
Dakatan
acedbb8310
Fix Scan int32, uint32 (#3801) 2020-11-30 10:09:08 +08:00
Jinzhu
0f77500917 Waiting for schema to be initialized, close #3790 2020-11-27 17:05:45 +08:00
Jinzhu
6950007d6a Fix failed to parse relations when using goroutinue, close #3790
commit ee0ec43e8dfa85c1c1a562c2d0d47776cf8abd92
Author: Jinzhu <wosmvp@gmail.com>
Date:   Fri Nov 27 14:31:57 2020 +0800

    Fix failed to parse relations when using goroutinue, close #3790

commit 590e73ff95d8af6bd14f0a0da687dd7d12e5f94e
Author: rokeyzhao <rokeyzhao@tencent.com>
Date:   Thu Nov 26 20:27:55 2020 +0800

    test: no cache preload in goroutine
2020-11-27 14:32:20 +08:00
Jinzhu
557b874ee3 Fix check field's precision 2020-11-25 14:55:53 +08:00
Jinzhu
66e8a72bf1 Support NameReplace for NamingStrategy, close #3779 2020-11-23 11:24:07 +08:00
Jinzhu
6186a4daa7 allow SkipHooks when preload & save associations 2020-11-20 16:56:52 +08:00
Jinzhu
dec8748512 Refactor QueryFields Option 2020-11-20 15:44:39 +08:00
Luis Guillermo Gómez
47ffd0bef4
Select all fields in SQL queries avoiding the SELECT * FROM (#3731)
* Select all fields in SQL queries avoiding the SELECT * FROM

* Select table name with fields in SQL queries

* Use QueryFields to execute the SQL query with all fields of the table
2020-11-20 15:38:25 +08:00
Jinzhu
e3b4e0418f Inherit SkipHooks option when preloading associations, close #3772 2020-11-20 15:11:47 +08:00
Deviller
d66af581b4
Fix Association.Replace() error returning (#3766)
* Fix Association.Replace() error returning

* Fallback to gorm.Model at TestAssociationNotNullClear()
2020-11-19 19:24:34 +08:00
Jinzhu
e7f45d5b01 Add error check for Transaction 2020-11-19 10:45:17 +08:00
Jinzhu
a1a30c38de Allow to omit fields when upsert associations, close #3762 2020-11-18 19:06:49 +08:00
Jinzhu
54b80b18bc Allow to omit fields in associations, close #3752 2020-11-17 21:49:40 +08:00
Jinzhu
50df9da6a1 Allow to skip associations when creating join table for many2many, close #3605 2020-11-17 20:24:08 +08:00
Jinzhu
694e42d6a1 Fix clause.IN with only one value of multiple rows 2020-11-17 19:11:24 +08:00
Jinzhu
9df9f7688b Change UpdatingColumn to SkipHooks 2020-11-17 17:49:43 +08:00
Jinzhu
26504f5cae Use NewDB to replace WithConditions for Session 2020-11-17 16:28:37 +08:00
Jinzhu
f6e1786ca2 Add skip hooks support 2020-11-17 15:19:58 +08:00
Jinzhu
f5c2126c29 Fix FindInBatches tests 2020-11-17 13:14:34 +08:00
Jinzhu
320f33061c Fix FindInBatches to modify the query conditions, close #3734 2020-11-17 11:19:04 +08:00
Jinzhu
a8db54afd6 Add CreateInBatches supports 2020-11-16 21:42:30 +08:00
Jinzhu
62be27d3ca Add OnConflict UpdateAll support 2020-11-16 20:22:08 +08:00
alresvor
a4c0c6b400
cache converted name (#3736)
BenchmarkToName-8     	  2322307	       521 ns/op	      88 B/op	       5 allocs/op
↓
BenchmarkToName-8     	19997366	        55.0 ns/op	       0 B/op	       0 allocs/op
2020-11-16 15:16:15 +08:00
Jinzhu
a9f54d53fb Don't preload when there are any error happened 2020-11-16 12:23:13 +08:00
Jinzhu
c1bb8e4551 Should not display the record not found error when using FirstOrXXX, close #3748 2020-11-16 11:20:13 +08:00
Jinzhu
1e241aa645 Reduce GC alloc 2020-11-10 21:23:20 +08:00
LeoZhan
832abda7a4
refactor: simplify the writing instead of using struct literal (#3728) 2020-11-08 09:41:43 +08:00
Jinzhu
85e9f66d26 Fix create index for other database/schema, close #3698 2020-11-05 11:43:21 +08:00
Jinzhu
fcf2ab6c0e Add deleted_at check when soft deleting, fix #3720 2020-11-05 11:20:08 +08:00
Jinzhu
560d303e71 Fix Scan with soft delete, close #3712 2020-11-04 11:03:22 +08:00
Jinzhu
c915471169 Support Expression for OrderBy clause 2020-11-03 10:30:05 +08:00
Amit Basuri
57b033e2dd
Marshalling zero valued Deleted at to nullhttps://github.com/go-gorm/gorm/issues/3693 (#3695) 2020-11-02 10:03:39 +08:00
Jinzhu
3ebdcbdb18 Marshal invalid DeletedAt as null, fix #3693 2020-10-30 19:08:20 +08:00
Jinzhu
a8141b6cc9 Fix DeletedAt marshal and unmarshal, close #3693 2020-10-30 18:15:07 +08:00
Jinzhu
4009ec5816 Fix call hook methods when updating with struct 2020-10-27 18:14:36 +08:00
Jinzhu
d011ebe7af Fix clone statement for Unscoped, UpdatingColumn, close #3681 2020-10-26 10:17:25 +08:00
Jinzhu
cb591a7129 Fix panic when using FirstOrCreate with soft delete, close #3671 2020-10-23 18:40:05 +08:00
Jinzhu
dd92f8bdc0 Allow create table for other database/schema #3640 2020-10-23 11:01:45 +08:00
Jinzhu
db2630cb3a Fix data race problem when using Scan, close #3662 2020-10-22 17:32:39 +08:00
Jinzhu
0aef8acc11 Add smart auto migrate tests 2020-10-22 16:36:27 +08:00
qifengzhang007
6d90d09cb8
Recorder追踪函数trace在finish_api文件358行scan函数所在的371行被调用时,BeginAt 没有赋值,默认值0001-0:0:0导致追踪日志显示的sql耗时无限大. (#3657)
Co-authored-by: 张奇峰 <10515935zwj>
2020-10-22 14:09:09 +08:00
Jinzhu
231aba53c5 Fix count with order by 2020-10-22 11:28:43 +08:00
Jinzhu
5fee5b1b24 Add option tag support for index 2020-10-21 20:18:21 +08:00
Michelle
635dcc9ad4
add gorm ColumnType interface, remove sql one (#3647) 2020-10-21 18:35:33 +08:00
Jinzhu
bdb30da0a7 Fix copy lock for prepared statement, close #3642, #3607 2020-10-21 15:47:46 +08:00
Jinzhu
33a11767ea Upgrade test go.mod dependencies 2020-10-20 19:13:31 +08:00
Jinzhu
9b2181199d Fix soft delete with OrCondition, close #3627 2020-10-19 14:50:11 +08:00
Jinzhu
9dbef26feb Fix feature request label 2020-10-19 11:49:03 +08:00
Jinzhu
5731e632db Merge branch 'tebrizetayi-null-in-logger' 2020-10-19 11:04:35 +08:00
Jinzhu
a1ea1713b0 Fix log Stringer 2020-10-19 11:04:18 +08:00
TABRIZ ATAYI
d825554307 nil point transfer '<nil>' not transfer NULL #3604 2020-10-18 00:05:43 +02:00
Jinzhu
08ecef8e0b Fix NamedArguments with nested struct, close #3596 2020-10-13 15:32:29 +08:00
Jinzhu
689d6e2331 Fix DeletedAt marshalling, close #3598 2020-10-13 14:12:03 +08:00
Jinzhu
063b1ca0c4 Refactor SlowSQL log 2020-10-10 10:56:00 +08:00
Jinzhu
3d846957cd Compatible with tag notNull 2020-10-09 17:42:28 +08:00
Jinzhu
7faf1ca80f Fix Select with AS, close #3581, #3567 2020-10-09 11:52:12 +08:00
Jinzhu
dbc6b34dce Add detailed error information when missing table name 2020-09-29 15:43:31 +08:00
Jinzhu
a2faa41cbe Refactor NamingStrategy, close #3540 2020-09-28 10:55:27 +08:00
Jinzhu
9eec6ae066 Fix affected rows for Scan, change affected rows count for row/rows to '-', close #3532 2020-09-27 12:25:38 +08:00
Jinzhu
ba253982bf Fix Pluck with Time and Scanner 2020-09-24 20:08:24 +08:00
Jinzhu
c0de3c5051 Support FullSaveAssociations Mode, close #3487, #3506 2020-09-24 19:29:15 +08:00
Jinzhu
5228735915 Don't build IN condition if value implemented Valuer interface, #3517 2020-09-24 15:00:13 +08:00
Jinzhu
1a526e6802 Fix NamingStrategy with embedded struct, close #3513 2020-09-24 11:32:38 +08:00
caelansar
68920449f9
Fix format sql log (#3492) 2020-09-19 13:48:34 +08:00
Jinzhu
089939c767 AutoMigrate should auto create indexes, close #3486 2020-09-18 21:50:11 +08:00
Jinzhu
c9165fe3ca Don't panic when using unmatched vars in query, close #3488 2020-09-18 21:42:27 +08:00
Jinzhu
072f1de83a Add DryRunModeUnsupported Error for Row/Rows 2020-09-18 21:35:46 +08:00
Jinzhu
d002c70cf6 Support named argument for struct 2020-09-17 21:52:41 +08:00
Jinzhu
a932175ccf Refactor cascade delete associations 2020-09-15 14:28:26 +08:00
Jinzhu
06d534d6ea Cascade delete associations, close #3473 2020-09-15 12:41:45 +08:00
Jinzhu
1d5f910b6e Update workflows template 2020-09-14 15:30:55 +08:00
Jinzhu
0ec10d4907 Fix format SQL log, close #3465 2020-09-14 12:37:16 +08:00
Jinzhu
ed1b134e1c Fix use uint to for autoCreateTime, autoUpdateTime 2020-09-11 17:33:31 +08:00
Jinzhu
02fb382ec0 Support scan into int, string data types 2020-09-11 15:01:02 +08:00
Jinzhu
e583dfa196 Allow negative number for limit 2020-09-11 11:54:21 +08:00
Jinzhu
b8a74a80d7 Fix embedded struct with default value, close #3451 2020-09-11 11:18:54 +08:00
Jinzhu
70a7bd52ca Support delete associations with Select when deleting 2020-09-10 21:46:18 +08:00
Jinzhu
53caa85cf4 Use db's Logger for callbacks logs, close #3448, #3447 2020-09-10 19:20:47 +08:00
Jinzhu
231effe119 Fix parse blank default value, close #3442 2020-09-10 11:59:18 +08:00
Jinzhu
619d306cef ignore (-) when creating default values, #3434 2020-09-10 10:55:02 +08:00
Jinzhu
f6ed895caf Build relationships if fields are not ignored, fix #3181 2020-09-09 16:37:05 +08:00
Jinzhu
f6117b7f3d Should not diplay SubQuery SQL log, close #3437 2020-09-09 16:26:16 +08:00
Jinzhu
0b6ef3cb87 Merge branch 'jsternberg-migrator-release-conn' 2020-09-09 10:56:07 +08:00
Jinzhu
567597f000 Fix fail on sqlserver, #3433 2020-09-09 10:53:13 +08:00
Jinzhu
e7188c04ca Fix tests & refactor for PR #3429 2020-09-09 10:42:13 +08:00
caelansar
839e09e985 correct generated sql 2020-09-09 10:42:13 +08:00
Jinzhu
2242ac6c0e Fix tests & refactor for PR #3429 2020-09-09 10:31:48 +08:00
Jonathan A. Sternberg
222427c474
Release the connection when discovering the column types in the migrator
When the migrator is used to discover the column types, such as when
used with `AutoMigrate()`, it does not close the query result. This
changes the migrator to close the query result and it also changes the
query to use `LIMIT 1` to prevent additional work against the database
when only discovering the schema.

Fixes #3432.
2020-09-08 18:12:14 -05:00
caelansar
aceb3dad3b correct generated sql 2020-09-08 21:28:04 +08:00
Jinzhu
c70c097e88 Refactor format SQL for driver.Valuer 2020-09-08 19:11:29 +08:00
Jinzhu
c9d5c0b07a Fix create database foreign keys for same type having has many/one & many2many relationships, close #3424 2020-09-08 18:25:29 +08:00
egenchen
6de0356a57
Fix monocolor log output inconsist with colorful log (#3425) 2020-09-08 16:59:47 +08:00
Jinzhu
05794298bd Fix Save with specified table, close #3396 2020-09-06 12:22:08 +08:00
Jinzhu
6e38a2c2d5 Fix many2many join table name rule 2020-09-06 10:51:21 +08:00
Jinzhu
d8ddccf147 Don't marshal to null for associations after preloading, close #3395 2020-09-04 19:09:51 +08:00
Jinzhu
f121622228 Don't add prefix for invalid embedded fields 2020-09-04 14:35:44 +08:00
Jinzhu
28121d4455 Fix panic when batch creating from slice contains invalid data, close #3385 2020-09-03 20:59:41 +08:00
Jinzhu
6a86646469 Fix use db function as integer's default value, close #3384 2020-09-03 20:41:00 +08:00
Jinzhu
dd0d74fad0 Fix transaction on closed conn when using prepared statement, close #3380 2020-09-03 19:16:55 +08:00
Jinzhu
3cd81ff646 Fix query with specified table and conditions, close #3382 2020-09-03 18:42:32 +08:00
Jinzhu
78e9c9b748 raise error when failed to parse default value, close #3378 2020-09-03 18:20:57 +08:00
Jinzhu
f2adb088c5 Set field size from primary fields to foreign fields 2020-09-03 16:11:15 +08:00
Jinzhu
cf31508095 Fix tests_all.sh 2020-09-03 15:02:04 +08:00
Jinzhu
3cc7a30712 Fix tests/go.mod 2020-09-03 13:28:37 +08:00
Jinzhu
98e15e0b95 Setup DB's ConnPool in PrepareStmt mode, fix #3362 2020-09-03 12:54:26 +08:00
Jinzhu
ff3880292d Update missing playground template 2020-09-03 11:48:44 +08:00
Jinzhu
48b395b760 returns ErrEmptySlice when creating with zero length slice 2020-09-03 11:32:30 +08:00
Jinzhu
fcb666cfa3 Fix associations using composite primary keys without ID field, close #3365 2020-09-03 10:58:48 +08:00
Jinzhu
130f24090d update default_value_test 2020-09-02 21:03:47 +08:00
Jinzhu
dbe0f4d8d7 Allow use NULL as default value for string, close #3363 2020-09-02 20:15:12 +08:00
Jinzhu
680dda2c15 Fix combine conditions when using string conditions, close #3358 2020-09-02 20:09:51 +08:00
Jinzhu
dbaa6b0ec3 Fix Scan struct with primary key, close #3357 2020-09-02 16:14:26 +08:00
aimuz
9a101c8a08
fmt.Sprint() to strconv.Format (#3354) 2020-09-01 21:03:37 +08:00
Jinzhu
d1e17d549f request ColumnTypes after new session method 2020-09-01 20:53:54 +08:00
Jinzhu
22317b43c0 Fix migrate field, failed to migrate when field size changed 2020-09-01 18:58:16 +08:00
Jinzhu
bf6123b01e Fix duplicated soft delete clause 2020-09-01 18:05:26 +08:00
Jinzhu
e73147fa8e Better support for scan into map, fix unfriendly data type for interface, close #3351 2020-09-01 17:45:14 +08:00
宋小北
e6f4b711a7
fix order case (#3350) 2020-09-01 15:50:53 +08:00
Jinzhu
e98a4a3a4e Change default timeout interval to avoid test fail on CI 2020-09-01 14:01:59 +08:00
Jinzhu
308d22b166 Clean up associations before Preload, close #3345 2020-09-01 13:48:37 +08:00
Jinzhu
162367be7d Fix multiple M2M relations on one table, close #3347 2020-09-01 11:30:16 +08:00
Jinzhu
0273856e4d Don't alter column with full column data type, close #3339 2020-08-31 16:27:27 +08:00
Jinzhu
496db1f13e Fix named argument with multiple line SQL, fix #3336 2020-08-31 15:45:56 +08:00
Jinzhu
9b0ad4730f Squashed commit of the following:
commit 759038a126122d5b3323979fdd7d867a4ab85585
Author: Jinzhu <wosmvp@gmail.com>
Date:   Mon Aug 31 12:06:31 2020 +0800

    Add PreparedStmt tests

commit 066d54db1fc93ea58c190195104a2d7086623f69
Author: 王岚 <wanglan.backend@bytedance.com>
Date:   Fri Aug 28 18:40:59 2020 +0800

    prepare_stmt add ctx
2020-08-31 12:08:33 +08:00
Jinzhu
53f8c9fc1c More compatible prioritized primary field #3156 2020-08-30 20:58:14 +08:00
Jinzhu
b4166d9515 Fix V2 Save compatibility, close #3332 2020-08-30 10:21:11 +08:00
Jinzhu
59586dcd31 Fix unnecessary duplicated primary condition when using Save, close #3330 2020-08-29 23:02:19 +08:00
Jinzhu
677edf9d9e ignore AS when alias table as it doesn't work on oracle db, close #3328 2020-08-29 22:09:07 +08:00
Jinzhu
06461b3254 GORM V2.0.0 2020-08-28 21:16:47 +08:00
Jinzhu
94c6bb980b Refactor association 2020-08-28 17:32:19 +08:00
Jinzhu
c19a3abefb Fix self-referential belongs to, close #3319 2020-08-28 11:31:13 +08:00
Jinzhu
dacbaa5f02 Fix update attrs order 2020-08-27 19:52:01 +08:00
Jinzhu
d50dbb0896 Fix check valid db name, close #3315 2020-08-27 19:15:40 +08:00
Jinzhu
cd54dddd94 Test update with GormValuer 2020-08-27 18:42:40 +08:00
Jinzhu
7a90496701 Test create from sql expr with map 2020-08-27 16:27:59 +08:00
Jinzhu
ce8853e7a6 Add GormValuer interface support 2020-08-27 15:03:57 +08:00
Jinzhu
0d96f99499 Update README 2020-08-26 12:22:11 +08:00
Jinzhu
3195ae1207 Allow override alias table in preload conditions 2020-08-25 18:59:19 +08:00
Jinzhu
0f3201e73b friendly invalid field error message 2020-08-25 18:18:16 +08:00
Jinzhu
3dfa8a66f1 Fix panic when delet without pointer, close #3308 2020-08-25 17:27:28 +08:00
Jinzhu
84dbb36d3b Add Golang v1.15 2020-08-24 20:24:25 +08:00
Jinzhu
ebdb4edda8 Add AllowGlobalUpdate mode 2020-08-23 20:08:23 +08:00
Jinzhu
cc6a64adfb Support smart migrate, close #3078 2020-08-23 18:16:12 +08:00
Jinzhu
3a97639880 Fix unordered joins, close #3267 2020-08-23 10:45:10 +08:00
Jinzhu
2b510d6423 Don't create index for join table, close #3294 2020-08-21 15:40:50 +08:00
Jinzhu
f88e8b072c Check valid pointer before use it as Valuer 2020-08-20 18:13:29 +08:00
Jinzhu
06de6e8834 Test same field name from embedded field, close #3291 2020-08-20 10:58:35 +08:00
Jinzhu
0c9870d1ae Test Association Mode with conditions 2020-08-20 10:39:01 +08:00
Jinzhu
528e5ba5c4 Cleanup Model after Count 2020-08-19 20:30:39 +08:00
Jinzhu
3313c11888 Fix embedded struct containing field named ID, close #3286 2020-08-19 19:15:27 +08:00
Jinzhu
c1782d60c1 Fix embedded scanner/valuer, close #3283 2020-08-19 15:47:08 +08:00
deepoli
3411425d65
fix return value and delete unused default (#3280) 2020-08-18 19:03:09 +08:00
Jinzhu
b5de8aeb42 Fix overrite SELECT clause 2020-08-18 18:58:53 +08:00
Jinzhu
50826742fd Add error gorm.ErrInvalidData 2020-08-18 18:00:36 +08:00
Jinzhu
dc48e04896 Fix nested embedded struct, close #3278 2020-08-18 11:21:40 +08:00
Jinzhu
9fcc337bd1 Fix create from map 2020-08-17 17:41:36 +08:00
Jinzhu
681268cc43 Refactor Create/Query/Update/DeleteClauses interface 2020-08-17 16:31:09 +08:00
Jinzhu
2a716e04e6 Avoid panic for invalid transaction, close #3271 2020-08-17 12:16:42 +08:00
Jinzhu
6834c25cec Fix stack overflow for embedded self-referred associations, close #3269 2020-08-17 12:02:46 +08:00
Jinzhu
2faff25dfb Fix FirstOr(Init/Create) when assigning with association 2020-08-13 18:38:39 +08:00
Jinzhu
2c4e857125 Should ignore association conditions when querying with struct 2020-08-13 18:09:04 +08:00
Jinzhu
dea93edb6a Copy TableExpr when clone statement 2020-08-13 16:28:21 +08:00
Jinzhu
ecc946be6e Test update from sub query 2020-08-13 16:05:06 +08:00
Jinzhu
045d5f8538 Fix count with join and no model, close #3255 2020-08-13 12:18:36 +08:00
Jinzhu
ec82da396b Merge branch 'caelansar-scanner-valuer-test' 2020-08-13 12:06:07 +08:00
Jinzhu
7d45833f3e Fix driver.Valuer interface returns nil, close #3248 2020-08-13 12:05:55 +08:00
Jinzhu
a3dda47afa Don't parse ignored anonymous field 2020-08-13 10:23:23 +08:00
Jinzhu
4a9d3a688a Don't parse ignored anonymous field 2020-08-11 21:22:51 +08:00
Caelansar
15b96ed3f4 add testcase 2020-08-10 15:34:20 +08:00
Jinzhu
39c8d6220b Fix soft delete panic when using unaddressable value 2020-08-06 17:48:46 +08:00
Jinzhu
3df249c127 Use table expr when inserting table, close #3239 2020-08-06 17:12:31 +08:00
Jinzhu
da1e54d5ab Add sql-cli 2020-08-06 15:37:36 +08:00
Jinzhu
f962872b48 Fix labeler 2020-08-05 14:22:35 +08:00
Jinzhu
ff985b90cc Fix failed to guess relations for embedded types, close #3224 2020-08-04 12:25:34 +08:00
Jinzhu
c11c939b95 callbacks support sort with wildcard 2020-08-03 21:48:36 +08:00
Jinzhu
f83b00d20d Fix Count with Select when Model not specfied, close #3220 2020-08-03 10:30:25 +08:00
Jinzhu
2676fa4fb8 Remove autoincrement tag for join table, close #3217 2020-07-31 18:19:25 +08:00
Jinzhu
dc299b900f Use specified table when preloading data with Join 2020-07-31 14:47:47 +08:00
Jinzhu
81c68db87f Fix zero time failed on mysql 8 2020-07-30 17:56:16 +08:00
Jinzhu
07ce8caf7d Remove labeler workflows 2020-07-30 17:42:41 +08:00
lninl
7bb883b665
Auto creating/updating time with unix (milli) second (#3213)
* Auto creating/updating time with unix (milli) second

* add test for 'Auto creating/updating time with unix (milli) second'
2020-07-30 17:39:57 +08:00
Jinzhu
47a5196734 Fix uninitialized Valuer return time.Time, close #3214 2020-07-30 17:37:07 +08:00
Jinzhu
7c2ecdfc1c Fix use pointer of Valuer as foreign key, close #3212 2020-07-30 10:23:35 +08:00
Jinzhu
2cbdd29f26 Returns error for invalid embedded field, close #3209 2020-07-29 10:23:14 +08:00
Qt
a140908839
refactor function convertParams's default case (#3208) 2020-07-28 17:25:03 +08:00
Jinzhu
c7667e9299 Refactor Prepared Statement 2020-07-28 14:46:48 +08:00
Qt
f4cfa9411b
define err with the same code style (#3199) 2020-07-26 10:03:58 +08:00
Jinzhu
69d8111893 Fix panic when using invalid data, close #3193 2020-07-24 08:32:50 +08:00
Jinzhu
c3f52cee8b Don't scan last insert id 0 2020-07-23 23:56:13 +08:00
Jinzhu
6ed697dd02 TestFirstOrCreateWithPrimaryKey, close #3192 2020-07-23 23:41:56 +08:00
Jinzhu
7021db3655 Fix FieldsWithDefaultDBValue for primary field, close #3187 2020-07-22 19:03:19 +08:00
Jinzhu
87112ab1c7 Fix row callback name 2020-07-22 15:05:38 +08:00
Jinzhu
da16f7b475 Create extension uuid-ossp for postgres test database 2020-07-22 12:13:40 +08:00
Jinzhu
0546b59743 Fix save many2many associations with UUID primary key, close #3182 2020-07-22 11:28:00 +08:00
Jinzhu
ef002fd7ac Add GORMDataType to Field, close #3171 2020-07-20 19:00:03 +08:00
Jinzhu
5d05441067 Test From SubQuery with vars 2020-07-20 08:12:41 +08:00
Jinzhu
a0477f94dd Allow Omit with Query, close #3165 2020-07-19 21:48:58 +08:00
Jinzhu
90183fadde Allow advanced table with args 2020-07-19 21:30:24 +08:00
Jinzhu
de764d9e3d Replace FullTable with TableExpr 2020-07-17 21:19:21 +08:00
Jinzhu
e77156980c Fix panic when using Select/Omit Associations with no schema, close #3160 2020-07-17 15:50:17 +08:00
Jinzhu
6dc583869b Don't use value's first field to guess data type for struct implements GormDataTypeInterface 2020-07-17 12:02:00 +08:00
Jinzhu
362779575c Fix Select with specific symbol, close #3157 2020-07-17 11:24:24 +08:00
Jinzhu
58e3241544 Fix Select with specific symbol, close #3158 2020-07-17 11:06:20 +08:00
Jinzhu
b8692c7671 Allow temporarily disable default transaction 2020-07-16 18:05:55 +08:00
Jinzhu
e83e210971 Update postgres DSN 2020-07-16 17:15:57 +08:00
Jinzhu
2595402507 Add reviewdog 2020-07-16 13:37:02 +08:00
Jinzhu
4456df7a5d Lint with golangci-lint 2020-07-16 11:27:12 +08:00
Jinzhu
0028246ea5 Don't set DefaultValueInterface when DefaultValue not set, close #3152 2020-07-16 10:19:24 +08:00
Jinzhu
72a64bef11 Don't merge clause From 2020-07-15 10:25:10 +08:00
Jinzhu
1f05cb7e55 Handle Associations with pointer of pointer, close #3130 2020-07-10 22:53:03 +08:00
Jinzhu
d4b462a351 Fix alias keyword with Table, close #3104 2020-07-10 21:11:28 +08:00
Jinzhu
33c48611b6 Fix customize table with Delete, close #3129 2020-07-10 13:08:29 +08:00
Jinzhu
c0319f6eed Test map with named argument for raw sql 2020-07-10 12:52:01 +08:00
Jinzhu
bba569af2b Add NamedArg support 2020-07-10 12:28:24 +08:00
Jinzhu
bc3728a18f Fix concurrent map writes, close #3126 2020-07-10 07:14:37 +08:00
Jinzhu
c091cd6aa4 Update stale action 2020-07-09 22:14:11 +08:00
Jinzhu
d04984323f Add stale for v1 action 2020-07-09 22:02:29 +08:00
Jinzhu
a8655f7947 Fix auto select with smaller struct for slices 2020-07-09 12:15:35 +08:00
Jinzhu
0790ff6937 Update tests helper to check time 2020-07-09 09:42:27 +08:00
Jinzhu
2ae0653af2 Fix ambiguous column when using same column name in join table, close #3120 2020-07-09 09:03:48 +08:00
Jinzhu
e1084e78d0 Allow customize AutoIncrement for primary field 2020-07-08 18:50:49 +08:00
Jinzhu
30188e7aa4 CHECK constraint without parentheses 2020-07-08 18:15:45 +08:00
Jinzhu
619cd332ec Add index priority supports 2020-07-08 17:59:40 +08:00
Jinzhu
de482f57ff Test raw sql with gorm.Expr 2020-07-06 15:47:33 +08:00
Jinzhu
b5725940e9 Test Select with Update Struct 2020-07-06 11:20:43 +08:00
Jinzhu
9a4941ba70 Test Order/GroupBy 2020-07-06 09:47:14 +08:00
Jinzhu
4e066c9590 Test Or 2020-07-05 12:23:45 +08:00
Jinzhu
1a2fabb34d Test Not 2020-07-05 11:56:12 +08:00
Jinzhu
89ea62077d DryRun for RowQuery, Exec, close #3106 2020-07-04 08:35:11 +08:00
Jinzhu
90a40361ed Fix set bool field from null 2020-07-04 08:21:23 +08:00
Jinzhu
f835a4deac Add health check for github action databases 2020-07-04 07:57:33 +08:00
Jinzhu
6b98ced13d Fix set time field from null, close #3108 2020-07-04 07:45:07 +08:00
Jinzhu
d4f8a52442 Fix join table foreign key in snake_case 2020-07-04 07:24:46 +08:00
Jinzhu
2416eabd3f Change unique_idnex to UniqueIndex 2020-07-04 00:36:27 +08:00
Jinzhu
f93345afa8 Close cached prepared stmt when got error 2020-07-03 10:26:18 +08:00
Jinzhu
8100ac7663 Change default postgres DSN for github action 2020-07-03 09:27:24 +08:00
Jinzhu
2d945a9641 Switch pgx as default driver 2020-07-03 08:54:12 +08:00
SmallTianTian
3c03b6e527
fix no limit no offset. (#3101)
* fix no limit no offset.

* add test for playground.
2020-07-02 18:14:33 +08:00
Jinzhu
3f355dc050 Refactor 2020-07-02 10:14:30 +08:00
Jinzhu
63e48191a8 Test failed to save association should rollback, close #3100 2020-07-01 21:28:19 +08:00
Jinzhu
b0aae504ab Merge branch 'mojocn-master' 2020-07-01 19:53:50 +08:00
Jinzhu
322c6a36ee Fix .github config 2020-07-01 19:52:16 +08:00
Jinzhu
65d6c19d73 Test multiple index tags 2020-07-01 19:49:04 +08:00
Jinzhu
d342f4122a Better support Count in chain 2020-07-01 19:49:04 +08:00
Jinzhu
9d7df71332 Query with smaller struct 2020-07-01 19:49:04 +08:00
Jinzhu
7aaac3a580 Allow to use sql function in Group, Pluck 2020-07-01 19:49:04 +08:00
Jinzhu
3e4dbde920 Test Hooks For Slice 2020-07-01 19:49:04 +08:00
Jinzhu
66dcd7e3ca Add SetColumn, Changed method 2020-07-01 19:49:04 +08:00
Jinzhu
e308b103c0 SingularTable for JoinTable 2020-07-01 19:49:04 +08:00
Jinzhu
eeee014500 Only query with readable fields 2020-07-01 19:49:04 +08:00
Jinzhu
d5d31b38a7 Test group with table name 2020-07-01 19:49:04 +08:00
Jinzhu
a550a05882 Set db type after autotime 2020-07-01 19:49:04 +08:00
Jinzhu
81f4fafae4 Test group by with multiple columns 2020-07-01 19:49:04 +08:00
Jinzhu
af632199cf Test set string field's default value to blank string 2020-07-01 19:49:04 +08:00
Jinzhu
c888560a0e Fix go.mod 2020-07-01 19:49:04 +08:00
Jinzhu
dcdcc6fedc Fix create with default value 2020-07-01 19:49:04 +08:00
Jinzhu
4cbd99aa94 Add default value test 2020-07-01 19:49:04 +08:00
Jinzhu
19f56ddc2a Upgrade default mysql driver 2020-07-01 19:49:04 +08:00
Jinzhu
6b92bca664 Update test script 2020-07-01 19:49:04 +08:00
Jinzhu
630f4fe03f Create join table with ReorderModels 2020-07-01 19:49:04 +08:00
Jinzhu
fea181e87c Test multiple index tags 2020-07-01 11:47:46 +08:00
Jinzhu
d02b592c6c Better support Count in chain 2020-07-01 10:19:52 +08:00
Jinzhu
9075b33620 Query with smaller struct 2020-07-01 08:56:21 +08:00
Jinzhu
ee1f46e3a1 Allow to use sql function in Group, Pluck 2020-06-30 23:06:48 +08:00
Jinzhu
929c0c576c Test Hooks For Slice 2020-06-30 22:47:21 +08:00
Jinzhu
f5566288de Add SetColumn, Changed method 2020-06-30 16:53:54 +08:00
Jinzhu
2d048d9ece SingularTable for JoinTable 2020-06-30 07:29:15 +08:00
Jinzhu
9bfe306975 Only query with readable fields 2020-06-27 08:04:12 +08:00
Jinzhu
cb5a35a807 Test group with table name 2020-06-26 08:39:18 +08:00
Jinzhu
2476c0fbb4 Set db type after autotime 2020-06-26 07:26:45 +08:00
Jinzhu
4eae3fea41 Test group by with multiple columns 2020-06-25 23:37:49 +08:00
Jinzhu
f2b49437fb Test set string field's default value to blank string 2020-06-25 22:48:10 +08:00
Jinzhu
c5feff1591 Fix go.mod 2020-06-25 08:08:37 +08:00
Jinzhu
1b28c187c0 Fix create with default value 2020-06-25 08:00:10 +08:00
Jinzhu
fb56fe993a Add default value test 2020-06-25 06:38:07 +08:00
Jinzhu
3ec7ed1d51 Upgrade default mysql driver 2020-06-24 20:19:28 +08:00
Jinzhu
8ce2dd5548 Update test script 2020-06-24 19:09:19 +08:00
Jinzhu
4a01d4c263 Create join table with ReorderModels 2020-06-24 17:19:11 +08:00
EricZhou
eac6d1bdb9 issue 2020-06-24 16:33:25 +08:00
Jinzhu
834cfa2c78 Disable GORM_VERBOSE in github action 2020-06-24 15:04:46 +08:00
Jinzhu
67bd842645 Update tests all script 2020-06-24 14:56:04 +08:00
Jinzhu
90f817db29 Update issue template 2020-06-24 14:48:44 +08:00
Jinzhu
7e1fa4a44d Fix Count after Session 2020-06-23 22:41:41 +08:00
Jinzhu
4201f7bdab Fix create unique index when creating table, close #3081 2020-06-23 22:14:41 +08:00
mojotv
dd7caa9db0
add macos and windows for sqlite unit test and use cache for go mod package download (#3079)
Co-authored-by: EricZhou <zhouqing1@360.cn>
2020-06-23 16:00:04 +08:00
Hinagiku Soranoba
b733d16f56
Create supports Array / ArrayPtr (#3076)
* add Array / ArrayPtr create tests

* support create using array
2020-06-23 14:38:36 +08:00
Jinzhu
1df757113a initialize plugins map 2020-06-23 10:36:45 +08:00
Jinzhu
f4bfc435cc Add register plugin API 2020-06-23 09:38:51 +08:00
Jinzhu
e77e7bb842 Fix nested embedded field with pointer, close #3071 2020-06-23 09:12:57 +08:00
Jinzhu
32bd6b3e8f Fix Count with Select 2020-06-23 08:51:01 +08:00
Jinzhu
c84a8fe571 Switch to github actions 2020-06-22 23:14:17 +08:00
Jinzhu
71ae2ddbee Refactor github actions 2020-06-22 22:51:54 +08:00
Jinzhu
60d1e68567 Update github action CI 2020-06-22 22:37:14 +08:00
Jinzhu
59d7150917 Update README 2020-06-22 20:22:15 +08:00
Jinzhu
5d044642d1 Allow DisableForeignKeyConstraintWhenMigrating 2020-06-22 11:04:44 +08:00
Jinzhu
7851faa094 Allow close prepared statements, double check before prepare 2020-06-21 18:18:23 +08:00
Jinzhu
d0764bead1 Test migrate with comment and check created constraints 2020-06-21 13:59:43 +08:00
Jinzhu
fee1e4aafd Fix create foreign keys for many2many relations 2020-06-21 10:48:23 +08:00
Jinzhu
5883490aa7 Select, Omit, Preload supports clause.Associations 2020-06-20 17:21:01 +08:00
Jinzhu
a1e35bdc94 Support merge batch data some having primary values 2020-06-20 16:52:15 +08:00
Jinzhu
3d8f6f9cf9 Test GroupConditions 2020-06-20 01:55:30 +08:00
Jinzhu
4f19e2a7b3 Test ForeignKeyConstraints 2020-06-20 01:20:18 +08:00
Jinzhu
d4d339f3b5 Handle data type cases 2020-06-19 22:51:46 +08:00
Jinzhu
e3292b3b41 Test with latest driver vesion 2020-06-19 18:44:19 +08:00
Jinzhu
7dc255acfe Add SavePoint/RollbackTo/NestedTransaction 2020-06-19 18:30:04 +08:00
Jinzhu
2c1b04a2cf Fix failed to create second record in same transaction, close #3060 2020-06-19 12:38:03 +08:00
Jinzhu
07960fe661 Fix []byte support 2020-06-18 11:24:08 +08:00
Jinzhu
96368eb967 Test embedded struct implements Scan & Value interface 2020-06-18 09:15:23 +08:00
Jinzhu
6b2f37189e Fix few cases with postgres 2020-06-18 08:40:41 +08:00
mojotv
ca2c80c8e3
add githubAction CI for tests (#3057) 2020-06-17 20:29:37 +08:00
Jinzhu
e487f355a0 Add DB method 2020-06-17 19:57:54 +08:00
2BFL
d716e456f4
fix broken url (#3053) 2020-06-15 12:28:35 +08:00
Jinzhu
9039e36cfc Allow scan into float close #1373 2020-06-14 19:18:48 +08:00
Jinzhu
1fdc66710e Add table options 2020-06-14 19:13:16 +08:00
Jinzhu
56bdded0f8 Fix statement modifier support 2020-06-14 12:18:46 +08:00
maiyama18
1bbaa43951
fix typos in test method names (#3052) 2020-06-14 09:24:07 +08:00
Razon Yang
537065fbd9
Replace godoc badge with pkg.go.dev (#3051) 2020-06-12 20:00:55 +08:00
Jinzhu
1af325ab4f Upgrade sqlserver driver 2020-06-10 16:06:54 +08:00
Jinzhu
45cb6b49bf Add FindInBatches support 2020-06-10 15:36:34 +08:00
Jinzhu
dbc3f8feb0 Add count soft deleted record test 2020-06-10 13:42:39 +08:00
Jinzhu
0d58d5a3a7 Upsert selected columns 2020-06-10 10:48:48 +08:00
Jinzhu
f3424c6864 Support save slice of data 2020-06-10 00:02:14 +08:00
Jinzhu
22ff8377df Fix Pluck with Table only 2020-06-09 15:36:10 +08:00
Jinzhu
05e6a65ee1 Fix typo 2020-06-09 12:00:43 +08:00
Jinzhu
a42f9bf439 Remove codecov as doesn't support detect code-coverage of separated folders 2020-06-09 11:00:50 +08:00
Jinzhu
c4872cddfd Refactor callbacks 2020-06-09 10:17:24 +08:00
Jinzhu
649d02fddd Add batch upsert tests 2020-06-09 09:04:32 +08:00
Jinzhu
f0b6bd9ee0 Fix typo 2020-06-08 23:25:16 +08:00
Jinzhu
4555796b62 Refactor Execute callbacks 2020-06-08 22:32:35 +08:00
Jinzhu
9f19378304 Grow SQL capacity to reduce allocation 2020-06-08 20:23:47 +08:00
Jinzhu
aaf0725771 Refactor for performance 2020-06-08 17:21:26 +08:00
Douglas Danger Manley
13f96f7a15
Spelling fix for "condtion" -> "condition" (#3042)
This fixes a spelling error in the word "condition"; in particular,
the `BuildCondtion` function should be named `BuildCondition`.
2020-06-08 11:38:51 +08:00
Jinzhu
8f8d549ca3 Refactor merge where exprs 2020-06-08 09:13:34 +08:00
Douglas Danger Manley
72d0fa6196 Fix Statement Where clone array corruption in v2
Method-chaining in gorm is predicated on a `Clause`'s `MergeClause`
method ensuring that the two clauses are disconnected in terms of
pointers (at least in the Wherec case).

However, the original Where implementation used `append`, which
only returns a new instance if the backing array needs to be resized.
In some cases, this is true.  Practically, go doubles the size of the
slice once it gets full, so the following slice `append` calls would
result in a new slice:

* 0 -> 1
* 1 -> 2
* 2 -> 4
* 4 -> 8
* and so on.

So, when the number of "where" conditions was 0, 1, 2, or 4, method-chaining
would work as expected.  However, when it was 3, 5, 6, or 7, modifying the
copy would modify the original.

This also updates the "order by", "group by" and "set" clauses.
2020-06-07 16:54:01 -04:00
Jinzhu
e7b2e92ce3 Remove RecordNotFound method 2020-06-07 22:03:45 +08:00
Jinzhu
31a0553b82 Fix FileWithLineNum on windows 2020-06-07 18:37:05 +08:00
Jinzhu
d11c424334 Fix typo 2020-06-07 15:26:43 +08:00
Jinzhu
4a4b8234de Update issues template 2020-06-07 13:16:09 +08:00
Jinzhu
82d55b1054 Add OnConflict DoUpdates test 2020-06-07 12:50:00 +08:00
Jinzhu
93043334c3
Create FUNDING.yml 2020-06-07 12:47:26 +08:00
Jinzhu
6937d713c3 Refactor clauses 2020-06-06 22:52:08 +08:00
Jinzhu
38d1cd2bf1 Replace For with Locking 2020-06-06 21:35:28 +08:00
Jinzhu
52b763aab3 Add convert map Assignments helper 2020-06-06 17:47:30 +08:00
Jinzhu
1acbb34406 Update wercker.yml 2020-06-06 15:05:24 +08:00
Jinzhu
ebb8511d59 Add go.sum 2020-06-06 14:28:59 +08:00
Jinzhu
edd4be3fcb Update README 2020-06-06 14:23:47 +08:00
Jinzhu
a954d772d7 Support customize gorm field type 2020-06-06 10:47:32 +08:00
Jinzhu
1490a062db Refactor codebase and add benchmark test 2020-06-05 23:26:56 +08:00
Jinzhu
163200d05f Test Hooks 2020-06-05 20:24:15 +08:00
Jinzhu
eda2f023b0 Add Distinct support 2020-06-05 19:19:12 +08:00
Jinzhu
d50879cc28 Add field permission test 2020-06-05 19:18:22 +08:00
Jinzhu
c8e7878b3e Add PrepareStmt support 2020-06-05 11:40:24 +08:00
Jinzhu
9934207c42 Fix logger panic on windows 2020-06-03 14:39:36 +08:00
Jinzhu
b32658358c Fix can't scan null value into normal data types 2020-06-03 09:00:20 +08:00
Jinzhu
94685d1024 Fix can't scan null value into normal data types 2020-06-02 23:30:26 +08:00
Jinzhu
2218e32999 Allow customize table name with TableName 2020-06-02 15:48:19 +08:00
Jinzhu
e959a67f87 Fix callbacks with Match 2020-06-02 12:46:55 +08:00
Jinzhu
669ce48f19 Fix order by primary key if it is not defined 2020-06-02 11:30:21 +08:00
Jinzhu
64ed645e4d Returns ping error 2020-06-02 11:09:17 +08:00
175 changed files with 26081 additions and 3588 deletions

5
.github/FUNDING.yml vendored Normal file
View File

@ -0,0 +1,5 @@
# These are supported funding model platforms
github: [jinzhu]
patreon: jinzhu
open_collective: gorm

View File

@ -1,45 +0,0 @@
Your issue may already be reported! Please search on the [issue track](https://github.com/jinzhu/gorm/issues) before creating one.
### What version of Go are you using (`go version`)?
### Which database and its version are you using?
### Please provide a complete runnable program to reproduce your issue. **IMPORTANT**
Need to runnable with [GORM's docker compose config](https://github.com/jinzhu/gorm/blob/master/docker-compose.yml) or please provides your config.
```go
package main
import (
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mssql"
_ "github.com/jinzhu/gorm/dialects/mysql"
_ "github.com/jinzhu/gorm/dialects/postgres"
_ "github.com/jinzhu/gorm/dialects/sqlite"
)
var db *gorm.DB
func init() {
var err error
db, err = gorm.Open("sqlite3", "test.db")
// db, err = gorm.Open("postgres", "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable")
// db, err = gorm.Open("mysql", "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True")
// db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm")
if err != nil {
panic(err)
}
db.LogMode(true)
}
func main() {
if /* failure condition */ {
fmt.Println("failed")
} else {
fmt.Println("success")
}
}
```

View File

@ -1,9 +0,0 @@
Make sure these boxes checked before submitting your pull request.
- [] Do only one thing
- [] No API-breaking changes
- [] New code/logic commented & tested
For significant changes like big bug fixes, new features, please open an issue to make an agreement on an implementation design/plan first before starting it.
### What did this pull request do?

15
.github/dependabot.yml vendored Normal file
View File

@ -0,0 +1,15 @@
---
version: 2
updates:
- package-ecosystem: gomod
directory: /
schedule:
interval: weekly
- package-ecosystem: github-actions
directory: /
schedule:
interval: weekly
- package-ecosystem: gomod
directory: /tests
schedule:
interval: weekly

166
.github/labels.json vendored Normal file
View File

@ -0,0 +1,166 @@
{
"labels": {
"critical": {
"name": "type:critical",
"colour": "#E84137",
"description": "critical questions"
},
"question": {
"name": "type:question",
"colour": "#EDEDED",
"description": "general questions"
},
"feature": {
"name": "type:feature_request",
"colour": "#43952A",
"description": "feature request"
},
"invalid_question": {
"name": "type:invalid question",
"colour": "#CF2E1F",
"description": "invalid question (not related to GORM or described in document or not enough information provided)"
},
"with_playground": {
"name": "type:with reproduction steps",
"colour": "#00ff00",
"description": "with reproduction steps"
},
"without_playground": {
"name": "type:missing reproduction steps",
"colour": "#CF2E1F",
"description": "missing reproduction steps"
},
"has_pr": {
"name": "type:has pull request",
"colour": "#43952A",
"description": "has pull request"
},
"not_tested": {
"name": "type:not tested",
"colour": "#CF2E1F",
"description": "not tested"
},
"tested": {
"name": "type:tested",
"colour": "#00ff00",
"description": "tested"
},
"breaking_change": {
"name": "type:breaking change",
"colour": "#CF2E1F",
"description": "breaking change"
}
},
"issue": {
"with_playground": {
"requires": 1,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/github.com\/go-gorm\/playground\/pull\/\\d\\d+/s"
}
]
},
"critical": {
"requires": 1,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/(critical|urgent)/i"
},
{
"type": "titleMatches",
"pattern": "/(critical|urgent)/i"
}
]
},
"question": {
"requires": 1,
"conditions": [
{
"type": "titleMatches",
"pattern": "/question/i"
},
{
"type": "descriptionMatches",
"pattern": "/question/i"
}
]
},
"feature": {
"requires": 1,
"conditions": [
{
"type": "titleMatches",
"pattern": "/feature/i"
},
{
"type": "descriptionMatches",
"pattern": "/Describe the feature/i"
}
]
},
"without_playground": {
"requires": 6,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/^((?!github.com\/go-gorm\/playground\/pull\/\\d\\d+).)*$/s"
},
{
"type": "titleMatches",
"pattern": "/^((?!question).)*$/s"
},
{
"type": "descriptionMatches",
"pattern": "/^((?!question).)*$/is"
},
{
"type": "descriptionMatches",
"pattern": "/^((?!Describe the feature).)*$/is"
},
{
"type": "titleMatches",
"pattern": "/^((?!critical|urgent).)*$/s"
},
{
"type": "descriptionMatches",
"pattern": "/^((?!critical|urgent).)*$/s"
}
]
}
},
"pr": {
"critical": {
"requires": 1,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/(critical|urgent)/i"
},
{
"type": "titleMatches",
"pattern": "/(critical|urgent)/i"
}
]
},
"not_tested": {
"requires": 1,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/\\[\\] Tested/"
}
]
},
"breaking_change": {
"requires": 1,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/\\[\\] Non breaking API changes/"
}
]
}
}
}

20
.github/release-drafter.yml vendored Normal file
View File

@ -0,0 +1,20 @@
name-template: 'v Release $NEXT_PATCH_VERSION 🌈'
tag-template: 'v$NEXT_PATCH_VERSION'
categories:
- title: '🚀 Features'
labels:
- 'feature'
- 'enhancement'
- title: '🐛 Bug Fixes'
labels:
- 'fix'
- 'bugfix'
- 'bug'
- title: '🧰 Maintenance'
label: 'chore'
change-template: '- $TITLE @$AUTHOR (#$NUMBER)'
change-title-escapes: '\<*_&'
template: |
## Changes
$CHANGES

31
.github/workflows/create-release.yml vendored Normal file
View File

@ -0,0 +1,31 @@
name: Create Release
on:
push:
tags:
- 'v*.*.*'
permissions:
contents: write
pull-requests: read
jobs:
create_release:
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Generate Release Notes and Publish
id: generate_release_notes
uses: release-drafter/release-drafter@v6
with:
config-name: 'release-drafter.yml'
name: "Release ${{ github.ref_name }}"
tag: ${{ github.ref_name }}
publish: true
prerelease: false
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

26
.github/workflows/golangci-lint.yml vendored Normal file
View File

@ -0,0 +1,26 @@
name: golangci-lint
on:
push:
branches:
- main
- master
pull_request:
permissions:
contents: read
pull-requests: read
jobs:
golangci:
name: lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: stable
- name: golangci-lint
uses: golangci/golangci-lint-action@v7
with:
version: v2.0
only-new-issues: true

28
.github/workflows/invalid_question.yml vendored Normal file
View File

@ -0,0 +1,28 @@
name: "Close invalid questions issues"
on:
schedule:
- cron: "*/10 * * * *"
permissions:
contents: read
jobs:
stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest
env:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v8
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale"
days-before-stale: 0
days-before-close: 30
remove-stale-when-updated: true
only-labels: "type:invalid question"

19
.github/workflows/labeler.yml vendored Normal file
View File

@ -0,0 +1,19 @@
name: "Issue Labeler"
on:
issues:
types: [opened, edited, reopened]
pull_request:
types: [opened, edited, reopened]
jobs:
triage:
runs-on: ubuntu-latest
name: Label issues and pull requests
steps:
- name: check out
uses: actions/checkout@v4
- name: labeler
uses: jinzhu/super-labeler-action@develop
with:
GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}"

View File

@ -0,0 +1,27 @@
name: "Close Missing Playground issues"
on:
schedule:
- cron: "*/10 * * * *"
permissions:
contents: read
jobs:
stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest
env:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v8
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale"
days-before-stale: 0
days-before-close: 30
remove-stale-when-updated: true
only-labels: "type:missing reproduction steps"

28
.github/workflows/stale.yml vendored Normal file
View File

@ -0,0 +1,28 @@
name: "Stale"
on:
schedule:
- cron: "0 2 * * *"
permissions:
contents: read
jobs:
stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest
env:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v8
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"
days-before-stale: 360
days-before-close: 180
stale-issue-label: "status:stale"
exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request'
stale-pr-label: 'status:stale'
exempt-pr-labels: 'type:feature,type:with reproduction steps,type:has pull request'

310
.github/workflows/tests.yml vendored Normal file
View File

@ -0,0 +1,310 @@
name: tests
on:
push:
branches-ignore:
- 'gh-pages'
pull_request:
branches-ignore:
- 'gh-pages'
permissions:
contents: read
jobs:
# Label of the container job
sqlite:
strategy:
matrix:
go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run in windows OS
runs-on: ${{ matrix.platform }}
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=sqlite ./tests/tests_all.sh
mysql:
strategy:
matrix:
dbversion: ['mysql:9', 'mysql:8', 'mysql:5.7']
go: ['1.23', '1.24']
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
services:
mysql:
image: ${{ matrix.dbversion }}
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
ports:
- 9910:3306
options: >-
--health-cmd "mysqladmin ping -ugorm -pgorm"
--health-interval 10s
--health-start-period 10s
--health-timeout 5s
--health-retries 10
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
mariadb:
strategy:
matrix:
dbversion: [ 'mariadb:latest' ]
go: ['1.23', '1.24']
platform: [ ubuntu-latest ]
runs-on: ${{ matrix.platform }}
services:
mysql:
image: ${{ matrix.dbversion }}
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
ports:
- 9910:3306
options: >-
--health-cmd "mariadb-admin ping -ugorm -pgorm"
--health-interval 10s
--health-start-period 10s
--health-timeout 5s
--health-retries 10
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
postgres:
strategy:
matrix:
dbversion: ['postgres:latest', 'postgres:15', 'postgres:14', 'postgres:13']
go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }}
services:
postgres:
image: ${{ matrix.dbversion }}
env:
POSTGRES_PASSWORD: gorm
POSTGRES_USER: gorm
POSTGRES_DB: gorm
TZ: Asia/Shanghai
ports:
- 9920:5432
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh
sqlserver:
strategy:
matrix:
go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run test in macOS and windows
runs-on: ${{ matrix.platform }}
services:
mssql:
image: mcr.microsoft.com/mssql/server:2022-latest
env:
TZ: Asia/Shanghai
ACCEPT_EULA: Y
MSSQL_SA_PASSWORD: LoremIpsum86
ports:
- 9930:1433
options: >-
--health-cmd="/opt/mssql-tools18/bin/sqlcmd -S localhost -U sa -P ${MSSQL_SA_PASSWORD} -N -C -l 30 -Q \"SELECT 1\" || exit 1"
--health-start-period 10s
--health-interval 10s
--health-timeout 5s
--health-retries 10
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://sa:LoremIpsum86@localhost:9930?database=master" ./tests/tests_all.sh
tidb:
strategy:
matrix:
dbversion: [ 'v6.5.0' ]
go: ['1.23', '1.24']
platform: [ ubuntu-latest ]
runs-on: ${{ matrix.platform }}
steps:
- name: Setup TiDB
uses: Icemap/tidb-action@main
with:
port: 9940
version: ${{matrix.dbversion}}
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh
gaussdb:
strategy:
matrix:
dbversion: ['opengauss/opengauss:7.0.0-RC1.B023']
go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }}
services:
gaussdb:
image: ${{ matrix.dbversion }}
env:
# GaussDB has password limitations
GS_PASSWORD: Gaussdb@123
TZ: Asia/Shanghai
ports:
- 9950:5432
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: Waiting for GaussDB to be ready
run: |
container_name=$(docker ps --filter "ancestor=opengauss/opengauss:7.0.0-RC1.B023" --format "{{.Names}}")
if [ -z "$container_name" ]; then
echo "Error: failed to find a container created from the 'opengauss/opengauss:7.0.0-RC1.B023' image."
exit 1
fi
max_retries=12
retry_count=0
if [ -t 0 ]; then
TTY_FLAG="-t"
else
TTY_FLAG=""
fi
while [ $retry_count -lt $max_retries ]; do
if docker exec -i "${container_name}" bash -c "su - omm -c 'gsql -U omm -c \"select 1;\"'"
then
echo "Creating database gorm..."
sql_file='/tmp/create_database.sql'
echo "CREATE DATABASE gorm DBCOMPATIBILITY 'PG';" > ${sql_file}
docker cp "${sql_file}" "${container_name}":"${sql_file}"
docker exec -i ${TTY_FLAG} "${container_name}" bash -c "su - omm -c 'gsql -U omm -f ${sql_file}'"
echo "Database initialization completed."
break
fi
echo "Waiting for database to be ready... (attempt $((retry_count + 1))/$max_retries)"
sleep 10
((++retry_count))
done
exit 0
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=gaussdb GORM_DSN="user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh

3
.gitignore vendored
View File

@ -2,3 +2,6 @@ TODO*
documents documents
coverage.txt coverage.txt
_book _book
.idea
vendor
.vscode

19
.golangci.yml Normal file
View File

@ -0,0 +1,19 @@
version: "2"
linters:
default: standard
enable:
- cyclop
- gocritic
- gosec
- ineffassign
- misspell
- prealloc
- unconvert
- unparam
- whitespace
formatters:
enable:
- gofumpt
- goimports

128
CODE_OF_CONDUCT.md Normal file
View File

@ -0,0 +1,128 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to participate in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community includes:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period. This
includes avoiding interactions in community spaces and external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any interaction or public
communication with the community for a specified period. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

View File

@ -1,6 +1,6 @@
The MIT License (MIT) The MIT License (MIT)
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com> Copyright (c) 2013-present Jinzhu <wosmvp@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal

View File

@ -2,41 +2,43 @@
The fantastic ORM library for Golang, aims to be developer friendly. The fantastic ORM library for Golang, aims to be developer friendly.
[![go report card](https://goreportcard.com/badge/gorm.io/gorm "go report card")](https://goreportcard.com/report/gorm.io/gorm) [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm)
[![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b) [![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions)
[![codecov](https://codecov.io/gh/jinzhu/gorm/branch/master/graph/badge.svg)](https://codecov.io/gh/jinzhu/gorm)
[![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm)
[![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm)
[![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT)
[![GoDoc](https://godoc.org/gorm.io/gorm?status.svg)](https://godoc.org/gorm.io/gorm) [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc)
## Overview ## Overview
* Full-Featured ORM (almost) * Full-Featured ORM
* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism) * Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism, Single-table inheritance)
* Hooks (Before/After Create/Save/Update/Delete/Find) * Hooks (Before/After Create/Save/Update/Delete/Find)
* Preloading (eager loading) * Eager loading with `Preload`, `Joins`
* Transactions * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
* Context, Prepared Statement Mode, DryRun Mode
* Batch Insert, FindInBatches, Find To Map
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr
* Composite Primary Key * Composite Primary Key
* SQL Builder
* Auto Migrations * Auto Migrations
* Logger * Logger
* Extendable, write Plugins based on GORM callbacks * Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus…
* Every feature comes with tests * Every feature comes with tests
* Developer Friendly * Developer Friendly
## Getting Started ## Getting Started
* GORM Guides [https://gorm.io](https://gorm.io) * GORM Guides [https://gorm.io](https://gorm.io)
* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html)
## Contributing ## Contributing
[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html) [You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
## Contributors
[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework!
## License ## License
© Jinzhu, 2013~time.Now © Jinzhu, 2013~time.Now
Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License) Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE)

View File

@ -1,7 +1,6 @@
package gorm package gorm
import ( import (
"errors"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
@ -15,6 +14,7 @@ import (
type Association struct { type Association struct {
DB *DB DB *DB
Relationship *schema.Relationship Relationship *schema.Relationship
Unscope bool
Error error Error error
} }
@ -27,10 +27,13 @@ func (db *DB) Association(column string) *Association {
association.Relationship = db.Statement.Schema.Relationships.Relations[column] association.Relationship = db.Statement.Schema.Relationships.Relations[column]
if association.Relationship == nil { if association.Relationship == nil {
association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column) association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column)
} }
db.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(db.Statement.Model)) db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
}
} else { } else {
association.Error = err association.Error = err
} }
@ -38,34 +41,19 @@ func (db *DB) Association(column string) *Association {
return association return association
} }
func (association *Association) Unscoped() *Association {
return &Association{
DB: association.DB,
Relationship: association.Relationship,
Error: association.Error,
Unscope: true,
}
}
func (association *Association) Find(out interface{}, conds ...interface{}) error { func (association *Association) Find(out interface{}, conds ...interface{}) error {
if association.Error == nil { if association.Error == nil {
var ( association.Error = association.buildCondition().Find(out, conds...).Error
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
tx = association.DB.Model(out)
)
if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause)
} }
joinStmt.Build("WHERE", "LIMIT")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
tx.Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: queryConds},
}}})
} else {
tx.Clauses(clause.Where{Exprs: queryConds})
}
association.Error = tx.Find(out, conds...).Error
}
return association.Error return association.Error
} }
@ -77,7 +65,7 @@ func (association *Association) Append(values ...interface{}) error {
association.Error = association.Replace(values...) association.Error = association.Replace(values...)
} }
default: default:
association.saveAssociation(false, values...) association.saveAssociation( /*clear*/ false, values...)
} }
} }
@ -86,12 +74,30 @@ func (association *Association) Append(values ...interface{}) error {
func (association *Association) Replace(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error {
if association.Error == nil { if association.Error == nil {
// save associations
association.saveAssociation(true, values...)
// set old associations's foreign key to null
reflectValue := association.DB.Statement.ReflectValue reflectValue := association.DB.Statement.ReflectValue
rel := association.Relationship rel := association.Relationship
var oldBelongsToExpr clause.Expression
// we have to record the old BelongsTo value
if association.Unscope && rel.Type == schema.BelongsTo {
var foreignFields []*schema.Field
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
oldBelongsToExpr = clause.IN{Column: column, Values: values}
}
}
// save associations
if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
return association.Error
}
// set old associations's foreign key to null
switch rel.Type { switch rel.Type {
case schema.BelongsTo: case schema.BelongsTo:
if len(values) == 0 { if len(values) == 0 {
@ -99,30 +105,33 @@ func (association *Association) Replace(values ...interface{}) error {
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
} }
case reflect.Struct: case reflect.Struct:
rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
} }
for _, ref := range rel.References { for _, ref := range rel.References {
updateMap[ref.ForeignKey.DBName] = nil updateMap[ref.ForeignKey.DBName] = nil
} }
association.DB.UpdateColumns(updateMap) association.Error = association.DB.UpdateColumns(updateMap).Error
}
if association.Unscope && oldBelongsToExpr != nil {
association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
} }
case schema.HasOne, schema.HasMany: case schema.HasOne, schema.HasMany:
var ( var (
primaryFields []*schema.Field primaryFields []*schema.Field
foreignKeys []string foreignKeys []string
updateMap = map[string]interface{}{} updateMap = map[string]interface{}{}
relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel}) relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel})
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue) tx = association.DB.Model(modelValue)
) )
if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
if column, values := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
tx.Not(clause.IN{Column: column, Values: values}) tx.Not(clause.IN{Column: column, Values: values})
} }
} }
@ -137,9 +146,13 @@ func (association *Association) Replace(values ...interface{}) error {
} }
} }
if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
column, values := schema.ToQueryValues(foreignKeys, pvs) column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap) if association.Unscope {
association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error
} else {
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
}
} }
case schema.Many2Many: case schema.Many2Many:
var ( var (
@ -163,19 +176,19 @@ func (association *Association) Replace(values ...interface{}) error {
} }
} }
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if column, values := schema.ToQueryValues(joinPrimaryKeys, pvs); len(values) > 0 { if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
tx.Where(clause.IN{Column: column, Values: values}) tx.Where(clause.IN{Column: column, Values: values})
} else { } else {
return ErrorPrimaryKeyRequired return ErrPrimaryKeyRequired
} }
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
if relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, rvs); len(relValues) > 0 { if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
} }
tx.Delete(modelValue) association.Error = tx.Delete(modelValue).Error
} }
} }
return association.Error return association.Error
@ -186,7 +199,7 @@ func (association *Association) Delete(values ...interface{}) error {
var ( var (
reflectValue = association.DB.Statement.ReflectValue reflectValue = association.DB.Statement.ReflectValue
rel = association.Relationship rel = association.Relationship
primaryFields, foreignFields []*schema.Field primaryFields []*schema.Field
foreignKeys []string foreignKeys []string
updateAttrs = map[string]interface{}{} updateAttrs = map[string]interface{}{}
conds []clause.Expression conds []clause.Expression
@ -195,7 +208,6 @@ func (association *Association) Delete(values ...interface{}) error {
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.PrimaryValue == "" { if ref.PrimaryValue == "" {
primaryFields = append(primaryFields, ref.PrimaryKey) primaryFields = append(primaryFields, ref.PrimaryKey)
foreignFields = append(foreignFields, ref.ForeignKey)
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
updateAttrs[ref.ForeignKey.DBName] = nil updateAttrs[ref.ForeignKey.DBName] = nil
} else { } else {
@ -205,34 +217,58 @@ func (association *Association) Delete(values ...interface{}) error {
switch rel.Type { switch rel.Type {
case schema.BelongsTo: case schema.BelongsTo:
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) associationDB := association.DB.Session(&Session{})
tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.Schema.PrimaryFieldDBNames, pvs) if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields)
relColumn, relValues := schema.ToQueryValues(foreignKeys, rvs) relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
if association.Unscope {
var foreignFields []*schema.Field
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
}
}
case schema.HasOne, schema.HasMany: case schema.HasOne, schema.HasMany:
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) model := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := association.DB.Model(model)
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(foreignKeys, pvs) if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, rvs) relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
if association.Unscope {
association.Error = tx.Clauses(conds...).Delete(model).Error
} else {
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
}
case schema.Many2Many: case schema.Many2Many:
var ( var (
primaryFields, relPrimaryFields []*schema.Field primaryFields, relPrimaryFields []*schema.Field
joinPrimaryKeys, joinRelPrimaryKeys []string joinPrimaryKeys, joinRelPrimaryKeys []string
modelValue = reflect.New(rel.JoinTable.ModelType).Interface() joinValue = reflect.New(rel.JoinTable.ModelType).Interface()
) )
for _, ref := range rel.References { for _, ref := range rel.References {
@ -249,23 +285,27 @@ func (association *Association) Delete(values ...interface{}) error {
} }
} }
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(joinPrimaryKeys, pvs) if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, rvs) relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
} }
if association.Error == nil { if association.Error == nil {
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) // clean up deleted values's foreign key
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
cleanUpDeletedRelations := func(data reflect.Value) { cleanUpDeletedRelations := func(data reflect.Value) {
if _, zero := rel.Field.ValueOf(data); !zero { if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero {
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data))
primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
switch fieldValue.Kind() { switch fieldValue.Kind() {
@ -273,7 +313,7 @@ func (association *Association) Delete(values ...interface{}) error {
validFieldValues := reflect.Zero(rel.Field.IndirectFieldType) validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
for i := 0; i < fieldValue.Len(); i++ { for i := 0; i < fieldValue.Len(); i++ {
for idx, field := range rel.FieldSchema.PrimaryFields { for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i)) primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i))
} }
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok { if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
@ -281,21 +321,23 @@ func (association *Association) Delete(values ...interface{}) error {
} }
} }
rel.Field.Set(data, validFieldValues.Interface()) association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface())
case reflect.Struct: case reflect.Struct:
for idx, field := range rel.FieldSchema.PrimaryFields { for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(fieldValue) primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue)
} }
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()) if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
break
}
if rel.JoinTable == nil { if rel.JoinTable == nil {
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey || ref.PrimaryValue != "" { if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
} else { } else {
ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
} }
} }
} }
@ -324,33 +366,8 @@ func (association *Association) Clear() error {
func (association *Association) Count() (count int64) { func (association *Association) Count() (count int64) {
if association.Error == nil { if association.Error == nil {
var ( association.Error = association.buildCondition().Count(&count).Error
conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause)
} }
joinStmt.Build("WHERE", "LIMIT")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
tx.Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: conds},
}}})
} else {
tx.Clauses(clause.Where{Exprs: conds})
}
association.Error = tx.Count(&count).Error
}
return return
} }
@ -372,14 +389,18 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
switch rv.Kind() { switch rv.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if rv.Len() > 0 { if rv.Len() > 0 {
association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface()) association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct { if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
} }
} }
case reflect.Struct: case reflect.Struct:
association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface()) if !rv.CanAddr() {
association.Error = ErrInvalidValue
return
}
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct { if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
@ -387,9 +408,13 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} }
case schema.HasMany, schema.Many2Many: case schema.HasMany, schema.Many2Many:
elemType := association.Relationship.Field.IndirectFieldType.Elem() elemType := association.Relationship.Field.IndirectFieldType.Elem()
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source)) oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
var fieldValue reflect.Value
if clear { if clear {
fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap())
} else {
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap())
reflect.Copy(fieldValue, oldFieldValue)
} }
appendToFieldValues := func(ev reflect.Value) { appendToFieldValues := func(ev reflect.Value) {
@ -398,7 +423,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} else if ev.Type().Elem().AssignableTo(elemType) { } else if ev.Type().Elem().AssignableTo(elemType) {
fieldValue = reflect.Append(fieldValue, ev.Elem()) fieldValue = reflect.Append(fieldValue, ev.Elem())
} else { } else {
association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name) association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name)
} }
if elemType.Kind() == reflect.Struct { if elemType.Kind() == reflect.Struct {
@ -412,33 +437,74 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
} }
case reflect.Struct: case reflect.Struct:
if !rv.CanAddr() {
association.Error = ErrInvalidValue
return
}
appendToFieldValues(rv.Addr()) appendToFieldValues(rv.Addr())
} }
if association.Error == nil { if association.Error == nil {
association.Error = association.Relationship.Field.Set(source, fieldValue.Interface()) association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface())
} }
} }
} }
selectedSaveColumns := []string{association.Relationship.Name} selectedSaveColumns := []string{association.Relationship.Name}
omitColumns := []string{}
selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false)
for name, ok := range selectColumns {
columnName := ""
if strings.HasPrefix(name, association.Relationship.Name) {
if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" {
columnName = name
}
} else if strings.HasPrefix(name, clause.Associations) {
columnName = name
}
if columnName != "" {
if ok {
selectedSaveColumns = append(selectedSaveColumns, columnName)
} else {
omitColumns = append(omitColumns, columnName)
}
}
}
for _, ref := range association.Relationship.References { for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey { if !ref.OwnPrimaryKey {
selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name) selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name)
} }
} }
associationDB := association.DB.Session(&Session{}).Model(nil)
if !association.DB.FullSaveAssociations {
associationDB.Select(selectedSaveColumns)
}
if len(omitColumns) > 0 {
associationDB.Omit(omitColumns...)
}
associationDB = associationDB.Session(&Session{})
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if len(values) != reflectValue.Len() { if len(values) != reflectValue.Len() {
// clear old data
if clear && len(values) == 0 { if clear && len(values) == 0 {
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
association.Error = err
break
}
if association.Relationship.JoinTable == nil { if association.Relationship.JoinTable == nil {
for _, ref := range association.Relationship.References { for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()) if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
association.Error = err
break
}
} }
} }
} }
@ -446,24 +512,28 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
break break
} }
association.Error = errors.New("invalid association values, length doesn't match") association.Error = ErrInvalidValueOfLength
return return
} }
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
if association.Error != nil {
return
}
// TODO support save slice data, sql with case? // TODO support save slice data, sql with case?
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
} }
case reflect.Struct: case reflect.Struct:
// clear old data
if clear && len(values) == 0 { if clear && len(values) == 0 {
association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
if association.Relationship.JoinTable == nil { if association.Relationship.JoinTable == nil && association.Error == nil {
for _, ref := range association.Relationship.References { for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
} }
} }
} }
@ -472,15 +542,18 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
for idx, value := range values { for idx, value := range values {
rv := reflect.Indirect(reflect.ValueOf(value)) rv := reflect.Indirect(reflect.ValueOf(value))
appendToRelations(reflectValue, rv, clear && idx == 0) appendToRelations(reflectValue, rv, clear && idx == 0)
if association.Error != nil {
return
}
} }
if len(values) > 0 { if len(values) > 0 {
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Addr().Interface()).Error association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error
} }
} }
for _, assignBack := range assignBacks { for _, assignBack := range assignBacks {
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source)) fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source))
if assignBack.Index > 0 { if assignBack.Index > 0 {
reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
} else { } else {
@ -488,3 +561,33 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} }
} }
} }
func (association *Association) buildCondition() *DB {
var (
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue)
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause)
}
joinStmt.Build("WHERE")
if len(joinStmt.SQL.String()) > 0 {
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
}
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: queryConds},
}}})
} else {
tx.Clauses(clause.Where{Exprs: queryConds})
}
return tx
}

View File

@ -5,9 +5,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"sort"
"time" "time"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
@ -15,12 +15,12 @@ import (
func initializeCallbacks(db *DB) *callbacks { func initializeCallbacks(db *DB) *callbacks {
return &callbacks{ return &callbacks{
processors: map[string]*processor{ processors: map[string]*processor{
"create": &processor{db: db}, "create": {db: db},
"query": &processor{db: db}, "query": {db: db},
"update": &processor{db: db}, "update": {db: db},
"delete": &processor{db: db}, "delete": {db: db},
"row": &processor{db: db}, "row": {db: db},
"raw": &processor{db: db}, "raw": {db: db},
}, },
} }
} }
@ -32,6 +32,7 @@ type callbacks struct {
type processor struct { type processor struct {
db *DB db *DB
Clauses []string
fns []func(*DB) fns []func(*DB)
callbacks []*callback callbacks []*callback
} }
@ -71,28 +72,57 @@ func (cs *callbacks) Raw() *processor {
return cs.processors["raw"] return cs.processors["raw"]
} }
func (p *processor) Execute(db *DB) { func (p *processor) Execute(db *DB) *DB {
curTime := time.Now() // call scopes
db.RowsAffected = 0 for len(db.Statement.scopes) > 0 {
if stmt := db.Statement; stmt != nil { db = db.executeScopes()
if stmt.Model == nil {
stmt.Model = stmt.Dest
} }
var (
curTime = time.Now()
stmt = db.Statement
resetBuildClauses bool
)
if len(stmt.BuildClauses) == 0 {
stmt.BuildClauses = p.Clauses
resetBuildClauses = true
}
if optimizer, ok := db.Statement.Dest.(StatementModifier); ok {
optimizer.ModifyStatement(stmt)
}
// assign model values
if stmt.Model == nil {
stmt.Model = stmt.Dest
} else if stmt.Dest == nil {
stmt.Dest = stmt.Model
}
// parse model values
if stmt.Model != nil { if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil {
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
} else {
db.AddError(err) db.AddError(err)
} }
} }
}
// assign stmt.ReflectValue
if stmt.Dest != nil { if stmt.Dest != nil {
stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest)) stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
for stmt.ReflectValue.Kind() == reflect.Ptr { for stmt.ReflectValue.Kind() == reflect.Ptr {
if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() {
stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem()))
}
stmt.ReflectValue = stmt.ReflectValue.Elem() stmt.ReflectValue = stmt.ReflectValue.Elem()
} }
if !stmt.ReflectValue.IsValid() { if !stmt.ReflectValue.IsValid() {
db.AddError(fmt.Errorf("invalid value")) db.AddError(ErrInvalidValue)
}
} }
} }
@ -100,14 +130,26 @@ func (p *processor) Execute(db *DB) {
f(db) f(db)
} }
if stmt := db.Statement; stmt != nil { if stmt.SQL.Len() > 0 {
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected sql, vars := stmt.SQL.String(), stmt.Vars
}, db.Error) if filter, ok := db.Logger.(ParamsFilter); ok {
sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
stmt.reinit()
// db.Config.statementPool.Put(stmt)
} }
return db.Dialector.Explain(sql, vars...), db.RowsAffected
}, db.Error)
}
if !stmt.DB.DryRun {
stmt.SQL.Reset()
stmt.Vars = nil
}
if resetBuildClauses {
stmt.BuildClauses = nil
}
return db
} }
func (p *processor) Get(name string) func(*DB) { func (p *processor) Get(name string) func(*DB) {
@ -145,14 +187,23 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
func (p *processor) compile() (err error) { func (p *processor) compile() (err error) {
var callbacks []*callback var callbacks []*callback
removedMap := map[string]bool{}
for _, callback := range p.callbacks { for _, callback := range p.callbacks {
if callback.match == nil || callback.match(p.db) { if callback.match == nil || callback.match(p.db) {
callbacks = append(callbacks, callback) callbacks = append(callbacks, callback)
} }
if callback.remove {
removedMap[callback.name] = true
}
} }
if len(removedMap) > 0 {
callbacks = removeCallbacks(callbacks, removedMap)
}
p.callbacks = callbacks
if p.fns, err = sortCallbacks(p.callbacks); err != nil { if p.fns, err = sortCallbacks(p.callbacks); err != nil {
logger.Default.Error(context.Background(), "Got error when compile callbacks, got %v", err) p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
} }
return return
} }
@ -175,7 +226,7 @@ func (c *callback) Register(name string, fn func(*DB)) error {
} }
func (c *callback) Remove(name string) error { func (c *callback) Remove(name string) error {
logger.Default.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum())
c.name = name c.name = name
c.remove = true c.remove = true
c.processor.callbacks = append(c.processor.callbacks, c) c.processor.callbacks = append(c.processor.callbacks, c)
@ -183,7 +234,7 @@ func (c *callback) Remove(name string) error {
} }
func (c *callback) Replace(name string, fn func(*DB)) error { func (c *callback) Replace(name string, fn func(*DB)) error {
logger.Default.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum())
c.name = name c.name = name
c.handler = fn c.handler = fn
c.replace = true c.replace = true
@ -206,23 +257,36 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
names, sorted []string names, sorted []string
sortCallback func(*callback) error sortCallback func(*callback) error
) )
sort.SliceStable(cs, func(i, j int) bool {
if cs[j].before == "*" && cs[i].before != "*" {
return true
}
if cs[j].after == "*" && cs[i].after != "*" {
return true
}
return false
})
for _, c := range cs { for _, c := range cs {
// show warning message the callback name already exists // show warning message the callback name already exists
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
logger.Default.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum())
} }
names = append(names, c.name) names = append(names, c.name)
} }
sortCallback = func(c *callback) error { sortCallback = func(c *callback) error {
if c.before != "" { // if defined before callback if c.before != "" { // if defined before callback
if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { if c.before == "*" && len(sorted) > 0 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
sorted = append([]string{c.name}, sorted...)
}
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// if before callback already sorted, append current callback just after it // if before callback already sorted, append current callback just after it
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
} else if curIdx > sortedIdx { } else if curIdx > sortedIdx {
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before) return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before)
} }
} else if idx := getRIndex(names, c.before); idx != -1 { } else if idx := getRIndex(names, c.before); idx != -1 {
// if before callback exists // if before callback exists
@ -231,12 +295,16 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
} }
if c.after != "" { // if defined after callback if c.after != "" { // if defined after callback
if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { if c.after == "*" && len(sorted) > 0 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
sorted = append(sorted, c.name)
}
} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// if after callback sorted, append current callback to last // if after callback sorted, append current callback to last
sorted = append(sorted, c.name) sorted = append(sorted, c.name)
} else if curIdx < sortedIdx { } else if curIdx < sortedIdx {
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after) return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after)
} }
} else if idx := getRIndex(names, c.after); idx != -1 { } else if idx := getRIndex(names, c.after); idx != -1 {
// if after callback exists but haven't sorted // if after callback exists but haven't sorted
@ -279,3 +347,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
return return
} }
func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
callbacks := make([]*callback, 0, len(cs))
for _, callback := range cs {
if nameMap[callback.name] {
continue
}
callbacks = append(callbacks, callback)
}
return callbacks
}

View File

@ -2,6 +2,7 @@ package callbacks
import ( import (
"reflect" "reflect"
"strings"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -9,21 +10,22 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
func SaveBeforeAssociations(db *gorm.DB) { func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
return func(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil { if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
// Save Belongs To associations // Save Belongs To associations
for _, rel := range db.Statement.Schema.Relationships.BelongsTo { for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if !saveAssociationCheck(db, rel, selectColumns, restricted) { if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue continue
} }
setupReferences := func(obj reflect.Value, elem reflect.Value) { setupReferences := func(obj reflect.Value, elem reflect.Value) {
for _, ref := range rel.References { for _, ref := range rel.References {
if !ref.OwnPrimaryKey { if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(elem) pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
ref.ForeignKey.Set(obj, pv) db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv))
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
dest[ref.ForeignKey.DBName] = pv dest[ref.ForeignKey.DBName] = pv
@ -38,64 +40,81 @@ func SaveBeforeAssociations(db *gorm.DB) {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
var ( var (
objs []reflect.Value rValLen = db.Statement.ReflectValue.Len()
objs = make([]reflect.Value, 0, rValLen)
fieldType = rel.Field.FieldType fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr isPtr = fieldType.Kind() == reflect.Ptr
) )
if !isPtr { if !isPtr {
fieldType = reflect.PtrTo(fieldType) fieldType = reflect.PointerTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{}
for i := 0; i < rValLen; i++ {
obj := db.Statement.ReflectValue.Index(i) obj := db.Statement.ReflectValue.Index(i)
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value if reflect.Indirect(obj).Kind() != reflect.Struct {
rv := rel.Field.ReflectValueOf(obj) // relation reflect value break
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
objs = append(objs, obj)
if isPtr {
elems = reflect.Append(elems, rv)
} else {
elems = reflect.Append(elems, rv.Addr())
} }
} else { if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
setupReferences(obj, rv) rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
if !isPtr {
rv = rv.Addr()
}
objs = append(objs, obj)
elems = reflect.Append(elems, rv)
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, rv)
} }
} }
} }
if elems.Len() > 0 { if elems.Len() > 0 {
if db.AddError(db.Session(&gorm.Session{}).Create(elems.Interface()).Error) == nil { if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil {
for i := 0; i < elems.Len(); i++ { for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i)) setupReferences(objs[i], elems.Index(i))
} }
} }
} }
case reflect.Struct: case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value
if rv.Kind() != reflect.Ptr { if rv.Kind() != reflect.Ptr {
rv = rv.Addr() rv = rv.Addr()
} }
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
db.Session(&gorm.Session{}).Create(rv.Interface())
}
setupReferences(db.Statement.ReflectValue, rv) setupReferences(db.Statement.ReflectValue, rv)
} }
} }
} }
} }
}
}
} }
func SaveAfterAssociations(db *gorm.DB) { func SaveAfterAssociations(create bool) func(db *gorm.DB) {
return func(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil { if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
// Save Has One associations // Save Has One associations
for _, rel := range db.Statement.Schema.Relationships.HasOne { for _, rel := range db.Statement.Schema.Relationships.HasOne {
if !saveAssociationCheck(db, rel, selectColumns, restricted) { if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue continue
} }
@ -107,100 +126,112 @@ func SaveAfterAssociations(db *gorm.DB) {
) )
if !isPtr { if !isPtr {
fieldType = reflect.PtrTo(fieldType) fieldType = reflect.PointerTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i) obj := db.Statement.ReflectValue.Index(i)
if _, zero := rel.Field.ValueOf(obj); !zero { if reflect.Indirect(obj).Kind() == reflect.Struct {
rv := rel.Field.ReflectValueOf(obj) if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj)
if rv.Kind() != reflect.Ptr { if rv.Kind() != reflect.Ptr {
rv = rv.Addr() rv = rv.Addr()
} }
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj) fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
ref.ForeignKey.Set(rv, fv) db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv))
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(rv, ref.PrimaryValue) db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue))
} }
} }
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
elems = reflect.Append(elems, rv) elems = reflect.Append(elems, rv)
} else {
db.Session(&gorm.Session{}).Save(rv.Addr().Interface())
} }
} }
} }
if elems.Len() > 0 { if elems.Len() > 0 {
db.Session(&gorm.Session{}).Create(elems.Interface()) assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
} }
case reflect.Struct: case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
if f.Kind() != reflect.Ptr { if f.Kind() != reflect.Ptr {
f = f.Addr() f = f.Addr()
} }
assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
ref.ForeignKey.Set(f, fv) db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv))
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(f, ref.PrimaryValue) db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue))
} }
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
} }
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f); isZero { saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
db.Session(&gorm.Session{}).Create(f.Interface())
} else {
db.Session(&gorm.Session{}).Save(f.Interface())
}
} }
} }
} }
// Save Has Many associations // Save Has Many associations
for _, rel := range db.Statement.Schema.Relationships.HasMany { for _, rel := range db.Statement.Schema.Relationships.HasMany {
if !saveAssociationCheck(db, rel, selectColumns, restricted) { if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue continue
} }
fieldType := rel.Field.IndirectFieldType.Elem() fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr { if !isPtr {
fieldType = reflect.PtrTo(fieldType) fieldType = reflect.PointerTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) { appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero { if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v)) f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ { for i := 0; i < f.Len(); i++ {
elem := f.Index(i) elem := f.Index(i)
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(v) pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v)
ref.ForeignKey.Set(elem, pv) db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv))
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(elem, ref.PrimaryValue) db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue))
} }
} }
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero { relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
if isPtr { if isPtr {
elems = reflect.Append(elems, elem) elems = reflect.Append(elems, elem)
} else { } else {
elems = reflect.Append(elems, elem.Addr()) elems = reflect.Append(elems, elem.Addr())
} }
} else {
db.Session(&gorm.Session{}).Save(elem.Addr().Interface())
} }
} }
} }
@ -209,65 +240,85 @@ func SaveAfterAssociations(db *gorm.DB) {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
appendToElems(db.Statement.ReflectValue.Index(i)) obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj)
}
} }
case reflect.Struct: case reflect.Struct:
appendToElems(db.Statement.ReflectValue) appendToElems(db.Statement.ReflectValue)
} }
if elems.Len() > 0 { if elems.Len() > 0 {
db.Session(&gorm.Session{}).Create(elems.Interface()) assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
} }
} }
// Save Many2Many associations // Save Many2Many associations
for _, rel := range db.Statement.Schema.Relationships.Many2Many { for _, rel := range db.Statement.Schema.Relationships.Many2Many {
if !saveAssociationCheck(db, rel, selectColumns, restricted) { if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue continue
} }
fieldType := rel.Field.IndirectFieldType.Elem() fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr { if !isPtr {
fieldType = reflect.PtrTo(fieldType) fieldType = reflect.PointerTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 0) distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.JoinTable.ModelType)), 0, 10)
objs := []reflect.Value{} objs := []reflect.Value{}
appendToJoins := func(obj reflect.Value, elem reflect.Value) { appendToJoins := func(obj reflect.Value, elem reflect.Value) {
joinValue := reflect.New(rel.JoinTable.ModelType) joinValue := reflect.New(rel.JoinTable.ModelType)
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj) fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
ref.ForeignKey.Set(joinValue, fv) db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(joinValue, ref.PrimaryValue) db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue))
} else { } else {
fv, _ := ref.PrimaryKey.ValueOf(elem) fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
ref.ForeignKey.Set(joinValue, fv) db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
} }
} }
joins = reflect.Append(joins, joinValue) joins = reflect.Append(joins, joinValue)
} }
identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) { appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero { if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v)) f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ { for i := 0; i < f.Len(); i++ {
elem := f.Index(i) elem := f.Index(i)
if !isPtr {
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero { elem = elem.Addr()
}
objs = append(objs, v) objs = append(objs, v)
if isPtr {
elems = reflect.Append(elems, elem) elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr()) relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
} }
} else {
appendToJoins(v, elem)
} }
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, elem)
}
} }
} }
} }
@ -275,38 +326,128 @@ func SaveAfterAssociations(db *gorm.DB) {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
appendToElems(db.Statement.ReflectValue.Index(i)) obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj)
}
} }
case reflect.Struct: case reflect.Struct:
appendToElems(db.Statement.ReflectValue) appendToElems(db.Statement.ReflectValue)
} }
if elems.Len() > 0 { // optimize elems of reflect value length
db.Session(&gorm.Session{}).Create(elems.Interface()) if elemLen := elems.Len(); elemLen > 0 {
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil)
}
for i := 0; i < elems.Len(); i++ { for i := 0; i < elemLen; i++ {
appendToJoins(objs[i], elems.Index(i)) appendToJoins(objs[i], elems.Index(i))
} }
} }
if joins.Len() > 0 { if joins.Len() > 0 {
db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()) db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{
SkipHooks: db.Statement.SkipHooks,
DisableNestedTransaction: true,
}).Create(joins.Interface()).Error)
}
} }
} }
} }
} }
func saveAssociationCheck(db *gorm.DB, rel *schema.Relationship, selectColumns map[string]bool, restricted bool) bool { func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
savable := true if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
if value, ok := db.Get("gorm:save_association"); ok { onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
savable = utils.CheckTruth(value) for _, dbName := range s.PrimaryFieldDBNames {
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName})
} }
if savable { onConflict.UpdateAll = stmt.DB.FullSaveAssociations
if v, ok := selectColumns[rel.Name]; (ok && v) || (!ok && !restricted) { if !onConflict.UpdateAll {
onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns)
}
} else {
onConflict.DoNothing = true
}
return
}
func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
// stop save association loop
if checkAssociationsSaved(db, rValues) {
return nil
}
var (
selects, omits []string
onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns)
refName = rel.Name + "."
values = rValues.Interface()
)
for name, ok := range selectColumns {
columnName := ""
if strings.HasPrefix(name, refName) {
columnName = strings.TrimPrefix(name, refName)
}
if columnName != "" {
if ok {
selects = append(selects, columnName)
} else {
omits = append(omits, columnName)
}
}
}
tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{
FullSaveAssociations: db.FullSaveAssociations,
SkipHooks: db.Statement.SkipHooks,
DisableNestedTransaction: true,
})
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if tx.Statement.FullSaveAssociations {
tx = tx.Set("gorm:update_track_time", true)
}
if len(selects) > 0 {
tx = tx.Select(selects)
} else if restricted && len(omits) == 0 {
tx = tx.Omit(clause.Associations)
}
if len(omits) > 0 {
tx = tx.Omit(omits...)
}
return db.AddError(tx.Create(values).Error)
}
// check association values has been saved
// if values kind is Struct, check it has been saved
// if values kind is Slice/Array, check all items have been saved
var visitMapStoreKey = "gorm:saved_association_map"
func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
if visit, ok := db.Get(visitMapStoreKey); ok {
if v, ok := visit.(*visitMap); ok {
if loadOrStoreVisitMap(v, values) {
return true return true
} }
} }
} else {
vistMap := make(visitMap)
loadOrStoreVisitMap(&vistMap, values)
db.Set(visitMapStoreKey, &vistMap)
}
return false return false
} }

View File

@ -4,9 +4,19 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
var (
createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
updateClauses = []string{"UPDATE", "SET", "WHERE"}
deleteClauses = []string{"DELETE", "FROM", "WHERE"}
)
type Config struct { type Config struct {
LastInsertIDReversed bool LastInsertIDReversed bool
WithReturning bool CreateClauses []string
QueryClauses []string
UpdateClauses []string
DeleteClauses []string
} }
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
@ -14,37 +24,60 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
return !db.SkipDefaultTransaction return !db.SkipDefaultTransaction
} }
if len(config.CreateClauses) == 0 {
config.CreateClauses = createClauses
}
if len(config.QueryClauses) == 0 {
config.QueryClauses = queryClauses
}
if len(config.DeleteClauses) == 0 {
config.DeleteClauses = deleteClauses
}
if len(config.UpdateClauses) == 0 {
config.UpdateClauses = updateClauses
}
createCallback := db.Callback().Create() createCallback := db.Callback().Create()
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
createCallback.Register("gorm:before_create", BeforeCreate) createCallback.Register("gorm:before_create", BeforeCreate)
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true))
createCallback.Register("gorm:create", Create(config)) createCallback.Register("gorm:create", Create(config))
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations) createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
createCallback.Register("gorm:after_create", AfterCreate) createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
createCallback.Clauses = config.CreateClauses
queryCallback := db.Callback().Query() queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query) queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload) queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery) queryCallback.Register("gorm:after_query", AfterQuery)
queryCallback.Clauses = config.QueryClauses
deleteCallback := db.Callback().Delete() deleteCallback := db.Callback().Delete()
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
deleteCallback.Register("gorm:before_delete", BeforeDelete) deleteCallback.Register("gorm:before_delete", BeforeDelete)
deleteCallback.Register("gorm:delete", Delete) deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
deleteCallback.Register("gorm:delete", Delete(config))
deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
deleteCallback.Clauses = config.DeleteClauses
updateCallback := db.Callback().Update() updateCallback := db.Callback().Update()
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
updateCallback.Register("gorm:before_update", BeforeUpdate) updateCallback.Register("gorm:before_update", BeforeUpdate)
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
updateCallback.Register("gorm:update", Update) updateCallback.Register("gorm:update", Update(config))
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations) updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Register("gorm:after_update", AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
updateCallback.Clauses = config.UpdateClauses
db.Callback().Row().Register("gorm:raw", RowQuery) rowCallback := db.Callback().Row()
db.Callback().Raw().Register("gorm:raw", RawExec) rowCallback.Register("gorm:row", RowQuery)
rowCallback.Clauses = config.QueryClauses
rawCallback := db.Callback().Raw()
rawCallback.Register("gorm:raw", RawExec)
rawCallback.Clauses = config.QueryClauses
} }

32
callbacks/callmethod.go Normal file
View File

@ -0,0 +1,32 @@
package callbacks
import (
"reflect"
"gorm.io/gorm"
)
func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
tx := db.Session(&gorm.Session{NewDB: true})
if called := fc(db.Statement.ReflectValue.Interface(), tx); !called {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
db.Statement.CurDestIndex = 0
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() {
fc(value.Addr().Interface(), tx)
} else {
db.AddError(gorm.ErrInvalidValue)
return
}
db.Statement.CurDestIndex++
}
case reflect.Struct:
if db.Statement.ReflectValue.CanAddr() {
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
} else {
db.AddError(gorm.ErrInvalidValue)
}
}
}
}

View File

@ -1,237 +1,269 @@
package callbacks package callbacks
import ( import (
"fmt"
"reflect" "reflect"
"strings"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
// BeforeCreate before create hooks
func BeforeCreate(db *gorm.DB) { func BeforeCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
tx := db.Session(&gorm.Session{}) callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.BeforeSave { if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok { if i, ok := value.(BeforeSaveInterface); ok {
ok = true called = true
db.AddError(i.BeforeSave(tx)) db.AddError(i.BeforeSave(tx))
} }
} }
if db.Statement.Schema.BeforeCreate { if db.Statement.Schema.BeforeCreate {
if i, ok := value.(gorm.BeforeCreateInterface); ok { if i, ok := value.(BeforeCreateInterface); ok {
ok = true called = true
db.AddError(i.BeforeCreate(tx)) db.AddError(i.BeforeCreate(tx))
} }
} }
return ok return called
} })
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
} }
} }
// Create create hook
func Create(config *Config) func(db *gorm.DB) { func Create(config *Config) func(db *gorm.DB) {
if config.WithReturning { supportReturning := utils.Contains(config.CreateClauses, "RETURNING")
return CreateWithReturning
} else {
return func(db *gorm.DB) { return func(db *gorm.DB) {
if db.Error == nil { if db.Error != nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped { return
}
if db.Statement.Schema != nil {
if !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses { for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c) db.Statement.AddClause(c)
} }
} }
if db.Statement.SQL.String() == "" { if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
db.Statement.AddClauseIfNotExists(clause.Insert{ if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
Table: clause.Table{Name: db.Statement.Table}, fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
}) for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
db.Statement.AddClause(ConvertToCreateValues(db.Statement)) if field.Readable {
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") }
}
if len(fromColumns) > 0 {
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
}
}
}
} }
if !db.DryRun { if db.Statement.SQL.Len() == 0 {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Insert{})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build(db.Statement.BuildClauses...)
}
isDryRun := !db.DryRun && db.Error == nil
if !isDryRun {
return
}
ok, mode := hasReturning(db, supportReturning)
if ok {
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing {
mode |= gorm.ScanOnConflictDoNothing
}
}
rows, err := db.Statement.ConnPool.QueryContext(
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
)
if db.AddError(err) == nil {
defer func() {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, mode)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
return
}
result, err := db.Statement.ConnPool.ExecContext(
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
)
if err != nil {
db.AddError(err)
return
}
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
if db.RowsAffected == 0 {
return
}
var (
pkField *schema.Field
pkFieldName = "@id"
)
if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil ||
!db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue ||
!db.Statement.Schema.PrioritizedPrimaryField.Readable {
return
}
pkField = db.Statement.Schema.PrioritizedPrimaryField
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
}
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
if !supportReturning {
db.AddError(err)
}
return
}
// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) {
case map[string]interface{}:
values[pkFieldName] = insertID
case *map[string]interface{}:
(*values)[pkFieldName] = insertID
case []map[string]interface{}, *[]map[string]interface{}:
mapValues, ok := values.([]map[string]interface{})
if !ok {
if v, ok := values.(*[]map[string]interface{}); ok {
if *v != nil {
mapValues = *v
}
}
}
if config.LastInsertIDReversed {
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
}
for _, mapValue := range mapValues {
if mapValue != nil {
mapValue[pkFieldName] = insertID
}
insertID += schema.DefaultAutoIncrementIncrement
}
default:
if pkField == nil {
return
}
if err == nil {
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil {
if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok {
if insertID, err := result.LastInsertId(); err == nil {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if config.LastInsertIDReversed { if config.LastInsertIDReversed {
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) rv := db.Statement.ReflectValue.Index(i)
insertID-- if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
if isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID -= pkField.AutoIncrementIncrement
}
} }
} else { } else {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) rv := db.Statement.ReflectValue.Index(i)
insertID++ if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID += pkField.AutoIncrementIncrement
}
} }
} }
case reflect.Struct: case reflect.Struct:
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) _, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
} if isZero {
} else { db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
db.AddError(err)
}
}
}
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
} }
} }
} }
} }
} }
func CreateWithReturning(db *gorm.DB) { // AfterCreate after create hooks
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Insert{
Table: clause.Table{Name: db.Statement.Table},
})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
}
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
db.Statement.WriteString(" RETURNING ")
var (
idx int
fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue))
values = make([]interface{}, len(sch.FieldsWithDefaultDBValue))
)
for dbName, field := range sch.FieldsWithDefaultDBValue {
if idx != 0 {
db.Statement.WriteByte(',')
}
fields[idx] = field
db.Statement.WriteQuoted(dbName)
idx++
}
if !db.DryRun {
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
defer rows.Close()
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for rows.Next() {
for idx, field := range fields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
}
if err := rows.Scan(values...); err != nil {
db.AddError(err)
}
db.RowsAffected++
}
case reflect.Struct:
for idx, field := range fields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
}
if rows.Next() {
db.RowsAffected++
err = rows.Scan(values...)
}
}
}
if err != nil {
db.AddError(err)
}
}
} else if !db.DryRun {
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
}
}
}
func AfterCreate(db *gorm.DB) { func AfterCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
tx := db.Session(&gorm.Session{}) callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok {
ok = true
db.AddError(i.AfterSave(tx))
}
}
if db.Statement.Schema.AfterCreate { if db.Statement.Schema.AfterCreate {
if i, ok := value.(gorm.AfterCreateInterface); ok { if i, ok := value.(AfterCreateInterface); ok {
ok = true called = true
db.AddError(i.AfterCreate(tx)) db.AddError(i.AfterCreate(tx))
} }
} }
return ok
}
if ok := callMethod(db.Statement.Dest); !ok { if db.Statement.Schema.AfterSave {
switch db.Statement.ReflectValue.Kind() { if i, ok := value.(AfterSaveInterface); ok {
case reflect.Slice, reflect.Array: called = true
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { db.AddError(i.AfterSave(tx))
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
} }
} }
return called
})
} }
} }
// ConvertToCreateValues convert to create values // ConvertToCreateValues convert to create values
func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
curTime := stmt.DB.NowFunc()
switch value := stmt.Dest.(type) { switch value := stmt.Dest.(type) {
case map[string]interface{}: case map[string]interface{}:
return ConvertMapToValuesForCreate(stmt, value) values = ConvertMapToValuesForCreate(stmt, value)
case *map[string]interface{}:
values = ConvertMapToValuesForCreate(stmt, *value)
case []map[string]interface{}: case []map[string]interface{}:
return ConvertSliceOfMapToValuesForCreate(stmt, value) values = ConvertSliceOfMapToValuesForCreate(stmt, value)
case *[]map[string]interface{}:
values = ConvertSliceOfMapToValuesForCreate(stmt, *value)
default: default:
var ( var (
values = clause.Values{} selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) _, updateTrackTime = stmt.Get("gorm:update_track_time")
curTime = stmt.DB.NowFunc() isZero bool
isZero = false
) )
stmt.Settings.Delete("gorm:update_track_time")
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
for _, db := range stmt.Schema.DBNames { for _, db := range stmt.Schema.DBNames {
if stmt.Schema.FieldsWithDefaultDBValue[db] == nil { if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) {
values.Columns = append(values.Columns, clause.Column{Name: db}) values.Columns = append(values.Columns, clause.Column{Name: db})
} }
} }
@ -239,71 +271,142 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
switch stmt.ReflectValue.Kind() { switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
values.Values = make([][]interface{}, stmt.ReflectValue.Len()) rValLen := stmt.ReflectValue.Len()
defaultValueFieldsHavingValue := map[string][]interface{}{} if rValLen == 0 {
for i := 0; i < stmt.ReflectValue.Len(); i++ { stmt.AddError(gorm.ErrEmptySlice)
return
}
stmt.SQL.Grow(rValLen * 18)
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
values.Values = make([][]interface{}, rValLen)
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
for i := 0; i < rValLen; i++ {
rv := reflect.Indirect(stmt.ReflectValue.Index(i)) rv := reflect.Indirect(stmt.ReflectValue.Index(i))
if !rv.IsValid() {
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
return
}
values.Values[i] = make([]interface{}, len(values.Columns)) values.Values[i] = make([]interface{}, len(values.Columns))
for idx, column := range values.Columns { for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name] field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[i][idx], isZero = field.ValueOf(rv); isZero { if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
if field.DefaultValueInterface != nil { if field.DefaultValueInterface != nil {
values.Values[i][idx] = field.DefaultValueInterface values.Values[i][idx] = field.DefaultValueInterface
field.Set(rv, field.DefaultValueInterface) stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(rv, curTime) stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(rv) values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
} }
} else if field.AutoUpdateTime > 0 && updateTrackTime {
stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
} }
} }
for db, field := range stmt.Schema.FieldsWithDefaultDBValue { for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, isZero := field.ValueOf(rv); !isZero { if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero {
if len(defaultValueFieldsHavingValue[db]) == 0 { if len(defaultValueFieldsHavingValue[field]) == 0 {
defaultValueFieldsHavingValue[db] = make([]interface{}, stmt.ReflectValue.Len()) defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
} }
defaultValueFieldsHavingValue[db][i] = v defaultValueFieldsHavingValue[field][i] = rvOfvalue
} }
} }
} }
} }
for db, vs := range defaultValueFieldsHavingValue { for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
values.Columns = append(values.Columns, clause.Column{Name: db}) if vs, ok := defaultValueFieldsHavingValue[field]; ok {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
for idx := range values.Values { for idx := range values.Values {
if vs[idx] == nil { if vs[idx] == nil {
values.Values[idx] = append(values.Values[idx], clause.Expr{SQL: "DEFAULT"}) values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
} else { } else {
values.Values[idx] = append(values.Values[idx], vs[idx]) values.Values[idx] = append(values.Values[idx], vs[idx])
} }
} }
} }
}
case reflect.Struct: case reflect.Struct:
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
for idx, column := range values.Columns { for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name] field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero { if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
if field.DefaultValueInterface != nil { if field.DefaultValueInterface != nil {
values.Values[0][idx] = field.DefaultValueInterface values.Values[0][idx] = field.DefaultValueInterface
field.Set(stmt.ReflectValue, field.DefaultValueInterface) stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(stmt.ReflectValue, curTime) stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
} else if field.AutoUpdateTime > 0 && updateTrackTime {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], rvOfvalue)
}
}
}
default:
stmt.AddError(gorm.ErrInvalidData)
}
}
if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
if stmt.Schema != nil && len(values.Columns) >= 1 {
selectColumns, restricted := stmt.SelectAndOmitColumns(true, true)
columns := make([]string, 0, len(values.Columns)-1)
for _, column := range values.Columns {
if field := stmt.Schema.LookUpField(column.Name); field != nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil ||
strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 {
if field.AutoUpdateTime > 0 {
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
switch field.AutoUpdateTime {
case schema.UnixNanosecond:
assignment.Value = curTime.UnixNano()
case schema.UnixMillisecond:
assignment.Value = curTime.UnixMilli()
case schema.UnixSecond:
assignment.Value = curTime.Unix()
}
onConflict.DoUpdates = append(onConflict.DoUpdates, assignment)
} else {
columns = append(columns, column.Name)
}
}
} }
} }
} }
for db, field := range stmt.Schema.FieldsWithDefaultDBValue { onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...)
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { if len(onConflict.DoUpdates) == 0 {
if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero { onConflict.DoNothing = true
values.Columns = append(values.Columns, clause.Column{Name: db})
values.Values[0] = append(values.Values[0], v)
} }
// use primary fields as default OnConflict columns
if len(onConflict.Columns) == 0 {
for _, field := range stmt.Schema.PrimaryFields {
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName})
}
}
stmt.AddClause(onConflict)
} }
} }
} }
return values return values
}
} }

71
callbacks/create_test.go Normal file
View File

@ -0,0 +1,71 @@
package callbacks
import (
"reflect"
"sync"
"testing"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
var schemaCache = &sync.Map{}
func TestConvertToCreateValues_DestType_Slice(t *testing.T) {
type user struct {
ID int `gorm:"primaryKey"`
Name string
Email string `gorm:"default:(-)"`
Age int `gorm:"default:(-)"`
}
s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{})
if err != nil {
t.Errorf("parse schema error: %v, is not expected", err)
return
}
dest := []*user{
{
ID: 1,
Name: "alice",
Email: "email",
Age: 18,
},
{
ID: 2,
Name: "bob",
Email: "email",
Age: 19,
},
}
stmt := &gorm.Statement{
DB: &gorm.DB{
Config: &gorm.Config{
NowFunc: func() time.Time { return time.Time{} },
},
Statement: &gorm.Statement{
Settings: sync.Map{},
Schema: s,
},
},
ReflectValue: reflect.ValueOf(dest),
Dest: dest,
}
stmt.Schema = s
values := ConvertToCreateValues(stmt)
expected := clause.Values{
// column has value + defaultValue column has value (which should have a stable order)
Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}},
Values: [][]interface{}{
{"alice", "email", 18, 1},
{"bob", "email", 19, 2},
},
}
if !reflect.DeepEqual(expected, values) {
t.Errorf("expected: %v got %v", expected, values)
}
}

View File

@ -2,60 +2,143 @@ package callbacks
import ( import (
"reflect" "reflect"
"strings"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
func BeforeDelete(db *gorm.DB) { func BeforeDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete {
tx := db.Session(&gorm.Session{}) callMethod(db, func(value interface{}, tx *gorm.DB) bool {
callMethod := func(value interface{}) bool { if i, ok := value.(BeforeDeleteInterface); ok {
if db.Statement.Schema.BeforeDelete {
if i, ok := value.(gorm.BeforeDeleteInterface); ok {
db.AddError(i.BeforeDelete(tx)) db.AddError(i.BeforeDelete(tx))
return true return true
} }
}
return false
}
if ok := callMethod(db.Statement.Dest); !ok { return false
switch db.Statement.ReflectValue.Kind() { })
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
} }
} }
func Delete(db *gorm.DB) { func DeleteBeforeAssociations(db *gorm.DB) {
if db.Error == nil { if db.Error == nil && db.Statement.Schema != nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped { selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
if !restricted {
return
}
for column, v := range selectColumns {
if !v {
continue
}
rel, ok := db.Statement.Schema.Relationships.Relations[column]
if !ok {
continue
}
switch rel.Type {
case schema.HasOne, schema.HasMany:
queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue)
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
withoutConditions := false
if db.Statement.Unscoped {
tx = tx.Unscoped()
}
if len(db.Statement.Selects) > 0 {
selects := make([]string, 0, len(db.Statement.Selects))
for _, s := range db.Statement.Selects {
if s == clause.Associations {
selects = append(selects, s)
} else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) {
selects = append(selects, strings.TrimPrefix(s, columnPrefix))
}
}
if len(selects) > 0 {
tx = tx.Select(selects)
}
}
for _, cond := range queryConds {
if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
withoutConditions = true
break
}
}
if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
case schema.Many2Many:
var (
queryConds = make([]clause.Expression, 0, len(rel.References))
foreignFields = make([]*schema.Field, 0, len(rel.References))
relForeignKeys = make([]string, 0, len(rel.References))
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
table = rel.JoinTable.Table
tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table)
)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey)
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
} else if ref.PrimaryValue != "" {
queryConds = append(queryConds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
}
}
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields)
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
}
}
}
}
func Delete(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.DeleteClauses, "RETURNING")
return func(db *gorm.DB) {
if db.Error != nil {
return
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.DeleteClauses { for _, c := range db.Statement.Schema.DeleteClauses {
db.Statement.AddClause(c) db.Statement.AddClause(c)
} }
} }
if db.Statement.SQL.String() == "" { if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(100)
db.Statement.AddClauseIfNotExists(clause.Delete{}) db.Statement.AddClauseIfNotExists(clause.Delete{})
if db.Statement.Schema != nil { if db.Statement.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
} }
if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) _, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
@ -63,49 +146,50 @@ func Delete(db *gorm.DB) {
} }
} }
if _, ok := db.Statement.Clauses["WHERE"]; !ok {
db.AddError(gorm.ErrMissingWhereClause)
return
}
db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.AddClauseIfNotExists(clause.From{})
db.Statement.Build("DELETE", "FROM", "WHERE")
db.Statement.Build(db.Statement.BuildClauses...)
} }
if !db.DryRun { checkMissingWhereConditions(db)
if !db.DryRun && db.Error == nil {
ok, mode := hasReturning(db, supportReturning)
if !ok {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil { if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected() db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err) if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
return
}
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
gorm.Scan(rows, db, mode)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
db.AddError(rows.Close())
} }
} }
} }
} }
func AfterDelete(db *gorm.DB) { func AfterDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete {
tx := db.Session(&gorm.Session{}) callMethod(db, func(value interface{}, tx *gorm.DB) bool {
callMethod := func(value interface{}) bool { if i, ok := value.(AfterDeleteInterface); ok {
if db.Statement.Schema.AfterDelete {
if i, ok := value.(gorm.AfterDeleteInterface); ok {
db.AddError(i.AfterDelete(tx)) db.AddError(i.AfterDelete(tx))
return true return true
} }
}
return false return false
} })
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
} }
} }

View File

@ -1,80 +1,38 @@
package callbacks package callbacks
import ( import (
"reflect"
"sort" "sort"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate bool) (map[string]bool, bool) {
results := map[string]bool{}
notRestricted := false
// select columns
for _, column := range stmt.Selects {
if column == "*" {
notRestricted = true
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = true
}
break
}
if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = true
} else {
results[column] = true
}
}
// omit columns
for _, omit := range stmt.Omits {
if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
results[field.DBName] = false
} else {
results[omit] = false
}
}
if stmt.Schema != nil {
for _, field := range stmt.Schema.Fields {
name := field.DBName
if name == "" {
name = field.Name
}
if requireCreate && !field.Creatable {
results[name] = false
} else if requireUpdate && !field.Updatable {
results[name] = false
}
}
}
return results, !notRestricted && len(stmt.Selects) > 0
}
// ConvertMapToValuesForCreate convert map to values // ConvertMapToValuesForCreate convert map to values
func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) {
columns := make([]string, 0, len(mapValue)) values.Columns = make([]clause.Column, 0, len(mapValue))
selectColumns, restricted := SelectAndOmitColumns(stmt, true, false) selectColumns, restricted := stmt.SelectAndOmitColumns(true, false)
var keys []string keys := make([]string, 0, len(mapValue))
for k, _ := range mapValue { for k := range mapValue {
keys = append(keys, k) keys = append(keys, k)
} }
sort.Strings(keys) sort.Strings(keys)
for _, k := range keys { for _, k := range keys {
value := mapValue[k] value := mapValue[k]
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil { if field := stmt.Schema.LookUpField(k); field != nil {
k = field.DBName k = field.DBName
} }
}
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
columns = append(columns, k) values.Columns = append(values.Columns, clause.Column{Name: k})
if len(values.Values) == 0 {
values.Values = [][]interface{}{{}}
}
values.Values[0] = append(values.Values[0], value) values.Values[0] = append(values.Values[0], value)
} }
} }
@ -83,17 +41,27 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter
// ConvertSliceOfMapToValuesForCreate convert slice of map to values // ConvertSliceOfMapToValuesForCreate convert slice of map to values
func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) {
columns := make([]string, 0, len(mapValues))
// when the length of mapValues is zero,return directly here
// no need to call stmt.SelectAndOmitColumns method
if len(mapValues) == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
var ( var (
columns = []string{} result = make(map[string][]interface{}, len(mapValues))
result = map[string][]interface{}{} selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false)
) )
for idx, mapValue := range mapValues { for idx, mapValue := range mapValues {
for k, v := range mapValue { for k, v := range mapValue {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil { if field := stmt.Schema.LookUpField(k); field != nil {
k = field.DBName k = field.DBName
} }
}
if _, ok := result[k]; !ok { if _, ok := result[k]; !ok {
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
@ -110,13 +78,75 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
sort.Strings(columns) sort.Strings(columns)
values.Values = make([][]interface{}, len(mapValues)) values.Values = make([][]interface{}, len(mapValues))
values.Columns = make([]clause.Column, len(columns))
for idx, column := range columns { for idx, column := range columns {
values.Columns[idx] = clause.Column{Name: column}
for i, v := range result[column] { for i, v := range result[column] {
if i == 0 { if len(values.Values[i]) == 0 {
values.Values[i] = make([]interface{}, len(columns)) values.Values[i] = make([]interface{}, len(columns))
} }
values.Values[i][idx] = v values.Values[i][idx] = v
} }
} }
return return
} }
func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
if supportReturning {
if c, ok := tx.Statement.Clauses["RETURNING"]; ok {
returning, _ := c.Expression.(clause.Returning)
if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") {
return true, 0
}
return true, gorm.ScanUpdate
}
}
return false, 0
}
func checkMissingWhereConditions(db *gorm.DB) {
if !db.AllowGlobalUpdate && db.Error == nil {
where, withCondition := db.Statement.Clauses["WHERE"]
if withCondition {
if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
whereClause, _ := where.Expression.(clause.Where)
withCondition = len(whereClause.Exprs) > 1
}
}
if !withCondition {
db.AddError(gorm.ErrMissingWhereClause)
}
return
}
}
type visitMap = map[reflect.Value]bool
// Check if circular values, return true if loaded
func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Slice, reflect.Array:
loaded = true
for i := 0; i < v.Len(); i++ {
if !loadOrStoreVisitMap(visitMap, v.Index(i)) {
loaded = false
}
}
case reflect.Struct, reflect.Interface:
if v.CanAddr() {
p := v.Addr()
if _, ok := (*visitMap)[p]; ok {
return true
}
(*visitMap)[p] = true
}
}
return
}

157
callbacks/helper_test.go Normal file
View File

@ -0,0 +1,157 @@
package callbacks
import (
"reflect"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
func TestLoadOrStoreVisitMap(t *testing.T) {
var vm visitMap
var loaded bool
type testM struct {
Name string
}
t1 := testM{Name: "t1"}
t2 := testM{Name: "t2"}
t3 := testM{Name: "t3"}
vm = make(visitMap)
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
t.Fatalf("loaded should be true")
}
// t1 already exist but t2 not
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
t.Fatalf("loaded should be true")
}
}
func TestConvertMapToValuesForCreate(t *testing.T) {
testCase := []struct {
name string
input map[string]interface{}
expect clause.Values
}{
{
name: "Test convert string value",
input: map[string]interface{}{
"name": "my name",
},
expect: clause.Values{
Columns: []clause.Column{{Name: "name"}},
Values: [][]interface{}{{"my name"}},
},
},
{
name: "Test convert int value",
input: map[string]interface{}{
"age": 18,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "age"}},
Values: [][]interface{}{{18}},
},
},
{
name: "Test convert float value",
input: map[string]interface{}{
"score": 99.5,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "score"}},
Values: [][]interface{}{{99.5}},
},
},
{
name: "Test convert bool value",
input: map[string]interface{}{
"active": true,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "active"}},
Values: [][]interface{}{{true}},
},
},
}
for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
actual := ConvertMapToValuesForCreate(&gorm.Statement{}, tc.input)
if !reflect.DeepEqual(actual, tc.expect) {
t.Errorf("expect %v got %v", tc.expect, actual)
}
})
}
}
func TestConvertSliceOfMapToValuesForCreate(t *testing.T) {
testCase := []struct {
name string
input []map[string]interface{}
expect clause.Values
}{
{
name: "Test convert slice of string value",
input: []map[string]interface{}{
{"name": "my name"},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "name"}},
Values: [][]interface{}{{"my name"}},
},
},
{
name: "Test convert slice of int value",
input: []map[string]interface{}{
{"age": 18},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "age"}},
Values: [][]interface{}{{18}},
},
},
{
name: "Test convert slice of float value",
input: []map[string]interface{}{
{"score": 99.5},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "score"}},
Values: [][]interface{}{{99.5}},
},
},
{
name: "Test convert slice of bool value",
input: []map[string]interface{}{
{"active": true},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "active"}},
Values: [][]interface{}{{true}},
},
},
}
for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
actual := ConvertSliceOfMapToValuesForCreate(&gorm.Statement{}, tc.input)
if !reflect.DeepEqual(actual, tc.expect) {
t.Errorf("expected %v but got %v", tc.expect, actual)
}
})
}
}

View File

@ -1,11 +0,0 @@
package callbacks
import "gorm.io/gorm"
type beforeSaveInterface interface {
BeforeSave(*gorm.DB) error
}
type beforeCreateInterface interface {
BeforeCreate(*gorm.DB) error
}

39
callbacks/interfaces.go Normal file
View File

@ -0,0 +1,39 @@
package callbacks
import "gorm.io/gorm"
type BeforeCreateInterface interface {
BeforeCreate(*gorm.DB) error
}
type AfterCreateInterface interface {
AfterCreate(*gorm.DB) error
}
type BeforeUpdateInterface interface {
BeforeUpdate(*gorm.DB) error
}
type AfterUpdateInterface interface {
AfterUpdate(*gorm.DB) error
}
type BeforeSaveInterface interface {
BeforeSave(*gorm.DB) error
}
type AfterSaveInterface interface {
AfterSave(*gorm.DB) error
}
type BeforeDeleteInterface interface {
BeforeDelete(*gorm.DB) error
}
type AfterDeleteInterface interface {
AfterDelete(*gorm.DB) error
}
type AfterFindInterface interface {
AfterFind(*gorm.DB) error
}

View File

@ -1,7 +1,10 @@
package callbacks package callbacks
import ( import (
"fmt"
"reflect" "reflect"
"sort"
"strings"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -9,11 +12,179 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { // parsePreloadMap extracts nested preloads. e.g.
//
// // schema has a "k0" relation and a "k7.k8" embedded relation
// parsePreloadMap(schema, map[string][]interface{}{
// clause.Associations: {"arg1"},
// "k1": {"arg2"},
// "k2.k3": {"arg3"},
// "k4.k5.k6": {"arg4"},
// })
// // preloadMap is
// map[string]map[string][]interface{}{
// "k0": {},
// "k7": {
// "k8": {},
// },
// "k1": {},
// "k2": {
// "k3": {"arg3"},
// },
// "k4": {
// "k5.k6": {"arg4"},
// },
// }
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
preloadMap := map[string]map[string][]interface{}{}
setPreloadMap := func(name, value string, args []interface{}) {
if _, ok := preloadMap[name]; !ok {
preloadMap[name] = map[string][]interface{}{}
}
if value != "" {
preloadMap[name][value] = args
}
}
for name, args := range preloads {
preloadFields := strings.Split(name, ".")
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
if preloadFields[0] == clause.Associations {
for _, relation := range s.Relationships.Relations {
if relation.Schema == s {
setPreloadMap(relation.Name, value, args)
}
}
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
for _, value := range embeddedValues(embeddedRelations) {
setPreloadMap(embedded, value, args)
}
}
} else {
setPreloadMap(preloadFields[0], value, args)
}
}
return preloadMap
}
func embeddedValues(embeddedRelations *schema.Relationships) []string {
if embeddedRelations == nil {
return nil
}
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
for _, relation := range embeddedRelations.Relations {
// skip first struct name
names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
}
for _, relations := range embeddedRelations.EmbeddedRelations {
names = append(names, embeddedValues(relations)...)
}
return names
}
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
// If the current relationship is embedded or joined, current query will be ignored.
//
//nolint:cyclop
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
// avoid random traversal of the map
preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap {
preloadNames = append(preloadNames, key)
}
sort.Strings(preloadNames)
isJoined := func(name string) (joined bool, nestedJoins []string) {
for _, join := range joins {
if _, ok := relationships.Relations[join]; ok && name == join {
joined = true
continue
}
join0, join1, cut := strings.Cut(join, ".")
if cut {
if _, ok := relationships.Relations[join0]; ok && name == join0 {
joined = true
nestedJoins = append(nestedJoins, join1)
}
}
}
return joined, nestedJoins
}
for _, name := range preloadNames {
if relations := relationships.EmbeddedRelations[name]; relations != nil {
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
return err
}
} else if rel := relationships.Relations[name]; rel != nil {
if joined, nestedJoins := isJoined(name); joined {
switch rv := db.Statement.ReflectValue; rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() > 0 {
reflectValue := rel.FieldSchema.MakeSlice().Elem()
for i := 0; i < rv.Len(); i++ {
frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
if frv.Kind() != reflect.Ptr {
reflectValue = reflect.Append(reflectValue, frv.Addr())
} else {
if frv.IsNil() {
continue
}
reflectValue = reflect.Append(reflectValue, frv)
}
}
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
}
case reflect.Struct, reflect.Pointer:
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
default:
return gorm.ErrInvalidData
}
} else {
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
tx.Statement.ReflectValue = db.Statement.ReflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
return err
}
}
} else {
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
}
}
return nil
}
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if err := tx.Statement.Parse(dest); err != nil {
tx.AddError(err)
return tx
}
tx.Statement.ReflectValue = reflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
return tx
}
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
var ( var (
reflectValue = db.Statement.ReflectValue reflectValue = tx.Statement.ReflectValue
rel = rels[len(rels)-1]
tx = db.Session(&gorm.Session{})
relForeignKeys []string relForeignKeys []string
relForeignFields []*schema.Field relForeignFields []*schema.Field
foreignFields []*schema.Field foreignFields []*schema.Field
@ -22,13 +193,13 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
inlineConds []interface{} inlineConds []interface{}
) )
if len(rels) > 1 {
reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)-1])
}
if rel.JoinTable != nil { if rel.JoinTable != nil {
var joinForeignFields, joinRelForeignFields []*schema.Field var (
var joinForeignKeys []string joinForeignFields = make([]*schema.Field, 0, len(rel.References))
joinRelForeignFields = make([]*schema.Field, 0, len(rel.References))
joinForeignKeys = make([]string, 0, len(rel.References))
)
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName) joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName)
@ -43,25 +214,28 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
} }
} }
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
if len(joinForeignValues) == 0 { if len(joinForeignValues) == 0 {
return return nil
} }
joinResults := rel.JoinTable.MakeSlice().Elem() joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(joinForeignKeys, joinForeignValues) column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()) if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil {
return err
}
// convert join identity map to relation identity map // convert join identity map to relation identity map
fieldValues := make([]interface{}, len(joinForeignFields)) fieldValues := make([]interface{}, len(joinForeignFields))
joinFieldValues := make([]interface{}, len(joinRelForeignFields)) joinFieldValues := make([]interface{}, len(joinRelForeignFields))
for i := 0; i < joinResults.Len(); i++ { for i := 0; i < joinResults.Len(); i++ {
joinIndexValue := joinResults.Index(i)
for idx, field := range joinForeignFields { for idx, field := range joinForeignFields {
fieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
} }
for idx, field := range joinRelForeignFields { for idx, field := range joinRelForeignFields {
joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
} }
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
@ -70,7 +244,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
} }
} }
_, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields) _, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields)
} else { } else {
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
@ -86,14 +260,22 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
} }
} }
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
if len(foreignValues) == 0 { if len(foreignValues) == 0 {
return return nil
} }
} }
// nested preload
for p, pvs := range preloads {
tx = tx.Preload(p, pvs...)
}
reflectResults := rel.FieldSchema.MakeSlice().Elem() reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(relForeignKeys, foreignValues) column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
if len(values) != 0 {
tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})
for _, cond := range conds { for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
@ -103,18 +285,50 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
} }
} }
tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...) if len(inlineConds) > 0 {
tx = tx.Where(inlineConds[0], inlineConds[1:]...)
}
if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
return err
}
}
fieldValues := make([]interface{}, len(relForeignFields)) fieldValues := make([]interface{}, len(relForeignFields))
// clean up old values before preloading
switch reflectValue.Kind() {
case reflect.Struct:
switch rel.Type {
case schema.HasMany, schema.Many2Many:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()))
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type {
case schema.HasMany, schema.Many2Many:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()))
}
}
}
for i := 0; i < reflectResults.Len(); i++ { for i := 0; i < reflectResults.Len(); i++ {
elem := reflectResults.Index(i) elem := reflectResults.Index(i)
for idx, field := range relForeignFields { for idx, field := range relForeignFields {
fieldValues[idx], _ = field.ValueOf(elem) fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem)
} }
for _, data := range identityMap[utils.ToStringKey(fieldValues...)] { datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
reflectFieldValue := rel.Field.ReflectValueOf(data) if !ok {
return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())
}
for _, data := range datas {
reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data)
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
} }
@ -122,14 +336,16 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
reflectFieldValue = reflect.Indirect(reflectFieldValue) reflectFieldValue = reflect.Indirect(reflectFieldValue)
switch reflectFieldValue.Kind() { switch reflectFieldValue.Kind() {
case reflect.Struct: case reflect.Struct:
rel.Field.Set(data, reflectResults.Index(i).Interface()) tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface()))
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()))
} else { } else {
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()))
} }
} }
} }
} }
return tx.Error
} }

View File

@ -3,46 +3,51 @@ package callbacks
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"sort"
"strings" "strings"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
func Query(db *gorm.DB) { func Query(db *gorm.DB) {
if db.Error == nil { if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.QueryClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
BuildQuerySQL(db) BuildQuerySQL(db)
}
if !db.DryRun { if !db.DryRun && db.Error == nil {
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil { if err != nil {
db.AddError(err) db.AddError(err)
return return
} }
defer rows.Close() defer func() {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, 0)
gorm.Scan(rows, db, false) if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
} }
} }
} }
func BuildQuerySQL(db *gorm.DB) { func BuildQuerySQL(db *gorm.DB) {
clauseSelect := clause.Select{} if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.QueryClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.ReflectValue.Kind() == reflect.Struct { if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(100)
clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
var conds []clause.Expression var conds []clause.Expression
for _, primaryField := range db.Statement.Schema.PrimaryFields { for _, primaryField := range db.Statement.Schema.PrimaryFields {
if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero {
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
} }
} }
@ -53,161 +58,257 @@ func BuildQuerySQL(db *gorm.DB) {
} }
if len(db.Statement.Selects) > 0 { if len(db.Statement.Selects) > 0 {
for _, name := range db.Statement.Selects { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
for idx, name := range db.Statement.Selects {
if db.Statement.Schema == nil { if db.Statement.Schema == nil {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
Name: name,
Raw: true,
})
} else if f := db.Statement.Schema.LookUpField(name); f != nil { } else if f := db.Statement.Schema.LookUpField(name); f != nil {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
Name: f.DBName,
})
} else { } else {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
Name: name, }
Raw: true, }
}) } else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 {
selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false)
clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames))
for _, dbName := range db.Statement.Schema.DBNames {
if v, ok := selectColumns[dbName]; (ok && v) || !ok {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName})
}
}
} else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() {
queryFields := db.QueryFields
if !queryFields {
switch db.Statement.ReflectValue.Kind() {
case reflect.Struct:
queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType
case reflect.Slice:
queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType
}
}
if queryFields {
stmt := gorm.Statement{DB: db}
// smaller struct
if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) {
clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))
for idx, dbName := range stmt.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
}
} }
} }
} }
// inline joins // inline joins
if len(db.Statement.Joins) != 0 { fromClause := clause.From{}
joins := []clause.Join{} if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
fromClause = v
}
if len(db.Statement.Selects) == 0 { if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
for _, dbName := range db.Statement.Schema.DBNames { if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
Table: db.Statement.Table, for idx, dbName := range db.Statement.Schema.DBNames {
Name: dbName, clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
})
} }
} }
for name, conds := range db.Statement.Joins { specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
if db.Statement.Schema == nil { for _, join := range db.Statement.Joins {
joins = append(joins, clause.Join{ if db.Statement.Schema != nil {
Expression: clause.Expr{SQL: name, Vars: conds}, var isRelations bool // is relations or raw sql
}) var relations []*schema.Relationship
} else if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
tableAliasName := relation.Name if ok {
isRelations = true
relations = append(relations, relation)
} else {
// handle nested join like "Manager.Company"
nestedJoinNames := strings.Split(join.Name, ".")
if len(nestedJoinNames) > 1 {
isNestedJoin := true
guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
currentRelations := db.Statement.Schema.Relationships.Relations
for _, relname := range nestedJoinNames {
// incomplete match, only treated as raw sql
if relation, ok = currentRelations[relname]; ok {
guessNestedRelations = append(guessNestedRelations, relation)
currentRelations = relation.FieldSchema.Relationships.Relations
} else {
isNestedJoin = false
break
}
}
if isNestedJoin {
isRelations = true
relations = guessNestedRelations
}
}
}
if isRelations {
genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
}
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
for _, s := range relation.FieldSchema.DBNames { for _, s := range relation.FieldSchema.DBNames {
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName, Table: tableAliasName,
Name: s, Name: s,
Alias: tableAliasName + "__" + s, Alias: utils.NestedRelationName(tableAliasName, s),
}) })
} }
}
var exprs []clause.Expression if join.Expression != nil {
for _, ref := range relation.References { return clause.Join{
Type: join.JoinType,
Expression: join.Expression,
}
}
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
exprs = append(exprs, clause.Eq{ exprs[idx] = clause.Eq{
Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName}, Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
}) }
} else { } else {
if ref.PrimaryValue == "" { if ref.PrimaryValue == "" {
exprs = append(exprs, clause.Eq{ exprs[idx] = clause.Eq{
Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName}, Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}) }
} else { } else {
exprs = append(exprs, clause.Eq{ exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue, Value: ref.PrimaryValue,
}) }
} }
} }
} }
joins = append(joins, clause.Join{ {
Type: clause.LeftJoin, onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
for _, c := range relation.FieldSchema.QueryClauses {
onStmt.AddClause(c)
}
if join.On != nil {
onStmt.AddClause(join.On)
}
if cs, ok := onStmt.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
where.Build(&onStmt)
if onSQL := onStmt.SQL.String(); onSQL != "" {
vars := onStmt.Vars
for idx, v := range vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
}
}
}
return clause.Join{
Type: joinType,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs}, ON: clause.Where{Exprs: exprs},
}) }
}
parentTableName := clause.CurrentTable
for idx, rel := range relations {
// joins table alias like "Manager, Company, Manager__Company"
curAliasName := rel.Name
if parentTableName != clause.CurrentTable {
curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
}
if _, ok := specifiedRelationsName[curAliasName]; !ok {
aliasName := curAliasName
if idx == len(relations)-1 && join.Alias != "" {
aliasName = join.Alias
}
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
specifiedRelationsName[curAliasName] = aliasName
}
parentTableName = curAliasName
}
} else { } else {
joins = append(joins, clause.Join{ fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.Expr{SQL: name, Vars: conds}, Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
}) })
} }
} }
db.Statement.AddClause(clause.From{Joins: joins}) db.Statement.AddClause(fromClause)
} else { } else {
db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.AddClauseIfNotExists(clause.From{})
} }
db.Statement.AddClauseIfNotExists(clauseSelect) db.Statement.AddClauseIfNotExists(clauseSelect)
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") db.Statement.Build(db.Statement.BuildClauses...)
}
} }
func Preload(db *gorm.DB) { func Preload(db *gorm.DB) {
if db.Error == nil { if db.Error == nil && len(db.Statement.Preloads) > 0 {
if len(db.Statement.Preloads) > 0 { if db.Statement.Schema == nil {
preloadMap := map[string][]string{} db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired))
for name := range db.Statement.Preloads { return
preloadFields := strings.Split(name, ".")
for idx := range preloadFields {
preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1]
}
} }
preloadNames := make([]string, len(preloadMap)) joins := make([]string, 0, len(db.Statement.Joins))
idx := 0 for _, join := range db.Statement.Joins {
for key := range preloadMap { joins = append(joins, join.Name)
preloadNames[idx] = key
idx++
}
sort.Strings(preloadNames)
for _, name := range preloadNames {
var (
curSchema = db.Statement.Schema
preloadFields = preloadMap[name]
rels = make([]*schema.Relationship, len(preloadFields))
)
for idx, preloadField := range preloadFields {
if rel := curSchema.Relationships.Relations[preloadField]; rel != nil {
rels[idx] = rel
curSchema = rel.FieldSchema
} else {
db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation))
}
} }
preload(db, rels, db.Statement.Preloads[name]) tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
} if tx.Error != nil {
return
} }
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
} }
} }
func AfterQuery(db *gorm.DB) { func AfterQuery(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind { // clear the joins after query because preload need it
tx := db.Session(&gorm.Session{}) if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
callMethod := func(value interface{}) bool { fromClause := db.Statement.Clauses["FROM"]
if db.Statement.Schema.AfterFind { fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(db.Statement.Joins))} // keep the original From Joins
if i, ok := value.(gorm.AfterFindInterface); ok { db.Statement.Clauses["FROM"] = fromClause
}
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(AfterFindInterface); ok {
db.AddError(i.AfterFind(tx)) db.AddError(i.AfterFind(tx))
return true return true
} }
}
return false return false
} })
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
} }
} }

View File

@ -5,12 +5,18 @@ import (
) )
func RawExec(db *gorm.DB) { func RawExec(db *gorm.DB) {
if db.Error == nil { if db.Error == nil && !db.DryRun {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil { if err != nil {
db.AddError(err) db.AddError(err)
} else { return
}
db.RowsAffected, _ = result.RowsAffected() db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
} }
} }
} }

View File

@ -6,14 +6,18 @@ import (
func RowQuery(db *gorm.DB) { func RowQuery(db *gorm.DB) {
if db.Error == nil { if db.Error == nil {
if db.Statement.SQL.String() == "" {
BuildQuerySQL(db) BuildQuerySQL(db)
if db.DryRun || db.Error != nil {
return
} }
if _, ok := db.Get("rows"); ok { if isRows, ok := db.Get("rows"); ok && isRows.(bool) {
db.Statement.Settings.Delete("rows")
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} else { } else {
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} }
db.RowsAffected = -1
} }
} }

View File

@ -5,21 +5,28 @@ import (
) )
func BeginTransaction(db *gorm.DB) { func BeginTransaction(db *gorm.DB) {
if !db.Config.SkipDefaultTransaction && db.Error == nil {
if tx := db.Begin(); tx.Error == nil { if tx := db.Begin(); tx.Error == nil {
db.Statement.ConnPool = tx.Statement.ConnPool db.Statement.ConnPool = tx.Statement.ConnPool
tx.InstanceSet("gorm:started_transaction", true) db.InstanceSet("gorm:started_transaction", true)
} else { } else if tx.Error == gorm.ErrInvalidTransaction {
tx.Error = nil tx.Error = nil
} else {
db.Error = tx.Error
}
} }
} }
func CommitOrRollbackTransaction(db *gorm.DB) { func CommitOrRollbackTransaction(db *gorm.DB) {
if !db.Config.SkipDefaultTransaction {
if _, ok := db.InstanceGet("gorm:started_transaction"); ok { if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
if db.Error == nil { if db.Error != nil {
db.Commit()
} else {
db.Rollback() db.Rollback()
} else {
db.Commit()
} }
db.Statement.ConnPool = db.ConnPool db.Statement.ConnPool = db.ConnPool
} }
}
} }

View File

@ -7,10 +7,11 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
func SetupUpdateReflectValue(db *gorm.DB) { func SetupUpdateReflectValue(db *gorm.DB) {
if db.Error == nil { if db.Error == nil && db.Statement.Schema != nil {
if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest { if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest {
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
for db.Statement.ReflectValue.Kind() == reflect.Ptr { for db.Statement.ReflectValue.Kind() == reflect.Ptr {
@ -20,7 +21,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
for _, rel := range db.Statement.Schema.Relationships.BelongsTo { for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if _, ok := dest[rel.Name]; ok { if _, ok := dest[rel.Name]; ok {
rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]))
} }
} }
} }
@ -28,113 +29,117 @@ func SetupUpdateReflectValue(db *gorm.DB) {
} }
} }
// BeforeUpdate before update hooks
func BeforeUpdate(db *gorm.DB) { func BeforeUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
tx := db.Session(&gorm.Session{}) callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.BeforeSave { if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok { if i, ok := value.(BeforeSaveInterface); ok {
ok = true called = true
db.AddError(i.BeforeSave(tx)) db.AddError(i.BeforeSave(tx))
} }
} }
if db.Statement.Schema.BeforeUpdate { if db.Statement.Schema.BeforeUpdate {
if i, ok := value.(gorm.BeforeUpdateInterface); ok { if i, ok := value.(BeforeUpdateInterface); ok {
ok = true called = true
db.AddError(i.BeforeUpdate(tx)) db.AddError(i.BeforeUpdate(tx))
} }
} }
return ok
}
if ok := callMethod(db.Statement.Dest); !ok { return called
switch db.Statement.ReflectValue.Kind() { })
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
} }
} }
func Update(db *gorm.DB) { // Update update hook
if db.Error == nil { func Update(config *Config) func(db *gorm.DB) {
if db.Statement.Schema != nil && !db.Statement.Unscoped { supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")
return func(db *gorm.DB) {
if db.Error != nil {
return
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.UpdateClauses { for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c) db.Statement.AddClause(c)
} }
} }
if db.Statement.SQL.String() == "" { if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{}) db.Statement.AddClauseIfNotExists(clause.Update{})
if _, ok := db.Statement.Clauses["SET"]; !ok {
if set := ConvertToAssignments(db.Statement); len(set) != 0 { if set := ConvertToAssignments(db.Statement); len(set) != 0 {
defer delete(db.Statement.Clauses, "SET")
db.Statement.AddClause(set) db.Statement.AddClause(set)
} else { } else {
return return
} }
db.Statement.Build("UPDATE", "SET", "WHERE")
} }
if _, ok := db.Statement.Clauses["WHERE"]; !ok { db.Statement.Build(db.Statement.BuildClauses...)
db.AddError(gorm.ErrMissingWhereClause)
return
} }
if !db.DryRun { checkMissingWhereConditions(db)
if !db.DryRun && db.Error == nil {
if ok, mode := hasReturning(db, supportReturning); ok {
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
dest := db.Statement.Dest
db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface()
gorm.Scan(rows, db, mode)
db.Statement.Dest = dest
db.AddError(rows.Close())
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
} else {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil { if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected() db.RowsAffected, _ = result.RowsAffected()
} else { }
db.AddError(err)
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
} }
} }
} }
} }
// AfterUpdate after update hooks
func AfterUpdate(db *gorm.DB) { func AfterUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
tx := db.Session(&gorm.Session{}) callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
callMethod := func(value interface{}) bool { if db.Statement.Schema.AfterUpdate {
var ok bool if i, ok := value.(AfterUpdateInterface); ok {
called = true
db.AddError(i.AfterUpdate(tx))
}
}
if db.Statement.Schema.AfterSave { if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok { if i, ok := value.(AfterSaveInterface); ok {
ok = true called = true
db.AddError(i.AfterSave(tx)) db.AddError(i.AfterSave(tx))
} }
} }
if db.Statement.Schema.AfterUpdate { return called
if i, ok := value.(gorm.AfterUpdateInterface); ok { })
ok = true
db.AddError(i.AfterUpdate(tx))
}
}
return ok
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
} }
} }
// ConvertToAssignments convert to update assignments // ConvertToAssignments convert to update assignments
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
var ( var (
selectColumns, restricted = SelectAndOmitColumns(stmt, false, true) selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
assignValue func(field *schema.Field, value interface{}) assignValue func(field *schema.Field, value interface{})
) )
@ -142,13 +147,15 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
assignValue = func(field *schema.Field, value interface{}) { assignValue = func(field *schema.Field, value interface{}) {
for i := 0; i < stmt.ReflectValue.Len(); i++ { for i := 0; i < stmt.ReflectValue.Len(); i++ {
field.Set(stmt.ReflectValue.Index(i), value) if stmt.ReflectValue.CanAddr() {
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
}
} }
} }
case reflect.Struct: case reflect.Struct:
assignValue = func(field *schema.Field, value interface{}) { assignValue = func(field *schema.Field, value interface{}) {
if stmt.ReflectValue.CanAddr() { if stmt.ReflectValue.CanAddr() {
field.Set(stmt.ReflectValue, value) field.Set(stmt.Context, stmt.ReflectValue, value)
} }
} }
default: default:
@ -161,92 +168,144 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
updatingValue = updatingValue.Elem() updatingValue = updatingValue.Elem()
} }
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if size := stmt.ReflectValue.Len(); size > 0 {
var isZero bool
for i := 0; i < size; i++ {
for _, field := range stmt.Schema.PrimaryFields {
_, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
if !isZero {
break
}
}
}
if !isZero {
_, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
}
case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields {
if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
}
}
switch value := updatingValue.Interface().(type) { switch value := updatingValue.Interface().(type) {
case map[string]interface{}: case map[string]interface{}:
set = make([]clause.Assignment, 0, len(value)) set = make([]clause.Assignment, 0, len(value))
var keys []string keys := make([]string, 0, len(value))
for k, _ := range value { for k := range value {
keys = append(keys, k) keys = append(keys, k)
} }
sort.Strings(keys) sort.Strings(keys)
for _, k := range keys { for _, k := range keys {
kv := value[k]
if _, ok := kv.(*gorm.DB); ok {
kv = []interface{}{kv}
}
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil { if field := stmt.Schema.LookUpField(k); field != nil {
if field.DBName != "" { if field.DBName != "" {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
assignValue(field, value[k]) assignValue(field, value[k])
} }
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
assignValue(field, value[k]) assignValue(field, value[k])
} }
} else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { continue
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]})
} }
} }
if !stmt.DisableUpdateTime { if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
for _, field := range stmt.Schema.FieldsByDBName { set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
}
}
if !stmt.SkipHooks && stmt.Schema != nil {
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.LookUpField(dbName)
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || !ok {
now := stmt.DB.NowFunc() now := stmt.DB.NowFunc()
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
assignValue(field, now) assignValue(field, now)
if field.AutoUpdateTime == schema.UnixNanosecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
} else if field.AutoUpdateTime == schema.UnixMillisecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()})
} else if field.AutoUpdateTime == schema.UnixSecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
} else {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
}
}
} }
} }
} }
default: default:
updatingSchema := stmt.Schema
var isDiffSchema bool
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
// different schema
updatingStmt := &gorm.Statement{DB: stmt.DB}
if err := updatingStmt.Parse(stmt.Dest); err == nil {
updatingSchema = updatingStmt.Schema
isDiffSchema = true
}
}
switch updatingValue.Kind() { switch updatingValue.Kind() {
case reflect.Struct: case reflect.Struct:
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
for _, field := range stmt.Schema.FieldsByDBName { for _, dbName := range stmt.Schema.DBNames {
if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { if field := updatingSchema.LookUpField(dbName); field != nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
value, isZero := field.ValueOf(updatingValue) if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
if !stmt.DisableUpdateTime { value, isZero := field.ValueOf(stmt.Context, updatingValue)
if field.AutoUpdateTime > 0 { if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano()
} else if field.AutoUpdateTime == schema.UnixMillisecond {
value = stmt.DB.NowFunc().UnixMilli()
} else if field.AutoUpdateTime == schema.UnixSecond {
value = stmt.DB.NowFunc().Unix()
} else {
value = stmt.DB.NowFunc() value = stmt.DB.NowFunc()
}
isZero = false isZero = false
} }
}
if ok || !isZero { if (ok || !isZero) && field.Updatable {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
assignValue(field, value) assignField := field
if isDiffSchema {
if originField := stmt.Schema.LookUpField(dbName); originField != nil {
assignField = originField
}
}
assignValue(assignField, value)
} }
} }
} else { } else {
if value, isZero := field.ValueOf(updatingValue); !isZero { if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
} }
} }
} }
} }
} default:
stmt.AddError(gorm.ErrInvalidData)
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var priamryKeyExprs []clause.Expression
for i := 0; i < stmt.ReflectValue.Len(); i++ {
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
var notZero bool
for idx, field := range stmt.Schema.PrimaryFields {
value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
notZero = notZero || !isZero
}
if notZero {
priamryKeyExprs = append(priamryKeyExprs, clause.And(exprs...))
}
}
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}})
case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields {
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
} }
} }

View File

@ -2,6 +2,7 @@ package gorm
import ( import (
"fmt" "fmt"
"regexp"
"strings" "strings"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -9,9 +10,10 @@ import (
) )
// Model specify the model you would like to run db operations // Model specify the model you would like to run db operations
//
// // update all users's name to `hello` // // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello") // db.Model(&User{}).Update("name", "hello")
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` // // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello`
// db.Model(&user).Update("name", "hello") // db.Model(&user).Update("name", "hello")
func (db *DB) Model(value interface{}) (tx *DB) { func (db *DB) Model(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
@ -20,6 +22,19 @@ func (db *DB) Model(value interface{}) (tx *DB) {
} }
// Clauses Add clauses // Clauses Add clauses
//
// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more
// advanced techniques like specifying lock strength and optimizer hints. See the
// [docs] for more depth.
//
// // add a simple limit clause
// db.Clauses(clause.Limit{Limit: 1}).Find(&User{})
// // tell the optimizer to use the `idx_user_name` index
// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{})
// // specify the lock strength to UPDATE
// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users)
//
// [docs]: https://gorm.io/docs/sql_builder.html#Clauses
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
var whereConds []interface{} var whereConds []interface{}
@ -27,25 +42,73 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
for _, cond := range conds { for _, cond := range conds {
if c, ok := cond.(clause.Interface); ok { if c, ok := cond.(clause.Interface); ok {
tx.Statement.AddClause(c) tx.Statement.AddClause(c)
} else if optimizer, ok := cond.(StatementModifier); ok {
optimizer.ModifyStatement(tx.Statement)
} else { } else {
whereConds = append(whereConds, cond) whereConds = append(whereConds, cond)
} }
} }
if len(whereConds) > 0 { if len(whereConds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...)}) tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)})
} }
return return
} }
var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`)
// Table specify the table you would like to run db operations // Table specify the table you would like to run db operations
func (db *DB) Table(name string) (tx *DB) { //
// // Get a user
// db.Table("users").Take(&result)
func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 {
if results[1] != "" {
tx.Statement.Table = results[1]
} else {
tx.Statement.Table = results[2]
}
}
} else if tables := strings.Split(name, "."); len(tables) == 2 {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = tables[1]
} else if name != "" {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = name tx.Statement.Table = name
} else {
tx.Statement.TableExpr = nil
tx.Statement.Table = ""
}
return
}
// Distinct specify distinct fields that you want querying
//
// // Select distinct names of users
// db.Distinct("name").Find(&results)
// // Select distinct name/age pairs from users
// db.Distinct("name", "age").Find(&results)
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Distinct = true
if len(args) > 0 {
tx = tx.Select(args[0], args[1:]...)
}
return return
} }
// Select specify fields that you want when querying, creating, updating // Select specify fields that you want when querying, creating, updating
//
// Use Select when you only want a subset of the fields. By default, GORM will select all fields.
// Select accepts both string arguments and arrays.
//
// // Select name and age of user using multiple arguments
// db.Select("name", "age").Find(&users)
// // Select name and age of user using an array
// db.Select([]string{"name", "age"}).Find(&users)
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
@ -64,12 +127,24 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
return return
} }
} }
case string:
fields := strings.FieldsFunc(v, utils.IsChar)
// normal field names if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { clause.Expression = nil
tx.Statement.Selects = fields tx.Statement.Clauses["SELECT"] = clause
}
case string:
if strings.Count(v, "?") >= len(args) && len(args) > 0 {
tx.Statement.AddClause(clause.Select{
Distinct: db.Statement.Distinct,
Expression: clause.Expr{SQL: v, Vars: args},
})
} else if strings.Count(v, "@") > 0 && len(args) > 0 {
tx.Statement.AddClause(clause.Select{
Distinct: db.Statement.Distinct,
Expression: clause.NamedExpr{SQL: v, Vars: args},
})
} else {
tx.Statement.Selects = []string{v}
for _, arg := range args { for _, arg := range args {
switch arg := arg.(type) { switch arg := arg.(type) {
@ -79,15 +154,17 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
tx.Statement.Selects = append(tx.Statement.Selects, arg...) tx.Statement.Selects = append(tx.Statement.Selects, arg...)
default: default:
tx.Statement.AddClause(clause.Select{ tx.Statement.AddClause(clause.Select{
Distinct: db.Statement.Distinct,
Expression: clause.Expr{SQL: v, Vars: args}, Expression: clause.Expr{SQL: v, Vars: args},
}) })
return return
} }
} }
} else {
tx.Statement.AddClause(clause.Select{ if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
Expression: clause.Expr{SQL: v, Vars: args}, clause.Expression = nil
}) tx.Statement.Clauses["SELECT"] = clause
}
} }
default: default:
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
@ -101,99 +178,182 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsChar) tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
} else { } else {
tx.Statement.Omits = columns tx.Statement.Omits = columns
} }
return return
} }
// MapColumns modify the column names in the query results to facilitate align to the corresponding structural fields
func (db *DB) MapColumns(m map[string]string) (tx *DB) {
tx = db.getInstance()
tx.Statement.ColumnMapping = m
return
}
// Where add conditions // Where add conditions
//
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
//
// // Find the first user with name jinzhu
// db.Where("name = ?", "jinzhu").First(&user)
// // Find the first user with name jinzhu and age 20
// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user)
// // Find the first user with name jinzhu and age not equal to 20
// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user)
//
// [docs]: https://gorm.io/docs/query.html#Conditions
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if conds := tx.Statement.BuildCondtion(query, args...); len(conds) > 0 { if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: conds}) tx.Statement.AddClause(clause.Where{Exprs: conds})
} }
return return
} }
// Not add NOT conditions // Not add NOT conditions
//
// Not works similarly to where, and has the same syntax.
//
// // Find the first user with name not equal to jinzhu
// db.Not("name = ?", "jinzhu").First(&user)
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if conds := tx.Statement.BuildCondtion(query, args...); len(conds) > 0 { if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}}) tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}})
} }
return return
} }
// Or add OR conditions // Or add OR conditions
//
// Or is used to chain together queries with an OR.
//
// // Find the first user with name equal to jinzhu or john
// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user)
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if conds := tx.Statement.BuildCondtion(query, args...); len(conds) > 0 { if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(conds...)}}) tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}})
} }
return return
} }
// Joins specify Joins conditions // Joins specify Joins conditions
//
// db.Joins("Account").Find(&user) // db.Joins("Account").Find(&user)
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
return joins(db, clause.LeftJoin, query, args...)
}
// InnerJoins specify inner joins conditions
// db.InnerJoins("Account").Find(&user)
func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) {
return joins(db, clause.InnerJoin, query, args...)
}
func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if tx.Statement.Joins == nil {
tx.Statement.Joins = map[string][]interface{}{} if len(args) == 1 {
if db, ok := args[0].(*DB); ok {
j := join{
Name: query, Conds: args, Selects: db.Statement.Selects,
Omits: db.Statement.Omits, JoinType: joinType,
} }
tx.Statement.Joins[query] = args if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
j.On = &where
}
tx.Statement.Joins = append(tx.Statement.Joins, j)
return
}
}
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType})
return return
} }
// Group specify the group method on the find // Group specify the group method on the find
//
// // Select the sum age of users with given names
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results)
func (db *DB) Group(name string) (tx *DB) { func (db *DB) Group(name string) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
fields := strings.FieldsFunc(name, utils.IsValidDBNameChar)
tx.Statement.AddClause(clause.GroupBy{ tx.Statement.AddClause(clause.GroupBy{
Columns: []clause.Column{{Name: name}}, Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}},
}) })
return return
} }
// Having specify HAVING conditions for GROUP BY // Having specify HAVING conditions for GROUP BY
//
// // Select the sum age of users with name jinzhu
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result)
func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.GroupBy{ tx.Statement.AddClause(clause.GroupBy{
Having: tx.Statement.BuildCondtion(query, args...), Having: tx.Statement.BuildCondition(query, args...),
}) })
return return
} }
// Order specify order when retrieve records from database // Order specify order when retrieving records from database
//
// db.Order("name DESC") // db.Order("name DESC")
// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression // db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
// db.Order(clause.OrderBy{Columns: []clause.OrderByColumn{
// {Column: clause.Column{Name: "name"}, Desc: true},
// {Column: clause.Column{Name: "age"}, Desc: true},
// }})
func (db *DB) Order(value interface{}) (tx *DB) { func (db *DB) Order(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
switch v := value.(type) { switch v := value.(type) {
case clause.OrderBy:
tx.Statement.AddClause(v)
case clause.OrderByColumn: case clause.OrderByColumn:
tx.Statement.AddClause(clause.OrderBy{ tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{v}, Columns: []clause.OrderByColumn{v},
}) })
default: case string:
if v != "" {
tx.Statement.AddClause(clause.OrderBy{ tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{{ Columns: []clause.OrderByColumn{{
Column: clause.Column{Name: fmt.Sprint(value), Raw: true}, Column: clause.Column{Name: v, Raw: true},
}}, }},
}) })
} }
}
return return
} }
// Limit specify the number of records to be retrieved // Limit specify the number of records to be retrieved
//
// Limit conditions can be cancelled by using `Limit(-1)`.
//
// // retrieve 3 users
// db.Limit(3).Find(&users)
// // retrieve 3 users into users1, and all users into users2
// db.Limit(3).Find(&users1).Limit(-1).Find(&users2)
func (db *DB) Limit(limit int) (tx *DB) { func (db *DB) Limit(limit int) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Limit: limit}) tx.Statement.AddClause(clause.Limit{Limit: &limit})
return return
} }
// Offset specify the number of records to skip before starting to return the records // Offset specify the number of records to skip before starting to return the records
//
// Offset conditions can be cancelled by using `Offset(-1)`.
//
// // select the third user
// db.Offset(2).First(&user)
// // select the first user by cancelling an earlier chained offset
// db.Offset(5).Offset(-1).First(&user)
func (db *DB) Offset(offset int) (tx *DB) { func (db *DB) Offset(offset int) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Offset: offset}) tx.Statement.AddClause(clause.Limit{Offset: offset})
@ -201,6 +361,7 @@ func (db *DB) Offset(offset int) (tx *DB) {
} }
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically // Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
//
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { // func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000) // return db.Where("amount > ?", 1000)
// } // }
@ -212,14 +373,24 @@ func (db *DB) Offset(offset int) (tx *DB) {
// } // }
// //
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) // db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB { func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
for _, f := range funcs { tx = db.getInstance()
db = f(db) tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
return tx
}
func (db *DB) executeScopes() (tx *DB) {
scopes := db.Statement.scopes
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
} }
return db return db
} }
// Preload preload associations with given conditions // Preload preload associations with given conditions
//
// // get all users, and preload all non-cancelled orders
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) // db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
@ -230,18 +401,57 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
return return
} }
// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit]
//
// Attrs only adds attributes if the record is not found.
//
// // assign an email if the record is not found
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign an email if the record is not found, otherwise ignore provided email
// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20}
//
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.attrs = attrs tx.Statement.attrs = attrs
return return
} }
// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit]
//
// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that
// records will be updated even if they are found.
//
// // assign an email regardless of if the record is not found
// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign email regardless of if record is found
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
//
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
func (db *DB) Assign(attrs ...interface{}) (tx *DB) { func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.assigns = attrs tx.Statement.assigns = attrs
return return
} }
// Unscoped disables the global scope of soft deletion in a query.
// By default, GORM uses soft deletion, marking records as "deleted"
// by setting a timestamp on a specific field (e.g., `deleted_at`).
// Unscoped allows queries to include records marked as deleted,
// overriding the soft deletion behavior.
// Example:
//
// var users []User
// db.Unscoped().Find(&users)
// // Retrieves all users, including deleted ones.
func (db *DB) Unscoped() (tx *DB) { func (db *DB) Unscoped() (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Unscoped = true tx.Statement.Unscoped = true
@ -251,6 +461,11 @@ func (db *DB) Unscoped() (tx *DB) {
func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.SQL = strings.Builder{} tx.Statement.SQL = strings.Builder{}
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
} else {
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
}
return return
} }

View File

@ -29,10 +29,12 @@ func BenchmarkSelect(b *testing.B) {
func BenchmarkComplexSelect(b *testing.B) { func BenchmarkComplexSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
limit10 := 10
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clauses := []clause.Interface{ clauses := []clause.Interface{
clause.Select{}, clause.From{}, clause.Select{},
clause.From{},
clause.Where{Exprs: []clause.Expression{ clause.Where{Exprs: []clause.Expression{
clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.Gt{Column: "age", Value: 18}, clause.Gt{Column: "age", Value: 18},
@ -42,7 +44,7 @@ func BenchmarkComplexSelect(b *testing.B) {
clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}), clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}),
}}, }},
clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}}, clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}},
clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Limit: &limit10, Offset: 20},
clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}}, clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}},
} }

View File

@ -7,7 +7,7 @@ type Interface interface {
MergeClause(*Clause) MergeClause(*Clause)
} }
// ClauseBuilder clause builder, allows to custmize how to build clause // ClauseBuilder clause builder, allows to customize how to build clause
type ClauseBuilder func(Clause, Builder) type ClauseBuilder func(Clause, Builder)
type Writer interface { type Writer interface {
@ -18,17 +18,17 @@ type Writer interface {
// Builder builder interface // Builder builder interface
type Builder interface { type Builder interface {
Writer Writer
WriteQuoted(field interface{}) error WriteQuoted(field interface{})
AddVar(Writer, ...interface{}) AddVar(Writer, ...interface{})
AddError(error) error
} }
// Clause // Clause
type Clause struct { type Clause struct {
Name string // WHERE Name string // WHERE
Priority float64 BeforeExpression Expression
BeforeExpressions []Expression AfterNameExpression Expression
AfterNameExpressions []Expression AfterExpression Expression
AfterExpressions []Expression
Expression Expression Expression Expression
Builder ClauseBuilder Builder ClauseBuilder
} }
@ -37,29 +37,35 @@ type Clause struct {
func (c Clause) Build(builder Builder) { func (c Clause) Build(builder Builder) {
if c.Builder != nil { if c.Builder != nil {
c.Builder(c, builder) c.Builder(c, builder)
} else { } else if c.Expression != nil {
builders := c.BeforeExpressions if c.BeforeExpression != nil {
if c.Name != "" { c.BeforeExpression.Build(builder)
builders = append(builders, Expr{SQL: c.Name})
}
builders = append(builders, c.AfterNameExpressions...)
if c.Expression != nil {
builders = append(builders, c.Expression)
}
for idx, expr := range append(builders, c.AfterExpressions...) {
if idx != 0 {
builder.WriteByte(' ') builder.WriteByte(' ')
} }
expr.Build(builder)
if c.Name != "" {
builder.WriteString(c.Name)
builder.WriteByte(' ')
}
if c.AfterNameExpression != nil {
c.AfterNameExpression.Build(builder)
builder.WriteByte(' ')
}
c.Expression.Build(builder)
if c.AfterExpression != nil {
builder.WriteByte(' ')
c.AfterExpression.Build(builder)
} }
} }
} }
const ( const (
PrimaryKey string = "@@@priamry_key@@@" PrimaryKey string = "~~~py~~~" // primary key
CurrentTable string = "@@@table@@@" CurrentTable string = "~~~ct~~~" // current table
Associations string = "~~~as~~~" // associations
) )
var ( var (

View File

@ -1,7 +1,9 @@
package clause package clause
import ( import (
"database/sql"
"database/sql/driver" "database/sql/driver"
"go/ast"
"reflect" "reflect"
) )
@ -19,6 +21,7 @@ type NegationExpressionBuilder interface {
type Expr struct { type Expr struct {
SQL string SQL string
Vars []interface{} Vars []interface{}
WithoutParentheses bool
} }
// Build build raw expression // Build build raw expression
@ -29,19 +32,23 @@ func (expr Expr) Build(builder Builder) {
) )
for _, v := range []byte(expr.SQL) { for _, v := range []byte(expr.SQL) {
if v == '?' { if v == '?' && len(expr.Vars) > idx {
if afterParenthesis { if afterParenthesis || expr.WithoutParentheses {
if _, ok := expr.Vars[idx].(driver.Valuer); ok { if _, ok := expr.Vars[idx].(driver.Valuer); ok {
builder.AddVar(builder, expr.Vars[idx]) builder.AddVar(builder, expr.Vars[idx])
} else { } else {
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
builder.AddVar(builder, nil)
} else {
for i := 0; i < rv.Len(); i++ { for i := 0; i < rv.Len(); i++ {
if i > 0 { if i > 0 {
builder.WriteByte(',') builder.WriteByte(',')
} }
builder.AddVar(builder, rv.Index(i).Interface()) builder.AddVar(builder, rv.Index(i).Interface())
} }
}
default: default:
builder.AddVar(builder, expr.Vars[idx]) builder.AddVar(builder, expr.Vars[idx])
} }
@ -60,6 +67,125 @@ func (expr Expr) Build(builder Builder) {
builder.WriteByte(v) builder.WriteByte(v)
} }
} }
if idx < len(expr.Vars) {
for _, v := range expr.Vars[idx:] {
builder.AddVar(builder, sql.NamedArg{Value: v})
}
}
}
// NamedExpr raw expression for named expr
type NamedExpr struct {
SQL string
Vars []interface{}
}
// Build build raw expression
func (expr NamedExpr) Build(builder Builder) {
var (
idx int
inName bool
afterParenthesis bool
namedMap = make(map[string]interface{}, len(expr.Vars))
)
for _, v := range expr.Vars {
switch value := v.(type) {
case sql.NamedArg:
namedMap[value.Name] = value.Value
case map[string]interface{}:
for k, v := range value {
namedMap[k] = v
}
default:
var appendFieldsToMap func(reflect.Value)
appendFieldsToMap = func(reflectValue reflect.Value) {
reflectValue = reflect.Indirect(reflectValue)
switch reflectValue.Kind() {
case reflect.Struct:
modelType := reflectValue.Type()
for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface()
if fieldStruct.Anonymous {
appendFieldsToMap(reflectValue.Field(i))
}
}
}
}
}
appendFieldsToMap(reflect.ValueOf(value))
}
}
name := make([]byte, 0, 10)
for _, v := range []byte(expr.SQL) {
if v == '@' && !inName {
inName = true
name = name[:0]
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' {
if inName {
if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv)
} else {
builder.WriteByte('@')
builder.WriteString(string(name))
}
inName = false
}
afterParenthesis = false
builder.WriteByte(v)
} else if v == '?' && len(expr.Vars) > idx {
if afterParenthesis {
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
builder.AddVar(builder, expr.Vars[idx])
} else {
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
builder.AddVar(builder, nil)
} else {
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
}
default:
builder.AddVar(builder, expr.Vars[idx])
}
}
} else {
builder.AddVar(builder, expr.Vars[idx])
}
idx++
} else if inName {
name = append(name, v)
} else {
if v == '(' {
afterParenthesis = true
} else {
afterParenthesis = false
}
builder.WriteByte(v)
}
}
if inName {
if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv)
} else {
builder.WriteByte('@')
builder.WriteString(string(name))
}
}
} }
// IN Whether a value is within a set of values // IN Whether a value is within a set of values
@ -75,8 +201,13 @@ func (in IN) Build(builder Builder) {
case 0: case 0:
builder.WriteString(" IN (NULL)") builder.WriteString(" IN (NULL)")
case 1: case 1:
if _, ok := in.Values[0].([]interface{}); !ok {
builder.WriteString(" = ") builder.WriteString(" = ")
builder.AddVar(builder, in.Values...) builder.AddVar(builder, in.Values[0])
break
}
fallthrough
default: default:
builder.WriteString(" IN (") builder.WriteString(" IN (")
builder.AddVar(builder, in.Values...) builder.AddVar(builder, in.Values...)
@ -85,14 +216,19 @@ func (in IN) Build(builder Builder) {
} }
func (in IN) NegationBuild(builder Builder) { func (in IN) NegationBuild(builder Builder) {
builder.WriteQuoted(in.Column)
switch len(in.Values) { switch len(in.Values) {
case 0: case 0:
builder.WriteString(" IS NOT NULL")
case 1: case 1:
builder.WriteQuoted(in.Column) if _, ok := in.Values[0].([]interface{}); !ok {
builder.WriteString(" <> ") builder.WriteString(" <> ")
builder.AddVar(builder, in.Values...) builder.AddVar(builder, in.Values[0])
break
}
fallthrough
default: default:
builder.WriteQuoted(in.Column)
builder.WriteString(" NOT IN (") builder.WriteString(" NOT IN (")
builder.AddVar(builder, in.Values...) builder.AddVar(builder, in.Values...)
builder.WriteByte(')') builder.WriteByte(')')
@ -108,16 +244,33 @@ type Eq struct {
func (eq Eq) Build(builder Builder) { func (eq Eq) Build(builder Builder) {
builder.WriteQuoted(eq.Column) builder.WriteQuoted(eq.Column)
if eq.Value == nil { switch eq.Value.(type) {
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
rv := reflect.ValueOf(eq.Value)
if rv.Len() == 0 {
builder.WriteString(" IN (NULL)")
} else {
builder.WriteString(" IN (")
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
builder.WriteByte(')')
}
default:
if eqNil(eq.Value) {
builder.WriteString(" IS NULL") builder.WriteString(" IS NULL")
} else { } else {
builder.WriteString(" = ") builder.WriteString(" = ")
builder.AddVar(builder, eq.Value) builder.AddVar(builder, eq.Value)
} }
}
} }
func (eq Eq) NegationBuild(builder Builder) { func (eq Eq) NegationBuild(builder Builder) {
Neq{eq.Column, eq.Value}.Build(builder) Neq(eq).Build(builder)
} }
// Neq not equal to for where // Neq not equal to for where
@ -126,16 +279,29 @@ type Neq Eq
func (neq Neq) Build(builder Builder) { func (neq Neq) Build(builder Builder) {
builder.WriteQuoted(neq.Column) builder.WriteQuoted(neq.Column)
if neq.Value == nil { switch neq.Value.(type) {
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
builder.WriteString(" NOT IN (")
rv := reflect.ValueOf(neq.Value)
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
builder.WriteByte(')')
default:
if eqNil(neq.Value) {
builder.WriteString(" IS NOT NULL") builder.WriteString(" IS NOT NULL")
} else { } else {
builder.WriteString(" <> ") builder.WriteString(" <> ")
builder.AddVar(builder, neq.Value) builder.AddVar(builder, neq.Value)
} }
}
} }
func (neq Neq) NegationBuild(builder Builder) { func (neq Neq) NegationBuild(builder Builder) {
Eq{neq.Column, neq.Value}.Build(builder) Eq(neq).Build(builder)
} }
// Gt greater than for where // Gt greater than for where
@ -148,7 +314,7 @@ func (gt Gt) Build(builder Builder) {
} }
func (gt Gt) NegationBuild(builder Builder) { func (gt Gt) NegationBuild(builder Builder) {
Lte{gt.Column, gt.Value}.Build(builder) Lte(gt).Build(builder)
} }
// Gte greater than or equal to for where // Gte greater than or equal to for where
@ -161,7 +327,7 @@ func (gte Gte) Build(builder Builder) {
} }
func (gte Gte) NegationBuild(builder Builder) { func (gte Gte) NegationBuild(builder Builder) {
Lt{gte.Column, gte.Value}.Build(builder) Lt(gte).Build(builder)
} }
// Lt less than for where // Lt less than for where
@ -174,7 +340,7 @@ func (lt Lt) Build(builder Builder) {
} }
func (lt Lt) NegationBuild(builder Builder) { func (lt Lt) NegationBuild(builder Builder) {
Gte{lt.Column, lt.Value}.Build(builder) Gte(lt).Build(builder)
} }
// Lte less than or equal to for where // Lte less than or equal to for where
@ -187,7 +353,7 @@ func (lte Lte) Build(builder Builder) {
} }
func (lte Lte) NegationBuild(builder Builder) { func (lte Lte) NegationBuild(builder Builder) {
Gt{lte.Column, lte.Value}.Build(builder) Gt(lte).Build(builder)
} }
// Like whether string matches regular expression // Like whether string matches regular expression
@ -204,3 +370,16 @@ func (like Like) NegationBuild(builder Builder) {
builder.WriteString(" NOT LIKE ") builder.WriteString(" NOT LIKE ")
builder.AddVar(builder, like.Value) builder.AddVar(builder, like.Value)
} }
func eqNil(value interface{}) bool {
if valuer, ok := value.(driver.Valuer); ok && !eqNilReflect(valuer) {
value, _ = valuer.Value()
}
return value == nil || eqNilReflect(value)
}
func eqNilReflect(value interface{}) bool {
reflectValue := reflect.ValueOf(value)
return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil()
}

View File

@ -1,7 +1,9 @@
package clause_test package clause_test
import ( import (
"database/sql"
"fmt" "fmt"
"reflect"
"sync" "sync"
"testing" "testing"
@ -33,3 +35,203 @@ func TestExpr(t *testing.T) {
}) })
} }
} }
func TestNamedExpr(t *testing.T) {
type Base struct {
Name2 string
}
type NamedArgument struct {
Name1 string
Base
}
results := []struct {
SQL string
Result string
Vars []interface{}
ExpectedVars []interface{}
}{{
SQL: "create table ? (? ?, ? ?)",
Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}},
Result: "create table `users` (`id` int, `name` text)",
}, {
SQL: "name1 = @name AND name2 = @name",
Vars: []interface{}{sql.Named("name", "jinzhu")},
Result: "name1 = ? AND name2 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "name1 = @name AND name2 = @@name",
Vars: []interface{}{map[string]interface{}{"name": "jinzhu"}},
Result: "name1 = ? AND name2 = @@name",
ExpectedVars: []interface{}{"jinzhu"},
}, {
SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1",
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
Result: "name1 = ? AND name2 = ? AND name3 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
}, {
SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1",
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu2"}},
Result: "name1 = ? AND name2 = ? AND name3 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
}, {
SQL: "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist",
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? @notexist",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
}, {
SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @notexist",
Vars: []interface{}{NamedArgument{Name1: "jinzhu", Base: Base{Name2: "jinzhu2"}}},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? @notexist",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
}, {
SQL: "create table ? (? ?, ? ?)",
Vars: []interface{}{},
Result: "create table ? (? ?, ? ?)",
}, {
SQL: "name1 = @name AND name2 = @name;",
Vars: []interface{}{sql.Named("name", "jinzhu")},
Result: "name1 = ? AND name2 = ?;",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "name1 = @name1\r\n AND name2 = @name2",
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}},
Result: "name1 = ?\r\n AND name2 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "name1 = @name1\r AND name2 = @name2",
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}},
Result: "name1 = ?\r AND name2 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "?",
Vars: []interface{}{clause.Column{Table: "table", Name: "col"}},
Result: "`table`.`col`",
}, {
SQL: "?",
Vars: []interface{}{clause.Column{Table: "table", Name: "col", Raw: true}},
Result: "table.col",
}, {
SQL: "?",
Vars: []interface{}{clause.Column{Table: "table", Name: clause.PrimaryKey, Raw: true}},
Result: "table.id",
}, {
SQL: "?",
Vars: []interface{}{clause.Column{Table: "table", Name: "col", Alias: "alias"}},
Result: "`table`.`col` AS `alias`",
}, {
SQL: "?",
Vars: []interface{}{clause.Column{Table: "table", Name: "col", Alias: "alias", Raw: true}},
Result: "table.col AS alias",
}, {
SQL: "?",
Vars: []interface{}{clause.Table{Name: "table", Alias: "alias"}},
Result: "`table` `alias`",
}, {
SQL: "?",
Vars: []interface{}{clause.Table{Name: "table", Alias: "alias", Raw: true}},
Result: "table alias",
}}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clause.NamedExpr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
if stmt.SQL.String() != result.Result {
t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
}
if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) {
t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars)
}
})
}
}
func TestExpression(t *testing.T) {
column := "column-name"
results := []struct {
Expressions []clause.Expression
ExpectedVars []interface{}
Result string
}{{
Expressions: []clause.Expression{
clause.Eq{Column: column, Value: "column-value"},
},
ExpectedVars: []interface{}{"column-value"},
Result: "`column-name` = ?",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: column, Value: nil},
clause.Eq{Column: column, Value: (*string)(nil)},
clause.Eq{Column: column, Value: (*int)(nil)},
clause.Eq{Column: column, Value: (*bool)(nil)},
clause.Eq{Column: column, Value: (interface{})(nil)},
clause.Eq{Column: column, Value: sql.NullString{String: "", Valid: false}},
},
Result: "`column-name` IS NULL",
}, {
Expressions: []clause.Expression{
clause.Neq{Column: column, Value: "column-value"},
},
ExpectedVars: []interface{}{"column-value"},
Result: "`column-name` <> ?",
}, {
Expressions: []clause.Expression{
clause.Neq{Column: column, Value: nil},
clause.Neq{Column: column, Value: (*string)(nil)},
clause.Neq{Column: column, Value: (*int)(nil)},
clause.Neq{Column: column, Value: (*bool)(nil)},
clause.Neq{Column: column, Value: (interface{})(nil)},
},
Result: "`column-name` IS NOT NULL",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: column, Value: []string{"a", "b"}},
},
ExpectedVars: []interface{}{"a", "b"},
Result: "`column-name` IN (?,?)",
}, {
Expressions: []clause.Expression{
clause.Neq{Column: column, Value: []string{"a", "b"}},
},
ExpectedVars: []interface{}{"a", "b"},
Result: "`column-name` NOT IN (?,?)",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: column, Value: []string{}},
},
Result: "`column-name` IN (NULL)",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100},
},
ExpectedVars: []interface{}{100},
Result: "SUM(`id`) = ?",
}, {
Expressions: []clause.Expression{
clause.Gte{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Table: "users", Name: "id"}}}, Value: 100},
},
ExpectedVars: []interface{}{100},
Result: "SUM(`users`.`id`) >= ?",
}}
for idx, result := range results {
for idy, expression := range result.Expressions {
t.Run(fmt.Sprintf("case #%v.%v", idx, idy), func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
expression.Build(stmt)
if stmt.SQL.String() != result.Result {
t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
}
if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) {
t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars)
}
})
}
}
}

View File

@ -33,9 +33,5 @@ func (from From) Build(builder Builder) {
// MergeClause merge from clause // MergeClause merge from clause
func (from From) MergeClause(clause *Clause) { func (from From) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(From); ok {
from.Tables = append(v.Tables, from.Tables...)
from.Joins = append(v.Joins, from.Joins...)
}
clause.Expression = from clause.Expression = from
} }

View File

@ -38,6 +38,16 @@ func TestFrom(t *testing.T) {
[]clause.Interface{ []clause.Interface{
clause.Select{}, clause.From{ clause.Select{}, clause.From{
Tables: []clause.Table{{Name: "users"}}, Tables: []clause.Table{{Name: "users"}},
Joins: []clause.Join{
{
Type: clause.RightJoin,
Table: clause.Table{Name: "profiles"},
ON: clause.Where{
[]clause.Expression{clause.Eq{clause.Column{Table: "profiles", Name: "email"}, clause.Column{Table: clause.CurrentTable, Name: "email"}}},
},
},
},
}, clause.From{
Joins: []clause.Join{ Joins: []clause.Join{
{ {
Type: clause.InnerJoin, Type: clause.InnerJoin,
@ -51,19 +61,9 @@ func TestFrom(t *testing.T) {
Using: []string{"company_name"}, Using: []string{"company_name"},
}, },
}, },
}, clause.From{
Joins: []clause.Join{
{
Type: clause.RightJoin,
Table: clause.Table{Name: "profiles"},
ON: clause.Where{
[]clause.Expression{clause.Eq{clause.Column{Table: "profiles", Name: "email"}, clause.Column{Table: clause.CurrentTable, Name: "email"}}},
}, },
}, },
}, "SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id` LEFT JOIN `companies` USING (`company_name`)", nil,
},
},
"SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id` LEFT JOIN `companies` USING (`company_name`) RIGHT JOIN `profiles` ON `profiles`.`email` = `users`.`email`", nil,
}, },
} }

View File

@ -30,8 +30,19 @@ func (groupBy GroupBy) Build(builder Builder) {
// MergeClause merge group by clause // MergeClause merge group by clause
func (groupBy GroupBy) MergeClause(clause *Clause) { func (groupBy GroupBy) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(GroupBy); ok { if v, ok := clause.Expression.(GroupBy); ok {
groupBy.Columns = append(v.Columns, groupBy.Columns...) copiedColumns := make([]Column, len(v.Columns))
groupBy.Having = append(v.Having, groupBy.Having...) copy(copiedColumns, v.Columns)
groupBy.Columns = append(copiedColumns, groupBy.Columns...)
copiedHaving := make([]Expression, len(v.Having))
copy(copiedHaving, v.Having)
groupBy.Having = append(copiedHaving, groupBy.Having...)
} }
clause.Expression = groupBy clause.Expression = groupBy
if len(groupBy.Columns) == 0 {
clause.Name = ""
} else {
clause.Name = groupBy.Name()
}
} }

View File

@ -18,7 +18,8 @@ func TestGroupBy(t *testing.T) {
Columns: []clause.Column{{Name: "role"}}, Columns: []clause.Column{{Name: "role"}},
Having: []clause.Expression{clause.Eq{"role", "admin"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}},
}}, }},
"SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", []interface{}{"admin"}, "SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?",
[]interface{}{"admin"},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{
@ -28,7 +29,8 @@ func TestGroupBy(t *testing.T) {
Columns: []clause.Column{{Name: "gender"}}, Columns: []clause.Column{{Name: "gender"}},
Having: []clause.Expression{clause.Neq{"gender", "U"}}, Having: []clause.Expression{clause.Neq{"gender", "U"}},
}}, }},
"SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", []interface{}{"admin", "U"}, "SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?",
[]interface{}{"admin", "U"},
}, },
} }

View File

@ -1,15 +1,41 @@
package clause package clause
import "gorm.io/gorm/utils"
type JoinType string type JoinType string
const ( const (
CrossJoin JoinType = "CROSS" CrossJoin JoinType = "CROSS"
InnerJoin = "INNER" InnerJoin JoinType = "INNER"
LeftJoin = "LEFT" LeftJoin JoinType = "LEFT"
RightJoin = "RIGHT" RightJoin JoinType = "RIGHT"
) )
// Join join clause for from type JoinTarget struct {
Type JoinType
Association string
Subquery Expression
Table string
}
func Has(name string) JoinTarget {
return JoinTarget{Type: InnerJoin, Association: name}
}
func (jt JoinType) Association(name string) JoinTarget {
return JoinTarget{Type: jt, Association: name}
}
func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
return JoinTarget{Type: jt, Association: name, Subquery: subquery}
}
func (jt JoinTarget) As(name string) JoinTarget {
jt.Table = name
return jt
}
// Join clause for from
type Join struct { type Join struct {
Type JoinType Type JoinType
Table Table Table Table
@ -18,6 +44,12 @@ type Join struct {
Expression Expression Expression Expression
} }
func JoinTable(names ...string) Table {
return Table{
Name: utils.JoinNestedRelationNames(names),
}
}
func (join Join) Build(builder Builder) { func (join Join) Build(builder Builder) {
if join.Expression != nil { if join.Expression != nil {
join.Expression.Build(builder) join.Expression.Build(builder)

101
clause/joins_test.go Normal file
View File

@ -0,0 +1,101 @@
package clause_test
import (
"sync"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
func TestJoin(t *testing.T) {
results := []struct {
name string
join clause.Join
sql string
}{
{
name: "LEFT JOIN",
join: clause.Join{
Type: clause.LeftJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "RIGHT JOIN",
join: clause.Join{
Type: clause.RightJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "INNER JOIN",
join: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "CROSS JOIN",
join: clause.Join{
Type: clause.CrossJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "USING",
join: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
Using: []string{"id"},
},
sql: "INNER JOIN `user` USING (`id`)",
},
{
name: "Expression",
join: clause.Join{
// Invalid
Type: clause.LeftJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
// Valid
Expression: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
Using: []string{"id"},
},
},
sql: "INNER JOIN `user` USING (`id`)",
},
}
for _, result := range results {
t.Run(result.name, func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
result.join.Build(stmt)
if result.sql != stmt.SQL.String() {
t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String())
}
})
}
}

View File

@ -1,10 +1,8 @@
package clause package clause
import "strconv"
// Limit limit clause // Limit limit clause
type Limit struct { type Limit struct {
Limit int Limit *int
Offset int Offset int
} }
@ -15,14 +13,16 @@ func (limit Limit) Name() string {
// Build build where clause // Build build where clause
func (limit Limit) Build(builder Builder) { func (limit Limit) Build(builder Builder) {
if limit.Limit > 0 { if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteString("LIMIT ") builder.WriteString("LIMIT ")
builder.WriteString(strconv.Itoa(limit.Limit)) builder.AddVar(builder, *limit.Limit)
if limit.Offset > 0 {
builder.WriteString(" OFFSET ")
builder.WriteString(strconv.Itoa(limit.Offset))
} }
if limit.Offset > 0 {
if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteByte(' ')
}
builder.WriteString("OFFSET ")
builder.AddVar(builder, limit.Offset)
} }
} }
@ -31,10 +31,8 @@ func (limit Limit) MergeClause(clause *Clause) {
clause.Name = "" clause.Name = ""
if v, ok := clause.Expression.(Limit); ok { if v, ok := clause.Expression.(Limit); ok {
if limit.Limit == 0 && v.Limit > 0 { if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil {
limit.Limit = v.Limit limit.Limit = v.Limit
} else if limit.Limit < 0 {
limit.Limit = 0
} }
if limit.Offset == 0 && v.Offset > 0 { if limit.Offset == 0 && v.Offset > 0 {

View File

@ -8,6 +8,10 @@ import (
) )
func TestLimit(t *testing.T) { func TestLimit(t *testing.T) {
limit0 := 0
limit10 := 10
limit50 := 50
limitNeg10 := -10
results := []struct { results := []struct {
Clauses []clause.Interface Clauses []clause.Interface
Result string Result string
@ -15,26 +19,56 @@ func TestLimit(t *testing.T) {
}{ }{
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{ []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{
Limit: 10, Limit: &limit10,
Offset: 20, Offset: 20,
}}, }},
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, "SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit10, 20},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
"SELECT * FROM `users` LIMIT 10 OFFSET 30", nil, "SELECT * FROM `users` LIMIT ?",
[]interface{}{limit0},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}},
"SELECT * FROM `users` LIMIT 10", nil, "SELECT * FROM `users` LIMIT ?",
[]interface{}{limit0},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
"SELECT * FROM `users`", nil, "SELECT * FROM `users` OFFSET ?",
[]interface{}{20},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` LIMIT 50 OFFSET 30", nil, "SELECT * FROM `users` OFFSET ?",
[]interface{}{30},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}},
"SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit10, 20},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit10, 30},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
"SELECT * FROM `users` LIMIT ?",
[]interface{}{limit10},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}},
"SELECT * FROM `users` OFFSET ?",
[]interface{}{30},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}},
"SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit50, 30},
}, },
} }

View File

@ -1,8 +1,11 @@
package clause package clause
type For struct { const (
Lockings []Locking LockingStrengthUpdate = "UPDATE"
} LockingStrengthShare = "SHARE"
LockingOptionsSkipLocked = "SKIP LOCKED"
LockingOptionsNoWait = "NOWAIT"
)
type Locking struct { type Locking struct {
Strength string Strength string
@ -11,18 +14,12 @@ type Locking struct {
} }
// Name where clause name // Name where clause name
func (f For) Name() string { func (locking Locking) Name() string {
return "FOR" return "FOR"
} }
// Build build where clause // Build build where clause
func (f For) Build(builder Builder) { func (locking Locking) Build(builder Builder) {
for idx, locking := range f.Lockings {
if idx > 0 {
builder.WriteByte(' ')
}
builder.WriteString("FOR ")
builder.WriteString(locking.Strength) builder.WriteString(locking.Strength)
if locking.Table.Name != "" { if locking.Table.Name != "" {
builder.WriteString(" OF ") builder.WriteString(" OF ")
@ -33,16 +30,9 @@ func (f For) Build(builder Builder) {
builder.WriteByte(' ') builder.WriteByte(' ')
builder.WriteString(locking.Options) builder.WriteString(locking.Options)
} }
}
} }
// MergeClause merge order by clauses // MergeClause merge order by clauses
func (f For) MergeClause(clause *Clause) { func (locking Locking) MergeClause(clause *Clause) {
clause.Name = "" clause.Expression = locking
if v, ok := clause.Expression.(For); ok {
f.Lockings = append(v.Lockings, f.Lockings...)
}
clause.Expression = f
} }

View File

@ -7,31 +7,27 @@ import (
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
func TestFor(t *testing.T) { func TestLocking(t *testing.T) {
results := []struct { results := []struct {
Clauses []clause.Interface Clauses []clause.Interface
Result string Result string
Vars []interface{} Vars []interface{}
}{ }{
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.For{ []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate}},
Lockings: []clause.Locking{{Strength: "UPDATE"}},
}},
"SELECT * FROM `users` FOR UPDATE", nil, "SELECT * FROM `users` FOR UPDATE", nil,
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.For{ []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthShare, Table: clause.Table{Name: clause.CurrentTable}}},
Lockings: []clause.Locking{{Strength: "UPDATE"}, {Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}}, "SELECT * FROM `users` FOR SHARE OF `users`", nil,
}},
"SELECT * FROM `users` FOR UPDATE FOR SHARE OF `users`", nil,
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.For{ []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsNoWait}},
Lockings: []clause.Locking{{Strength: "UPDATE"}, {Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}}, "SELECT * FROM `users` FOR UPDATE NOWAIT", nil,
}, clause.For{ },
Lockings: []clause.Locking{{Strength: "UPDATE", Options: "NOWAIT"}}, {
}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsSkipLocked}},
"SELECT * FROM `users` FOR UPDATE FOR SHARE OF `users` FOR UPDATE NOWAIT", nil, "SELECT * FROM `users` FOR UPDATE SKIP LOCKED", nil,
}, },
} }

View File

@ -3,8 +3,11 @@ package clause
type OnConflict struct { type OnConflict struct {
Columns []Column Columns []Column
Where Where Where Where
TargetWhere Where
OnConstraint string
DoNothing bool DoNothing bool
DoUpdates Set DoUpdates Set
UpdateAll bool
} }
func (OnConflict) Name() string { func (OnConflict) Name() string {
@ -13,16 +16,28 @@ func (OnConflict) Name() string {
// Build build onConflict clause // Build build onConflict clause
func (onConflict OnConflict) Build(builder Builder) { func (onConflict OnConflict) Build(builder Builder) {
if len(onConflict.Columns) > 0 { if onConflict.OnConstraint != "" {
builder.WriteQuoted(onConflict.Columns) // FIXME columns builder.WriteString("ON CONSTRAINT ")
builder.WriteString(onConflict.OnConstraint)
builder.WriteByte(' ') builder.WriteByte(' ')
} else {
if len(onConflict.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range onConflict.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
builder.WriteString(`) `)
} }
if len(onConflict.Where.Exprs) > 0 { if len(onConflict.TargetWhere.Exprs) > 0 {
builder.WriteString("WHERE ") builder.WriteString(" WHERE ")
onConflict.Where.Build(builder) onConflict.TargetWhere.Build(builder)
builder.WriteByte(' ') builder.WriteByte(' ')
} }
}
if onConflict.DoNothing { if onConflict.DoNothing {
builder.WriteString("DO NOTHING") builder.WriteString("DO NOTHING")
@ -30,6 +45,12 @@ func (onConflict OnConflict) Build(builder Builder) {
builder.WriteString("DO UPDATE SET ") builder.WriteString("DO UPDATE SET ")
onConflict.DoUpdates.Build(builder) onConflict.DoUpdates.Build(builder)
} }
if len(onConflict.Where.Exprs) > 0 {
builder.WriteString(" WHERE ")
onConflict.Where.Build(builder)
builder.WriteByte(' ')
}
} }
// MergeClause merge onConflict clauses // MergeClause merge onConflict clauses

View File

@ -8,6 +8,7 @@ type OrderByColumn struct {
type OrderBy struct { type OrderBy struct {
Columns []OrderByColumn Columns []OrderByColumn
Expression Expression
} }
// Name where clause name // Name where clause name
@ -17,6 +18,9 @@ func (orderBy OrderBy) Name() string {
// Build build where clause // Build build where clause
func (orderBy OrderBy) Build(builder Builder) { func (orderBy OrderBy) Build(builder Builder) {
if orderBy.Expression != nil {
orderBy.Expression.Build(builder)
} else {
for idx, column := range orderBy.Columns { for idx, column := range orderBy.Columns {
if idx > 0 { if idx > 0 {
builder.WriteByte(',') builder.WriteByte(',')
@ -27,6 +31,7 @@ func (orderBy OrderBy) Build(builder Builder) {
builder.WriteString(" DESC") builder.WriteString(" DESC")
} }
} }
}
} }
// MergeClause merge order by clauses // MergeClause merge order by clauses
@ -40,7 +45,9 @@ func (orderBy OrderBy) MergeClause(clause *Clause) {
} }
} }
orderBy.Columns = append(v.Columns, orderBy.Columns...) copiedColumns := make([]OrderByColumn, len(v.Columns))
copy(copiedColumns, v.Columns)
orderBy.Columns = append(copiedColumns, orderBy.Columns...)
} }
clause.Expression = orderBy clause.Expression = orderBy

View File

@ -39,6 +39,15 @@ func TestOrderBy(t *testing.T) {
}, },
"SELECT * FROM `users` ORDER BY `name`", nil, "SELECT * FROM `users` ORDER BY `name`", nil,
}, },
{
[]clause.Interface{
clause.Select{}, clause.From{}, clause.OrderBy{
Expression: clause.Expr{SQL: "FIELD(id, ?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true},
},
},
"SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)",
[]interface{}{1, 2, 3},
},
} }
for idx, result := range results { for idx, result := range results {

View File

@ -11,6 +11,7 @@ func (returning Returning) Name() string {
// Build build where clause // Build build where clause
func (returning Returning) Build(builder Builder) { func (returning Returning) Build(builder Builder) {
if len(returning.Columns) > 0 {
for idx, column := range returning.Columns { for idx, column := range returning.Columns {
if idx > 0 { if idx > 0 {
builder.WriteByte(',') builder.WriteByte(',')
@ -18,13 +19,19 @@ func (returning Returning) Build(builder Builder) {
builder.WriteQuoted(column) builder.WriteQuoted(column)
} }
} else {
builder.WriteByte('*')
}
} }
// MergeClause merge order by clauses // MergeClause merge order by clauses
func (returning Returning) MergeClause(clause *Clause) { func (returning Returning) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Returning); ok { if v, ok := clause.Expression.(Returning); ok && len(returning.Columns) > 0 {
if v.Columns != nil {
returning.Columns = append(v.Columns, returning.Columns...) returning.Columns = append(v.Columns, returning.Columns...)
} else {
returning.Columns = nil
}
} }
clause.Expression = returning clause.Expression = returning
} }

View File

@ -26,6 +26,22 @@ func TestReturning(t *testing.T) {
}}, }},
"SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil, "SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil,
}, },
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
[]clause.Column{clause.PrimaryColumn},
}, clause.Returning{}, clause.Returning{
[]clause.Column{{Name: "name"}, {Name: "age"}},
}},
"SELECT * FROM `users` RETURNING *", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
[]clause.Column{clause.PrimaryColumn},
}, clause.Returning{
[]clause.Column{{Name: "name"}, {Name: "age"}},
}, clause.Returning{}},
"SELECT * FROM `users` RETURNING *", nil,
},
} }
for idx, result := range results { for idx, result := range results {

View File

@ -2,6 +2,7 @@ package clause
// Select select attrs when querying, updating, creating // Select select attrs when querying, updating, creating
type Select struct { type Select struct {
Distinct bool
Columns []Column Columns []Column
Expression Expression Expression Expression
} }
@ -12,6 +13,10 @@ func (s Select) Name() string {
func (s Select) Build(builder Builder) { func (s Select) Build(builder Builder) {
if len(s.Columns) > 0 { if len(s.Columns) > 0 {
if s.Distinct {
builder.WriteString("DISTINCT ")
}
for idx, column := range s.Columns { for idx, column := range s.Columns {
if idx > 0 { if idx > 0 {
builder.WriteByte(',') builder.WriteByte(',')
@ -25,8 +30,30 @@ func (s Select) Build(builder Builder) {
func (s Select) MergeClause(clause *Clause) { func (s Select) MergeClause(clause *Clause) {
if s.Expression != nil { if s.Expression != nil {
if s.Distinct {
if expr, ok := s.Expression.(Expr); ok {
expr.SQL = "DISTINCT " + expr.SQL
clause.Expression = expr
return
}
}
clause.Expression = s.Expression clause.Expression = s.Expression
} else { } else {
clause.Expression = s clause.Expression = s
} }
} }
// CommaExpression represents a group of expressions separated by commas.
type CommaExpression struct {
Exprs []Expression
}
func (comma CommaExpression) Build(builder Builder) {
for idx, expr := range comma.Exprs {
if idx > 0 {
_, _ = builder.WriteString(", ")
}
expr.Build(builder)
}
}

View File

@ -31,6 +31,37 @@ func TestSelect(t *testing.T) {
}, clause.From{}}, }, clause.From{}},
"SELECT `name` FROM `users`", nil, "SELECT `name` FROM `users`", nil,
}, },
{
[]clause.Interface{clause.Select{
Expression: clause.CommaExpression{
Exprs: []clause.Expression{
clause.NamedExpr{"?", []interface{}{clause.Column{Name: "id"}}},
clause.NamedExpr{"?", []interface{}{clause.Column{Name: "name"}}},
clause.NamedExpr{"LENGTH(?)", []interface{}{clause.Column{Name: "mobile"}}},
},
},
}, clause.From{}},
"SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil,
},
{
[]clause.Interface{clause.Select{
Expression: clause.CommaExpression{
Exprs: []clause.Expression{
clause.Expr{
SQL: "? as name",
Vars: []interface{}{
clause.Eq{
Column: clause.Column{Name: "age"},
Value: 18,
},
},
},
},
},
}, clause.From{}},
"SELECT `age` = ? as name FROM `users`",
[]interface{}{18},
},
} }
for idx, result := range results { for idx, result := range results {

View File

@ -1,5 +1,7 @@
package clause package clause
import "sort"
type Set []Assignment type Set []Assignment
type Assignment struct { type Assignment struct {
@ -22,13 +24,37 @@ func (set Set) Build(builder Builder) {
builder.AddVar(builder, assignment.Value) builder.AddVar(builder, assignment.Value)
} }
} else { } else {
builder.WriteQuoted(PrimaryColumn) builder.WriteQuoted(Column{Name: PrimaryKey})
builder.WriteByte('=') builder.WriteByte('=')
builder.WriteQuoted(PrimaryColumn) builder.WriteQuoted(Column{Name: PrimaryKey})
} }
} }
// MergeClause merge assignments clauses // MergeClause merge assignments clauses
func (set Set) MergeClause(clause *Clause) { func (set Set) MergeClause(clause *Clause) {
clause.Expression = set copiedAssignments := make([]Assignment, len(set))
copy(copiedAssignments, set)
clause.Expression = Set(copiedAssignments)
}
func Assignments(values map[string]interface{}) Set {
keys := make([]string, 0, len(values))
for key := range values {
keys = append(keys, key)
}
sort.Strings(keys)
assignments := make([]Assignment, len(keys))
for idx, key := range keys {
assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]}
}
return assignments
}
func AssignmentColumns(values []string) Set {
assignments := make([]Assignment, len(values))
for idx, value := range values {
assignments[idx] = Assignment{Column: Column{Name: value}, Value: Column{Table: "excluded", Name: value}}
}
return assignments
} }

View File

@ -2,6 +2,8 @@ package clause_test
import ( import (
"fmt" "fmt"
"sort"
"strings"
"testing" "testing"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -18,7 +20,8 @@ func TestSet(t *testing.T) {
clause.Update{}, clause.Update{},
clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}),
}, },
"UPDATE `users` SET `users`.`id`=?", []interface{}{1}, "UPDATE `users` SET `users`.`id`=?",
[]interface{}{1},
}, },
{ {
[]clause.Interface{ []clause.Interface{
@ -26,7 +29,8 @@ func TestSet(t *testing.T) {
clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}),
clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}), clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}),
}, },
"UPDATE `users` SET `name`=?", []interface{}{"jinzhu"}, "UPDATE `users` SET `name`=?",
[]interface{}{"jinzhu"},
}, },
} }
@ -36,3 +40,20 @@ func TestSet(t *testing.T) {
}) })
} }
} }
func TestAssignments(t *testing.T) {
set := clause.Assignments(map[string]interface{}{
"name": "jinzhu",
"age": 18,
})
assignments := []clause.Assignment(set)
sort.Slice(assignments, func(i, j int) bool {
return strings.Compare(assignments[i].Column.Name, assignments[j].Column.Name) > 0
})
if len(assignments) != 2 || assignments[0].Column.Name != "name" || assignments[0].Value.(string) != "jinzhu" || assignments[1].Column.Name != "age" || assignments[1].Value.(int) != 18 {
t.Errorf("invalid assignments, got %v", assignments)
}
}

View File

@ -21,7 +21,8 @@ func TestValues(t *testing.T) {
Values: [][]interface{}{{"jinzhu", 18}, {"josh", 1}}, Values: [][]interface{}{{"jinzhu", 18}, {"josh", 1}},
}, },
}, },
"INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", []interface{}{"jinzhu", 18, "josh", 1}, "INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)",
[]interface{}{"jinzhu", 18, "josh", 1},
}, },
} }

View File

@ -1,5 +1,14 @@
package clause package clause
import (
"strings"
)
const (
AndWithSpace = " AND "
OrWithSpace = " OR "
)
// Where where clause // Where where clause
type Where struct { type Where struct {
Exprs []Expression Exprs []Expression
@ -12,9 +21,15 @@ func (where Where) Name() string {
// Build build where clause // Build build where clause
func (where Where) Build(builder Builder) { func (where Where) Build(builder Builder) {
if len(where.Exprs) == 1 {
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
where.Exprs = andCondition.Exprs
}
}
// Switch position if the first query expression is a single Or condition // Switch position if the first query expression is a single Or condition
for idx, expr := range where.Exprs { for idx, expr := range where.Exprs {
if v, ok := expr.(OrConditions); (!ok && expr != nil) || len(v.Exprs) > 1 { if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
if idx != 0 { if idx != 0 {
where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0] where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0]
} }
@ -22,27 +37,64 @@ func (where Where) Build(builder Builder) {
} }
} }
for idx, expr := range where.Exprs { buildExprs(where.Exprs, builder, AndWithSpace)
if expr != nil { }
func buildExprs(exprs []Expression, builder Builder, joinCond string) {
wrapInParentheses := false
for idx, expr := range exprs {
if idx > 0 { if idx > 0 {
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
builder.WriteString(" OR ") builder.WriteString(OrWithSpace)
} else { } else {
builder.WriteString(" AND ") builder.WriteString(joinCond)
} }
} }
if len(exprs) > 1 {
switch v := expr.(type) {
case OrConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToUpper(e.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
}
case AndConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToUpper(e.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
}
case Expr:
sql := strings.ToUpper(v.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
case NamedExpr:
sql := strings.ToUpper(v.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
}
if wrapInParentheses {
builder.WriteByte('(')
expr.Build(builder)
builder.WriteByte(')')
wrapInParentheses = false
} else {
expr.Build(builder) expr.Build(builder)
} }
} }
return
} }
// MergeClause merge where clauses // MergeClause merge where clauses
func (where Where) MergeClause(clause *Clause) { func (where Where) MergeClause(clause *Clause) {
if w, ok := clause.Expression.(Where); ok { if w, ok := clause.Expression.(Where); ok {
where.Exprs = append(w.Exprs, where.Exprs...) exprs := make([]Expression, len(w.Exprs)+len(where.Exprs))
copy(exprs, w.Exprs)
copy(exprs[len(w.Exprs):], where.Exprs)
where.Exprs = exprs
} }
clause.Expression = where clause.Expression = where
@ -52,6 +104,13 @@ func And(exprs ...Expression) Expression {
if len(exprs) == 0 { if len(exprs) == 0 {
return nil return nil
} }
if len(exprs) == 1 {
if _, ok := exprs[0].(OrConditions); !ok {
return exprs[0]
}
}
return AndConditions{Exprs: exprs} return AndConditions{Exprs: exprs}
} }
@ -62,15 +121,10 @@ type AndConditions struct {
func (and AndConditions) Build(builder Builder) { func (and AndConditions) Build(builder Builder) {
if len(and.Exprs) > 1 { if len(and.Exprs) > 1 {
builder.WriteByte('(') builder.WriteByte('(')
} buildExprs(and.Exprs, builder, AndWithSpace)
for idx, c := range and.Exprs {
if idx > 0 {
builder.WriteString(" AND ")
}
c.Build(builder)
}
if len(and.Exprs) > 1 {
builder.WriteByte(')') builder.WriteByte(')')
} else {
buildExprs(and.Exprs, builder, AndWithSpace)
} }
} }
@ -88,15 +142,10 @@ type OrConditions struct {
func (or OrConditions) Build(builder Builder) { func (or OrConditions) Build(builder Builder) {
if len(or.Exprs) > 1 { if len(or.Exprs) > 1 {
builder.WriteByte('(') builder.WriteByte('(')
} buildExprs(or.Exprs, builder, OrWithSpace)
for idx, c := range or.Exprs {
if idx > 0 {
builder.WriteString(" OR ")
}
c.Build(builder)
}
if len(or.Exprs) > 1 {
builder.WriteByte(')') builder.WriteByte(')')
} else {
buildExprs(or.Exprs, builder, OrWithSpace)
} }
} }
@ -104,6 +153,11 @@ func Not(exprs ...Expression) Expression {
if len(exprs) == 0 { if len(exprs) == 0 {
return nil return nil
} }
if len(exprs) == 1 {
if andCondition, ok := exprs[0].(AndConditions); ok {
exprs = andCondition.Exprs
}
}
return NotConditions{Exprs: exprs} return NotConditions{Exprs: exprs}
} }
@ -112,22 +166,80 @@ type NotConditions struct {
} }
func (not NotConditions) Build(builder Builder) { func (not NotConditions) Build(builder Builder) {
anyNegationBuilder := false
for _, c := range not.Exprs {
if _, ok := c.(NegationExpressionBuilder); ok {
anyNegationBuilder = true
break
}
}
if anyNegationBuilder {
if len(not.Exprs) > 1 { if len(not.Exprs) > 1 {
builder.WriteByte('(') builder.WriteByte('(')
} }
for idx, c := range not.Exprs { for idx, c := range not.Exprs {
if idx > 0 { if idx > 0 {
builder.WriteString(" AND ") builder.WriteString(AndWithSpace)
} }
if negationBuilder, ok := c.(NegationExpressionBuilder); ok { if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder) negationBuilder.NegationBuild(builder)
} else { } else {
builder.WriteString(" NOT ") builder.WriteString("NOT ")
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToUpper(e.SQL)
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
builder.WriteByte('(')
}
}
c.Build(builder) c.Build(builder)
if wrapInParentheses {
builder.WriteByte(')')
} }
} }
}
if len(not.Exprs) > 1 { if len(not.Exprs) > 1 {
builder.WriteByte(')') builder.WriteByte(')')
} }
} else {
builder.WriteString("NOT ")
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}
for idx, c := range not.Exprs {
if idx > 0 {
switch c.(type) {
case OrConditions:
builder.WriteString(OrWithSpace)
default:
builder.WriteString(AndWithSpace)
}
}
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToUpper(e.SQL)
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
builder.WriteByte('(')
}
}
c.Build(builder)
if wrapInParentheses {
builder.WriteByte(')')
}
}
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
}
} }

View File

@ -17,25 +17,29 @@ func TestWhere(t *testing.T) {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{ []clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
}}, }},
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?", []interface{}{"1", 18, "jinzhu"}, "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?",
[]interface{}{"1", 18, "jinzhu"},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{ []clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}},
}}, }},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18}, "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?",
[]interface{}{"1", "jinzhu", 18},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{ []clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}},
}}, }},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18}, "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?",
[]interface{}{"1", "jinzhu", 18},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{ []clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, Exprs: []clause.Expression{clause.Or(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
}}, }},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?", []interface{}{"1", "jinzhu"}, "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?",
[]interface{}{"1", "jinzhu"},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{ []clause.Interface{clause.Select{}, clause.From{}, clause.Where{
@ -43,7 +47,8 @@ func TestWhere(t *testing.T) {
}, clause.Where{ }, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"})}, Exprs: []clause.Expression{clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"})},
}}, }},
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)",
[]interface{}{"1", 18, "jinzhu", 100, "%linus%"},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{ []clause.Interface{clause.Select{}, clause.From{}, clause.Where{
@ -51,7 +56,78 @@ func TestWhere(t *testing.T) {
}, clause.Where{ }, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Not(clause.Gt{Column: "score", Value: 100}), clause.Like{Column: "name", Value: "%linus%"})}, Exprs: []clause.Expression{clause.Or(clause.Not(clause.Gt{Column: "score", Value: 100}), clause.Like{Column: "name", Value: "%linus%"})},
}}, }},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)",
[]interface{}{"1", 18, "jinzhu", 100, "%linus%"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))},
}},
"SELECT * FROM `users` WHERE `age` = ? OR `name` <> ?",
[]interface{}{18, "jinzhu"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{
clause.And(clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}),
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})),
},
}},
"SELECT * FROM `users` WHERE `users`.`id` <> ? AND `score` <= ?",
[]interface{}{"1", 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}))},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)",
[]interface{}{"1", 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}},
clause.Expr{SQL: "`age` <= ?", Vars: []interface{}{60}})},
}},
"SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)",
[]interface{}{100, 60},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{
clause.Not(clause.AndConditions{
Exprs: []clause.Expression{
clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.Gt{Column: "age", Value: 18},
}}, clause.OrConditions{
Exprs: []clause.Expression{
clause.Lt{Column: "score", Value: 100},
},
}),
}}},
"SELECT * FROM `users` WHERE NOT ((`users`.`id` = ? AND `age` > ?) OR `score` < ?)",
[]interface{}{"1", 18, 100},
}, },
} }

View File

@ -1,4 +1,3 @@
package clause package clause
type With struct { type With struct{}
}

View File

@ -2,25 +2,53 @@ package gorm
import ( import (
"errors" "errors"
"gorm.io/gorm/logger"
) )
var ( var (
// ErrRecordNotFound record not found error // ErrRecordNotFound record not found error
ErrRecordNotFound = errors.New("record not found") ErrRecordNotFound = logger.ErrRecordNotFound
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
ErrInvalidSQL = errors.New("invalid SQL")
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
ErrInvalidTransaction = errors.New("no valid transaction") ErrInvalidTransaction = errors.New("invalid transaction")
// ErrUnaddressable unaddressable value
ErrUnaddressable = errors.New("using unaddressable value")
// ErrNotImplemented not implemented // ErrNotImplemented not implemented
ErrNotImplemented = errors.New("not implemented") ErrNotImplemented = errors.New("not implemented")
// ErrMissingWhereClause missing where clause // ErrMissingWhereClause missing where clause
ErrMissingWhereClause = errors.New("WHERE conditions required") ErrMissingWhereClause = errors.New("WHERE conditions required")
// ErrUnsupportedRelation unsupported relations // ErrUnsupportedRelation unsupported relations
ErrUnsupportedRelation = errors.New("unsupported relations") ErrUnsupportedRelation = errors.New("unsupported relations")
// ErrPtrStructSupported only ptr of struct supported // ErrPrimaryKeyRequired primary keys required
ErrPtrStructSupported = errors.New("only ptr of struct supported") ErrPrimaryKeyRequired = errors.New("primary key required")
// ErrorPrimaryKeyRequired primary keys required // ErrModelValueRequired model value required
ErrorPrimaryKeyRequired = errors.New("primary key required") ErrModelValueRequired = errors.New("model value required")
// ErrModelAccessibleFieldsRequired model accessible fields required
ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required")
// ErrSubQueryRequired sub query required
ErrSubQueryRequired = errors.New("sub query required")
// ErrInvalidData unsupported data
ErrInvalidData = errors.New("unsupported data")
// ErrUnsupportedDriver unsupported driver
ErrUnsupportedDriver = errors.New("unsupported driver")
// ErrRegistered registered
ErrRegistered = errors.New("registered")
// ErrInvalidField invalid field
ErrInvalidField = errors.New("invalid field")
// ErrEmptySlice empty slice found
ErrEmptySlice = errors.New("empty slice found")
// ErrDryRunModeUnsupported dry run mode unsupported
ErrDryRunModeUnsupported = errors.New("dry run mode unsupported")
// ErrInvalidDB invalid db
ErrInvalidDB = errors.New("invalid db")
// ErrInvalidValue invalid value
ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice")
// ErrInvalidValueOfLength invalid values do not match length
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
// ErrPreloadNotAllowed preload is not allowed when count is used
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
// ErrDuplicatedKey occurs when there is a unique key constraint violation
ErrDuplicatedKey = errors.New("duplicated key not allowed")
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
// ErrCheckConstraintViolated occurs when there is a check constraint violation
ErrCheckConstraintViolated = errors.New("violates check constraint")
) )

View File

@ -1,183 +1,394 @@
package gorm package gorm
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt"
"hash/maphash"
"reflect" "reflect"
"strings" "strings"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
// Create insert the value into database // Create inserts value, returning the inserted data's primary key in value's id
func (db *DB) Create(value interface{}) (tx *DB) { func (db *DB) Create(value interface{}) (tx *DB) {
if db.CreateBatchSize > 0 {
return db.CreateInBatches(value, db.CreateBatchSize)
}
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = value tx.Statement.Dest = value
tx.callbacks.Create().Execute(tx) return tx.callbacks.Create().Execute(tx)
}
// CreateInBatches inserts value in batches of batchSize
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
var rowsAffected int64
tx = db.getInstance()
// the reflection length judgment of the optimized value
reflectLen := reflectValue.Len()
callFc := func(tx *DB) error {
for i := 0; i < reflectLen; i += batchSize {
ends := i + batchSize
if ends > reflectLen {
ends = reflectLen
}
subtx := tx.getInstance()
subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface()
subtx.callbacks.Create().Execute(subtx)
if subtx.Error != nil {
return subtx.Error
}
rowsAffected += subtx.RowsAffected
}
return nil
}
if tx.SkipDefaultTransaction || reflectLen <= batchSize {
tx.AddError(callFc(tx.Session(&Session{})))
} else {
tx.AddError(tx.Transaction(callFc))
}
tx.RowsAffected = rowsAffected
default:
tx = db.getInstance()
tx.Statement.Dest = value
tx = tx.callbacks.Create().Execute(tx)
}
return return
} }
// Save update value in database, if the value doesn't have primary key, will insert it // Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
func (db *DB) Save(value interface{}) (tx *DB) { func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = value tx.Statement.Dest = value
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))}
reflectValue := reflect.Indirect(reflect.ValueOf(value)) reflectValue := reflect.Indirect(reflect.ValueOf(value))
for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface {
reflectValue = reflect.Indirect(reflectValue)
}
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
tx.AddError(ErrPtrStructSupported) if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
}
tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true))
case reflect.Struct: case reflect.Struct:
for idx, pf := range tx.Statement.Schema.PrimaryFields { if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
if pv, isZero := pf.ValueOf(reflectValue); isZero { for _, pf := range tx.Statement.Schema.PrimaryFields {
tx.callbacks.Create().Execute(tx) if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero {
where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} return tx.callbacks.Create().Execute(tx)
return
} }
} }
} }
tx.Statement.AddClause(where) fallthrough
} default:
selectedUpdate := len(tx.Statement.Selects) != 0
if len(tx.Statement.Selects) == 0 { // when updating, use all fields including those zero-value fields
if !selectedUpdate {
tx.Statement.Selects = append(tx.Statement.Selects, "*") tx.Statement.Selects = append(tx.Statement.Selects, "*")
} }
tx.callbacks.Update().Execute(tx)
updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true}))
if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate {
return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
}
return updateTx
}
return return
} }
// First find first record that match given conditions, order by primary key // First finds the first record ordered by primary key, matching given conditions conds
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
}) })
if len(conds) > 0 { if len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
} }
tx.Statement.RaiseErrorOnNotFound = true tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
return
} }
// Take return a record that match given conditions, the order will depend on the database implementation // Take finds the first record returned by the database in no specified order, matching given conditions conds
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance().Limit(1) tx = db.Limit(1)
if len(conds) > 0 { if len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
} }
tx.Statement.RaiseErrorOnNotFound = true tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
return
} }
// Last find last record that match given conditions, order by primary key // Last finds the last record ordered by primary key, matching given conditions conds
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
Desc: true, Desc: true,
}) })
if len(conds) > 0 { if len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
} }
tx.Statement.RaiseErrorOnNotFound = true tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
return
} }
// Find find records that match given conditions // Find finds all records matching given conditions conds
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(conds) > 0 { if len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
} }
tx.Statement.Dest = dest tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
return
} }
func (tx *DB) assignExprsToValue(exprs []clause.Expression) { // FindInBatches finds all records in batches of batchSize
for _, expr := range exprs { func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
var (
tx = db.Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
}).Session(&Session{})
queryDB = tx
rowsAffected int64
batch int
)
// user specified offset or limit
var totalSize int
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
if limit, ok := c.Expression.(clause.Limit); ok {
if limit.Limit != nil {
totalSize = *limit.Limit
}
if totalSize > 0 && batchSize > totalSize {
batchSize = totalSize
}
// reset to offset to 0 in next batch
tx = tx.Offset(-1).Session(&Session{})
}
}
for {
result := queryDB.Limit(batchSize).Find(dest)
rowsAffected += result.RowsAffected
batch++
if result.Error == nil && result.RowsAffected != 0 {
fcTx := result.Session(&Session{NewDB: true})
fcTx.RowsAffected = result.RowsAffected
tx.AddError(fc(fcTx, batch))
} else if result.Error != nil {
tx.AddError(result.Error)
}
if tx.Error != nil || int(result.RowsAffected) < batchSize {
break
}
if totalSize > 0 {
if totalSize <= int(rowsAffected) {
break
}
if totalSize/batchSize == batch {
batchSize = totalSize % batchSize
}
}
// Optimize for-break
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
if result.Statement.Schema.PrioritizedPrimaryField == nil {
tx.AddError(ErrPrimaryKeyRequired)
break
}
primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
if zero {
tx.AddError(ErrPrimaryKeyRequired)
break
}
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
}
tx.RowsAffected = rowsAffected
return tx
}
func (db *DB) assignInterfacesToValue(values ...interface{}) {
for _, value := range values {
switch v := value.(type) {
case []clause.Expression:
for _, expr := range v {
if eq, ok := expr.(clause.Eq); ok { if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) { switch column := eq.Column.(type) {
case string: case string:
if field := tx.Statement.Schema.LookUpField(column); field != nil { if field := db.Statement.Schema.LookUpField(column); field != nil {
field.Set(tx.Statement.ReflectValue, eq.Value) db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
} }
case clause.Column: case clause.Column:
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { if field := db.Statement.Schema.LookUpField(column.Name); field != nil {
field.Set(tx.Statement.ReflectValue, eq.Value) db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
}
}
} else if andCond, ok := expr.(clause.AndConditions); ok {
db.assignInterfacesToValue(andCond.Exprs)
}
}
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 {
db.assignInterfacesToValue(exprs)
} }
default: default:
if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Struct:
for _, f := range s.Fields {
if f.Readable {
if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero {
if field := db.Statement.Schema.LookUpField(f.Name); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v))
}
}
}
}
}
} else if len(values) > 0 {
if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
db.assignInterfacesToValue(exprs)
}
return
} }
} }
} }
} }
// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
// Each conds must be a struct or map.
//
// FirstOrInit never modifies the database. It is often used with Assign and Attrs.
//
// // assign an email if the record is not found
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign email regardless of if record is found
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() queryTx := db.Limit(1).Order(clause.OrderByColumn{
if tx = tx.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
if c, ok := tx.Statement.Clauses["WHERE"]; ok { if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok { if where, ok := c.Expression.(clause.Where); ok {
tx.assignExprsToValue(where.Exprs) tx.assignInterfacesToValue(where.Exprs)
} }
} }
// initialize with attrs, conds // initialize with attrs, conds
if len(tx.Statement.attrs) > 0 { if len(tx.Statement.attrs) > 0 {
exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) tx.assignInterfacesToValue(tx.Statement.attrs...)
tx.assignExprsToValue(exprs)
} }
tx.Error = nil
} }
// initialize with attrs, conds // initialize with attrs, conds
if len(tx.Statement.assigns) > 0 { if len(tx.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) tx.assignInterfacesToValue(tx.Statement.assigns...)
tx.assignExprsToValue(exprs)
} }
return return
} }
// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
// Each conds must be a struct or map.
//
// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists.
//
// // assign an email if the record is not found
// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
// // result.RowsAffected -> 1
//
// // assign email regardless of if record is found
// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
// // result.RowsAffected -> 1
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if err := tx.First(dest, conds...).Error; errors.Is(err, ErrRecordNotFound) { queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
tx.Error = nil Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if c, ok := tx.Statement.Clauses["WHERE"]; ok { result := queryTx.Find(dest, conds...)
if result.Error != nil {
tx.Error = result.Error
return tx
}
if result.RowsAffected == 0 {
if c, ok := result.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok { if where, ok := c.Expression.(clause.Where); ok {
tx.assignExprsToValue(where.Exprs) result.assignInterfacesToValue(where.Exprs)
} }
} }
// initialize with attrs, conds // initialize with attrs, conds
if len(tx.Statement.attrs) > 0 { if len(db.Statement.attrs) > 0 {
exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) result.assignInterfacesToValue(db.Statement.attrs...)
tx.assignExprsToValue(exprs)
} }
// initialize with attrs, conds // initialize with attrs, conds
if len(tx.Statement.assigns) > 0 { if len(db.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) result.assignInterfacesToValue(db.Statement.assigns...)
tx.assignExprsToValue(exprs)
} }
return tx.Create(dest) return tx.Create(dest)
} else if len(tx.Statement.assigns) > 0 { } else if len(db.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]) exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
assigns := map[string]interface{}{} assigns := map[string]interface{}{}
for _, expr := range exprs { for i := 0; i < len(exprs); i++ {
if eq, ok := expr.(clause.Eq); ok { expr := exprs[i]
if eq, ok := expr.(clause.AndConditions); ok {
exprs = append(exprs, eq.Exprs...)
} else if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) { switch column := eq.Column.(type) {
case string: case string:
assigns[column] = eq.Value assigns[column] = eq.Value
case clause.Column: case clause.Column:
assigns[column.Name] = eq.Value assigns[column.Name] = eq.Value
default:
} }
} }
} }
@ -185,113 +396,254 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.Model(dest).Updates(assigns) return tx.Model(dest).Updates(assigns)
} }
return return tx
} }
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update // Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Update(column string, value interface{}) (tx *DB) { func (db *DB) Update(column string, value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value} tx.Statement.Dest = map[string]interface{}{column: value}
tx.callbacks.Update().Execute(tx) return tx.callbacks.Update().Execute(tx)
return
} }
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update // Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Updates(values interface{}) (tx *DB) { func (db *DB) Updates(values interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = values tx.Statement.Dest = values
tx.callbacks.Update().Execute(tx) return tx.callbacks.Update().Execute(tx)
return
} }
func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value} tx.Statement.Dest = map[string]interface{}{column: value}
tx.Statement.DisableUpdateTime = true tx.Statement.SkipHooks = true
tx.callbacks.Update().Execute(tx) return tx.callbacks.Update().Execute(tx)
return
} }
func (db *DB) UpdateColumns(values interface{}) (tx *DB) { func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = values tx.Statement.Dest = values
tx.Statement.DisableUpdateTime = true tx.Statement.SkipHooks = true
tx.callbacks.Update().Execute(tx) return tx.callbacks.Update().Execute(tx)
return
} }
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition // Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If
// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
// time if null.
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(conds) > 0 { if len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
} }
tx.Statement.Dest = value tx.Statement.Dest = value
tx.callbacks.Delete().Execute(tx) return tx.callbacks.Delete().Execute(tx)
return
} }
func (db *DB) Count(count *int64) (tx *DB) { func (db *DB) Count(count *int64) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if s, ok := tx.Statement.Clauses["SELECT"].Expression.(clause.Select); !ok || len(s.Columns) == 0 {
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
}
if tx.Statement.Model == nil { if tx.Statement.Model == nil {
tx.Statement.Model = tx.Statement.Dest tx.Statement.Model = tx.Statement.Dest
defer func() {
tx.Statement.Model = nil
}()
} }
if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
defer func() {
tx.Statement.Clauses["SELECT"] = selectClause
}()
} else {
defer delete(tx.Statement.Clauses, "SELECT")
}
if len(tx.Statement.Selects) == 0 {
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}})
} else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") {
expr := clause.Expr{SQL: "count(*)"}
if len(tx.Statement.Selects) == 1 {
dbName := tx.Statement.Selects[0]
fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) {
if tx.Statement.Parse(tx.Statement.Model) == nil {
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
dbName = f.DBName
}
}
if tx.Statement.Distinct {
expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
} else if dbName != "*" {
expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
}
}
}
tx.Statement.AddClause(clause.Select{Expression: expr})
}
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
delete(tx.Statement.Clauses, "ORDER BY")
defer func() {
tx.Statement.Clauses["ORDER BY"] = orderByClause
}()
}
}
tx.Statement.Dest = count tx.Statement.Dest = count
tx.callbacks.Query().Execute(tx) tx = tx.callbacks.Query().Execute(tx)
if db.RowsAffected != 1 {
*count = db.RowsAffected if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 {
*count = tx.RowsAffected
} }
return return
} }
func (db *DB) Row() *sql.Row { func (db *DB) Row() *sql.Row {
tx := db.getInstance() tx := db.getInstance().Set("rows", false)
tx.callbacks.Row().Execute(tx) tx = tx.callbacks.Row().Execute(tx)
return tx.Statement.Dest.(*sql.Row) row, ok := tx.Statement.Dest.(*sql.Row)
if !ok && tx.DryRun {
db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
}
return row
} }
func (db *DB) Rows() (*sql.Rows, error) { func (db *DB) Rows() (*sql.Rows, error) {
tx := db.Set("rows", true) tx := db.getInstance().Set("rows", true)
tx.callbacks.Row().Execute(tx) tx = tx.callbacks.Row().Execute(tx)
return tx.Statement.Dest.(*sql.Rows), tx.Error rows, ok := tx.Statement.Dest.(*sql.Rows)
if !ok && tx.DryRun && tx.Error == nil {
tx.Error = ErrDryRunModeUnsupported
}
return rows, tx.Error
} }
// Scan scan value to a struct // Scan scans selected value to the struct dest
func (db *DB) Scan(dest interface{}) (tx *DB) { func (db *DB) Scan(dest interface{}) (tx *DB) {
config := *db.Config
currentLogger, newLogger := config.Logger, logger.Recorder.New()
config.Logger = newLogger
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = dest tx.Config = &config
tx.callbacks.Query().Execute(tx)
if rows, err := tx.Rows(); err == nil {
if rows.Next() {
tx.ScanRows(rows, dest)
} else {
tx.RowsAffected = 0
tx.AddError(rows.Err())
}
tx.AddError(rows.Close())
}
currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
return newLogger.SQL, tx.RowsAffected
}, tx.Error)
tx.Logger = currentLogger
return return
} }
// Pluck used to query single column from a model as a map // Pluck queries a single column from a model, returning in the slice dest. E.g.:
//
// var ages []int64 // var ages []int64
// db.Find(&users).Pluck("age", &ages) // db.Model(&users).Pluck("age", &ages)
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClauseIfNotExists(clause.Select{Columns: []clause.Column{{Name: column}}}) if tx.Statement.Model != nil {
if tx.Statement.Parse(tx.Statement.Model) == nil {
if f := tx.Statement.Schema.LookUpField(column); f != nil {
column = f.DBName
}
}
}
if len(tx.Statement.Selects) != 1 {
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
tx.Statement.AddClauseIfNotExists(clause.Select{
Distinct: tx.Statement.Distinct,
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
})
}
tx.Statement.Dest = dest tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
return
} }
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
tx := db.getInstance() tx := db.getInstance()
tx.Error = tx.Statement.Parse(dest) if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
tx.AddError(err)
}
tx.Statement.Dest = dest tx.Statement.Dest = dest
tx.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(dest)) tx.Statement.ReflectValue = reflect.ValueOf(dest)
Scan(rows, tx, true) for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
elem := tx.Statement.ReflectValue.Elem()
if !elem.IsValid() {
elem = reflect.New(tx.Statement.ReflectValue.Type().Elem())
tx.Statement.ReflectValue.Set(elem)
}
tx.Statement.ReflectValue = elem
}
Scan(rows, tx, ScanInitialized)
return tx.Error return tx.Error
} }
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. // Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is
// returned to the connection pool.
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
if db.Error != nil {
return db.Error
}
tx := db.getInstance()
sqlDB, err := tx.DB()
if err != nil {
return
}
conn, err := sqlDB.Conn(tx.Statement.Context)
if err != nil {
return
}
defer conn.Close()
tx.Statement.ConnPool = conn
return fc(tx)
}
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
// they are rolled back.
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true panicked := true
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// nested transaction
if !db.DisableNestedTransaction {
spID := new(maphash.Hash).Sum64()
err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error
if err != nil {
return
}
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
db.RollbackTo(fmt.Sprintf("sp%d", spID))
}
}()
}
err = fc(db.Session(&Session{NewDB: db.clone == 1}))
} else {
tx := db.Begin(opts...) tx := db.Begin(opts...)
if tx.Error != nil {
return tx.Error
}
defer func() { defer func() {
// Make sure to rollback when panic, Block error or Commit error // Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil { if panicked || err != nil {
@ -299,64 +651,130 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
} }
}() }()
err = fc(tx.Session(&Session{})) if err = fc(tx); err == nil {
panicked = false
if err == nil { return tx.Commit().Error
err = tx.Commit().Error }
} }
panicked = false panicked = false
return return
} }
// Begin begins a transaction // Begin begins a transaction with any transaction options opts
func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
tx = db.getInstance() var (
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { // clone statement
var opt *sql.TxOptions tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1})
var err error opt *sql.TxOptions
err error
)
if len(opts) > 0 { if len(opts) > 0 {
opt = opts[0] opt = opts[0]
} }
if tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt); err != nil { ctx := tx.Statement.Context
if _, ok := ctx.Deadline(); !ok {
if db.Config.DefaultTransactionTimeout > 0 {
ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout)
}
}
switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
default:
err = ErrInvalidTransaction
}
if err != nil {
tx.AddError(err) tx.AddError(err)
} }
} else {
tx.AddError(ErrInvalidTransaction) return tx
}
return
} }
// Commit commit a transaction // Commit commits the changes in a transaction
func (db *DB) Commit() *DB { func (db *DB) Commit() *DB {
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
db.AddError(comminter.Commit()) db.AddError(committer.Commit())
} else { } else {
db.AddError(ErrInvalidTransaction) db.AddError(ErrInvalidTransaction)
} }
return db return db
} }
// Rollback rollback a transaction // Rollback rollbacks the changes in a transaction
func (db *DB) Rollback() *DB { func (db *DB) Rollback() *DB {
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
db.AddError(comminter.Rollback()) if !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Rollback())
}
} else { } else {
db.AddError(ErrInvalidTransaction) db.AddError(ErrInvalidTransaction)
} }
return db return db
} }
// Exec execute raw sql func (db *DB) SavePoint(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because SavePoint not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.SavePoint(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else {
db.AddError(ErrUnsupportedDriver)
}
return db
}
func (db *DB) RollbackTo(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because RollbackTo not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.RollbackTo(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else {
db.AddError(ErrUnsupportedDriver)
}
return db
}
// Exec executes raw sql
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.SQL = strings.Builder{} tx.Statement.SQL = strings.Builder{}
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
tx.callbacks.Raw().Execute(tx)
return
}
func (db *DB) RecordNotFound() bool { if strings.Contains(sql, "@") {
return errors.Is(db.Error, ErrRecordNotFound) clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
} else {
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
}
return tx.callbacks.Raw().Execute(tx)
} }

605
generics.go Normal file
View File

@ -0,0 +1,605 @@
package gorm
import (
"context"
"database/sql"
"fmt"
"sort"
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
)
type result struct {
Result sql.Result
RowsAffected int64
}
func (info *result) ModifyStatement(stmt *Statement) {
stmt.Result = info
}
// Build implements clause.Expression interface
func (result) Build(clause.Builder) {
}
func WithResult() *result {
return &result{}
}
type Interface[T any] interface {
Raw(sql string, values ...interface{}) ExecInterface[T]
Exec(ctx context.Context, sql string, values ...interface{}) error
CreateInterface[T]
}
type CreateInterface[T any] interface {
ChainInterface[T]
Table(name string, args ...interface{}) CreateInterface[T]
Create(ctx context.Context, r *T) error
CreateInBatches(ctx context.Context, r *[]T, batchSize int) error
}
type ChainInterface[T any] interface {
ExecInterface[T]
Scopes(scopes ...func(db *Statement)) ChainInterface[T]
Where(query interface{}, args ...interface{}) ChainInterface[T]
Not(query interface{}, args ...interface{}) ChainInterface[T]
Or(query interface{}, args ...interface{}) ChainInterface[T]
Limit(offset int) ChainInterface[T]
Offset(offset int) ChainInterface[T]
Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T]
Select(query string, args ...interface{}) ChainInterface[T]
Omit(columns ...string) ChainInterface[T]
MapColumns(m map[string]string) ChainInterface[T]
Distinct(args ...interface{}) ChainInterface[T]
Group(name string) ChainInterface[T]
Having(query interface{}, args ...interface{}) ChainInterface[T]
Order(value interface{}) ChainInterface[T]
Build(builder clause.Builder)
Delete(ctx context.Context) (rowsAffected int, err error)
Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
Updates(ctx context.Context, t T) (rowsAffected int, err error)
Count(ctx context.Context, column string) (result int64, err error)
}
type ExecInterface[T any] interface {
Scan(ctx context.Context, r interface{}) error
First(context.Context) (T, error)
Last(ctx context.Context) (T, error)
Take(context.Context) (T, error)
Find(ctx context.Context) ([]T, error)
FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error
Row(ctx context.Context) *sql.Row
Rows(ctx context.Context) (*sql.Rows, error)
}
type JoinBuilder interface {
Select(...string) JoinBuilder
Omit(...string) JoinBuilder
Where(query interface{}, args ...interface{}) JoinBuilder
Not(query interface{}, args ...interface{}) JoinBuilder
Or(query interface{}, args ...interface{}) JoinBuilder
}
type PreloadBuilder interface {
Select(...string) PreloadBuilder
Omit(...string) PreloadBuilder
Where(query interface{}, args ...interface{}) PreloadBuilder
Not(query interface{}, args ...interface{}) PreloadBuilder
Or(query interface{}, args ...interface{}) PreloadBuilder
Limit(offset int) PreloadBuilder
Offset(offset int) PreloadBuilder
Order(value interface{}) PreloadBuilder
LimitPerRecord(num int) PreloadBuilder
}
type op func(*DB) *DB
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
v := &g[T]{
db: db,
ops: make([]op, 0, 5),
}
if len(opts) > 0 {
v.ops = append(v.ops, func(db *DB) *DB {
return db.Clauses(opts...)
})
}
v.createG = &createG[T]{
chainG: chainG[T]{
execG: execG[T]{g: v},
},
}
return v
}
type g[T any] struct {
*createG[T]
db *DB
ops []op
}
func (g *g[T]) apply(ctx context.Context) *DB {
db := g.db
if !db.DryRun {
db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
}
for _, op := range g.ops {
db = op(db)
}
return db
}
func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] {
return execG[T]{g: &g[T]{
db: c.db,
ops: append(c.ops, func(db *DB) *DB {
return db.Raw(sql, values...)
}),
}}
}
func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error {
return c.apply(ctx).Exec(sql, values...).Error
}
type createG[T any] struct {
chainG[T]
}
func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] {
return createG[T]{c.with(func(db *DB) *DB {
return db.Table(name, args...)
})}
}
func (c createG[T]) Create(ctx context.Context, r *T) error {
return c.g.apply(ctx).Create(r).Error
}
func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error {
return c.g.apply(ctx).CreateInBatches(r, batchSize).Error
}
type chainG[T any] struct {
execG[T]
}
func (c chainG[T]) getInstance() *DB {
var r T
return c.g.apply(context.Background()).Model(r).getInstance()
}
func (c chainG[T]) with(v op) chainG[T] {
return chainG[T]{
execG: execG[T]{g: &g[T]{
db: c.g.db,
ops: append(append([]op(nil), c.g.ops...), v),
}},
}
}
func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] {
return c.with(func(db *DB) *DB {
for _, fc := range scopes {
fc(db.Statement)
}
return db
})
}
func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Table(name, args...)
})
}
func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Where(query, args...)
})
}
func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Not(query, args...)
})
}
func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Or(query, args...)
})
}
func (c chainG[T]) Limit(offset int) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Limit(offset)
})
}
func (c chainG[T]) Offset(offset int) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Offset(offset)
})
}
type joinBuilder struct {
db *DB
}
func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Select(columns ...string) JoinBuilder {
q.db.Select(columns)
return q
}
func (q *joinBuilder) Omit(columns ...string) JoinBuilder {
q.db.Omit(columns...)
return q
}
type preloadBuilder struct {
limitPerRecord int
db *DB
}
func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Select(columns ...string) PreloadBuilder {
q.db.Select(columns)
return q
}
func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder {
q.db.Omit(columns...)
return q
}
func (q *preloadBuilder) Limit(limit int) PreloadBuilder {
q.db.Limit(limit)
return q
}
func (q *preloadBuilder) Offset(offset int) PreloadBuilder {
q.db.Offset(offset)
return q
}
func (q *preloadBuilder) Order(value interface{}) PreloadBuilder {
q.db.Order(value)
return q
}
func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder {
q.limitPerRecord = num
return q
}
func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] {
return c.with(func(db *DB) *DB {
if jt.Table == "" {
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
}
q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)}
if on != nil {
if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
db.AddError(err)
}
}
j := join{
Name: jt.Association,
Alias: jt.Table,
Selects: q.db.Statement.Selects,
Omits: q.db.Statement.Omits,
JoinType: jt.Type,
}
if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
j.On = &where
}
if jt.Subquery != nil {
joinType := j.JoinType
if joinType == "" {
joinType = clause.LeftJoin
}
if db, ok := jt.Subquery.(interface{ getInstance() *DB }); ok {
stmt := db.getInstance().Statement
if len(j.Selects) == 0 {
j.Selects = stmt.Selects
}
if len(j.Omits) == 0 {
j.Omits = stmt.Omits
}
}
expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}}
if j.On != nil {
expr.SQL += " ON ?"
expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs})
}
j.Expression = expr
}
db.Statement.Joins = append(db.Statement.Joins, j)
sort.Slice(db.Statement.Joins, func(i, j int) bool {
return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name
})
return db
})
}
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Select(query, args...)
})
}
func (c chainG[T]) Omit(columns ...string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Omit(columns...)
})
}
func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.MapColumns(m)
})
}
func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Distinct(args...)
})
}
func (c chainG[T]) Group(name string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Group(name)
})
}
func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Having(query, args...)
})
}
func (c chainG[T]) Order(value interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Order(value)
})
}
func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Preload(association, func(tx *DB) *DB {
q := preloadBuilder{db: tx.getInstance()}
if query != nil {
if err := query(&q); err != nil {
db.AddError(err)
}
}
relation, ok := db.Statement.Schema.Relationships.Relations[association]
if !ok {
if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 {
relationships := db.Statement.Schema.Relationships
for _, field := range preloadFields {
var ok bool
relation, ok = relationships.Relations[field]
if ok {
relationships = relation.FieldSchema.Relationships
} else {
db.AddError(fmt.Errorf("relation %s not found", association))
return nil
}
}
} else {
db.AddError(fmt.Errorf("relation %s not found", association))
return nil
}
}
if q.limitPerRecord > 0 {
if relation.JoinTable != nil {
tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association))
return tx
}
refColumns := []clause.Column{}
for _, rel := range relation.References {
if rel.OwnPrimaryKey {
refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName})
}
}
if len(refColumns) != 0 {
selectExpr := clause.CommaExpression{}
for _, column := range q.db.Statement.Selects {
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}})
}
if len(selectExpr.Exprs) == 0 {
selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}}
}
partitionBy := clause.CommaExpression{}
for _, column := range refColumns {
partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}})
}
rnnColumn := clause.Column{Name: "gorm_preload_rnn"}
sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)"
vars := []interface{}{partitionBy}
if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok {
vars = append(vars, orderBy)
} else {
vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
}})
}
vars = append(vars, rnnColumn)
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars})
q.db.Clauses(clause.Select{Expression: selectExpr})
return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord)
}
}
return q.db
})
})
}
func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) {
r := new(T)
res := c.g.apply(ctx).Delete(r)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) {
var r T
res := c.g.apply(ctx).Model(r).Update(name, value)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) {
res := c.g.apply(ctx).Updates(t)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) {
var r T
err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error
return
}
func (c chainG[T]) Build(builder clause.Builder) {
subdb := c.getInstance()
subdb.Logger = logger.Discard
subdb.DryRun = true
if stmt, ok := builder.(*Statement); ok {
if subdb.Statement.SQL.Len() > 0 {
var (
vars = subdb.Statement.Vars
sql = subdb.Statement.SQL.String()
)
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{}
subdb.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1)
}
subdb.Statement.SQL.Reset()
subdb.Statement.Vars = stmt.Vars
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
} else {
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
}
} else {
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
subdb.callbacks.Query().Execute(subdb)
}
builder.WriteString(subdb.Statement.SQL.String())
stmt.Vars = subdb.Statement.Vars
}
}
type execG[T any] struct {
g *g[T]
}
func (g execG[T]) First(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).First(&r).Error
return r, err
}
func (g execG[T]) Scan(ctx context.Context, result interface{}) error {
var r T
err := g.g.apply(ctx).Model(r).Find(result).Error
return err
}
func (g execG[T]) Last(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).Last(&r).Error
return r, err
}
func (g execG[T]) Take(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).Take(&r).Error
return r, err
}
func (g execG[T]) Find(ctx context.Context) ([]T, error) {
var r []T
err := g.g.apply(ctx).Find(&r).Error
return r, err
}
func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error {
var data []T
return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error {
return fc(data, batch)
}).Error
}
func (g execG[T]) Row(ctx context.Context) *sql.Row {
return g.g.apply(ctx).Row()
}
func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) {
return g.g.apply(ctx).Rows()
}

5
go.mod
View File

@ -1,8 +1,9 @@
module gorm.io/gorm module gorm.io/gorm
go 1.14 go 1.18
require ( require (
github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.1.1 github.com/jinzhu/now v1.1.5
golang.org/x/text v0.20.0
) )

6
go.sum Normal file
View File

@ -0,0 +1,6 @@
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=

441
gorm.go
View File

@ -2,7 +2,10 @@ package gorm
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"reflect"
"sort"
"sync" "sync"
"time" "time"
@ -11,19 +14,51 @@ import (
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
// for Config.cacheStore store PreparedStmtDB key
const preparedStmtDBKey = "preparedStmt"
// Config GORM config // Config GORM config
type Config struct { type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can disable it by setting `SkipDefaultTransaction` to true // You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool SkipDefaultTransaction bool
DefaultTransactionTimeout time.Duration
// NamingStrategy tables, columns naming strategy // NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer NamingStrategy schema.Namer
// FullSaveAssociations full save associations
FullSaveAssociations bool
// Logger // Logger
Logger logger.Interface Logger logger.Interface
// NowFunc the function to be used when creating a new timestamp // NowFunc the function to be used when creating a new timestamp
NowFunc func() time.Time NowFunc func() time.Time
// DryRun generate sql without execute // DryRun generate sql without execute
DryRun bool DryRun bool
// PrepareStmt executes the given query in cached statement
PrepareStmt bool
// PrepareStmt cache support LRU expired,
// default maxsize=int64 Max value and ttl=1h
PrepareStmtMaxSize int
PrepareStmtTTL time.Duration
// DisableAutomaticPing
DisableAutomaticPing bool
// DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool
// IgnoreRelationshipsWhenMigrating
IgnoreRelationshipsWhenMigrating bool
// DisableNestedTransaction disable nested transaction
DisableNestedTransaction bool
// AllowGlobalUpdate allow global update
AllowGlobalUpdate bool
// QueryFields executes the SQL query with all fields of the table
QueryFields bool
// CreateBatchSize default create batch size
CreateBatchSize int
// TranslateError enabling error translation
TranslateError bool
// PropagateUnscoped propagate Unscoped to every other nested statement
PropagateUnscoped bool
// ClauseBuilders clause builder // ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder ClauseBuilders map[string]clause.ClauseBuilder
@ -31,11 +66,39 @@ type Config struct {
ConnPool ConnPool ConnPool ConnPool
// Dialector database dialector // Dialector database dialector
Dialector Dialector
// Plugins registered plugins
Plugins map[string]Plugin
callbacks *callbacks callbacks *callbacks
cacheStore *sync.Map cacheStore *sync.Map
} }
// Apply update config to new config
func (c *Config) Apply(config *Config) error {
if config != c {
*config = *c
}
return nil
}
// AfterInitialize initialize plugins after db connected
func (c *Config) AfterInitialize(db *DB) error {
if db != nil {
for _, plugin := range c.Plugins {
if err := plugin.Initialize(db); err != nil {
return err
}
}
}
return nil
}
// Option gorm option interface
type Option interface {
Apply(*Config) error
AfterInitialize(*DB) error
}
// DB GORM DB definition // DB GORM DB definition
type DB struct { type DB struct {
*Config *Config
@ -48,20 +111,65 @@ type DB struct {
// Session session config when create session with Session() method // Session session config when create session with Session() method
type Session struct { type Session struct {
DryRun bool DryRun bool
WithConditions bool PrepareStmt bool
NewDB bool
Initialized bool
SkipHooks bool
SkipDefaultTransaction bool
DisableNestedTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
PropagateUnscoped bool
QueryFields bool
Context context.Context Context context.Context
Logger logger.Interface Logger logger.Interface
NowFunc func() time.Time NowFunc func() time.Time
CreateBatchSize int
} }
// Open initialize db session based on dialector // Open initialize db session based on dialector
func Open(dialector Dialector, config *Config) (db *DB, err error) { func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
if config == nil { config := &Config{}
config = &Config{}
sort.Slice(opts, func(i, j int) bool {
_, isConfig := opts[i].(*Config)
_, isConfig2 := opts[j].(*Config)
return isConfig && !isConfig2
})
if len(opts) > 0 {
if c, ok := opts[0].(*Config); ok {
config = c
} else {
opts = append([]Option{config}, opts...)
}
}
var skipAfterInitialize bool
for _, opt := range opts {
if opt != nil {
if applyErr := opt.Apply(config); applyErr != nil {
return nil, applyErr
}
defer func(opt Option) {
if skipAfterInitialize {
return
}
if errr := opt.AfterInitialize(db); errr != nil {
err = errr
}
}(opt)
}
}
if d, ok := dialector.(interface{ Apply(*Config) error }); ok {
if err = d.Apply(config); err != nil {
return
}
} }
if config.NamingStrategy == nil { if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{} config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
} }
if config.Logger == nil { if config.Logger == nil {
@ -76,6 +184,10 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
config.Dialector = dialector config.Dialector = dialector
} }
if config.Plugins == nil {
config.Plugins = map[string]Plugin{}
}
if config.cacheStore == nil { if config.cacheStore == nil {
config.cacheStore = &sync.Map{} config.cacheStore = &sync.Map{}
} }
@ -88,9 +200,48 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
config.ClauseBuilders = map[string]clause.ClauseBuilder{} config.ClauseBuilders = map[string]clause.ClauseBuilder{}
} }
if dialector != nil { if config.Dialector != nil {
err = dialector.Initialize(db) err = config.Dialector.Initialize(db)
if err != nil {
if db, _ := db.DB(); db != nil {
_ = db.Close()
} }
// DB is not initialized, so we skip AfterInitialize
skipAfterInitialize = true
return
}
if config.TranslateError {
if _, ok := db.Dialector.(ErrorTranslator); !ok {
config.Logger.Warn(context.Background(), "The TranslateError option is enabled, but the Dialector %s does not implement ErrorTranslator.", db.Dialector.Name())
}
}
}
if config.PrepareStmt {
preparedStmt := NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
db.ConnPool = preparedStmt
}
db.Statement = &Statement{
DB: db,
ConnPool: db.ConnPool,
Context: context.Background(),
Clauses: map[string]clause.Clause{},
}
if err == nil && !config.DisableAutomaticPing {
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
err = pinger.Ping()
}
}
if err != nil {
config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
}
return return
} }
@ -101,33 +252,86 @@ func (db *DB) Session(config *Session) *DB {
tx = &DB{ tx = &DB{
Config: &txConfig, Config: &txConfig,
Statement: db.Statement, Statement: db.Statement,
Error: db.Error,
clone: 1, clone: 1,
} }
) )
if config.CreateBatchSize > 0 {
tx.Config.CreateBatchSize = config.CreateBatchSize
}
if config.Context != nil { if config.SkipDefaultTransaction {
if tx.Statement != nil { tx.Config.SkipDefaultTransaction = true
}
if config.AllowGlobalUpdate {
txConfig.AllowGlobalUpdate = true
}
if config.FullSaveAssociations {
txConfig.FullSaveAssociations = true
}
if config.PropagateUnscoped {
txConfig.PropagateUnscoped = true
}
if config.Context != nil || config.PrepareStmt || config.SkipHooks {
tx.Statement = tx.Statement.clone() tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx tx.Statement.DB = tx
} else {
tx.Statement = &Statement{
DB: tx,
Clauses: map[string]clause.Clause{},
ConnPool: tx.ConnPool,
}
} }
if config.Context != nil {
tx.Statement.Context = config.Context tx.Statement.Context = config.Context
} }
if config.WithConditions { if config.PrepareStmt {
tx.clone = 3 var preparedStmt *PreparedStmtDB
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
preparedStmt = v.(*PreparedStmtDB)
} else {
preparedStmt = NewPreparedStmtDB(db.ConnPool, db.PrepareStmtMaxSize, db.PrepareStmtTTL)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
}
switch t := tx.Statement.ConnPool.(type) {
case Tx:
tx.Statement.ConnPool = &PreparedStmtTX{
Tx: t,
PreparedStmtDB: preparedStmt,
}
default:
tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool,
Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts,
}
}
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
}
if config.SkipHooks {
tx.Statement.SkipHooks = true
}
if config.DisableNestedTransaction {
txConfig.DisableNestedTransaction = true
}
if !config.NewDB {
tx.clone = 2
} }
if config.DryRun { if config.DryRun {
tx.Config.DryRun = true tx.Config.DryRun = true
} }
if config.QueryFields {
tx.Config.QueryFields = true
}
if config.Logger != nil { if config.Logger != nil {
tx.Config.Logger = config.Logger tx.Config.Logger = config.Logger
} }
@ -136,18 +340,22 @@ func (db *DB) Session(config *Session) *DB {
tx.Config.NowFunc = config.NowFunc tx.Config.NowFunc = config.NowFunc
} }
if config.Initialized {
tx = tx.getInstance()
}
return tx return tx
} }
// WithContext change current instance db's context to ctx // WithContext change current instance db's context to ctx
func (db *DB) WithContext(ctx context.Context) *DB { func (db *DB) WithContext(ctx context.Context) *DB {
return db.Session(&Session{WithConditions: true, Context: ctx}) return db.Session(&Session{Context: ctx})
} }
// Debug start debug mode // Debug start debug mode
func (db *DB) Debug() (tx *DB) { func (db *DB) Debug() (tx *DB) {
return db.Session(&Session{ tx = db.getInstance()
WithConditions: true, return tx.Session(&Session{
Logger: db.Logger.LogMode(logger.Info), Logger: db.Logger.LogMode(logger.Info),
}) })
} }
@ -161,10 +369,7 @@ func (db *DB) Set(key string, value interface{}) *DB {
// Get get value with key from current db instance's context // Get get value with key from current db instance's context
func (db *DB) Get(key string) (interface{}, bool) { func (db *DB) Get(key string) (interface{}, bool) {
if db.Statement != nil {
return db.Statement.Settings.Load(key) return db.Statement.Settings.Load(key)
}
return nil, false
} }
// InstanceSet store value with key into current db instance's context // InstanceSet store value with key into current db instance's context
@ -176,47 +381,7 @@ func (db *DB) InstanceSet(key string, value interface{}) *DB {
// InstanceGet get value with key from current db instance's context // InstanceGet get value with key from current db instance's context
func (db *DB) InstanceGet(key string) (interface{}, bool) { func (db *DB) InstanceGet(key string) (interface{}, bool) {
if db.Statement != nil {
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
}
return nil, false
}
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
var (
tx = db.getInstance()
stmt = tx.Statement
modelSchema, joinSchema *schema.Schema
)
if err := stmt.Parse(model); err == nil {
modelSchema = stmt.Schema
} else {
return err
}
if err := stmt.Parse(joinTable); err == nil {
joinSchema = stmt.Schema
} else {
return err
}
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
for _, ref := range relation.References {
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
f.DataType = ref.ForeignKey.DataType
ref.ForeignKey = f
} else {
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
}
}
relation.JoinTable = joinSchema
} else {
return fmt.Errorf("failed to found relation: %v", field)
}
return nil
} }
// Callback returns callback manager // Callback returns callback manager
@ -224,50 +389,68 @@ func (db *DB) Callback() *callbacks {
return db.callbacks return db.callbacks
} }
// AutoMigrate run auto migration for given models
func (db *DB) AutoMigrate(dst ...interface{}) error {
return db.Migrator().AutoMigrate(dst...)
}
// AddError add error to db // AddError add error to db
func (db *DB) AddError(err error) error { func (db *DB) AddError(err error) error {
if err != nil {
if db.Config.TranslateError {
if errTranslator, ok := db.Dialector.(ErrorTranslator); ok {
err = errTranslator.Translate(err)
}
}
if db.Error == nil { if db.Error == nil {
db.Error = err db.Error = err
} else if err != nil { } else {
db.Error = fmt.Errorf("%v; %w", db.Error, err) db.Error = fmt.Errorf("%v; %w", db.Error, err)
} }
}
return db.Error return db.Error
} }
// DB returns `*sql.DB`
func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool
if db.Statement != nil && db.Statement.ConnPool != nil {
connPool = db.Statement.ConnPool
}
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
}
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil {
return sqldb, err
}
}
if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil {
return sqldb, nil
}
return nil, ErrInvalidDB
}
func (db *DB) getInstance() *DB { func (db *DB) getInstance() *DB {
if db.clone > 0 { if db.clone > 0 {
tx := &DB{Config: db.Config} tx := &DB{Config: db.Config, Error: db.Error}
switch db.clone { if db.clone == 1 {
case 1: // clone with new statement // clone with new statement
case 2: // with old statement, generate new statement for future call, used to pass to callbacks
db.clone = 1
tx.Statement = db.Statement
case 3: // with clone statement
if db.Statement != nil {
tx.Statement = db.Statement.clone()
tx.Statement.DB = tx
}
}
if tx.Statement == nil {
tx.Statement = &Statement{ tx.Statement = &Statement{
DB: tx, DB: tx,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{}, Clauses: map[string]clause.Clause{},
Vars: make([]interface{}, 0, 8),
SkipHooks: db.Statement.SkipHooks,
} }
if db.Config.PropagateUnscoped {
tx.Statement.Unscoped = db.Statement.Unscoped
} }
if db.Statement != nil {
tx.Statement.Context = db.Statement.Context
tx.Statement.ConnPool = db.Statement.ConnPool
} else { } else {
tx.Statement.Context = context.Background() // with clone statement
tx.Statement.ConnPool = db.ConnPool tx.Statement = db.Statement.clone()
tx.Statement.DB = tx
} }
return tx return tx
@ -276,6 +459,86 @@ func (db *DB) getInstance() *DB {
return db return db
} }
// Expr returns clause.Expr, which can be used to pass SQL expression as params
func Expr(expr string, args ...interface{}) clause.Expr { func Expr(expr string, args ...interface{}) clause.Expr {
return clause.Expr{SQL: expr, Vars: args} return clause.Expr{SQL: expr, Vars: args}
} }
// SetupJoinTable setup join table schema
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
var (
tx = db.getInstance()
stmt = tx.Statement
modelSchema, joinSchema *schema.Schema
)
err := stmt.Parse(model)
if err != nil {
return err
}
modelSchema = stmt.Schema
err = stmt.Parse(joinTable)
if err != nil {
return err
}
joinSchema = stmt.Schema
relation, ok := modelSchema.Relationships.Relations[field]
isRelation := ok && relation.JoinTable != nil
if !isRelation {
return fmt.Errorf("failed to find relation: %s", field)
}
for _, ref := range relation.References {
f := joinSchema.LookUpField(ref.ForeignKey.DBName)
if f == nil {
return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName)
}
f.DataType = ref.ForeignKey.DataType
f.GORMDataType = ref.ForeignKey.GORMDataType
if f.Size == 0 {
f.Size = ref.ForeignKey.Size
}
ref.ForeignKey = f
}
for name, rel := range relation.JoinTable.Relationships.Relations {
if _, ok := joinSchema.Relationships.Relations[name]; !ok {
rel.Schema = joinSchema
joinSchema.Relationships.Relations[name] = rel
}
}
relation.JoinTable = joinSchema
return nil
}
// Use use plugin
func (db *DB) Use(plugin Plugin) error {
name := plugin.Name()
if _, ok := db.Plugins[name]; ok {
return ErrRegistered
}
if err := plugin.Initialize(db); err != nil {
return err
}
db.Plugins[name] = plugin
return nil
}
// ToSQL for generate SQL string.
//
// db.ToSQL(func(tx *gorm.DB) *gorm.DB {
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
// .Limit(10).Offset(5)
// .Order("name ASC")
// .First(&User{})
// })
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance())
stmt := tx.Statement
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
}

View File

@ -14,60 +14,79 @@ type Dialector interface {
Initialize(*DB) error Initialize(*DB) error
Migrator(db *DB) Migrator Migrator(db *DB) Migrator
DataTypeOf(*schema.Field) string DataTypeOf(*schema.Field) string
DefaultValueOf(*schema.Field) clause.Expression
BindVarTo(writer clause.Writer, stmt *Statement, v interface{}) BindVarTo(writer clause.Writer, stmt *Statement, v interface{})
QuoteTo(clause.Writer, string) QuoteTo(clause.Writer, string)
Explain(sql string, vars ...interface{}) string Explain(sql string, vars ...interface{}) string
} }
// Plugin GORM plugin interface
type Plugin interface {
Name() string
Initialize(*DB) error
}
type ParamsFilter interface {
ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{})
}
// ConnPool db conns pool interface // ConnPool db conns pool interface
type ConnPool interface { type ConnPool interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
} }
// SavePointerDialectorInterface save pointer interface
type SavePointerDialectorInterface interface {
SavePoint(tx *DB, name string) error
RollbackTo(tx *DB, name string) error
}
// TxBeginner tx beginner
type TxBeginner interface { type TxBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
} }
type TxCommiter interface { // ConnPoolBeginner conn pool beginner
type ConnPoolBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
}
// TxCommitter tx committer
type TxCommitter interface {
Commit() error Commit() error
Rollback() error Rollback() error
} }
type BeforeCreateInterface interface { // Tx sql.Tx interface
BeforeCreate(*DB) error type Tx interface {
ConnPool
TxCommitter
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
} }
type AfterCreateInterface interface { // Valuer gorm valuer interface
AfterCreate(*DB) error type Valuer interface {
GormValue(context.Context, *DB) clause.Expr
} }
type BeforeUpdateInterface interface { // GetDBConnector SQL db connector
BeforeUpdate(*DB) error type GetDBConnector interface {
GetDBConn() (*sql.DB, error)
} }
type AfterUpdateInterface interface { // Rows rows interface
AfterUpdate(*DB) error type Rows interface {
Columns() ([]string, error)
ColumnTypes() ([]*sql.ColumnType, error)
Next() bool
Scan(dest ...interface{}) error
Err() error
Close() error
} }
type BeforeSaveInterface interface { type ErrorTranslator interface {
BeforeSave(*DB) error Translate(err error) error
}
type AfterSaveInterface interface {
AfterSave(*DB) error
}
type BeforeDeleteInterface interface {
BeforeDelete(*DB) error
}
type AfterDeleteInterface interface {
AfterDelete(*DB) error
}
type AfterFindInterface interface {
AfterFind(*DB) error
} }

493
internal/lru/lru.go Normal file
View File

@ -0,0 +1,493 @@
package lru
// golang -lru
// https://github.com/hashicorp/golang-lru
import (
"sync"
"time"
)
// EvictCallback is used to get a callback when a cache entry is evicted
type EvictCallback[K comparable, V any] func(key K, value V)
// LRU implements a thread-safe LRU with expirable entries.
type LRU[K comparable, V any] struct {
size int
evictList *LruList[K, V]
items map[K]*Entry[K, V]
onEvict EvictCallback[K, V]
// expirable options
mu sync.Mutex
ttl time.Duration
done chan struct{}
// buckets for expiration
buckets []bucket[K, V]
// uint8 because it's number between 0 and numBuckets
nextCleanupBucket uint8
}
// bucket is a container for holding entries to be expired
type bucket[K comparable, V any] struct {
entries map[K]*Entry[K, V]
newestEntry time.Time
}
// noEvictionTTL - very long ttl to prevent eviction
const noEvictionTTL = time.Hour * 24 * 365 * 10
// because of uint8 usage for nextCleanupBucket, should not exceed 256.
// casting it as uint8 explicitly requires type conversions in multiple places
const numBuckets = 100
// NewLRU returns a new thread-safe cache with expirable entries.
//
// Size parameter set to 0 makes cache of unlimited size, e.g. turns LRU mechanism off.
//
// Providing 0 TTL turns expiring off.
//
// Delete expired entries every 1/100th of ttl value. Goroutine which deletes expired entries runs indefinitely.
func NewLRU[K comparable, V any](size int, onEvict EvictCallback[K, V], ttl time.Duration) *LRU[K, V] {
if size < 0 {
size = 0
}
if ttl <= 0 {
ttl = noEvictionTTL
}
res := LRU[K, V]{
ttl: ttl,
size: size,
evictList: NewList[K, V](),
items: make(map[K]*Entry[K, V]),
onEvict: onEvict,
done: make(chan struct{}),
}
// initialize the buckets
res.buckets = make([]bucket[K, V], numBuckets)
for i := 0; i < numBuckets; i++ {
res.buckets[i] = bucket[K, V]{entries: make(map[K]*Entry[K, V])}
}
// enable deleteExpired() running in separate goroutine for cache with non-zero TTL
//
// Important: done channel is never closed, so deleteExpired() goroutine will never exit,
// it's decided to add functionality to close it in the version later than v2.
if res.ttl != noEvictionTTL {
go func(done <-chan struct{}) {
ticker := time.NewTicker(res.ttl / numBuckets)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
res.deleteExpired()
}
}
}(res.done)
}
return &res
}
// Purge clears the cache completely.
// onEvict is called for each evicted key.
func (c *LRU[K, V]) Purge() {
c.mu.Lock()
defer c.mu.Unlock()
for k, v := range c.items {
if c.onEvict != nil {
c.onEvict(k, v.Value)
}
delete(c.items, k)
}
for _, b := range c.buckets {
for _, ent := range b.entries {
delete(b.entries, ent.Key)
}
}
c.evictList.Init()
}
// Add adds a value to the cache. Returns true if an eviction occurred.
// Returns false if there was no eviction: the item was already in the cache,
// or the size was not exceeded.
func (c *LRU[K, V]) Add(key K, value V) (evicted bool) {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
// Check for existing item
if ent, ok := c.items[key]; ok {
c.evictList.MoveToFront(ent)
c.removeFromBucket(ent) // remove the entry from its current bucket as expiresAt is renewed
ent.Value = value
ent.ExpiresAt = now.Add(c.ttl)
c.addToBucket(ent)
return false
}
// Add new item
ent := c.evictList.PushFrontExpirable(key, value, now.Add(c.ttl))
c.items[key] = ent
c.addToBucket(ent) // adds the entry to the appropriate bucket and sets entry.expireBucket
evict := c.size > 0 && c.evictList.Length() > c.size
// Verify size not exceeded
if evict {
c.removeOldest()
}
return evict
}
// Get looks up a key's value from the cache.
func (c *LRU[K, V]) Get(key K) (value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
var ent *Entry[K, V]
if ent, ok = c.items[key]; ok {
// Expired item check
if time.Now().After(ent.ExpiresAt) {
return value, false
}
c.evictList.MoveToFront(ent)
return ent.Value, true
}
return
}
// Contains checks if a key is in the cache, without updating the recent-ness
// or deleting it for being stale.
func (c *LRU[K, V]) Contains(key K) (ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
_, ok = c.items[key]
return ok
}
// Peek returns the key value (or undefined if not found) without updating
// the "recently used"-ness of the key.
func (c *LRU[K, V]) Peek(key K) (value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
var ent *Entry[K, V]
if ent, ok = c.items[key]; ok {
// Expired item check
if time.Now().After(ent.ExpiresAt) {
return value, false
}
return ent.Value, true
}
return
}
// Remove removes the provided key from the cache, returning if the
// key was contained.
func (c *LRU[K, V]) Remove(key K) bool {
c.mu.Lock()
defer c.mu.Unlock()
if ent, ok := c.items[key]; ok {
c.removeElement(ent)
return true
}
return false
}
// RemoveOldest removes the oldest item from the cache.
func (c *LRU[K, V]) RemoveOldest() (key K, value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.evictList.Back(); ent != nil {
c.removeElement(ent)
return ent.Key, ent.Value, true
}
return
}
// GetOldest returns the oldest entry
func (c *LRU[K, V]) GetOldest() (key K, value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.evictList.Back(); ent != nil {
return ent.Key, ent.Value, true
}
return
}
func (c *LRU[K, V]) KeyValues() map[K]V {
c.mu.Lock()
defer c.mu.Unlock()
maps := make(map[K]V)
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
maps[ent.Key] = ent.Value
// keys = append(keys, ent.Key)
}
return maps
}
// Keys returns a slice of the keys in the cache, from oldest to newest.
// Expired entries are filtered out.
func (c *LRU[K, V]) Keys() []K {
c.mu.Lock()
defer c.mu.Unlock()
keys := make([]K, 0, len(c.items))
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
keys = append(keys, ent.Key)
}
return keys
}
// Values returns a slice of the values in the cache, from oldest to newest.
// Expired entries are filtered out.
func (c *LRU[K, V]) Values() []V {
c.mu.Lock()
defer c.mu.Unlock()
values := make([]V, 0, len(c.items))
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
values = append(values, ent.Value)
}
return values
}
// Len returns the number of items in the cache.
func (c *LRU[K, V]) Len() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.evictList.Length()
}
// Resize changes the cache size. Size of 0 means unlimited.
func (c *LRU[K, V]) Resize(size int) (evicted int) {
c.mu.Lock()
defer c.mu.Unlock()
if size <= 0 {
c.size = 0
return 0
}
diff := c.evictList.Length() - size
if diff < 0 {
diff = 0
}
for i := 0; i < diff; i++ {
c.removeOldest()
}
c.size = size
return diff
}
// Close destroys cleanup goroutine. To clean up the cache, run Purge() before Close().
// func (c *LRU[K, V]) Close() {
// c.mu.Lock()
// defer c.mu.Unlock()
// select {
// case <-c.done:
// return
// default:
// }
// close(c.done)
// }
// removeOldest removes the oldest item from the cache. Has to be called with lock!
func (c *LRU[K, V]) removeOldest() {
if ent := c.evictList.Back(); ent != nil {
c.removeElement(ent)
}
}
// removeElement is used to remove a given list element from the cache. Has to be called with lock!
func (c *LRU[K, V]) removeElement(e *Entry[K, V]) {
c.evictList.Remove(e)
delete(c.items, e.Key)
c.removeFromBucket(e)
if c.onEvict != nil {
c.onEvict(e.Key, e.Value)
}
}
// deleteExpired deletes expired records from the oldest bucket, waiting for the newest entry
// in it to expire first.
func (c *LRU[K, V]) deleteExpired() {
c.mu.Lock()
bucketIdx := c.nextCleanupBucket
timeToExpire := time.Until(c.buckets[bucketIdx].newestEntry)
// wait for newest entry to expire before cleanup without holding lock
if timeToExpire > 0 {
c.mu.Unlock()
time.Sleep(timeToExpire)
c.mu.Lock()
}
for _, ent := range c.buckets[bucketIdx].entries {
c.removeElement(ent)
}
c.nextCleanupBucket = (c.nextCleanupBucket + 1) % numBuckets
c.mu.Unlock()
}
// addToBucket adds entry to expire bucket so that it will be cleaned up when the time comes. Has to be called with lock!
func (c *LRU[K, V]) addToBucket(e *Entry[K, V]) {
bucketID := (numBuckets + c.nextCleanupBucket - 1) % numBuckets
e.ExpireBucket = bucketID
c.buckets[bucketID].entries[e.Key] = e
if c.buckets[bucketID].newestEntry.Before(e.ExpiresAt) {
c.buckets[bucketID].newestEntry = e.ExpiresAt
}
}
// removeFromBucket removes the entry from its corresponding bucket. Has to be called with lock!
func (c *LRU[K, V]) removeFromBucket(e *Entry[K, V]) {
delete(c.buckets[e.ExpireBucket].entries, e.Key)
}
// Cap returns the capacity of the cache
func (c *LRU[K, V]) Cap() int {
return c.size
}
// Entry is an LRU Entry
type Entry[K comparable, V any] struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *Entry[K, V]
// The list to which this element belongs.
list *LruList[K, V]
// The LRU Key of this element.
Key K
// The Value stored with this element.
Value V
// The time this element would be cleaned up, optional
ExpiresAt time.Time
// The expiry bucket item was put in, optional
ExpireBucket uint8
}
// PrevEntry returns the previous list element or nil.
func (e *Entry[K, V]) PrevEntry() *Entry[K, V] {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// LruList represents a doubly linked list.
// The zero Value for LruList is an empty list ready to use.
type LruList[K comparable, V any] struct {
root Entry[K, V] // sentinel list element, only &root, root.prev, and root.next are used
len int // current list Length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *LruList[K, V]) Init() *LruList[K, V] {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// NewList returns an initialized list.
func NewList[K comparable, V any]() *LruList[K, V] { return new(LruList[K, V]).Init() }
// Length returns the number of elements of list l.
// The complexity is O(1).
func (l *LruList[K, V]) Length() int { return l.len }
// Back returns the last element of list l or nil if the list is empty.
func (l *LruList[K, V]) Back() *Entry[K, V] {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List Value.
func (l *LruList[K, V]) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *LruList[K, V]) insert(e, at *Entry[K, V]) *Entry[K, V] {
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Entry{Value: v, ExpiresAt: ExpiresAt}, at).
func (l *LruList[K, V]) insertValue(k K, v V, expiresAt time.Time, at *Entry[K, V]) *Entry[K, V] {
return l.insert(&Entry[K, V]{Value: v, Key: k, ExpiresAt: expiresAt}, at)
}
// Remove removes e from its list, decrements l.len
func (l *LruList[K, V]) Remove(e *Entry[K, V]) V {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
return e.Value
}
// move moves e to next to at.
func (l *LruList[K, V]) move(e, at *Entry[K, V]) {
if e == at {
return
}
e.prev.next = e.next
e.next.prev = e.prev
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *LruList[K, V]) PushFront(k K, v V) *Entry[K, V] {
l.lazyInit()
return l.insertValue(k, v, time.Time{}, &l.root)
}
// PushFrontExpirable inserts a new expirable element e with Value v at the front of list l and returns e.
func (l *LruList[K, V]) PushFrontExpirable(k K, v V, expiresAt time.Time) *Entry[K, V] {
l.lazyInit()
return l.insertValue(k, v, expiresAt, &l.root)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *LruList[K, V]) MoveToFront(e *Entry[K, V]) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.move(e, &l.root)
}

View File

@ -0,0 +1,183 @@
package stmt_store
import (
"context"
"database/sql"
"math"
"sync"
"time"
"gorm.io/gorm/internal/lru"
)
type Stmt struct {
*sql.Stmt
Transaction bool
prepared chan struct{}
prepareErr error
}
func (stmt *Stmt) Error() error {
return stmt.prepareErr
}
func (stmt *Stmt) Close() error {
<-stmt.prepared
if stmt.Stmt != nil {
return stmt.Stmt.Close()
}
return nil
}
// Store defines an interface for managing the caching operations of SQL statements (Stmt).
// This interface provides methods for creating new statements, retrieving all cache keys,
// getting cached statements, setting cached statements, and deleting cached statements.
type Store interface {
// New creates a new Stmt object and caches it.
// Parameters:
// ctx: The context for the request, which can carry deadlines, cancellation signals, etc.
// key: The key representing the SQL query, used for caching and preparing the statement.
// isTransaction: Indicates whether this operation is part of a transaction, which may affect the caching strategy.
// connPool: A connection pool that provides database connections.
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
// Returns:
// *Stmt: A newly created statement object for executing SQL operations.
// error: An error if the statement preparation fails.
New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error)
// Keys returns a slice of all cache keys in the store.
Keys() []string
// Get retrieves a Stmt object from the store based on the given key.
// Parameters:
// key: The key used to look up the Stmt object.
// Returns:
// *Stmt: The found Stmt object, or nil if not found.
// bool: Indicates whether the corresponding Stmt object was successfully found.
Get(key string) (*Stmt, bool)
// Set stores the given Stmt object in the store and associates it with the specified key.
// Parameters:
// key: The key used to associate the Stmt object.
// value: The Stmt object to be stored.
Set(key string, value *Stmt)
// Delete removes the Stmt object corresponding to the specified key from the store.
// Parameters:
// key: The key associated with the Stmt object to be deleted.
Delete(key string)
}
// defaultMaxSize defines the default maximum capacity of the cache.
// Its value is the maximum value of the int64 type, which means that when the cache size is not specified,
// the cache can theoretically store as many elements as possible.
// (1 << 63) - 1 is the maximum value that an int64 type can represent.
const (
defaultMaxSize = math.MaxInt
// defaultTTL defines the default time-to-live (TTL) for each cache entry.
// When the TTL for cache entries is not specified, each cache entry will expire after 24 hours.
defaultTTL = time.Hour * 24
)
// New creates and returns a new Store instance.
//
// Parameters:
// - size: The maximum capacity of the cache. If the provided size is less than or equal to 0,
// it defaults to defaultMaxSize.
// - ttl: The time-to-live duration for each cache entry. If the provided ttl is less than or equal to 0,
// it defaults to defaultTTL.
//
// This function defines an onEvicted callback that is invoked when a cache entry is evicted.
// The callback ensures that if the evicted value (v) is not nil, its Close method is called asynchronously
// to release associated resources.
//
// Returns:
// - A Store instance implemented by lruStore, which internally uses an LRU cache with the specified size,
// eviction callback, and TTL.
func New(size int, ttl time.Duration) Store {
if size <= 0 {
size = defaultMaxSize
}
if ttl <= 0 {
ttl = defaultTTL
}
onEvicted := func(k string, v *Stmt) {
if v != nil {
go v.Close()
}
}
return &lruStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
}
type lruStore struct {
lru *lru.LRU[string, *Stmt]
}
func (s *lruStore) Keys() []string {
return s.lru.Keys()
}
func (s *lruStore) Get(key string) (*Stmt, bool) {
stmt, ok := s.lru.Get(key)
if ok && stmt != nil {
<-stmt.prepared
}
return stmt, ok
}
func (s *lruStore) Set(key string, value *Stmt) {
s.lru.Add(key, value)
}
func (s *lruStore) Delete(key string) {
s.lru.Remove(key)
}
type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
}
// New creates a new Stmt object for executing SQL queries.
// It caches the Stmt object for future use and handles preparation and error states.
// Parameters:
//
// ctx: Context for the request, used to carry deadlines, cancellation signals, etc.
// key: The key representing the SQL query, used for caching and preparing the statement.
// isTransaction: Indicates whether this operation is part of a transaction, affecting cache strategy.
// conn: A connection pool that provides database connections.
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
//
// Returns:
//
// *Stmt: A newly created statement object for executing SQL operations.
// error: An error if the statement preparation fails.
func (s *lruStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) {
// Create a Stmt object and set its Transaction property.
// The prepared channel is used to synchronize the statement preparation state.
cacheStmt := &Stmt{
Transaction: isTransaction,
prepared: make(chan struct{}),
}
// Cache the Stmt object with the associated key.
s.Set(key, cacheStmt)
// Unlock after completing initialization to prevent deadlocks.
locker.Unlock()
// Ensure the prepared channel is closed after the function execution completes.
defer close(cacheStmt.prepared)
// Prepare the SQL statement using the provided connection.
cacheStmt.Stmt, err = conn.PrepareContext(ctx, key)
if err != nil {
// If statement preparation fails, record the error and remove the invalid Stmt object from the cache.
cacheStmt.prepareErr = err
s.Delete(key)
return &Stmt{}, err
}
// Return the successfully prepared Stmt object.
return cacheStmt, nil
}

View File

@ -2,6 +2,9 @@ package logger
import ( import (
"context" "context"
"errors"
"fmt"
"io"
"log" "log"
"os" "os"
"time" "time"
@ -9,6 +12,9 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
// ErrRecordNotFound record not found error
var ErrRecordNotFound = errors.New("record not found")
// Colors // Colors
const ( const (
Reset = "\033[0m" Reset = "\033[0m"
@ -19,18 +25,23 @@ const (
Magenta = "\033[35m" Magenta = "\033[35m"
Cyan = "\033[36m" Cyan = "\033[36m"
White = "\033[37m" White = "\033[37m"
BlueBold = "\033[34;1m"
MagentaBold = "\033[35;1m" MagentaBold = "\033[35;1m"
RedBold = "\033[31;1m" RedBold = "\033[31;1m"
YellowBold = "\033[33;1m" YellowBold = "\033[33;1m"
) )
// LogLevel // LogLevel log level
type LogLevel int type LogLevel int
const ( const (
// Silent silent log level
Silent LogLevel = iota + 1 Silent LogLevel = iota + 1
// Error error log level
Error Error
// Warn warn log level
Warn Warn
// Info info log level
Info Info
) )
@ -39,9 +50,12 @@ type Writer interface {
Printf(string, ...interface{}) Printf(string, ...interface{})
} }
// Config logger config
type Config struct { type Config struct {
SlowThreshold time.Duration SlowThreshold time.Duration
Colorful bool Colorful bool
IgnoreRecordNotFoundError bool
ParameterizedQueries bool
LogLevel LogLevel LogLevel LogLevel
} }
@ -51,32 +65,46 @@ type Interface interface {
Info(context.Context, string, ...interface{}) Info(context.Context, string, ...interface{})
Warn(context.Context, string, ...interface{}) Warn(context.Context, string, ...interface{})
Error(context.Context, string, ...interface{}) Error(context.Context, string, ...interface{})
Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
} }
var Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ var (
SlowThreshold: 100 * time.Millisecond, // Discard logger will print any log to io.Discard
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
// Default Default logger
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: Warn, LogLevel: Warn,
IgnoreRecordNotFoundError: false,
Colorful: true, Colorful: true,
}) })
// Recorder logger records running SQL into a recorder instance
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
// RecorderParamsFilter defaults to no-op, allows to be run-over by a different implementation
RecorderParamsFilter = func(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
return sql, params
}
)
// New initialize logger
func New(writer Writer, config Config) Interface { func New(writer Writer, config Config) Interface {
var ( var (
infoStr = "%s\n[info] " infoStr = "%s\n[info] "
warnStr = "%s\n[warn] " warnStr = "%s\n[warn] "
errStr = "%s\n[error] " errStr = "%s\n[error] "
traceStr = "%s\n[%v] [rows:%d] %s" traceStr = "%s\n[%.3fms] [rows:%v] %s"
traceWarnStr = "%s\n[%v] [rows:%d] %s" traceWarnStr = "%s %s\n[%.3fms] [rows:%v] %s"
traceErrStr = "%s %s\n[%v] [rows:%d] %s" traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s"
) )
if config.Colorful { if config.Colorful {
infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset
warnStr = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset
errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset traceWarnStr = Green + "%s " + Yellow + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
} }
return &logger{ return &logger{
@ -106,40 +134,92 @@ func (l *logger) LogMode(level LogLevel) Interface {
} }
// Info print info // Info print info
func (l logger) Info(ctx context.Context, msg string, data ...interface{}) { func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Info { if l.LogLevel >= Info {
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
} }
} }
// Warn print warn messages // Warn print warn messages
func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) { func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Warn { if l.LogLevel >= Warn {
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
} }
} }
// Error print error messages // Error print error messages
func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Error { if l.LogLevel >= Error {
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
} }
} }
// Trace print sql message // Trace print sql message
func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { //
if l.LogLevel > 0 { //nolint:cyclop
elapsed := time.Now().Sub(begin) func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel <= Silent {
return
}
elapsed := time.Since(begin)
switch { switch {
case err != nil && l.LogLevel >= Error: case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError):
sql, rows := fc() sql, rows := fc()
if rows == -1 {
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn:
sql, rows := fc() sql, rows := fc()
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)
case l.LogLevel >= Info: if rows == -1 {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case l.LogLevel == Info:
sql, rows := fc() sql, rows := fc()
if rows == -1 {
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
} }
} }
} }
// ParamsFilter filter params
func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
if l.Config.ParameterizedQueries {
return sql, nil
}
return sql, params
}
type traceRecorder struct {
Interface
BeginAt time.Time
SQL string
RowsAffected int64
Err error
}
// New trace recorder
func (l *traceRecorder) New() *traceRecorder {
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
}
// Trace implement logger interface
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
l.BeginAt = begin
l.SQL, l.RowsAffected = fc()
l.Err = err
}
func (l *traceRecorder) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
if RecorderParamsFilter == nil {
return sql, params
}
return RecorderParamsFilter(ctx, sql, params...)
}

View File

@ -9,88 +9,172 @@ import (
"strings" "strings"
"time" "time"
"unicode" "unicode"
"gorm.io/gorm/utils"
) )
func isPrintable(s []byte) bool { const (
tmFmtWithMS = "2006-01-02 15:04:05.999"
tmFmtZero = "0000-00-00 00:00:00"
nullStr = "NULL"
)
func isPrintable(s string) bool {
for _, r := range s { for _, r := range s {
if !unicode.IsPrint(rune(r)) { if !unicode.IsPrint(r) {
return false return false
} }
} }
return true return true
} }
var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} // A list of Go types that should be converted to SQL primitives
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
// RegEx matches only numeric values
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
func isNumeric(k reflect.Kind) bool {
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
case reflect.Float32, reflect.Float64:
return true
default:
return false
}
}
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var convertParams func(interface{}, int) var (
var vars = make([]interface{}, len(avars)) convertParams func(interface{}, int)
copy(vars, avars) vars = make([]string, len(avars))
)
convertParams = func(v interface{}, idx int) { convertParams = func(v interface{}, idx int) {
switch v := v.(type) { switch v := v.(type) {
case bool: case bool:
vars[idx] = fmt.Sprint(v) vars[idx] = strconv.FormatBool(v)
case time.Time: case time.Time:
if v.IsZero() { if v.IsZero() {
vars[idx] = escaper + "0000-00-00 00:00:00" + escaper vars[idx] = escaper + tmFmtZero + escaper
} else { } else {
vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
}
case *time.Time:
if v != nil {
if v.IsZero() {
vars[idx] = escaper + tmFmtZero + escaper
} else {
vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
}
} else {
vars[idx] = nullStr
}
case driver.Valuer:
reflectValue := reflect.ValueOf(v)
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
r, _ := v.Value()
convertParams(r, idx)
} else {
vars[idx] = nullStr
}
case fmt.Stringer:
reflectValue := reflect.ValueOf(v)
switch reflectValue.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
vars[idx] = fmt.Sprintf("%d", reflectValue.Interface())
case reflect.Float32, reflect.Float64:
vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface())
case reflect.Bool:
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
case reflect.String:
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
default:
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
} else {
vars[idx] = nullStr
}
} }
case []byte: case []byte:
if isPrintable(v) { if s := string(v); isPrintable(s) {
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
} else { } else {
vars[idx] = escaper + "<binary>" + escaper vars[idx] = escaper + "<binary>" + escaper
} }
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
vars[idx] = fmt.Sprintf("%d", v) vars[idx] = utils.ToString(v)
case float64, float32: case float32:
vars[idx] = fmt.Sprintf("%.6f", v) vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32)
case float64:
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
case string: case string:
vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper
default: default:
if v == nil {
vars[idx] = "NULL"
} else {
rv := reflect.ValueOf(v) rv := reflect.ValueOf(v)
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
if !rv.IsValid() { vars[idx] = nullStr
vars[idx] = "NULL"
} else if rv.Kind() == reflect.Ptr && rv.IsNil() {
vars[idx] = "NULL"
} else if valuer, ok := v.(driver.Valuer); ok { } else if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value() v, _ = valuer.Value()
convertParams(v, idx) convertParams(v, idx)
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() { } else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
convertParams(reflect.Indirect(rv).Interface(), idx) convertParams(reflect.Indirect(rv).Interface(), idx)
} else if isNumeric(rv.Kind()) {
if rv.CanInt() || rv.CanUint() {
vars[idx] = fmt.Sprintf("%d", rv.Interface())
} else { } else {
for _, t := range convertableTypes { vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
}
} else {
for _, t := range convertibleTypes {
if rv.Type().ConvertibleTo(t) { if rv.Type().ConvertibleTo(t) {
convertParams(rv.Convert(t).Interface(), idx) convertParams(rv.Convert(t).Interface(), idx)
return return
} }
} }
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
}
} }
} }
} }
for idx, v := range vars { for idx, v := range avars {
convertParams(v, idx) convertParams(v, idx)
} }
if numericPlaceholder == nil { if numericPlaceholder == nil {
for _, v := range vars { var idx int
sql = strings.Replace(sql, "?", v.(string), 1) var newSQL strings.Builder
for _, v := range []byte(sql) {
if v == '?' {
if len(vars) > idx {
newSQL.WriteString(vars[idx])
idx++
continue
} }
}
newSQL.WriteByte(v)
}
sql = newSQL.String()
} else { } else {
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
for idx, v := range vars {
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v.(string), 1) sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string {
num := v[1 : len(v)-1]
n, _ := strconv.Atoi(num)
// position var start from 1 ($1, $2)
n -= 1
if n >= 0 && n <= len(vars)-1 {
return vars[n]
} }
return v
})
} }
return sql return sql

View File

@ -1,20 +1,54 @@
package logger_test package logger_test
import ( import (
"database/sql/driver"
"encoding/json"
"fmt"
"regexp" "regexp"
"strings"
"testing" "testing"
"github.com/jinzhu/now" "github.com/jinzhu/now"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
) )
type JSON json.RawMessage
func (j JSON) Value() (driver.Value, error) {
if len(j) == 0 {
return nil, nil
}
return json.RawMessage(j).MarshalJSON()
}
type ExampleStruct struct {
Name string
Val string
}
func (s ExampleStruct) Value() (driver.Value, error) {
return json.Marshal(s)
}
func format(v []byte, escaper string) string {
return escaper + strings.ReplaceAll(string(v), escaper, escaper+escaper) + escaper
}
func TestExplainSQL(t *testing.T) { func TestExplainSQL(t *testing.T) {
type role string type role string
type password []byte type password []byte
type intType int
type floatType float64
var ( var (
tt = now.MustParse("2020-02-23 11:10:10") tt = now.MustParse("2020-02-23 11:10:10")
myrole = role("admin") myrole = role("admin")
pwd = password([]byte("pass")) pwd = password("pass")
jsVal = []byte(`{"Name":"test","Val":"test"}`)
js = JSON(jsVal)
esVal = []byte(`{"Name":"test","Val":"test"}`)
es = ExampleStruct{Name: "test", Val: "test"}
intVal intType = 1
floatVal floatType = 1.23
) )
results := []struct { results := []struct {
@ -27,25 +61,67 @@ func TestExplainSQL(t *testing.T) {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil, NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`,
}, },
{ {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
NumericRegexp: regexp.MustCompile("@p(\\d+)"), NumericRegexp: regexp.MustCompile(`@p(\d+)`),
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
}, },
{ {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)",
NumericRegexp: regexp.MustCompile("\\$(\\d+)"), NumericRegexp: regexp.MustCompile(`\$(\d+)`),
Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd}, Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
}, },
{ {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)",
NumericRegexp: regexp.MustCompile("@p(\\d+)"), NumericRegexp: regexp.MustCompile(`@p(\d+)`),
Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`,
}, },
} }

View File

@ -1,41 +1,97 @@
package gorm package gorm
import ( import (
"database/sql" "reflect"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
) )
// Migrator returns migrator // Migrator returns migrator
func (db *DB) Migrator() Migrator { func (db *DB) Migrator() Migrator {
return db.Dialector.Migrator(db) tx := db.getInstance()
// apply scopes to migrator
for len(tx.Statement.scopes) > 0 {
tx = tx.executeScopes()
}
return tx.Dialector.Migrator(tx.Session(&Session{}))
}
// AutoMigrate run auto migration for given models
func (db *DB) AutoMigrate(dst ...interface{}) error {
return db.Migrator().AutoMigrate(dst...)
} }
// ViewOption view option // ViewOption view option
type ViewOption struct { type ViewOption struct {
Replace bool Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE`
CheckOption string CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION`
Query *DB Query *DB // required subquery.
} }
// ColumnType column type interface
type ColumnType interface {
Name() string
DatabaseTypeName() string // varchar
ColumnType() (columnType string, ok bool) // varchar(64)
PrimaryKey() (isPrimaryKey bool, ok bool)
AutoIncrement() (isAutoIncrement bool, ok bool)
Length() (length int64, ok bool)
DecimalSize() (precision int64, scale int64, ok bool)
Nullable() (nullable bool, ok bool)
Unique() (unique bool, ok bool)
ScanType() reflect.Type
Comment() (value string, ok bool)
DefaultValue() (value string, ok bool)
}
type Index interface {
Table() string
Name() string
Columns() []string
PrimaryKey() (isPrimaryKey bool, ok bool)
Unique() (unique bool, ok bool)
Option() string
}
// TableType table type interface
type TableType interface {
Schema() string
Name() string
Type() string
Comment() (comment string, ok bool)
}
// Migrator migrator interface
type Migrator interface { type Migrator interface {
// AutoMigrate // AutoMigrate
AutoMigrate(dst ...interface{}) error AutoMigrate(dst ...interface{}) error
// Database // Database
CurrentDatabase() string CurrentDatabase() string
FullDataTypeOf(*schema.Field) clause.Expr
GetTypeAliases(databaseTypeName string) []string
// Tables // Tables
CreateTable(dst ...interface{}) error CreateTable(dst ...interface{}) error
DropTable(dst ...interface{}) error DropTable(dst ...interface{}) error
HasTable(dst interface{}) bool HasTable(dst interface{}) bool
RenameTable(oldName, newName interface{}) error RenameTable(oldName, newName interface{}) error
GetTables() (tableList []string, err error)
TableType(dst interface{}) (TableType, error)
// Columns // Columns
AddColumn(dst interface{}, field string) error AddColumn(dst interface{}, field string) error
DropColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error
AlterColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error
HasColumn(dst interface{}, field string) bool HasColumn(dst interface{}, field string) bool
RenameColumn(dst interface{}, oldName, field string) error RenameColumn(dst interface{}, oldName, field string) error
ColumnTypes(dst interface{}) ([]*sql.ColumnType, error) ColumnTypes(dst interface{}) ([]ColumnType, error)
// Views // Views
CreateView(name string, option ViewOption) error CreateView(name string, option ViewOption) error
@ -51,4 +107,5 @@ type Migrator interface {
DropIndex(dst interface{}, name string) error DropIndex(dst interface{}, name string) error
HasIndex(dst interface{}, name string) bool HasIndex(dst interface{}, name string) bool
RenameIndex(dst interface{}, oldName, newName string) error RenameIndex(dst interface{}, oldName, newName string) error
GetIndexes(dst interface{}) ([]Index, error)
} }

107
migrator/column_type.go Normal file
View File

@ -0,0 +1,107 @@
package migrator
import (
"database/sql"
"reflect"
)
// ColumnType column type implements ColumnType interface
type ColumnType struct {
SQLColumnType *sql.ColumnType
NameValue sql.NullString
DataTypeValue sql.NullString
ColumnTypeValue sql.NullString
PrimaryKeyValue sql.NullBool
UniqueValue sql.NullBool
AutoIncrementValue sql.NullBool
LengthValue sql.NullInt64
DecimalSizeValue sql.NullInt64
ScaleValue sql.NullInt64
NullableValue sql.NullBool
ScanTypeValue reflect.Type
CommentValue sql.NullString
DefaultValueValue sql.NullString
}
// Name returns the name or alias of the column.
func (ct ColumnType) Name() string {
if ct.NameValue.Valid {
return ct.NameValue.String
}
return ct.SQLColumnType.Name()
}
// DatabaseTypeName returns the database system name of the column type. If an empty
// string is returned, then the driver type name is not supported.
// Consult your driver documentation for a list of driver data types. Length specifiers
// are not included.
// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
// "INT", and "BIGINT".
func (ct ColumnType) DatabaseTypeName() string {
if ct.DataTypeValue.Valid {
return ct.DataTypeValue.String
}
return ct.SQLColumnType.DatabaseTypeName()
}
// ColumnType returns the database type of the column. like `varchar(16)`
func (ct ColumnType) ColumnType() (columnType string, ok bool) {
return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid
}
// PrimaryKey returns the column is primary key or not.
func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) {
return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid
}
// AutoIncrement returns the column is auto increment or not.
func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) {
return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid
}
// Length returns the column type length for variable length column types
func (ct ColumnType) Length() (length int64, ok bool) {
if ct.LengthValue.Valid {
return ct.LengthValue.Int64, true
}
return ct.SQLColumnType.Length()
}
// DecimalSize returns the scale and precision of a decimal type.
func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) {
if ct.DecimalSizeValue.Valid {
return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true
}
return ct.SQLColumnType.DecimalSize()
}
// Nullable reports whether the column may be null.
func (ct ColumnType) Nullable() (nullable bool, ok bool) {
if ct.NullableValue.Valid {
return ct.NullableValue.Bool, true
}
return ct.SQLColumnType.Nullable()
}
// Unique reports whether the column may be unique.
func (ct ColumnType) Unique() (unique bool, ok bool) {
return ct.UniqueValue.Bool, ct.UniqueValue.Valid
}
// ScanType returns a Go type suitable for scanning into using Rows.Scan.
func (ct ColumnType) ScanType() reflect.Type {
if ct.ScanTypeValue != nil {
return ct.ScanTypeValue
}
return ct.SQLColumnType.ScanType()
}
// Comment returns the comment of current column.
func (ct ColumnType) Comment() (value string, ok bool) {
return ct.CommentValue.String, ct.CommentValue.Valid
}
// DefaultValue returns the default value of current column.
func (ct ColumnType) DefaultValue() (value string, ok bool) {
return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid
}

43
migrator/index.go Normal file
View File

@ -0,0 +1,43 @@
package migrator
import "database/sql"
// Index implements gorm.Index interface
type Index struct {
TableName string
NameValue string
ColumnList []string
PrimaryKeyValue sql.NullBool
UniqueValue sql.NullBool
OptionValue string
}
// Table return the table name of the index.
func (idx Index) Table() string {
return idx.TableName
}
// Name return the name of the index.
func (idx Index) Name() string {
return idx.NameValue
}
// Columns return the columns of the index
func (idx Index) Columns() []string {
return idx.ColumnList
}
// PrimaryKey returns the index is primary key or not.
func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) {
return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid
}
// Unique returns whether the index is unique or not.
func (idx Index) Unique() (unique bool, ok bool) {
return idx.UniqueValue.Bool, idx.UniqueValue.Valid
}
// Option return the optional attribute of the index
func (idx Index) Option() string {
return idx.OptionValue
}

File diff suppressed because it is too large Load Diff

33
migrator/table_type.go Normal file
View File

@ -0,0 +1,33 @@
package migrator
import (
"database/sql"
)
// TableType table type implements TableType interface
type TableType struct {
SchemaValue string
NameValue string
TypeValue string
CommentValue sql.NullString
}
// Schema returns the schema of the table.
func (ct TableType) Schema() string {
return ct.SchemaValue
}
// Name returns the name of the table.
func (ct TableType) Name() string {
return ct.NameValue
}
// Type returns the type of the table.
func (ct TableType) Type() string {
return ct.TypeValue
}
// Comment returns the comment of current table.
func (ct TableType) Comment() (comment string, ok bool) {
return ct.CommentValue.String, ct.CommentValue.Valid
}

View File

@ -3,7 +3,8 @@ package gorm
import "time" import "time"
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embeded into your model or you may build your own model without it // It may be embedded into your model or you may build your own model without it
//
// type User struct { // type User struct {
// gorm.Model // gorm.Model
// } // }

206
prepare_stmt.go Normal file
View File

@ -0,0 +1,206 @@
package gorm
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"reflect"
"sync"
"time"
"gorm.io/gorm/internal/stmt_store"
)
type PreparedStmtDB struct {
Stmts stmt_store.Store
Mux *sync.RWMutex
ConnPool
}
// NewPreparedStmtDB creates and initializes a new instance of PreparedStmtDB.
//
// Parameters:
// - connPool: A connection pool that implements the ConnPool interface, used for managing database connections.
// - maxSize: The maximum number of prepared statements that can be stored in the statement store.
// - ttl: The time-to-live duration for each prepared statement in the store. Statements older than this duration will be automatically removed.
//
// Returns:
// - A pointer to a PreparedStmtDB instance, which manages prepared statements using the provided connection pool and configuration.
func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool, // Assigns the provided connection pool to manage database connections.
Stmts: stmt_store.New(maxSize, ttl), // Initializes a new statement store with the specified maximum size and TTL.
Mux: &sync.RWMutex{}, // Sets up a read-write mutex for synchronizing access to the statement store.
}
}
// GetDBConn returns the underlying *sql.DB connection
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil
}
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}
return nil, ErrInvalidDB
}
// Close closes all prepared statements in the store
func (db *PreparedStmtDB) Close() {
db.Mux.Lock()
defer db.Mux.Unlock()
for _, key := range db.Stmts.Keys() {
db.Stmts.Delete(key)
}
}
// Reset Deprecated use Close instead
func (db *PreparedStmtDB) Reset() {
db.Close()
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) {
db.Mux.RLock()
if db.Stmts != nil {
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock()
return stmt, stmt.Error()
}
}
db.Mux.RUnlock()
// retry
db.Mux.Lock()
if db.Stmts != nil {
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock()
return stmt, stmt.Error()
}
}
return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux)
}
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
if beginner, ok := db.ConnPool.(TxBeginner); ok {
tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
}
beginner, ok := db.ConnPool.(ConnPoolBeginner)
if !ok {
return nil, ErrInvalidTransaction
}
connPool, err := beginner.BeginTx(ctx, opt)
if err != nil {
return nil, err
}
if tx, ok := connPool.(Tx); ok {
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
}
return nil, ErrInvalidTransaction
}
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
result, err = stmt.ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
db.Stmts.Delete(query)
}
}
return result, err
}
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
db.Stmts.Delete(query)
}
}
return rows, err
}
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
return stmt.QueryRowContext(ctx, args...)
}
return &sql.Row{}
}
func (db *PreparedStmtDB) Ping() error {
conn, err := db.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}
type PreparedStmtTX struct {
Tx
PreparedStmtDB *PreparedStmtDB
}
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
return db.PreparedStmtDB.GetDBConn()
}
func (tx *PreparedStmtTX) Commit() error {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Commit()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) Rollback() error {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Rollback()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Stmts.Delete(query)
}
}
return result, err
}
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Stmts.Delete(query)
}
}
return rows, err
}
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...)
}
return &sql.Row{}
}
func (tx *PreparedStmtTX) Ping() error {
conn, err := tx.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}

372
scan.go
View File

@ -2,148 +2,368 @@ package gorm
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"reflect" "reflect"
"strings" "strings"
"time"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
func Scan(rows *sql.Rows, db *DB, initialized bool) { // prepareValues prepare values slice
columns, _ := rows.Columns() func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
values := make([]interface{}, len(columns)) if db.Statement.Schema != nil {
for idx, name := range columns {
if field := db.Statement.Schema.LookUpField(name); field != nil {
values[idx] = reflect.New(reflect.PointerTo(field.FieldType)).Interface()
continue
}
values[idx] = new(interface{})
}
} else if len(columnTypes) > 0 {
for idx, columnType := range columnTypes {
if columnType.ScanType() != nil {
values[idx] = reflect.New(reflect.PointerTo(columnType.ScanType())).Interface()
} else {
values[idx] = new(interface{})
}
}
} else {
for idx := range columns {
values[idx] = new(interface{})
}
}
}
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
for idx, column := range columns {
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
mapValue[column] = reflectValue.Interface()
if valuer, ok := mapValue[column].(driver.Valuer); ok {
mapValue[column], _ = valuer.Value()
} else if b, ok := mapValue[column].(sql.RawBytes); ok {
mapValue[column] = string(b)
}
} else {
mapValue[column] = nil
}
}
}
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) {
for idx, field := range fields {
if field != nil {
values[idx] = field.NewValuePool.Get()
} else if len(fields) == 1 {
if reflectValue.CanAddr() {
values[idx] = reflectValue.Addr().Interface()
} else {
values[idx] = reflectValue.Interface()
}
}
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
joinedNestedSchemaMap := make(map[string]interface{})
for idx, field := range fields {
if field == nil {
continue
}
if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
} else { // joinFields count is larger than 2 when using join
var isNilPtrValue bool
var relValue reflect.Value
// does not contain raw dbname
nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
// current reflect value
currentReflectValue := reflectValue
fullRels := make([]string, 0, len(nestedJoinSchemas))
for _, joinSchema := range nestedJoinSchemas {
fullRels = append(fullRels, joinSchema.Name)
relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
if relValue.Kind() == reflect.Ptr {
fullRelsName := utils.JoinNestedRelationNames(fullRels)
// same nested structure
if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
isNilPtrValue = true
break
}
relValue.Set(reflect.New(relValue.Type().Elem()))
joinedNestedSchemaMap[fullRelsName] = nil
}
}
currentReflectValue = relValue
}
if !isNilPtrValue { // ignore if value is nil
f := joinFields[idx][len(joinFields[idx])-1]
db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
}
}
// release data to pool
field.NewValuePool.Put(values[idx])
}
}
// ScanMode scan data mode
type ScanMode uint8
// scan modes
const (
ScanInitialized ScanMode = 1 << 0 // 1
ScanUpdate ScanMode = 1 << 1 // 2
ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
)
// Scan scan rows into db statement
func Scan(rows Rows, db *DB, mode ScanMode) {
var (
columns, _ = rows.Columns()
values = make([]interface{}, len(columns))
initialized = mode&ScanInitialized != 0
update = mode&ScanUpdate != 0
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
)
if len(db.Statement.ColumnMapping) > 0 {
for i, column := range columns {
v, ok := db.Statement.ColumnMapping[column]
if ok {
columns[i] = v
}
}
}
db.RowsAffected = 0
switch dest := db.Statement.Dest.(type) { switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}: case map[string]interface{}, *map[string]interface{}:
for idx, _ := range columns {
values[idx] = new(interface{})
}
if initialized || rows.Next() { if initialized || rows.Next() {
columnTypes, _ := rows.ColumnTypes()
prepareValues(values, db, columnTypes, columns)
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
}
mapValue, ok := dest.(map[string]interface{}) mapValue, ok := dest.(map[string]interface{})
if ok { if !ok {
if v, ok := dest.(*map[string]interface{}); ok { if v, ok := dest.(*map[string]interface{}); ok {
if *v == nil {
*v = map[string]interface{}{}
}
mapValue = *v mapValue = *v
} }
} }
scanIntoMap(mapValue, values, columns)
for idx, column := range columns {
mapValue[column] = *(values[idx].(*interface{}))
} }
case *[]map[string]interface{}: case *[]map[string]interface{}:
for idx, _ := range columns { columnTypes, _ := rows.ColumnTypes()
values[idx] = new(interface{})
}
for initialized || rows.Next() { for initialized || rows.Next() {
prepareValues(values, db, columnTypes, columns)
initialized = false initialized = false
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
v := map[string]interface{}{} mapValue := map[string]interface{}{}
for idx, column := range columns { scanIntoMap(mapValue, values, columns)
v[column] = *(values[idx].(*interface{})) *dest = append(*dest, mapValue)
} }
*dest = append(*dest, v) case *int, *int8, *int16, *int32, *int64,
} *uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
case *int, *int64, *uint, *uint64: *float32, *float64,
*bool, *string, *time.Time,
*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
*sql.NullBool, *sql.NullString, *sql.NullTime:
for initialized || rows.Next() { for initialized || rows.Next() {
initialized = false initialized = false
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(dest)) db.AddError(rows.Scan(dest))
} }
default: default:
switch db.Statement.ReflectValue.Kind() { var (
case reflect.Slice, reflect.Array: fields = make([]*schema.Field, len(columns))
reflectValueType := db.Statement.ReflectValue.Type().Elem() joinFields [][]*schema.Field
sch = db.Statement.Schema
reflectValue = db.Statement.ReflectValue
)
if reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem()
}
reflectValueType := reflectValue.Type()
switch reflectValueType.Kind() {
case reflect.Array, reflect.Slice:
reflectValueType = reflectValueType.Elem()
}
isPtr := reflectValueType.Kind() == reflect.Ptr isPtr := reflectValueType.Kind() == reflect.Ptr
if isPtr { if isPtr {
reflectValueType = reflectValueType.Elem() reflectValueType = reflectValueType.Elem()
} }
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) if sch != nil {
fields := make([]*schema.Field, len(columns)) if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
joinFields := make([][2]*schema.Field, len(columns)) sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
if len(columns) == 1 {
// Is Pluck
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
sch = nil
}
}
// Not Pluck
if sch != nil {
matchedFieldCount := make(map[string]int, len(columns))
for idx, column := range columns { for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { if field := sch.LookUpField(column); field != nil && field.Readable {
fields[idx] = field fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 { if count, ok := matchedFieldCount[column]; ok {
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { // handle duplicate fields
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { for _, selectField := range sch.Fields {
joinFields[idx] = [2]*schema.Field{rel.Field, field} if selectField.DBName == column && selectField.Readable {
if count == 0 {
matchedFieldCount[column]++
fields[idx] = selectField
break
}
count--
}
}
} else {
matchedFieldCount[column] = 1
}
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1])
for _, join := range db.Statement.Joins {
if join.Alias == aliasName {
names = append(strings.Split(join.Name, "."), names[len(names)-1])
break
}
}
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
subNameCount := len(names)
// nested relation fields
relFields := make([]*schema.Field, 0, subNameCount-1)
relFields = append(relFields, rel.Field)
for _, name := range names[1 : subNameCount-1] {
rel = rel.FieldSchema.Relationships.Relations[name]
relFields = append(relFields, rel.Field)
}
// latest name is raw dbname
dbName := names[subNameCount-1]
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][]*schema.Field, len(columns))
}
relFields = append(relFields, field)
joinFields[idx] = relFields
continue continue
} }
} }
values[idx] = &sql.RawBytes{} var val interface{}
values[idx] = &val
} else { } else {
values[idx] = &sql.RawBytes{} var val interface{}
values[idx] = &val
}
}
}
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
elem reflect.Value
isArrayKind = reflectValue.Kind() == reflect.Array
)
if !update || reflectValue.Len() == 0 {
update = false
if isArrayKind {
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
} else {
// if the slice cap is externally initialized, the externally initialized slice is directly used here
if reflectValue.Cap() == 0 {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
} else {
reflectValue.SetLen(0)
db.Statement.ReflectValue.Set(reflectValue)
}
} }
} }
for initialized || rows.Next() { for initialized || rows.Next() {
BEGIN:
initialized = false initialized = false
elem := reflect.New(reflectValueType).Elem()
if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 { if update {
values[0] = elem.Addr().Interface() if int(db.RowsAffected) >= reflectValue.Len() {
} else { return
for idx, field := range fields {
if field != nil {
values[idx] = field.ReflectValueOf(elem).Addr().Interface()
} else if joinFields[idx][0] != nil {
relValue := joinFields[idx][0].ReflectValueOf(elem)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
relValue.Set(reflect.New(relValue.Type().Elem()))
} }
elem = reflectValue.Index(int(db.RowsAffected))
values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface() if onConflictDonothing {
} for _, field := range fields {
} if _, ok := field.ValueOf(db.Statement.Context, elem); !ok {
}
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) goto BEGIN
}
if isPtr { }
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr())) }
} else { } else {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem)) elem = reflect.New(reflectValueType)
}
}
case reflect.Struct:
for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
relValue.Set(reflect.New(relValue.Type().Elem()))
} }
values[idx] = field.ReflectValueOf(relValue).Addr().Interface() db.scanIntoStruct(rows, elem, values, fields, joinFields)
continue
if !update {
if !isPtr {
elem = elem.Elem()
} }
if isArrayKind {
if reflectValue.Len() >= int(db.RowsAffected) {
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
} }
values[idx] = &sql.RawBytes{}
} else { } else {
values[idx] = &sql.RawBytes{} reflectValue = reflect.Append(reflectValue, elem)
}
} }
} }
if !update {
db.Statement.ReflectValue.Set(reflectValue)
}
case reflect.Struct, reflect.Ptr:
if initialized || rows.Next() { if initialized || rows.Next() {
db.RowsAffected++ if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct {
db.AddError(rows.Scan(values...)) db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
} }
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
}
default:
db.AddError(rows.Scan(dest))
} }
} }
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { if err := rows.Err(); err != nil && err != db.Error {
db.AddError(err)
}
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil {
db.AddError(ErrRecordNotFound) db.AddError(ErrRecordNotFound)
} }
} }

View File

@ -9,8 +9,7 @@ import (
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
type UserWithCallback struct { type UserWithCallback struct{}
}
func (UserWithCallback) BeforeSave(*gorm.DB) error { func (UserWithCallback) BeforeSave(*gorm.DB) error {
return nil return nil

View File

@ -1,32 +0,0 @@
package schema
import (
"regexp"
"strings"
)
type Check struct {
Name string
Constraint string // length(phone) >= 10
*Field
}
// ParseCheckConstraints parse schema check constraints
func (schema *Schema) ParseCheckConstraints() map[string]Check {
var checks = map[string]Check{}
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(names[0]) {
checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
if names[0] == "" {
chk = strings.Join(names[1:], ",")
}
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = Check{Name: name, Constraint: chk, Field: field}
}
}
}
return checks
}

66
schema/constraint.go Normal file
View File

@ -0,0 +1,66 @@
package schema
import (
"regexp"
"strings"
"gorm.io/gorm/clause"
)
// reg match english letters and midline
var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`)
type CheckConstraint struct {
Name string
Constraint string // length(phone) >= 10
*Field
}
func (chk *CheckConstraint) GetName() string { return chk.Name }
func (chk *CheckConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
}
// ParseCheckConstraints parse schema check constraints
func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint {
checks := map[string]CheckConstraint{}
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
if names[0] == "" {
chk = strings.Join(names[1:], ",")
}
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field}
}
}
}
return checks
}
type UniqueConstraint struct {
Name string
Field *Field
}
func (uni *UniqueConstraint) GetName() string { return uni.Name }
func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}}
}
// ParseUniqueConstraints parse schema unique constraints
func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint {
uniques := make(map[string]UniqueConstraint)
for _, field := range schema.Fields {
if field.Unique {
name := schema.namer.UniqueName(schema.Table, field.DBName)
uniques[name] = UniqueConstraint{Name: name, Field: field}
}
}
return uniques
}

View File

@ -6,6 +6,7 @@ import (
"testing" "testing"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
) )
type UserCheck struct { type UserCheck struct {
@ -20,7 +21,7 @@ func TestParseCheck(t *testing.T) {
t.Fatalf("failed to parse user check, got error %v", err) t.Fatalf("failed to parse user check, got error %v", err)
} }
results := map[string]schema.Check{ results := map[string]schema.CheckConstraint{
"name_checker": { "name_checker": {
Name: "name_checker", Name: "name_checker",
Constraint: "name <> 'jinzhu'", Constraint: "name <> 'jinzhu'",
@ -53,3 +54,31 @@ func TestParseCheck(t *testing.T) {
} }
} }
} }
func TestParseUniqueConstraints(t *testing.T) {
type UserUnique struct {
Name1 string `gorm:"unique"`
Name2 string `gorm:"uniqueIndex"`
}
user, err := schema.Parse(&UserUnique{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user unique, got error %v", err)
}
constraints := user.ParseUniqueConstraints()
results := map[string]schema.UniqueConstraint{
"uni_user_uniques_name1": {
Name: "uni_user_uniques_name1",
Field: &schema.Field{Name: "Name1", Unique: true},
},
}
for k, result := range results {
v, ok := constraints[k]
if !ok {
t.Errorf("Failed to found unique constraint %v from parsed constraints %+v", k, constraints)
}
tests.AssertObjEqual(t, result, v, "Name")
tests.AssertObjEqual(t, result.Field, v.Field, "Name", "Unique", "UniqueIndex")
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
package schema_test package schema_test
import ( import (
"context"
"database/sql" "database/sql"
"reflect" "reflect"
"sync" "sync"
@ -19,6 +20,7 @@ func TestFieldValuerAndSetter(t *testing.T) {
Model: gorm.Model{ Model: gorm.Model{
ID: 10, ID: 10,
CreatedAt: time.Now(), CreatedAt: time.Now(),
UpdatedAt: time.Now(),
DeletedAt: gorm.DeletedAt{Time: time.Now(), Valid: true}, DeletedAt: gorm.DeletedAt{Time: time.Now(), Valid: true},
}, },
Name: "valuer_and_setter", Name: "valuer_and_setter",
@ -34,6 +36,7 @@ func TestFieldValuerAndSetter(t *testing.T) {
"name": user.Name, "name": user.Name,
"id": user.ID, "id": user.ID,
"created_at": user.CreatedAt, "created_at": user.CreatedAt,
"updated_at": user.UpdatedAt,
"deleted_at": user.DeletedAt, "deleted_at": user.DeletedAt,
"age": user.Age, "age": user.Age,
"birthday": user.Birthday, "birthday": user.Birthday,
@ -41,30 +44,36 @@ func TestFieldValuerAndSetter(t *testing.T) {
} }
checkField(t, userSchema, reflectValue, values) checkField(t, userSchema, reflectValue, values)
var f *bool
// test setter // test setter
newValues := map[string]interface{}{ newValues := map[string]interface{}{
"name": "valuer_and_setter_2", "name": "valuer_and_setter_2",
"id": 2, "id": 2,
"created_at": time.Now(), "created_at": time.Now(),
"updated_at": nil,
"deleted_at": time.Now(), "deleted_at": time.Now(),
"age": 20, "age": 20,
"birthday": time.Now(), "birthday": time.Now(),
"active": false, "active": f,
} }
for k, v := range newValues { for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
newValues["updated_at"] = time.Time{}
newValues["active"] = false
checkField(t, userSchema, reflectValue, newValues) checkField(t, userSchema, reflectValue, newValues)
// test valuer and other type // test valuer and other type
age := myint(10) age := myint(10)
var nilTime *time.Time
newValues2 := map[string]interface{}{ newValues2 := map[string]interface{}{
"name": sql.NullString{String: "valuer_and_setter_3", Valid: true}, "name": sql.NullString{String: "valuer_and_setter_3", Valid: true},
"id": &sql.NullInt64{Int64: 3, Valid: true}, "id": &sql.NullInt64{Int64: 3, Valid: true},
"created_at": tests.Now(), "created_at": tests.Now(),
"updated_at": nilTime,
"deleted_at": time.Now(), "deleted_at": time.Now(),
"age": &age, "age": &age,
"birthday": mytime(time.Now()), "birthday": mytime(time.Now()),
@ -72,10 +81,11 @@ func TestFieldValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues2 { for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
newValues2["updated_at"] = time.Time{}
checkField(t, userSchema, reflectValue, newValues2) checkField(t, userSchema, reflectValue, newValues2)
} }
@ -123,7 +133,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues { for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -142,7 +152,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues2 { for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -193,7 +203,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues { for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -210,7 +220,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues2 { for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -225,6 +235,8 @@ type UserWithPermissionControl struct {
Name4 string `gorm:"<-:create"` Name4 string `gorm:"<-:create"`
Name5 string `gorm:"<-:update"` Name5 string `gorm:"<-:update"`
Name6 string `gorm:"<-:create,update"` Name6 string `gorm:"<-:create,update"`
Name7 string `gorm:"->:false;<-:create,update"`
Name8 string `gorm:"->;-:migration"`
} }
func TestParseFieldWithPermission(t *testing.T) { func TestParseFieldWithPermission(t *testing.T) {
@ -233,17 +245,90 @@ func TestParseFieldWithPermission(t *testing.T) {
t.Fatalf("Failed to parse user with permission, got error %v", err) t.Fatalf("Failed to parse user with permission, got error %v", err)
} }
fields := []schema.Field{ fields := []*schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true}, {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true, AutoIncrement: true},
{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String, Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, {Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false},
{Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true}, {Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true},
{Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: false}, {Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: true},
{Name: "Name4", DBName: "name4", BindNames: []string{"Name4"}, DataType: schema.String, Tag: `gorm:"<-:create"`, Creatable: true, Updatable: false, Readable: false}, {Name: "Name4", DBName: "name4", BindNames: []string{"Name4"}, DataType: schema.String, Tag: `gorm:"<-:create"`, Creatable: true, Updatable: false, Readable: true},
{Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: false}, {Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: true},
{Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: false}, {Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: true},
{Name: "Name7", DBName: "name7", BindNames: []string{"Name7"}, DataType: schema.String, Tag: `gorm:"->:false;<-:create,update"`, Creatable: true, Updatable: true, Readable: false},
{Name: "Name8", DBName: "name8", BindNames: []string{"Name8"}, DataType: schema.String, Tag: `gorm:"->;-:migration"`, Creatable: false, Updatable: false, Readable: true, IgnoreMigration: true},
} }
for _, f := range fields { for _, f := range fields {
checkSchemaField(t, user, &f, func(f *schema.Field) {}) checkSchemaField(t, user, f, func(f *schema.Field) {})
}
}
type (
ID int64
INT int
INT8 int8
INT16 int16
INT32 int32
INT64 int64
UINT uint
UINT8 uint8
UINT16 uint16
UINT32 uint32
UINT64 uint64
FLOAT32 float32
FLOAT64 float64
BOOL bool
STRING string
TIME time.Time
BYTES []byte
TypeAlias struct {
ID
INT `gorm:"column:fint"`
INT8 `gorm:"column:fint8"`
INT16 `gorm:"column:fint16"`
INT32 `gorm:"column:fint32"`
INT64 `gorm:"column:fint64"`
UINT `gorm:"column:fuint"`
UINT8 `gorm:"column:fuint8"`
UINT16 `gorm:"column:fuint16"`
UINT32 `gorm:"column:fuint32"`
UINT64 `gorm:"column:fuint64"`
FLOAT32 `gorm:"column:ffloat32"`
FLOAT64 `gorm:"column:ffloat64"`
BOOL `gorm:"column:fbool"`
STRING `gorm:"column:fstring"`
TIME `gorm:"column:ftime"`
BYTES `gorm:"column:fbytes"`
}
)
func TestTypeAliasField(t *testing.T) {
alias, err := schema.Parse(&TypeAlias{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse TypeAlias with permission, got error %v", err)
}
fields := []*schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, PrimaryKey: true, HasDefaultValue: true, AutoIncrement: true},
{Name: "INT", DBName: "fint", BindNames: []string{"INT"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint"`},
{Name: "INT8", DBName: "fint8", BindNames: []string{"INT8"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fint8"`},
{Name: "INT16", DBName: "fint16", BindNames: []string{"INT16"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fint16"`},
{Name: "INT32", DBName: "fint32", BindNames: []string{"INT32"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fint32"`},
{Name: "INT64", DBName: "fint64", BindNames: []string{"INT64"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint64"`},
{Name: "UINT", DBName: "fuint", BindNames: []string{"UINT"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint"`},
{Name: "UINT8", DBName: "fuint8", BindNames: []string{"UINT8"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fuint8"`},
{Name: "UINT16", DBName: "fuint16", BindNames: []string{"UINT16"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fuint16"`},
{Name: "UINT32", DBName: "fuint32", BindNames: []string{"UINT32"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fuint32"`},
{Name: "UINT64", DBName: "fuint64", BindNames: []string{"UINT64"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint64"`},
{Name: "FLOAT32", DBName: "ffloat32", BindNames: []string{"FLOAT32"}, DataType: schema.Float, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:ffloat32"`},
{Name: "FLOAT64", DBName: "ffloat64", BindNames: []string{"FLOAT64"}, DataType: schema.Float, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:ffloat64"`},
{Name: "BOOL", DBName: "fbool", BindNames: []string{"BOOL"}, DataType: schema.Bool, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbool"`},
{Name: "STRING", DBName: "fstring", BindNames: []string{"STRING"}, DataType: schema.String, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fstring"`},
{Name: "TIME", DBName: "ftime", BindNames: []string{"TIME"}, DataType: schema.Time, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:ftime"`},
{Name: "BYTES", DBName: "fbytes", BindNames: []string{"BYTES"}, DataType: schema.Bytes, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbytes"`},
}
for _, f := range fields {
checkSchemaField(t, alias, f, func(f *schema.Field) {})
} }
} }

View File

@ -1,6 +1,8 @@
package schema package schema
import ( import (
"fmt"
"sort"
"strconv" "strconv"
"strings" "strings"
) )
@ -11,7 +13,8 @@ type Index struct {
Type string // btree, hash, gist, spgist, gin, and brin Type string // btree, hash, gist, spgist, gin, and brin
Where string Where string
Comment string Comment string
Fields []IndexOption Option string // WITH PARSER parser_name
Fields []IndexOption // Note: IndexOption's Field maybe the same
} }
type IndexOption struct { type IndexOption struct {
@ -20,16 +23,28 @@ type IndexOption struct {
Sort string // DESC, ASC Sort string // DESC, ASC
Collate string Collate string
Length int Length int
Priority int
} }
// ParseIndexes parse schema indexes // ParseIndexes parse schema indexes
func (schema *Schema) ParseIndexes() map[string]Index { func (schema *Schema) ParseIndexes() []*Index {
var indexes = map[string]Index{} indexesByName := map[string]*Index{}
indexes := []*Index{}
for _, field := range schema.Fields { for _, field := range schema.Fields {
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE_INDEX"] != "" { if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
for _, index := range parseFieldIndexes(field) { fieldIndexes, err := parseFieldIndexes(field)
idx := indexes[index.Name] if err != nil {
schema.err = err
break
}
for _, index := range fieldIndexes {
idx := indexesByName[index.Name]
if idx == nil {
idx = &Index{Name: index.Name}
indexesByName[index.Name] = idx
indexes = append(indexes, idx)
}
idx.Name = index.Name idx.Name = index.Name
if idx.Class == "" { if idx.Class == "" {
idx.Class = index.Class idx.Class = index.Class
@ -43,25 +58,37 @@ func (schema *Schema) ParseIndexes() map[string]Index {
if idx.Comment == "" { if idx.Comment == "" {
idx.Comment = index.Comment idx.Comment = index.Comment
} }
idx.Fields = append(idx.Fields, index.Fields...) if idx.Option == "" {
indexes[index.Name] = idx idx.Option = index.Option
}
}
} }
idx.Fields = append(idx.Fields, index.Fields...)
sort.Slice(idx.Fields, func(i, j int) bool {
return idx.Fields[i].Priority < idx.Fields[j].Priority
})
}
}
}
for _, index := range indexes {
if index.Class == "UNIQUE" && len(index.Fields) == 1 {
index.Fields[0].Field.UniqueIndex = index.Name
}
}
return indexes return indexes
} }
func (schema *Schema) LookIndex(name string) *Index { func (schema *Schema) LookIndex(name string) *Index {
if schema != nil {
indexes := schema.ParseIndexes() indexes := schema.ParseIndexes()
for _, index := range indexes { for _, index := range indexes {
if index.Name == name { if index.Name == name {
return &index return index
} }
for _, field := range index.Fields { for _, field := range index.Fields {
if field.Name == name { if field.Name == name {
return &index return index
}
} }
} }
} }
@ -69,17 +96,18 @@ func (schema *Schema) LookIndex(name string) *Index {
return nil return nil
} }
func parseFieldIndexes(field *Field) (indexes []Index) { func parseFieldIndexes(field *Field) (indexes []Index, err error) {
for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
if value != "" { if value != "" {
v := strings.Split(value, ":") v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0])) k := strings.TrimSpace(strings.ToUpper(v[0]))
if k == "INDEX" || k == "UNIQUE_INDEX" { if k == "INDEX" || k == "UNIQUEINDEX" {
var ( var (
name string name string
tag = strings.Join(v[1:], ":") tag = strings.Join(v[1:], ":")
idx = strings.Index(tag, ",") idx = strings.IndexByte(tag, ',')
settings = ParseTagSetting(tag, ",") tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",")
settings = ParseTagSetting(tagSetting, ",")
length, _ = strconv.Atoi(settings["LENGTH"]) length, _ = strconv.Atoi(settings["LENGTH"])
) )
@ -87,35 +115,53 @@ func parseFieldIndexes(field *Field) (indexes []Index) {
idx = len(tag) idx = len(tag)
} }
if idx != -1 {
name = tag[0:idx] name = tag[0:idx]
}
if name == "" { if name == "" {
name = field.Schema.namer.IndexName(field.Schema.Table, field.Name) subName := field.Name
const key = "COMPOSITE"
if composite, found := settings[key]; found {
if len(composite) == 0 || composite == key {
err = fmt.Errorf(
"the composite tag of %s.%s cannot be empty",
field.Schema.Name,
field.Name)
return
}
subName = composite
}
name = field.Schema.namer.IndexName(
field.Schema.Table, subName)
} }
if (k == "UNIQUE_INDEX") || settings["UNIQUE"] != "" { if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" {
settings["CLASS"] = "UNIQUE" settings["CLASS"] = "UNIQUE"
} }
priority, err := strconv.Atoi(settings["PRIORITY"])
if err != nil {
priority = 10
}
indexes = append(indexes, Index{ indexes = append(indexes, Index{
Name: name, Name: name,
Class: settings["CLASS"], Class: settings["CLASS"],
Type: settings["TYPE"], Type: settings["TYPE"],
Where: settings["WHERE"], Where: settings["WHERE"],
Comment: settings["COMMENT"], Comment: settings["COMMENT"],
Option: settings["OPTION"],
Fields: []IndexOption{{ Fields: []IndexOption{{
Field: field, Field: field,
Expression: settings["EXPRESSION"], Expression: settings["EXPRESSION"],
Sort: settings["SORT"], Sort: settings["SORT"],
Collate: settings["COLLATE"], Collate: settings["COLLATE"],
Length: length, Length: length,
Priority: priority,
}}, }},
}) })
} }
} }
} }
err = nil
return return
} }

View File

@ -1,23 +1,58 @@
package schema_test package schema_test
import ( import (
"reflect"
"sync" "sync"
"testing" "testing"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
) )
type UserIndex struct { type UserIndex struct {
Name string `gorm:"index"` Name string `gorm:"index"`
Name2 string `gorm:"index:idx_name,unique"` Name2 string `gorm:"index:idx_name,unique"`
Name3 string `gorm:"index:,sort:desc,collate:utf8,type:btree,length:10,where:name3 != 'jinzhu'"` Name3 string `gorm:"index:,sort:desc,collate:utf8,type:btree,length:10,where:name3 != 'jinzhu'"`
Name4 string `gorm:"unique_index"` Name4 string `gorm:"uniqueIndex"`
Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"` Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"`
Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"` Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"`
Age int64 `gorm:"index:profile,expression:ABS(age)"` Age int64 `gorm:"index:profile,expression:ABS(age),option:WITH PARSER parser_name"`
OID int64 `gorm:"index:idx_id"` OID int64 `gorm:"index:idx_id;index:idx_oid,unique"`
MemberNumber string `gorm:"index:idx_id"` MemberNumber string `gorm:"index:idx_id,priority:1"`
Name7 string `gorm:"index:type"`
Name8 string `gorm:"index:,length:10;index:,collate:utf8"`
CompName1 string `gorm:"index:,unique,composite:idx_compname_1,option:NULLS NOT DISTINCT;not null"`
CompName2 string `gorm:"index:,composite:idx_compname_1"`
// Composite Index: Flattened structure.
Data0A string `gorm:"index:,composite:comp_id0"`
Data0B string `gorm:"index:,composite:comp_id0"`
// Composite Index: Nested structure.
Data1A string `gorm:"index:,composite:comp_id1"`
CompIdxLevel1C
// Composite Index: Unique and priority.
Data2A string `gorm:"index:,unique,composite:comp_id2,priority:2"`
CompIdxLevel2C
}
type CompIdxLevel1C struct {
CompIdxLevel1B
Data1C string `gorm:"index:,composite:comp_id1"`
}
type CompIdxLevel1B struct {
Data1B string `gorm:"index:,composite:comp_id1"`
}
type CompIdxLevel2C struct {
CompIdxLevel2B
Data2C string `gorm:"index:,unique,composite:comp_id2,priority:1"`
}
type CompIdxLevel2B struct {
Data2B string `gorm:"index:,unique,composite:comp_id2,priority:3"`
} }
func TestParseIndex(t *testing.T) { func TestParseIndex(t *testing.T) {
@ -26,79 +61,215 @@ func TestParseIndex(t *testing.T) {
t.Fatalf("failed to parse user index, got error %v", err) t.Fatalf("failed to parse user index, got error %v", err)
} }
results := map[string]schema.Index{ results := []*schema.Index{
"idx_user_indices_name": { {
Name: "idx_user_indices_name", Name: "idx_user_indices_name",
Fields: []schema.IndexOption{{}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name"}}},
}, },
"idx_name": { {
Name: "idx_name", Name: "idx_name",
Class: "UNIQUE", Class: "UNIQUE",
Fields: []schema.IndexOption{{}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", UniqueIndex: "idx_name"}}},
}, },
"idx_user_indices_name3": { {
Name: "idx_user_indices_name3", Name: "idx_user_indices_name3",
Type: "btree", Type: "btree",
Where: "name3 != 'jinzhu'", Where: "name3 != 'jinzhu'",
Fields: []schema.IndexOption{{ Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Name3"},
Sort: "desc", Sort: "desc",
Collate: "utf8", Collate: "utf8",
Length: 10, Length: 10,
}}, }},
}, },
"idx_user_indices_name4": { {
Name: "idx_user_indices_name4", Name: "idx_user_indices_name4",
Class: "UNIQUE", Class: "UNIQUE",
Fields: []schema.IndexOption{{}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", UniqueIndex: "idx_user_indices_name4"}}},
}, },
"idx_user_indices_name5": { {
Name: "idx_user_indices_name5", Name: "idx_user_indices_name5",
Class: "FULLTEXT", Class: "FULLTEXT",
Comment: "hello , world", Comment: "hello , world",
Where: "age > 10", Where: "age > 10",
Fields: []schema.IndexOption{{}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name5"}}},
}, },
"profile": { {
Name: "profile", Name: "profile",
Comment: "hello , world", Comment: "hello , world",
Where: "age > 10", Where: "age > 10",
Fields: []schema.IndexOption{{}, { Option: "WITH PARSER parser_name",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name6"}}, {
Field: &schema.Field{Name: "Age"},
Expression: "ABS(age)", Expression: "ABS(age)",
}}, }},
}, },
"idx_id": { {
Name: "idx_id", Name: "idx_id",
Fields: []schema.IndexOption{{}, {}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
},
{
Name: "idx_oid",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
},
{
Name: "type",
Type: "",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}},
},
{
Name: "idx_user_indices_name8",
Type: "",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "Name8"}, Length: 10},
// Note: Duplicate Columns
{Field: &schema.Field{Name: "Name8"}, Collate: "utf8"},
},
},
{
Class: "UNIQUE",
Name: "idx_user_indices_idx_compname_1",
Option: "NULLS NOT DISTINCT",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "CompName1", NotNull: true}},
{Field: &schema.Field{Name: "CompName2"}},
},
},
{
Name: "idx_user_indices_comp_id0",
Type: "",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data0A"},
}, {
Field: &schema.Field{Name: "Data0B"},
}},
},
{
Name: "idx_user_indices_comp_id1",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data1A"},
}, {
Field: &schema.Field{Name: "Data1B"},
}, {
Field: &schema.Field{Name: "Data1C"},
}},
},
{
Name: "idx_user_indices_comp_id2",
Class: "UNIQUE",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data2C"},
}, {
Field: &schema.Field{Name: "Data2A"},
}, {
Field: &schema.Field{Name: "Data2B"},
}},
}, },
} }
indices := user.ParseIndexes() CheckIndices(t, results, user.ParseIndexes())
}
for k, result := range results { func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) {
v, ok := indices[k] type IndexTest struct {
if !ok { FieldA string `gorm:"unique;index"` // unique and index
t.Fatalf("Failed to found index %v from parsed indices %+v", k, indices) FieldB string `gorm:"unique"` // unique
FieldC string `gorm:"index:,unique"` // uniqueIndex
FieldD string `gorm:"uniqueIndex;index"` // uniqueIndex and index
FieldE1 string `gorm:"uniqueIndex:uniq_field_e1_e2"` // mul uniqueIndex
FieldE2 string `gorm:"uniqueIndex:uniq_field_e1_e2"`
FieldF1 string `gorm:"uniqueIndex:uniq_field_f1_f2;index"` // mul uniqueIndex and index
FieldF2 string `gorm:"uniqueIndex:uniq_field_f1_f2;"`
FieldG string `gorm:"unique;uniqueIndex"` // unique and uniqueIndex
FieldH1 string `gorm:"unique;uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
FieldH2 string `gorm:"uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
}
indexSchema, err := schema.Parse(&IndexTest{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user index, got error %v", err)
}
indices := indexSchema.ParseIndexes()
expectedIndices := []*schema.Index{
{
Name: "idx_index_tests_field_a",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldA", Unique: true}}},
},
{
Name: "idx_index_tests_field_c",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldC", UniqueIndex: "idx_index_tests_field_c"}}},
},
{
Name: "idx_index_tests_field_d",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldD"}},
// Note: Duplicate Columns
{Field: &schema.Field{Name: "FieldD"}},
},
},
{
Name: "uniq_field_e1_e2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldE1"}},
{Field: &schema.Field{Name: "FieldE2"}},
},
},
{
Name: "uniq_field_f1_f2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldF1"}},
{Field: &schema.Field{Name: "FieldF2"}},
},
},
{
Name: "idx_index_tests_field_f1",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldF1"}}},
},
{
Name: "idx_index_tests_field_g",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldG", Unique: true, UniqueIndex: "idx_index_tests_field_g"}}},
},
{
Name: "uniq_field_h1_h2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldH1", Unique: true}},
{Field: &schema.Field{Name: "FieldH2"}},
},
},
}
CheckIndices(t, expectedIndices, indices)
}
func CheckIndices(t *testing.T, expected, actual []*schema.Index) {
if len(expected) != len(actual) {
t.Errorf("expected %d indices, but got %d", len(expected), len(actual))
return
} }
for _, name := range []string{"Name", "Class", "Type", "Where", "Comment"} { for i, ei := range expected {
if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() { t.Run(ei.Name, func(t *testing.T) {
t.Errorf( ai := actual[i]
"index %v %v should equal, expects %v, got %v", tests.AssertObjEqual(t, ai, ei, "Name", "Class", "Type", "Where", "Comment", "Option")
k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(),
)
}
}
for idx, ef := range result.Fields { if len(ei.Fields) != len(ai.Fields) {
rf := v.Fields[idx] t.Errorf("expected index %q field length is %d but actual %d", ei.Name, len(ei.Fields), len(ai.Fields))
for _, name := range []string{"Expression", "Sort", "Collate", "Length"} { return
if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() {
t.Errorf(
"index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name,
reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface(),
)
}
} }
for i, ef := range ei.Fields {
af := ai.Fields[i]
tests.AssertObjEqual(t, af, ef, "Name", "Unique", "UniqueIndex", "Expression", "Sort", "Collate", "Length", "NotNull")
} }
})
} }
} }

42
schema/interfaces.go Normal file
View File

@ -0,0 +1,42 @@
package schema
import (
"gorm.io/gorm/clause"
)
// ConstraintInterface database constraint interface
type ConstraintInterface interface {
GetName() string
Build() (sql string, vars []interface{})
}
// GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface {
GormDataType() string
}
// FieldNewValuePool field new scan value pool
type FieldNewValuePool interface {
Get() interface{}
Put(interface{})
}
// CreateClausesInterface create clauses interface
type CreateClausesInterface interface {
CreateClauses(*Field) []clause.Interface
}
// QueryClausesInterface query clauses interface
type QueryClausesInterface interface {
QueryClauses(*Field) []clause.Interface
}
// UpdateClausesInterface update clauses interface
type UpdateClausesInterface interface {
UpdateClauses(*Field) []clause.Interface
}
// DeleteClausesInterface delete clauses interface
type DeleteClausesInterface interface {
DeleteClauses(*Field) []clause.Interface
}

View File

@ -26,9 +26,11 @@ type User struct {
Active *bool Active *bool
} }
type mytime time.Time type (
type myint int mytime time.Time
type mybool = bool myint int
mybool = bool
)
type AdvancedDataTypeUser struct { type AdvancedDataTypeUser struct {
ID sql.NullInt64 ID sql.NullInt64
@ -39,3 +41,24 @@ type AdvancedDataTypeUser struct {
Active mybool Active mybool
Admin *mybool Admin *mybool
} }
type BaseModel struct {
ID uint
CreatedAt time.Time
CreatedBy *int
Created *VersionUser `gorm:"foreignKey:CreatedBy"`
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
}
type VersionModel struct {
BaseModel
Version int
}
type VersionUser struct {
VersionModel
Name string
Age uint
Birthday *time.Time
}

View File

@ -2,92 +2,149 @@ package schema
import ( import (
"crypto/sha1" "crypto/sha1"
"fmt" "encoding/hex"
"regexp"
"strings" "strings"
"sync"
"unicode/utf8" "unicode/utf8"
"github.com/jinzhu/inflection" "github.com/jinzhu/inflection"
"golang.org/x/text/cases"
"golang.org/x/text/language"
) )
// Namer namer interface // Namer namer interface
type Namer interface { type Namer interface {
TableName(table string) string TableName(table string) string
SchemaName(table string) string
ColumnName(table, column string) string ColumnName(table, column string) string
JoinTableName(table string) string JoinTableName(joinTable string) string
RelationshipFKName(Relationship) string RelationshipFKName(Relationship) string
CheckerName(table, column string) string CheckerName(table, column string) string
IndexName(table, column string) string IndexName(table, column string) string
UniqueName(table, column string) string
} }
// Replacer replacer interface like strings.Replacer
type Replacer interface {
Replace(name string) string
}
var _ Namer = (*NamingStrategy)(nil)
// NamingStrategy tables, columns naming strategy // NamingStrategy tables, columns naming strategy
type NamingStrategy struct { type NamingStrategy struct {
TablePrefix string TablePrefix string
SingularTable bool SingularTable bool
NameReplacer Replacer
NoLowerCase bool
IdentifierMaxLength int
} }
// TableName convert string to table name // TableName convert string to table name
func (ns NamingStrategy) TableName(str string) string { func (ns NamingStrategy) TableName(str string) string {
if ns.SingularTable { if ns.SingularTable {
return ns.TablePrefix + toDBName(str) return ns.TablePrefix + ns.toDBName(str)
} }
return ns.TablePrefix + inflection.Plural(toDBName(str)) return ns.TablePrefix + inflection.Plural(ns.toDBName(str))
}
// SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName
func (ns NamingStrategy) SchemaName(table string) string {
table = strings.TrimPrefix(table, ns.TablePrefix)
if ns.SingularTable {
return ns.toSchemaName(table)
}
return ns.toSchemaName(inflection.Singular(table))
} }
// ColumnName convert string to column name // ColumnName convert string to column name
func (ns NamingStrategy) ColumnName(table, column string) string { func (ns NamingStrategy) ColumnName(table, column string) string {
return toDBName(column) return ns.toDBName(column)
} }
// JoinTableName convert string to join table name // JoinTableName convert string to join table name
func (ns NamingStrategy) JoinTableName(str string) string { func (ns NamingStrategy) JoinTableName(str string) string {
return ns.TablePrefix + inflection.Plural(toDBName(str)) if !ns.NoLowerCase && strings.ToLower(str) == str {
return ns.TablePrefix + str
}
if ns.SingularTable {
return ns.TablePrefix + ns.toDBName(str)
}
return ns.TablePrefix + inflection.Plural(ns.toDBName(str))
} }
// RelationshipFKName generate fk name for relation // RelationshipFKName generate fk name for relation
func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Field.Name)) return ns.formatName("fk", rel.Schema.Table, ns.toDBName(rel.Name))
} }
// CheckerName generate checker name // CheckerName generate checker name
func (ns NamingStrategy) CheckerName(table, column string) string { func (ns NamingStrategy) CheckerName(table, column string) string {
return fmt.Sprintf("chk_%s_%s", table, column) return ns.formatName("chk", table, column)
} }
// IndexName generate index name // IndexName generate index name
func (ns NamingStrategy) IndexName(table, column string) string { func (ns NamingStrategy) IndexName(table, column string) string {
idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column)) return ns.formatName("idx", table, ns.toDBName(column))
}
if utf8.RuneCountInString(idxName) > 64 { // UniqueName generate unique constraint name
func (ns NamingStrategy) UniqueName(table, column string) string {
return ns.formatName("uni", table, ns.toDBName(column))
}
func (ns NamingStrategy) formatName(prefix, table, name string) string {
formattedName := strings.ReplaceAll(strings.Join([]string{
prefix, table, name,
}, "_"), ".", "_")
if ns.IdentifierMaxLength == 0 {
ns.IdentifierMaxLength = 64
}
if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength {
h := sha1.New() h := sha1.New()
h.Write([]byte(idxName)) h.Write([]byte(formattedName))
bs := h.Sum(nil) bs := h.Sum(nil)
idxName = fmt.Sprintf("idx%v%v", table, column)[0:56] + string(bs)[:8] formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8]
} }
return idxName return formattedName
} }
var ( var (
smap sync.Map
// https://github.com/golang/lint/blob/master/lint.go#L770 // https://github.com/golang/lint/blob/master/lint.go#L770
commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
commonInitialismsReplacer *strings.Replacer commonInitialismsReplacer *strings.Replacer
) )
func init() { func init() {
var commonInitialismsForReplacer []string commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
for _, initialism := range commonInitialisms { for _, initialism := range commonInitialisms {
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, cases.Title(language.Und).String(initialism))
} }
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
} }
func toDBName(name string) string { func (ns NamingStrategy) toDBName(name string) string {
if name == "" { if name == "" {
return "" return ""
} else if v, ok := smap.Load(name); ok { }
return fmt.Sprint(v)
if ns.NameReplacer != nil {
tmpName := ns.NameReplacer.Replace(name)
if tmpName == "" {
return name
}
name = tmpName
}
if ns.NoLowerCase {
return name
} }
var ( var (
@ -126,6 +183,14 @@ func toDBName(name string) string {
} else { } else {
buf.WriteByte(value[len(value)-1]) buf.WriteByte(value[len(value)-1])
} }
ret := buf.String()
return buf.String() return ret
}
func (ns NamingStrategy) toSchemaName(name string) string {
result := strings.ReplaceAll(cases.Title(language.Und, cases.NoLower).String(strings.ReplaceAll(name, "_", " ")), " ", "")
for _, initialism := range commonInitialisms {
result = regexp.MustCompile(cases.Title(language.Und, cases.NoLower).String(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
}
return result
} }

View File

@ -1,11 +1,12 @@
package schema package schema
import ( import (
"strings"
"testing" "testing"
) )
func TestToDBName(t *testing.T) { func TestToDBName(t *testing.T) {
var maps = map[string]string{ maps := map[string]string{
"": "", "": "",
"x": "x", "x": "x",
"X": "x", "X": "x",
@ -26,9 +27,193 @@ func TestToDBName(t *testing.T) {
"ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
} }
ns := NamingStrategy{}
for key, value := range maps { for key, value := range maps {
if toDBName(key) != value { if ns.toDBName(key) != value {
t.Errorf("%v toName should equal %v, but got %v", key, value, toDBName(key)) t.Errorf("%v toName should equal %v, but got %v", key, value, ns.toDBName(key))
}
}
maps = map[string]string{
"x": "X",
"user_restrictions": "UserRestriction",
"this_is_a_test": "ThisIsATest",
"abc_and_jkl": "AbcAndJkl",
"employee_id": "EmployeeID",
"field_x": "FieldX",
"http_and_smtp": "HTTPAndSMTP",
"http_server_handler_for_url_id": "HTTPServerHandlerForURLID",
"uuid": "UUID",
"http_url": "HTTPURL",
"sha256_hash": "Sha256Hash",
"this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id": "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIDCanBeUsedAtTheEndAsID",
}
for key, value := range maps {
if ns.SchemaName(key) != value {
t.Errorf("%v schema name should equal %v, but got %v", key, value, ns.SchemaName(key))
} }
} }
} }
func TestNamingStrategy(t *testing.T) {
ns := NamingStrategy{
TablePrefix: "public.",
SingularTable: true,
NameReplacer: strings.NewReplacer("CID", "Cid"),
}
idxName := ns.IndexName("public.table", "name")
if idxName != "idx_public_table_name" {
t.Errorf("invalid index name generated, got %v", idxName)
}
chkName := ns.CheckerName("public.table", "name")
if chkName != "chk_public_table_name" {
t.Errorf("invalid checker name generated, got %v", chkName)
}
joinTable := ns.JoinTableName("user_languages")
if joinTable != "public.user_languages" {
t.Errorf("invalid join table generated, got %v", joinTable)
}
joinTable2 := ns.JoinTableName("UserLanguage")
if joinTable2 != "public.user_language" {
t.Errorf("invalid join table generated, got %v", joinTable2)
}
tableName := ns.TableName("Company")
if tableName != "public.company" {
t.Errorf("invalid table name generated, got %v", tableName)
}
columdName := ns.ColumnName("", "NameCID")
if columdName != "name_cid" {
t.Errorf("invalid column name generated, got %v", columdName)
}
}
type CustomReplacer struct {
f func(string) string
}
func (r CustomReplacer) Replace(name string) string {
return r.f(name)
}
func TestCustomReplacer(t *testing.T) {
ns := NamingStrategy{
TablePrefix: "public.",
SingularTable: true,
NameReplacer: CustomReplacer{
func(name string) string {
replaced := "REPLACED_" + strings.ToUpper(name)
return strings.NewReplacer("CID", "_Cid").Replace(replaced)
},
},
NoLowerCase: false,
}
idxName := ns.IndexName("public.table", "name")
if idxName != "idx_public_table_replaced_name" {
t.Errorf("invalid index name generated, got %v", idxName)
}
chkName := ns.CheckerName("public.table", "name")
if chkName != "chk_public_table_name" {
t.Errorf("invalid checker name generated, got %v", chkName)
}
joinTable := ns.JoinTableName("user_languages")
if joinTable != "public.user_languages" { // Seems like a bug in NamingStrategy to skip the Replacer when the name is lowercase here.
t.Errorf("invalid join table generated, got %v", joinTable)
}
joinTable2 := ns.JoinTableName("UserLanguage")
if joinTable2 != "public.replaced_userlanguage" {
t.Errorf("invalid join table generated, got %v", joinTable2)
}
tableName := ns.TableName("Company")
if tableName != "public.replaced_company" {
t.Errorf("invalid table name generated, got %v", tableName)
}
columdName := ns.ColumnName("", "NameCID")
if columdName != "replaced_name_cid" {
t.Errorf("invalid column name generated, got %v", columdName)
}
}
func TestCustomReplacerWithNoLowerCase(t *testing.T) {
ns := NamingStrategy{
TablePrefix: "public.",
SingularTable: true,
NameReplacer: CustomReplacer{
func(name string) string {
replaced := "REPLACED_" + strings.ToUpper(name)
return strings.NewReplacer("CID", "_Cid").Replace(replaced)
},
},
NoLowerCase: true,
}
idxName := ns.IndexName("public.table", "name")
if idxName != "idx_public_table_REPLACED_NAME" {
t.Errorf("invalid index name generated, got %v", idxName)
}
chkName := ns.CheckerName("public.table", "name")
if chkName != "chk_public_table_name" {
t.Errorf("invalid checker name generated, got %v", chkName)
}
joinTable := ns.JoinTableName("user_languages")
if joinTable != "public.REPLACED_USER_LANGUAGES" {
t.Errorf("invalid join table generated, got %v", joinTable)
}
joinTable2 := ns.JoinTableName("UserLanguage")
if joinTable2 != "public.REPLACED_USERLANGUAGE" {
t.Errorf("invalid join table generated, got %v", joinTable2)
}
tableName := ns.TableName("Company")
if tableName != "public.REPLACED_COMPANY" {
t.Errorf("invalid table name generated, got %v", tableName)
}
columdName := ns.ColumnName("", "NameCID")
if columdName != "REPLACED_NAME_Cid" {
t.Errorf("invalid column name generated, got %v", columdName)
}
}
func TestFormatNameWithStringLongerThan63Characters(t *testing.T) {
ns := NamingStrategy{IdentifierMaxLength: 63}
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVer180f2c67" {
t.Errorf("invalid formatted name generated, got %v", formattedName)
}
}
func TestFormatNameWithStringLongerThan64Characters(t *testing.T) {
ns := NamingStrategy{IdentifierMaxLength: 64}
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" {
t.Errorf("invalid formatted name generated, got %v", formattedName)
}
}
func TestReplaceEmptyTableName(t *testing.T) {
ns := NamingStrategy{
SingularTable: true,
NameReplacer: strings.NewReplacer("Model", ""),
}
tableName := ns.TableName("Model")
if tableName != "Model" {
t.Errorf("invalid table name generated, got %v", tableName)
}
}

19
schema/pool.go Normal file
View File

@ -0,0 +1,19 @@
package schema
import (
"reflect"
"sync"
)
// sync pools
var (
normalPool sync.Map
poolInitializer = func(reflectType reflect.Type) FieldNewValuePool {
v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{
New: func() interface{} {
return reflect.New(reflectType).Interface()
},
})
return v.(FieldNewValuePool)
}
)

View File

@ -1,12 +1,16 @@
package schema package schema
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"regexp"
"strings" "strings"
"sync"
"github.com/jinzhu/inflection" "github.com/jinzhu/inflection"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
@ -18,6 +22,7 @@ const (
HasMany RelationshipType = "has_many" // HasManyRel has many relationship HasMany RelationshipType = "has_many" // HasManyRel has many relationship
BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship
Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship
has RelationshipType = "has"
) )
type Relationships struct { type Relationships struct {
@ -26,6 +31,10 @@ type Relationships struct {
HasMany []*Relationship HasMany []*Relationship
Many2Many []*Relationship Many2Many []*Relationship
Relations map[string]*Relationship Relations map[string]*Relationship
EmbeddedRelations map[string]*Relationships
Mux sync.RWMutex
} }
type Relationship struct { type Relationship struct {
@ -53,7 +62,7 @@ type Reference struct {
OwnPrimaryKey bool OwnPrimaryKey bool
} }
func (schema *Schema) parseRelation(field *Field) { func (schema *Schema) parseRelation(field *Field) *Relationship {
var ( var (
err error err error
fieldValue = reflect.New(field.IndirectFieldType).Interface() fieldValue = reflect.New(field.IndirectFieldType).Interface()
@ -66,25 +75,38 @@ func (schema *Schema) parseRelation(field *Field) {
} }
) )
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { cacheStore := schema.cacheStore
schema.err = err
return if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err)
return nil
} }
if polymorphic, _ := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { if hasPolymorphicRelation(field.TagSettings) {
schema.buildPolymorphicRelation(relation, field, polymorphic) schema.buildPolymorphicRelation(relation, field)
} else if many2many, _ := field.TagSettings["MANY2MANY"]; many2many != "" { } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
schema.buildMany2ManyRelation(relation, field, many2many) schema.buildMany2ManyRelation(relation, field, many2many)
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
schema.guessRelation(relation, field, guessBelongs)
} else { } else {
switch field.IndirectFieldType.Kind() { switch field.IndirectFieldType.Kind() {
case reflect.Struct, reflect.Slice: case reflect.Struct:
schema.guessRelation(relation, field, true) schema.guessRelation(relation, field, guessGuess)
case reflect.Slice:
schema.guessRelation(relation, field, guessHas)
default: default:
schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema,
field.Name)
} }
} }
if relation.Type == "has" { if relation.Type == has {
if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil {
relation.FieldSchema.Relationships.Mux.Lock()
relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation
relation.FieldSchema.Relationships.Mux.Unlock()
}
switch field.IndirectFieldType.Kind() { switch field.IndirectFieldType.Kind() {
case reflect.Struct: case reflect.Struct:
relation.Type = HasOne relation.Type = HasOne
@ -94,7 +116,7 @@ func (schema *Schema) parseRelation(field *Field) {
} }
if schema.err == nil { if schema.err == nil {
schema.Relationships.Relations[relation.Name] = relation schema.setRelation(relation)
switch relation.Type { switch relation.Type {
case HasOne: case HasOne:
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation) schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
@ -106,9 +128,59 @@ func (schema *Schema) parseRelation(field *Field) {
schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation) schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation)
} }
} }
return relation
}
// hasPolymorphicRelation check if has polymorphic relation
// 1. `POLYMORPHIC` tag
// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
func hasPolymorphicRelation(tagSettings map[string]string) bool {
if _, ok := tagSettings["POLYMORPHIC"]; ok {
return true
}
_, hasType := tagSettings["POLYMORPHICTYPE"]
_, hasId := tagSettings["POLYMORPHICID"]
return hasType && hasId
}
func (schema *Schema) setRelation(relation *Relationship) {
// set non-embedded relation
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
if len(rel.Field.BindNames) > 1 {
schema.Relationships.Relations[relation.Name] = relation
}
} else {
schema.Relationships.Relations[relation.Name] = relation
}
// set embedded relation
if len(relation.Field.EmbeddedBindNames) <= 1 {
return
}
relationships := &schema.Relationships
for i, name := range relation.Field.EmbeddedBindNames {
if i < len(relation.Field.EmbeddedBindNames)-1 {
if relationships.EmbeddedRelations == nil {
relationships.EmbeddedRelations = map[string]*Relationships{}
}
if r := relationships.EmbeddedRelations[name]; r == nil {
relationships.EmbeddedRelations[name] = &Relationships{}
}
relationships = relationships.EmbeddedRelations[name]
} else {
if relationships.Relations == nil {
relationships.Relations = map[string]*Relationship{}
}
relationships.Relations[relation.Name] = relation
}
}
} }
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
//
// type User struct { // type User struct {
// Toys []Toy `gorm:"polymorphic:Owner;"` // Toys []Toy `gorm:"polymorphic:Owner;"`
// } // }
@ -119,23 +191,41 @@ func (schema *Schema) parseRelation(field *Field) {
// OwnerID int // OwnerID int
// OwnerType string // OwnerType string
// } // }
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) { func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
polymorphic := field.TagSettings["POLYMORPHIC"]
relation.Polymorphic = &Polymorphic{ relation.Polymorphic = &Polymorphic{
Value: schema.Table, Value: schema.Table,
PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"],
} }
if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { var (
typeName = polymorphic + "Type"
typeId = polymorphic + "ID"
)
if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
typeName = strings.TrimSpace(value)
}
if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
typeId = strings.TrimSpace(value)
}
relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]
if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
relation.Polymorphic.Value = strings.TrimSpace(value) relation.Polymorphic.Value = strings.TrimSpace(value)
} }
if relation.Polymorphic.PolymorphicType == nil { if relation.Polymorphic.PolymorphicType == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type") schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"Type")
} }
if relation.Polymorphic.PolymorphicID == nil { if relation.Polymorphic.PolymorphicID == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID") schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"ID")
} }
if schema.err == nil { if schema.err == nil {
@ -147,12 +237,25 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
primaryKeyField := schema.PrioritizedPrimaryField primaryKeyField := schema.PrioritizedPrimaryField
if len(relation.foreignKeys) > 0 { if len(relation.foreignKeys) > 0 {
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys,
schema, field.Name)
} }
} }
if primaryKeyField == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
relation.FieldSchema, schema, field.Name)
return
}
// use same data type for foreign keys // use same data type for foreign keys
if copyableDataType(primaryKeyField.DataType) {
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
}
relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType
if relation.Polymorphic.PolymorphicID.Size == 0 {
relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size
}
relation.References = append(relation.References, &Reference{ relation.References = append(relation.References, &Reference{
PrimaryKey: primaryKeyField, PrimaryKey: primaryKeyField,
@ -161,7 +264,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
}) })
} }
relation.Type = "has" relation.Type = has
} }
func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) { func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) {
@ -171,7 +274,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
err error err error
joinTableFields []reflect.StructField joinTableFields []reflect.StructField
fieldsMap = map[string]*Field{} fieldsMap = map[string]*Field{}
ownFieldsMap = map[string]bool{} // fix self join many2many ownFieldsMap = map[string]*Field{} // fix self join many2many
referFieldsMap = map[string]*Field{}
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"]) joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"]) joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
) )
@ -185,7 +289,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
if field := schema.LookUpField(foreignKey); field != nil { if field := schema.LookUpField(foreignKey); field != nil {
ownForeignFields = append(ownForeignFields, field) ownForeignFields = append(ownForeignFields, field)
} else { } else {
schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
return return
} }
} }
@ -197,33 +301,31 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
if field := relation.FieldSchema.LookUpField(foreignKey); field != nil { if field := relation.FieldSchema.LookUpField(foreignKey); field != nil {
refForeignFields = append(refForeignFields, field) refForeignFields = append(refForeignFields, field)
} else { } else {
schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
return return
} }
} }
} }
for idx, ownField := range ownForeignFields { for idx, ownField := range ownForeignFields {
joinFieldName := schema.Name + ownField.Name joinFieldName := cases.Title(language.Und, cases.NoLower).String(schema.Name) + ownField.Name
if len(joinForeignKeys) > idx { if len(joinForeignKeys) > idx {
joinFieldName = joinForeignKeys[idx] joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinForeignKeys[idx])
} }
ownFieldsMap[joinFieldName] = true ownFieldsMap[joinFieldName] = ownField
fieldsMap[joinFieldName] = ownField fieldsMap[joinFieldName] = ownField
joinTableFields = append(joinTableFields, reflect.StructField{ joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName, Name: joinFieldName,
PkgPath: ownField.StructField.PkgPath, PkgPath: ownField.StructField.PkgPath,
Type: ownField.StructField.Type, Type: ownField.StructField.Type,
Tag: removeSettingFromTag(ownField.StructField.Tag, "column"), Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
}) })
} }
for idx, relField := range refForeignFields { for idx, relField := range refForeignFields {
joinFieldName := relation.FieldSchema.Name + relField.Name joinFieldName := cases.Title(language.Und, cases.NoLower).String(relation.FieldSchema.Name) + relField.Name
if len(joinReferences) > idx {
joinFieldName = joinReferences[idx]
}
if _, ok := ownFieldsMap[joinFieldName]; ok { if _, ok := ownFieldsMap[joinFieldName]; ok {
if field.Name != relation.FieldSchema.Name { if field.Name != relation.FieldSchema.Name {
@ -233,100 +335,254 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
} }
} }
if len(joinReferences) > idx {
joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinReferences[idx])
}
referFieldsMap[joinFieldName] = relField
if _, ok := fieldsMap[joinFieldName]; !ok {
fieldsMap[joinFieldName] = relField fieldsMap[joinFieldName] = relField
joinTableFields = append(joinTableFields, reflect.StructField{ joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName, Name: joinFieldName,
PkgPath: relField.StructField.PkgPath, PkgPath: relField.StructField.PkgPath,
Type: relField.StructField.Type, Type: relField.StructField.Type,
Tag: removeSettingFromTag(relField.StructField.Tag, "column"), Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
}) })
} }
}
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { joinTableFields = append(joinTableFields, reflect.StructField{
Name: cases.Title(language.Und, cases.NoLower).String(schema.Name) + field.Name,
Type: schema.ModelType,
Tag: `gorm:"-"`,
})
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
schema.namer); err != nil {
schema.err = err schema.err = err
} }
relation.JoinTable.Name = many2many relation.JoinTable.Name = many2many
relation.JoinTable.Table = schema.namer.JoinTableName(many2many) relation.JoinTable.Table = schema.namer.JoinTableName(many2many)
relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields))
relName := relation.Schema.Name
relRefName := relation.FieldSchema.Name
if relName == relRefName {
relRefName = relation.Field.Name
}
if _, ok := relation.JoinTable.Relationships.Relations[relName]; !ok {
relation.JoinTable.Relationships.Relations[relName] = &Relationship{
Name: relName,
Type: BelongsTo,
Schema: relation.JoinTable,
FieldSchema: relation.Schema,
}
} else {
relation.JoinTable.Relationships.Relations[relName].References = []*Reference{}
}
if _, ok := relation.JoinTable.Relationships.Relations[relRefName]; !ok {
relation.JoinTable.Relationships.Relations[relRefName] = &Relationship{
Name: relRefName,
Type: BelongsTo,
Schema: relation.JoinTable,
FieldSchema: relation.FieldSchema,
}
} else {
relation.JoinTable.Relationships.Relations[relRefName].References = []*Reference{}
}
// build references // build references
for _, f := range relation.JoinTable.Fields { for _, f := range relation.JoinTable.Fields {
if f.Creatable || f.Readable || f.Updatable {
// use same data type for foreign keys // use same data type for foreign keys
if copyableDataType(fieldsMap[f.Name].DataType) {
f.DataType = fieldsMap[f.Name].DataType f.DataType = fieldsMap[f.Name].DataType
}
f.GORMDataType = fieldsMap[f.Name].GORMDataType
if f.Size == 0 {
f.Size = fieldsMap[f.Name].Size
}
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
if of, ok := ownFieldsMap[f.Name]; ok {
joinRel := relation.JoinTable.Relationships.Relations[relName]
joinRel.Field = relation.Field
joinRel.References = append(joinRel.References, &Reference{
PrimaryKey: of,
ForeignKey: f,
})
relation.References = append(relation.References, &Reference{ relation.References = append(relation.References, &Reference{
PrimaryKey: fieldsMap[f.Name], PrimaryKey: of,
ForeignKey: f, ForeignKey: f,
OwnPrimaryKey: schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name], OwnPrimaryKey: true,
}) })
} }
return
if rf, ok := referFieldsMap[f.Name]; ok {
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
if joinRefRel.Field == nil {
joinRefRel.Field = relation.Field
}
joinRefRel.References = append(joinRefRel.References, &Reference{
PrimaryKey: rf,
ForeignKey: f,
})
relation.References = append(relation.References, &Reference{
PrimaryKey: rf,
ForeignKey: f,
})
}
}
}
} }
func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) { type guessLevel int
const (
guessGuess guessLevel = iota
guessBelongs
guessEmbeddedBelongs
guessHas
guessEmbeddedHas
)
func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl guessLevel) {
var ( var (
primaryFields, foreignFields []*Field primaryFields, foreignFields []*Field
primarySchema, foreignSchema = schema, relation.FieldSchema primarySchema, foreignSchema = schema, relation.FieldSchema
gl = cgl
) )
if !guessHas { if gl == guessGuess {
primarySchema, foreignSchema = relation.FieldSchema, schema if field.Schema == relation.FieldSchema {
gl = guessBelongs
} else {
gl = guessHas
}
} }
reguessOrErr := func(err string, args ...interface{}) { reguessOrErr := func() {
if guessHas { switch cgl {
schema.guessRelation(relation, field, false) case guessGuess:
} else { schema.guessRelation(relation, field, guessBelongs)
schema.err = fmt.Errorf(err, args...) case guessBelongs:
schema.guessRelation(relation, field, guessEmbeddedBelongs)
case guessEmbeddedBelongs:
schema.guessRelation(relation, field, guessHas)
case guessHas:
schema.guessRelation(relation, field, guessEmbeddedHas)
// case guessEmbeddedHas:
default:
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface",
schema, field.Name)
} }
} }
switch gl {
case guessBelongs:
primarySchema, foreignSchema = relation.FieldSchema, schema
case guessEmbeddedBelongs:
if field.OwnerSchema == nil {
reguessOrErr()
return
}
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
case guessHas:
case guessEmbeddedHas:
if field.OwnerSchema == nil {
reguessOrErr()
return
}
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
}
if len(relation.foreignKeys) > 0 { if len(relation.foreignKeys) > 0 {
for _, foreignKey := range relation.foreignKeys { for _, foreignKey := range relation.foreignKeys {
if f := foreignSchema.LookUpField(foreignKey); f != nil { f := foreignSchema.LookUpField(foreignKey)
foreignFields = append(foreignFields, f) if f == nil {
} else { reguessOrErr()
reguessOrErr("unsupported relations %v for %v on field %v with foreign keys %v", relation.FieldSchema, schema, field.Name, relation.foreignKeys)
return return
} }
foreignFields = append(foreignFields, f)
} }
} else { } else {
for _, primaryField := range primarySchema.PrimaryFields { primarySchemaName := primarySchema.Name
lookUpName := schema.Name + primaryField.Name if primarySchemaName == "" {
if !guessHas { primarySchemaName = relation.FieldSchema.Name
}
if len(relation.primaryKeys) > 0 {
for _, primaryKey := range relation.primaryKeys {
if f := primarySchema.LookUpField(primaryKey); f != nil {
primaryFields = append(primaryFields, f)
}
}
} else {
primaryFields = primarySchema.PrimaryFields
}
primaryFieldLoop:
for _, primaryField := range primaryFields {
lookUpName := primarySchemaName + primaryField.Name
if gl == guessBelongs {
lookUpName = field.Name + primaryField.Name lookUpName = field.Name + primaryField.Name
} }
if f := foreignSchema.LookUpField(lookUpName); f != nil { lookUpNames := []string{lookUpName}
if len(primaryFields) == 1 {
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID",
strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
}
for _, name := range lookUpNames {
if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil {
foreignFields = append(foreignFields, f) foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField) primaryFields = append(primaryFields, primaryField)
continue primaryFieldLoop
}
}
for _, name := range lookUpNames {
if f := foreignSchema.LookUpField(name); f != nil {
foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField)
continue primaryFieldLoop
}
} }
} }
} }
if len(foreignFields) == 0 { switch {
reguessOrErr("failed to guess %v's relations with %v's field %v 1 g %v", relation.FieldSchema, schema, field.Name, guessHas) case len(foreignFields) == 0:
reguessOrErr()
return return
} else if len(relation.primaryKeys) > 0 { case len(relation.primaryKeys) > 0:
for idx, primaryKey := range relation.primaryKeys { for idx, primaryKey := range relation.primaryKeys {
if f := primarySchema.LookUpField(primaryKey); f != nil { if f := primarySchema.LookUpField(primaryKey); f != nil {
if len(primaryFields) < idx+1 { if len(primaryFields) < idx+1 {
primaryFields = append(primaryFields, f) primaryFields = append(primaryFields, f)
} else if f != primaryFields[idx] { } else if f != primaryFields[idx] {
reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys) reguessOrErr()
return return
} }
} else { } else {
reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys) reguessOrErr()
return return
} }
} }
} else if len(primaryFields) == 0 { case len(primaryFields) == 0:
if len(foreignFields) == 1 { if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
} else if len(primarySchema.PrimaryFields) == len(foreignFields) { } else if len(primarySchema.PrimaryFields) == len(foreignFields) {
primaryFields = append(primaryFields, primarySchema.PrimaryFields...) primaryFields = append(primaryFields, primarySchema.PrimaryFields...)
} else { } else {
reguessOrErr("unsupported relations %v for %v on field %v", relation.FieldSchema, schema, field.Name) reguessOrErr()
return return
} }
} }
@ -334,22 +590,29 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
// build references // build references
for idx, foreignField := range foreignFields { for idx, foreignField := range foreignFields {
// use same data type for foreign keys // use same data type for foreign keys
if copyableDataType(primaryFields[idx].DataType) {
foreignField.DataType = primaryFields[idx].DataType foreignField.DataType = primaryFields[idx].DataType
}
foreignField.GORMDataType = primaryFields[idx].GORMDataType
if foreignField.Size == 0 {
foreignField.Size = primaryFields[idx].Size
}
relation.References = append(relation.References, &Reference{ relation.References = append(relation.References, &Reference{
PrimaryKey: primaryFields[idx], PrimaryKey: primaryFields[idx],
ForeignKey: foreignField, ForeignKey: foreignField,
OwnPrimaryKey: schema == primarySchema && guessHas, OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas),
}) })
} }
if guessHas { if gl == guessHas || gl == guessEmbeddedHas {
relation.Type = "has" relation.Type = has
} else { } else {
relation.Type = BelongsTo relation.Type = BelongsTo
} }
} }
// Constraint is ForeignKey Constraint
type Constraint struct { type Constraint struct {
Name string Name string
Field *Field Field *Field
@ -361,19 +624,67 @@ type Constraint struct {
OnUpdate string OnUpdate string
} }
func (constraint *Constraint) GetName() string { return constraint.Name }
func (constraint *Constraint) Build() (sql string, vars []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys))
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
references := make([]interface{}, 0, len(constraint.References))
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}
func (rel *Relationship) ParseConstraint() *Constraint { func (rel *Relationship) ParseConstraint() *Constraint {
str := rel.Field.TagSettings["CONSTRAINT"] str := rel.Field.TagSettings["CONSTRAINT"]
if str == "-" { if str == "-" {
return nil return nil
} }
if rel.Type == BelongsTo {
for _, r := range rel.FieldSchema.Relationships.Relations {
if r != rel && r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) {
matched := true
for idx, ref := range r.References {
if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey &&
rel.References[idx].PrimaryValue == ref.PrimaryValue) {
matched = false
break
}
}
if matched {
return nil
}
}
}
}
var ( var (
name string name string
idx = strings.Index(str, ",") idx = strings.IndexByte(str, ',')
settings = ParseTagSetting(str, ",") settings = ParseTagSetting(str, ",")
) )
if idx != -1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(str[0:idx]) { // optimize match english letters and midline
// The following code is basically called in for.
// In order to avoid the performance problems caused by repeated compilation of regular expressions,
// it only needs to be done once outside, so optimization is done here.
if idx != -1 && regEnLetterAndMidline.MatchString(str[0:idx]) {
name = str[0:idx] name = str[0:idx]
} else { } else {
name = rel.Schema.namer.RelationshipFKName(*rel) name = rel.Schema.namer.RelationshipFKName(*rel)
@ -384,29 +695,33 @@ func (rel *Relationship) ParseConstraint() *Constraint {
Field: rel.Field, Field: rel.Field,
OnUpdate: settings["ONUPDATE"], OnUpdate: settings["ONUPDATE"],
OnDelete: settings["ONDELETE"], OnDelete: settings["ONDELETE"],
Schema: rel.Schema,
} }
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.PrimaryKey != nil && !ref.OwnPrimaryKey { if ref.PrimaryKey != nil && (rel.JoinTable == nil || ref.OwnPrimaryKey) {
constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey) constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey)
constraint.References = append(constraint.References, ref.PrimaryKey) constraint.References = append(constraint.References, ref.PrimaryKey)
if ref.OwnPrimaryKey {
constraint.Schema = ref.ForeignKey.Schema
constraint.ReferenceSchema = rel.Schema
} else {
constraint.Schema = rel.Schema
constraint.ReferenceSchema = ref.PrimaryKey.Schema constraint.ReferenceSchema = ref.PrimaryKey.Schema
} }
} }
if rel.JoinTable != nil || constraint.ReferenceSchema == nil {
return nil
} }
return &constraint return &constraint
} }
func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) { func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) {
table := rel.FieldSchema.Table
foreignFields := []*Field{} foreignFields := []*Field{}
relForeignKeys := []string{} relForeignKeys := []string{}
if rel.JoinTable != nil { if rel.JoinTable != nil {
table = rel.JoinTable.Table
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey) foreignFields = append(foreignFields, ref.PrimaryKey)
@ -440,9 +755,19 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []
} }
} }
_, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) _, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields)
column, values := ToQueryValues(relForeignKeys, foreignValues) column, values := ToQueryValues(table, relForeignKeys, foreignValues)
conds = append(conds, clause.IN{Column: column, Values: values}) conds = append(conds, clause.IN{Column: column, Values: values})
return return
} }
func copyableDataType(str DataType) bool {
lowerStr := strings.ToLower(string(str))
for _, s := range []string{"auto_increment", "primary key"} {
if strings.Contains(lowerStr, s) {
return false
}
}
return true
}

View File

@ -10,7 +10,7 @@ import (
func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) { func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) {
if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil { if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil {
t.Errorf("Failed to parse schema") t.Errorf("Failed to parse schema, got error %v", err)
} else { } else {
for _, rel := range relations { for _, rel := range relations {
checkSchemaRelation(t, s, rel) checkSchemaRelation(t, s, rel)
@ -55,6 +55,95 @@ func TestBelongsToOverrideReferences(t *testing.T) {
}) })
} }
func TestBelongsToWithOnlyReferences(t *testing.T) {
type Profile struct {
gorm.Model
Refer string
Name string
}
type User struct {
gorm.Model
Profile Profile `gorm:"References:Refer"`
ProfileRefer int
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}},
})
}
func TestBelongsToWithOnlyReferences2(t *testing.T) {
type Profile struct {
gorm.Model
Refer string
Name string
}
type User struct {
gorm.Model
Profile Profile `gorm:"References:Refer"`
ProfileID int
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"Refer", "Profile", "ProfileID", "User", "", false}},
})
}
func TestSelfReferentialBelongsTo(t *testing.T) {
type User struct {
ID int32 `gorm:"primaryKey"`
Name string
CreatorID *int32
Creator *User
}
checkStructRelation(t, &User{}, Relation{
Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User",
References: []Reference{{"ID", "User", "CreatorID", "User", "", false}},
})
}
func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) {
type User struct {
ID int32 `gorm:"primaryKey"`
Name string
CreatedBy *int32
Creator *User `gorm:"foreignKey:CreatedBy;references:ID"`
}
checkStructRelation(t, &User{}, Relation{
Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User",
References: []Reference{{"ID", "User", "CreatedBy", "User", "", false}},
})
}
func TestBelongsToWithMixin(t *testing.T) {
type Profile struct {
gorm.Model
Refer string
Name string
}
type ProfileMixin struct {
Profile Profile `gorm:"References:Refer"`
ProfileRefer int
}
type User struct {
gorm.Model
ProfileMixin
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}},
})
}
func TestHasOneOverrideForeignKey(t *testing.T) { func TestHasOneOverrideForeignKey(t *testing.T) {
type Profile struct { type Profile struct {
gorm.Model gorm.Model
@ -92,6 +181,62 @@ func TestHasOneOverrideReferences(t *testing.T) {
}) })
} }
func TestHasOneOverrideReferences2(t *testing.T) {
type Profile struct {
gorm.Model
Name string
}
type User struct {
gorm.Model
ProfileID uint `gorm:"column:profile_id"`
Profile *Profile `gorm:"foreignKey:ID;references:ProfileID"`
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"ProfileID", "User", "ID", "Profile", "", true}},
})
}
func TestHasOneWithOnlyReferences(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserRefer uint
}
type User struct {
gorm.Model
Refer string
Profile Profile `gorm:"References:Refer"`
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"Refer", "User", "UserRefer", "Profile", "", true}},
})
}
func TestHasOneWithOnlyReferences2(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserID uint
}
type User struct {
gorm.Model
Refer string
Profile Profile `gorm:"References:Refer"`
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"Refer", "User", "UserID", "Profile", "", true}},
})
}
func TestHasManyOverrideForeignKey(t *testing.T) { func TestHasManyOverrideForeignKey(t *testing.T) {
type Profile struct { type Profile struct {
gorm.Model gorm.Model
@ -139,6 +284,7 @@ func TestMany2ManyOverrideForeignKeyAndReferences(t *testing.T) {
type User struct { type User struct {
gorm.Model gorm.Model
Profiles []Profile `gorm:"many2many:user_profiles;ForeignKey:Refer;JoinForeignKey:UserReferID;References:UserRefer;JoinReferences:ProfileRefer"` Profiles []Profile `gorm:"many2many:user_profiles;ForeignKey:Refer;JoinForeignKey:UserReferID;References:UserRefer;JoinReferences:ProfileRefer"`
Profiles2 []Profile `gorm:"many2many:user_profiles2;ForeignKey:refer;JoinForeignKey:user_refer_id;References:user_refer;JoinReferences:profile_refer"`
Refer uint Refer uint
} }
@ -149,6 +295,13 @@ func TestMany2ManyOverrideForeignKeyAndReferences(t *testing.T) {
{"Refer", "User", "UserReferID", "user_profiles", "", true}, {"Refer", "User", "UserReferID", "user_profiles", "", true},
{"UserRefer", "Profile", "ProfileRefer", "user_profiles", "", false}, {"UserRefer", "Profile", "ProfileRefer", "user_profiles", "", false},
}, },
}, Relation{
Name: "Profiles2", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profiles2", Table: "user_profiles2"},
References: []Reference{
{"Refer", "User", "User_refer_id", "user_profiles2", "", true},
{"UserRefer", "Profile", "Profile_refer", "user_profiles2", "", false},
},
}) })
} }
@ -175,6 +328,33 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) {
}) })
} }
func TestMany2ManySharedForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
Name string
Kind string
ProfileRefer uint
}
type User struct {
gorm.Model
Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"`
Kind string
Refer uint
}
checkStructRelation(t, &User{}, Relation{
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"},
References: []Reference{
{"Refer", "User", "UserRefer", "user_profiles", "", true},
{"Kind", "User", "Kind", "user_profiles", "", true},
{"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false},
{"Kind", "Profile", "Kind", "user_profiles", "", false},
},
})
}
func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
type Profile struct { type Profile struct {
gorm.Model gorm.Model
@ -184,16 +364,39 @@ func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
type User struct { type User struct {
gorm.Model gorm.Model
Profiles []Profile `gorm:"many2many:user_profiles;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` Profiles []Profile `gorm:"many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"`
Refer uint Refer uint
} }
checkStructRelation(t, &User{}, Relation{ checkStructRelation(t, &User{}, Relation{
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"},
References: []Reference{ References: []Reference{
{"ID", "User", "UserReferID", "user_profiles", "", true}, {"ID", "User", "UserReferID", "user_profile", "", true},
{"ID", "Profile", "ProfileRefer", "user_profiles", "", false}, {"ID", "Profile", "ProfileRefer", "user_profile", "", false},
},
})
}
func TestBuildReadonlyMany2ManyRelation(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserRefer uint
}
type User struct {
gorm.Model
Profiles []Profile `gorm:"->;many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"`
Refer uint
}
checkStructRelation(t, &User{}, Relation{
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"},
References: []Reference{
{"ID", "User", "UserReferID", "user_profile", "", true},
{"ID", "Profile", "ProfileRefer", "user_profile", "", false},
}, },
}) })
} }
@ -245,3 +448,551 @@ func TestMany2ManyWithMultiPrimaryKeys(t *testing.T) {
}, },
) )
} }
func TestMultipleMany2Many(t *testing.T) {
type Thing struct {
ID int
}
type Person struct {
ID int
Likes []Thing `gorm:"many2many:likes"`
Dislikes []Thing `gorm:"many2many:dislikes"`
}
checkStructRelation(t, &Person{},
Relation{
Name: "Likes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing",
JoinTable: JoinTable{Name: "likes", Table: "likes"},
References: []Reference{
{"ID", "Person", "PersonID", "likes", "", true},
{"ID", "Thing", "ThingID", "likes", "", false},
},
},
Relation{
Name: "Dislikes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing",
JoinTable: JoinTable{Name: "dislikes", Table: "dislikes"},
References: []Reference{
{"ID", "Person", "PersonID", "dislikes", "", true},
{"ID", "Thing", "ThingID", "dislikes", "", false},
},
},
)
}
func TestSelfReferentialMany2Many(t *testing.T) {
type User struct {
ID int32 `gorm:"primaryKey"`
Name string
CreatedBy int32
Creators []User `gorm:"foreignKey:CreatedBy"`
AnotherPro interface{} `gorm:"-"`
}
checkStructRelation(t, &User{}, Relation{
Name: "Creators", Type: schema.HasMany, Schema: "User", FieldSchema: "User",
References: []Reference{{"ID", "User", "CreatedBy", "User", "", true}},
})
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse schema")
}
relSchema := user.Relationships.Relations["Creators"].FieldSchema
if user != relSchema {
t.Fatalf("schema should be same, expects %p but got %p", user, relSchema)
}
}
type CreatedByModel struct {
CreatedByID uint
CreatedBy *CreatedUser
}
type CreatedUser struct {
gorm.Model
CreatedByModel
}
func TestEmbeddedRelation(t *testing.T) {
checkStructRelation(t, &CreatedUser{}, Relation{
Name: "CreatedBy", Type: schema.BelongsTo, Schema: "CreatedUser", FieldSchema: "CreatedUser",
References: []Reference{
{"ID", "CreatedUser", "CreatedByID", "CreatedUser", "", false},
},
})
userSchema, err := schema.Parse(&CreatedUser{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse schema, got error %v", err)
}
if len(userSchema.Relationships.Relations) != 1 {
t.Fatalf("expects 1 relations, but got %v", len(userSchema.Relationships.Relations))
}
if createdByRel, ok := userSchema.Relationships.Relations["CreatedBy"]; ok {
if createdByRel.FieldSchema != userSchema {
t.Fatalf("expects same field schema, but got new %p, old %p", createdByRel.FieldSchema, userSchema)
}
} else {
t.Fatalf("expects created by relations, but not found")
}
}
func TestEmbeddedHas(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type User struct {
ID int
Cat struct {
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
Toys []Toy `gorm:"polymorphic:Owner;"`
} `gorm:"embedded;embeddedPrefix:cat_"`
Dog struct {
ID int
Name string
UserID int
Toy Toy `gorm:"polymorphic:Owner;"`
Toys []Toy `gorm:"polymorphic:Owner;"`
}
Toys []Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
}
func TestPolymorphic(t *testing.T) {
t.Run("has one", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has one with custom polymorphic type and id", func(t *testing.T) {
type Toy struct {
ID int
Name string
RefId int
Type string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type;polymorphicId:RefId"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has one with only polymorphic type", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
Type string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "owner_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has many", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type Cat struct {
ID int
Name string
Toys []Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has many with custom polymorphic type and id", func(t *testing.T) {
type Toy struct {
ID int
Name string
RefId int
Type string
}
type Cat struct {
ID int
Name string
Toys []Toy `gorm:"polymorphicType:Type;polymorphicId:RefId"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
}
func TestEmbeddedBelongsTo(t *testing.T) {
type Country struct {
ID int `gorm:"primaryKey"`
Name string
}
type Address struct {
CountryID int
Country Country
}
type NestedAddress struct {
Address
}
type CountryMixin struct {
CountryID int
Country Country
}
type Org struct {
ID int
PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"`
VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"`
AddressID int
Address struct {
ID int
Address
}
NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
CountryMixin
}
s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Errorf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"PostalAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
"VisitingAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
"NestedAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
})
}
func TestVariableRelation(t *testing.T) {
var result struct {
User
}
checkStructRelation(t, &result, Relation{
Name: "Account", Type: schema.HasOne, Schema: "", FieldSchema: "Account",
References: []Reference{
{"ID", "", "UserID", "Account", "", true},
},
})
checkStructRelation(t, &result, Relation{
Name: "Company", Type: schema.BelongsTo, Schema: "", FieldSchema: "Company",
References: []Reference{
{"ID", "Company", "CompanyID", "", "", false},
},
})
}
func TestSameForeignKey(t *testing.T) {
type UserAux struct {
gorm.Model
Aux string
UUID string
}
type User struct {
gorm.Model
Name string
UUID string
Aux *UserAux `gorm:"foreignkey:UUID;references:UUID"`
}
checkStructRelation(t, &User{},
Relation{
Name: "Aux", Type: schema.HasOne, Schema: "User", FieldSchema: "UserAux",
References: []Reference{
{"UUID", "User", "UUID", "UserAux", "", true},
},
},
)
}
func TestBelongsToSameForeignKey(t *testing.T) {
type User struct {
gorm.Model
Name string
UUID string
}
type UserAux struct {
gorm.Model
Aux string
UUID string
User User `gorm:"ForeignKey:UUID;references:UUID;belongsTo"`
}
checkStructRelation(t, &UserAux{},
Relation{
Name: "User", Type: schema.BelongsTo, Schema: "UserAux", FieldSchema: "User",
References: []Reference{
{"UUID", "User", "UUID", "UserAux", "", false},
},
},
)
}
func TestHasOneWithSameForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
Name string
ProfileRefer int // not used in relationship
}
type User struct {
gorm.Model
Profile Profile `gorm:"ForeignKey:ID;references:ProfileRefer"`
ProfileRefer int
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"ProfileRefer", "User", "ID", "Profile", "", true}},
})
}
func TestHasManySameForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserRefer uint
}
type User struct {
gorm.Model
UserRefer uint
Profile []Profile `gorm:"ForeignKey:UserRefer"`
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.HasMany, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}},
})
}
type Author struct {
gorm.Model
}
type Book struct {
gorm.Model
Author Author
AuthorID uint
}
func (Book) TableName() string {
return "my_schema.a_very_very_very_very_very_very_very_very_long_table_name"
}
func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) {
s, err := schema.Parse(
&Book{},
&sync.Map{},
schema.NamingStrategy{IdentifierMaxLength: 64},
)
if err != nil {
t.Fatalf("Failed to parse schema")
}
expectedConstraintName := "fk_my_schema_a_very_very_very_very_very_very_very_very_l4db13eec"
constraint := s.Relationships.Relations["Author"].ParseConstraint()
if constraint.Name != expectedConstraintName {
t.Fatalf(
"expected constraint name %s, got %s",
expectedConstraintName,
constraint.Name,
)
}
}

View File

@ -5,13 +5,29 @@ import (
"errors" "errors"
"fmt" "fmt"
"go/ast" "go/ast"
"path"
"reflect" "reflect"
"strings"
"sync" "sync"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
) )
type callbackType string
const (
callbackTypeBeforeCreate callbackType = "BeforeCreate"
callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
callbackTypeAfterCreate callbackType = "AfterCreate"
callbackTypeAfterUpdate callbackType = "AfterUpdate"
callbackTypeBeforeSave callbackType = "BeforeSave"
callbackTypeAfterSave callbackType = "AfterSave"
callbackTypeBeforeDelete callbackType = "BeforeDelete"
callbackTypeAfterDelete callbackType = "AfterDelete"
callbackTypeAfterFind callbackType = "AfterFind"
)
// ErrUnsupportedDataType unsupported data type // ErrUnsupportedDataType unsupported data type
var ErrUnsupportedDataType = errors.New("unsupported data type") var ErrUnsupportedDataType = errors.New("unsupported data type")
@ -25,8 +41,9 @@ type Schema struct {
PrimaryFieldDBNames []string PrimaryFieldDBNames []string
Fields []*Field Fields []*Field
FieldsByName map[string]*Field FieldsByName map[string]*Field
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
FieldsByDBName map[string]*Field FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
Relationships Relationships Relationships Relationships
CreateClauses []clause.Interface CreateClauses []clause.Interface
QueryClauses []clause.Interface QueryClauses []clause.Interface
@ -38,37 +55,23 @@ type Schema struct {
BeforeSave, AfterSave bool BeforeSave, AfterSave bool
AfterFind bool AfterFind bool
err error err error
initialized chan struct{}
namer Namer namer Namer
cacheStore *sync.Map cacheStore *sync.Map
} }
type CreateClausesInterface interface {
CreateClauses() []clause.Interface
}
type QueryClausesInterface interface {
QueryClauses() []clause.Interface
}
type UpdateClausesInterface interface {
UpdateClauses() []clause.Interface
}
type DeleteClausesInterface interface {
DeleteClauses() []clause.Interface
}
func (schema Schema) String() string { func (schema Schema) String() string {
if schema.ModelType.Name() == "" { if schema.ModelType.Name() == "" {
return fmt.Sprintf("%v(%v)", schema.Name, schema.Table) return fmt.Sprintf("%s(%s)", schema.Name, schema.Table)
} }
return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name()) return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name())
} }
func (schema Schema) MakeSlice() reflect.Value { func (schema Schema) MakeSlice() reflect.Value {
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 0) slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20)
results := reflect.New(slice.Type()) results := reflect.New(slice.Type())
results.Elem().Set(slice) results.Elem().Set(slice)
return results return results
} }
@ -82,10 +85,57 @@ func (schema Schema) LookUpField(name string) *Field {
return nil return nil
} }
// get data type from dialector // LookUpFieldByBindName looks for the closest field in the embedded struct.
//
// type Struct struct {
// Embedded struct {
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
// }
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
// }
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
if len(bindNames) == 0 {
return nil
}
for i := len(bindNames) - 1; i >= 0; i-- {
find := strings.Join(bindNames[:i], ".") + "." + name
if field, ok := schema.FieldsByBindName[find]; ok {
return field
}
}
return nil
}
type Tabler interface {
TableName() string
}
type TablerWithNamer interface {
TableName(Namer) string
}
// Parse get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type() return ParseWithSpecialTableName(dest, cacheStore, namer, "")
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr { }
// ParseWithSpecialTableName get data type from dialector with extra schema table
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
value := reflect.ValueOf(dest)
if value.Kind() == reflect.Ptr && value.IsNil() {
value = reflect.New(value.Type().Elem())
}
modelType := reflect.Indirect(value).Type()
if modelType.Kind() == reflect.Interface {
modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type()
}
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@ -93,30 +143,63 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
if modelType.PkgPath() == "" { if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
} }
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
} }
if v, ok := cacheStore.Load(modelType); ok { // Cache the Schema for performance,
return v.(*Schema), nil // Use the modelType or modelType + schemaTable (if it present) as cache key.
var schemaCacheKey interface{}
if specialTableName != "" {
schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName)
} else {
schemaCacheKey = modelType
}
// Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
modelValue := reflect.New(modelType)
tableName := namer.TableName(modelType.Name())
if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName()
}
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
tableName = tabler.TableName(namer)
}
if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table
}
if specialTableName != "" && specialTableName != tableName {
tableName = specialTableName
} }
schema := &Schema{ schema := &Schema{
Name: modelType.Name(), Name: modelType.Name(),
ModelType: modelType, ModelType: modelType,
Table: namer.TableName(modelType.Name()), Table: tableName,
FieldsByName: map[string]*Field{}, FieldsByName: map[string]*Field{},
FieldsByBindName: map[string]*Field{},
FieldsByDBName: map[string]*Field{}, FieldsByDBName: map[string]*Field{},
Relationships: Relationships{Relations: map[string]*Relationship{}}, Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore, cacheStore: cacheStore,
namer: namer, namer: namer,
initialized: make(chan struct{}),
} }
// When the schema initialization is completed, the channel will be closed
defer close(schema.initialized)
defer func() { // Load exist schema cache, return if exists
if schema.err != nil { if v, ok := cacheStore.Load(schemaCacheKey); ok {
logger.Default.Error(context.Background(), schema.err.Error()) s := v.(*Schema)
cacheStore.Delete(modelType) // Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
} }
}()
for i := 0; i < modelType.NumField(); i++ { for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
@ -133,52 +216,67 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
field.DBName = namer.ColumnName(schema.Table, field.Name) field.DBName = namer.ColumnName(schema.Table, field.Name)
} }
bindName := field.BindName()
if field.DBName != "" { if field.DBName != "" {
// nonexistence or shortest path or first appear prioritized if has permission // nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) { if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
if _, ok := schema.FieldsByDBName[field.DBName]; !ok { if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
schema.DBNames = append(schema.DBNames, field.DBName) schema.DBNames = append(schema.DBNames, field.DBName)
} }
schema.FieldsByDBName[field.DBName] = field schema.FieldsByDBName[field.DBName] = field
schema.FieldsByName[field.Name] = field schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[bindName] = field
if v != nil && v.PrimaryKey { if v != nil && v.PrimaryKey {
if schema.PrioritizedPrimaryField == v {
schema.PrioritizedPrimaryField = nil
}
for idx, f := range schema.PrimaryFields { for idx, f := range schema.PrimaryFields {
if f == v { if f == v {
schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
} else if schema.PrioritizedPrimaryField == nil {
schema.PrioritizedPrimaryField = f
} }
} }
} }
if field.PrimaryKey { if field.PrimaryKey {
if schema.PrioritizedPrimaryField == nil {
schema.PrioritizedPrimaryField = field
}
schema.PrimaryFields = append(schema.PrimaryFields, field) schema.PrimaryFields = append(schema.PrimaryFields, field)
} }
} }
} }
if _, ok := schema.FieldsByName[field.Name]; !ok { if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByName[field.Name] = field schema.FieldsByName[field.Name] = field
} }
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
field.setupValuerAndSetter() schema.FieldsByBindName[bindName] = field
} }
if f := schema.LookUpField("id"); f != nil { field.setupValuerAndSetter(modelType)
if f.PrimaryKey { }
schema.PrioritizedPrimaryField = f
prioritizedPrimaryField := schema.LookUpField("id")
if prioritizedPrimaryField == nil {
prioritizedPrimaryField = schema.LookUpField("ID")
}
if prioritizedPrimaryField != nil {
if prioritizedPrimaryField.PrimaryKey {
schema.PrioritizedPrimaryField = prioritizedPrimaryField
} else if len(schema.PrimaryFields) == 0 { } else if len(schema.PrimaryFields) == 0 {
f.PrimaryKey = true prioritizedPrimaryField.PrimaryKey = true
schema.PrioritizedPrimaryField = f schema.PrioritizedPrimaryField = prioritizedPrimaryField
schema.PrimaryFields = append(schema.PrimaryFields, f) schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField)
}
}
if schema.PrioritizedPrimaryField == nil {
if len(schema.PrimaryFields) == 1 {
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
} else if len(schema.PrimaryFields) > 1 {
// If there are multiple primary keys, the AUTOINCREMENT field is prioritized
for _, field := range schema.PrimaryFields {
if field.AutoIncrement {
schema.PrioritizedPrimaryField = field
break
}
}
} }
} }
@ -186,43 +284,148 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName)
} }
schema.FieldsWithDefaultDBValue = map[string]*Field{}
for db, field := range schema.FieldsByDBName {
if field.HasDefaultValue && field.DefaultValueInterface == nil {
schema.FieldsWithDefaultDBValue[db] = field
}
}
if schema.PrioritizedPrimaryField != nil {
switch schema.PrioritizedPrimaryField.DataType {
case Int, Uint:
schema.FieldsWithDefaultDBValue[schema.PrioritizedPrimaryField.DBName] = schema.PrioritizedPrimaryField
}
}
reflectValue := reflect.New(modelType)
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
for _, name := range callbacks {
if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() {
switch methodValue.Type().String() {
case "func(*gorm.DB) error": // TODO hack
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
default:
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name)
}
}
}
cacheStore.Store(modelType, schema)
// parse relations for unidentified fields
for _, field := range schema.Fields { for _, field := range schema.Fields {
if field.DataType == "" && field.Creatable { if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
}
}
if field := schema.PrioritizedPrimaryField; field != nil {
switch field.GORMDataType {
case Int, Uint:
if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok {
if !field.HasDefaultValue || field.DefaultValueInterface != nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
}
field.HasDefaultValue = true
field.AutoIncrement = true
}
}
}
callbackTypes := []callbackType{
callbackTypeBeforeCreate, callbackTypeAfterCreate,
callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
callbackTypeBeforeSave, callbackTypeAfterSave,
callbackTypeBeforeDelete, callbackTypeAfterDelete,
callbackTypeAfterFind,
}
for _, cbName := range callbackTypes {
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
switch methodValue.Type().String() {
case "func(*gorm.DB) error":
expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath())
if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath {
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
} else {
logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg)
// PASS
}
default:
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
}
}
}
// Cache the schema
if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
defer func() {
if schema.err != nil {
logger.Default.Error(context.Background(), schema.err.Error())
cacheStore.Delete(modelType)
}
}()
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
for _, field := range schema.Fields {
if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) {
if schema.parseRelation(field); schema.err != nil { if schema.parseRelation(field); schema.err != nil {
return schema, schema.err return schema, schema.err
} else {
schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[field.BindName()] = field
}
}
fieldValue := reflect.New(field.IndirectFieldType)
fieldInterface := fieldValue.Interface()
if fc, ok := fieldInterface.(CreateClausesInterface); ok {
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
}
if fc, ok := fieldInterface.(QueryClausesInterface); ok {
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
}
if fc, ok := fieldInterface.(UpdateClausesInterface); ok {
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
}
if fc, ok := fieldInterface.(DeleteClausesInterface); ok {
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
} }
} }
} }
return schema, schema.err return schema, schema.err
} }
// This unrolling is needed to show to the compiler the exact set of methods
// that can be used on the modelType.
// Prior to go1.22 any use of MethodByName would cause the linker to
// abandon dead code elimination for the entire binary.
// As of go1.22 the compiler supports one special case of a string constant
// being passed to MethodByName. For enterprise customers or those building
// large binaries, this gives a significant reduction in binary size.
// https://github.com/golang/go/issues/62257
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
switch cbType {
case callbackTypeBeforeCreate:
return modelType.MethodByName(string(callbackTypeBeforeCreate))
case callbackTypeAfterCreate:
return modelType.MethodByName(string(callbackTypeAfterCreate))
case callbackTypeBeforeUpdate:
return modelType.MethodByName(string(callbackTypeBeforeUpdate))
case callbackTypeAfterUpdate:
return modelType.MethodByName(string(callbackTypeAfterUpdate))
case callbackTypeBeforeSave:
return modelType.MethodByName(string(callbackTypeBeforeSave))
case callbackTypeAfterSave:
return modelType.MethodByName(string(callbackTypeAfterSave))
case callbackTypeBeforeDelete:
return modelType.MethodByName(string(callbackTypeBeforeDelete))
case callbackTypeAfterDelete:
return modelType.MethodByName(string(callbackTypeAfterDelete))
case callbackTypeAfterFind:
return modelType.MethodByName(string(callbackTypeAfterFind))
default:
return reflect.ValueOf(nil)
}
}
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
}
if v, ok := cacheStore.Load(modelType); ok {
return v.(*Schema), nil
}
return Parse(dest, cacheStore, namer)
}

Some files were not shown because too many files have changed in this diff Show More