Compare commits

...

575 Commits

Author SHA1 Message Date
贾一饼
49eaeacb89
optimize: field.ReflectValueOf (#7530)
Some checks failed
tests / mysql (mysql:5.7, 1.24, ubuntu-latest) (push) Failing after 59s
tests / mysql (mysql:8, 1.23, ubuntu-latest) (push) Failing after 59s
tests / mysql (mysql:8, 1.24, ubuntu-latest) (push) Failing after 5s
tests / sqlite (1.23, ubuntu-latest) (push) Failing after 1m26s
tests / mysql (mysql:9, 1.23, ubuntu-latest) (push) Failing after 36s
tests / mysql (mysql:9, 1.24, ubuntu-latest) (push) Failing after 32s
tests / mariadb (mariadb:latest, 1.24, ubuntu-latest) (push) Failing after 9s
tests / mariadb (mariadb:latest, 1.23, ubuntu-latest) (push) Failing after 22s
tests / postgres (postgres:13, 1.23, ubuntu-latest) (push) Failing after 24s
tests / mysql (mysql:5.7, 1.23, ubuntu-latest) (push) Failing after 2m1s
tests / postgres (postgres:14, 1.23, ubuntu-latest) (push) Failing after 20s
tests / postgres (postgres:14, 1.24, ubuntu-latest) (push) Failing after 6s
tests / postgres (postgres:15, 1.24, ubuntu-latest) (push) Failing after 14s
tests / postgres (postgres:15, 1.23, ubuntu-latest) (push) Failing after 29s
tests / postgres (postgres:latest, 1.23, ubuntu-latest) (push) Failing after 26s
tests / postgres (postgres:latest, 1.24, ubuntu-latest) (push) Failing after 18s
tests / postgres (postgres:13, 1.24, ubuntu-latest) (push) Failing after 2m55s
tests / sqlite (1.24, ubuntu-latest) (push) Failing after 4m53s
tests / tidb (v6.5.0, 1.24, ubuntu-latest) (push) Failing after 2m6s
golangci-lint / lint (push) Failing after 6m41s
tests / gaussdb (opengauss/opengauss:7.0.0-RC1.B023, 1.23, ubuntu-latest) (push) Failing after 3m31s
tests / gaussdb (opengauss/opengauss:7.0.0-RC1.B023, 1.24, ubuntu-latest) (push) Failing after 3m38s
tests / sqlserver (1.23, ubuntu-latest) (push) Failing after 7m36s
tests / tidb (v6.5.0, 1.23, ubuntu-latest) (push) Failing after 12m41s
tests / sqlserver (1.24, ubuntu-latest) (push) Failing after 14m12s
Stale / stale (push) Successful in 13s
Close Missing Playground issues / stale (push) Successful in 6s
Close invalid questions issues / stale (push) Successful in 5s
2025-07-23 13:02:49 +08:00
贾一饼
52b4744410
optimize: performance optimization (#7526) 2025-07-22 18:32:28 +08:00
Jinzhu
9af6d510b5 Fix query when map keys include table-qualified column names, close #7507 2025-07-22 14:21:04 +08:00
Jinzhu
c63374f5d1 Don't request LastInsertID from database if not necessary, close #7469 2025-07-21 17:55:20 +08:00
jc
b9c7e562b0
fix(schema): check the hook function parameter type (#7468)
* fix(schema): Check the callback function parameter type

* fix log

* fix
2025-07-21 17:06:16 +08:00
Riseif
985940f0d8
should check inner condition length (#7512) 2025-07-21 11:57:12 +08:00
moseszane168
991c2d4891
Add GaussDB Database Support (#7508)
* support gaussdb

* use github CI

* change function name

* use gorm.io/driver/gaussdb

---------

Co-authored-by: bing.ma <bing.ma@daocloud.io>
2025-07-21 10:46:58 +08:00
Jinzhu
751a6dde7a
Call after initialize for gorm.Config (#7518) 2025-07-15 12:05:03 +08:00
贾一饼
2f4925e017
A little optimization for filed.ValueOf (#7499)
Co-authored-by: 贾一饼 <Boyang.Liang@apulis.com>
2025-07-07 11:15:10 +08:00
Eshan-Jogwar
1e8baf5459
fixes #7486 (#7492)
* fixes #7486

* Added A test case for subset changes of model

* completed the test file for check_subset_model_change_test.go
2025-06-25 11:11:08 +08:00
Salent Olivick
842ee527eb
fix decimal migrate error.(#7450) (#7450)
Signed-off-by: Chise1 <chise123@live.com>
2025-06-06 10:35:01 +08:00
enomotodev
23c0d7cf05
test: update MySQL test matrix to use official images and add 9.0, 8.4 versions (#7476)
* test: update MySQL test matrix to use official images and add 9.0, 8.4 versions

* test: use major version tags for MySQL test matrix
2025-06-06 10:10:23 +08:00
Jinzhu
718eae4fdd fix tests for mysql 9.0 2025-06-05 19:34:13 +08:00
Jinzhu
49b01a3e93 Fix Generics Scan, close https://github.com/go-gorm/playground/pull/803 2025-05-29 14:23:57 +08:00
Jinzhu
c44405a25b
Implement Generics API (#7424)
* Implement Generics API

* Add more generics tests

* Add more tests and Take method

* use delayed‑ops pipeline for generics API

* fix generics tests for mysql

* Support SubQuery for Generics

* Add clause.JoinTable helper method

* Fix golangci-lint error

* Complete the design and implementation of generic version Join

* improve generics version Joins support

* allow configuring select/omit columns for joins via subqueries

* finish generic version Preload

* handle error of generics Joins/Preload

* fix tests

* Add LimitPerRecord for generic version Preload

* fix tests for mysql 5.7

* test for nested generic version Join/Preload

* Add WithResult support for generics API

* test reuse generics db conditions

* fix data race

* remove ExampleLRU test

* Add default transaction timeout support

* fix test
2025-05-25 15:40:40 +08:00
Name
751c1d6b45
perf(schema): avoid redundant strings.ToLower call (#7464)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2025-05-25 09:27:21 +08:00
codingplz
8e7ab46c1b
fix: return init dialector error (#7379)
* fix: return init dialector error

* mock defer

* fix: skip AfterInitialize

---------

Co-authored-by: wenyazhou.13 <wenyazhou.13@bytedance.com>
2025-05-22 10:53:47 +08:00
Name
e3037e4ef0
perf: break early on match failure in ParseConstraint (#7402)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2025-05-22 10:49:19 +08:00
pipipipipip
1204330419
feat: error message show field name (#7452)
* feat: error message show field name

* feat: failed to parse field

* feat: failed to parse field

---------

Co-authored-by: zyp <zyp>
2025-05-21 11:13:31 +08:00
Name
9703eb775f
perf: use strings.IndexByte to replace strings.Index (#7454)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2025-05-21 10:35:56 +08:00
Name
1c966e0d25
perf: use strings.Cut to replace strings.SplitN (#7455)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2025-05-21 10:35:23 +08:00
Jinzhu
e5b867e785 remove unnecessary session-level configuration for prepared statements 2025-05-07 14:56:49 +08:00
iTanken
8c4e8e2d2a
fix: int type variable defaultMaxSize overflows in 32-bit environment (#7439)
Refs: #7435
2025-04-27 14:05:16 +08:00
Zhaodong Xie
a827495be1
Preparestmt use LRU Map instead default map (#7435)
* 支持lru淘汰preparestmt cache

* 支持lru淘汰preparestmt cache

* 支持lru淘汰preparestmt cache

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* change const export

* Add stmt_store

* refact prepare stmt store

* Rename lru store

* change const export

* ADD UT

* format code and add session level prepare stmt config

* code format according to golinter ci

* ADD UT

---------

Co-authored-by: xiezhaodong <xiezhaodong@bytedance.com>
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2025-04-25 16:22:26 +08:00
Jinzhu
489a563293 only check new issues for golangci linter 2025-04-17 15:30:17 +08:00
Jinzhu
42bd4f603c
use golangci replace reviewdog (#7426)
* use golangci replace reviewdog

* Update golangci config
2025-04-17 11:55:13 +08:00
a631807682
a9d27293de test: mssql ci 2025-03-11 15:56:46 +08:00
a631807682
3876ffe4bb test: mssql ci 2025-03-11 15:56:46 +08:00
Jinzhu
ee3b549d7d Update tests script 2025-03-09 15:35:17 +08:00
Aman
9f273777f5
fix deprecated reflect.PtrTo reflect.PointerTo usage (#7366)
* fix deprecated reflect.PtrTo reflect.PointerTo usage

* replace all deprecated reflect.PtrTo reflect.PointerTo usage
2025-02-13 14:16:26 +08:00
Vladimir Avtsenov
9ca84b3dde
fix concurrent map writes (#7298) 2025-01-12 19:49:06 +08:00
Max Katz
86b1d22911
chore: update copyright year (#7332) 2025-01-12 18:32:39 +08:00
Evyatar Yaffe
fed49230cb
Enhance db.Scan with ParamsFilter - Issue 7336 - Suggestion (#7337) 2025-01-12 18:19:28 +08:00
aviyam181199
8503287ca4
Fixed Empty Returning Clause Merge Bug (#7339) 2025-01-12 18:18:04 +08:00
nowindexman
4ef3af10ed
feat:Capitalize the priority field of IndexOption so that other systems can access this field from outside the package. (#7342) 2025-01-12 18:16:48 +08:00
Bennett Amodio
f482f25c71
fix: deterministic index ordering when migrating (#7208)
Issue: We observed that, when creating a database based on the same
gORM schema multiple times, indexes could appear in different orders,
hurting determinism for use-cases like schema comparison.

In order to fix this, it's simple to switch the ParseIndexes function
to return a list of indices rather than a map, so the callers will
iterate in deterministic order.
2024-12-06 10:27:44 +08:00
Jinzhu
6bfccf8afa Refactor all tests script 2024-11-21 17:03:31 +08:00
Ivan Ryabov
49bbaa637f
use map look-up for indexes (#7242) 2024-11-14 17:41:43 +08:00
홍성욱
b0d70a26d1
[#6372] Fixed nullable constraint bug for columns during auto migration (#7269)
* [#6372] Fixed nullable constraint bug for columns during auto migration

* [#6372] fix comment

* [#6372] Add test code

* [#6372] Add test code

* [#6372] Fix failed test case

* [#6372] Fix failed test case

* [#6372] wip

* [#6372] wip

* [#6372] wip

* [#6372] wip
2024-11-14 17:40:18 +08:00
omid fth
deceebfab8
Create CODE_OF_CONDUCT.md (#7240) 2024-10-17 14:18:13 +08:00
Yidi Sprei
52e3b353eb
refactor(workflow): update release workflow to enhance automation (#7224)
Replaced the old release workflow with a new setup using Release
Drafter. This refactor allows for more detailed release notes by
categorizing changes and automatically generating release drafts.
The new workflow triggers on semantic version tags and improves
permissions management. This change enhances the release process
by providing better documentation and automation.
2024-10-09 19:31:04 +08:00
Hansu Park
8020e8c166
refactor: improve logging for unimplemented ErrorTranslator in TranslateError config (#7225) 2024-10-09 19:29:48 +08:00
Yidi Sprei
62bd0b9331
Add GitHub Actions workflow to automate release creation on tagged pushes (#7209)
* feat(workflows): add GitHub Action to create release on new tag

This workflow automates the release creation process whenever a new
tag is pushed to the repository. It checks if a release for the tag
already exists and creates one if it doesn't, enhancing the release
management and streamlining the deployment process.

* test

* fix(workflow): improve release existence check in release-on-tag.yml

Refactor the script to improve checking for existing releases by tag.
Return an object instead of using core.setOutput to streamline the
workflow logic. Also, set result-encoding to string for better
compatibility.

* fix(workflow): correct output handling in release-on-tag.yml

Corrected the output handling in the 'release-on-tag.yml' workflow file
by changing 'result-encoding' to 'outputs'. This ensures that the step
correctly checks if a release exists before attempting to create a new
one, thereby avoiding potential errors during the release process.
Added a blank line for better readability.

* fix(workflow): correct output setting in release-on-tag workflow

Refactored how the release_exists output is set to be compatible with
GitHub Actions syntax. This ensures the workflow reliably detects if
a release already exists, improving the robustness of the release
process.

* fix(release): correct output handling in release-on-tag workflow

Improved the way outputs are managed in the 'check_release' step by
returning values instead of direct assignments. This change ensures
better handling of release existence checks and improves code
readability. Added 'result-encoding' to specify string encoding for
results.

* fix(workflow): correct release existence check and add debug output

Refactored the release existence check to return a simple boolean
string ('true'/'false') rather than an object. Added a step to print
the release existence status for debugging purposes. This ensures
correct conditional evaluation and aids in troubleshooting workflow
issues.

* fix(workflow): simplify release process by removing redundant checks

The release-on-tag workflow has been streamlined by eliminating the
redundant steps checking for existing releases. This change reduces
complexity and speeds up the release process by directly creating a
release on every tag push without prior existence verification.
2024-09-30 14:18:13 +08:00
Omkar P
c6ac54812a
Use official SQL Server docker image for tests (#7205)
* Use official SQL Server docker image for tests

* Try with tag 2019-latest instead of latest

* Use platform ubuntu-20.04 for SQL Server

* Switch to 2019-CU18-ubuntu-20.04

* Check with 2022-latest image tag and ubuntu-latest platform

* Update health-cmd

* Try sqlcmd without -N -C

* Re-include -N -C, try with ubuntu-20.04

* Try ubuntu-20.04 without -N -C (last trial)

* Finalize working config

* Remove unused env variables
2024-09-30 11:21:19 +08:00
Jinzhu
68434b76eb fix test script with latest docker release 2024-09-18 20:03:35 +08:00
Kazuki Isogai
c2515ce260
feat: remove version top-level element and rename. (#7086) 2024-09-18 19:42:52 +08:00
Leo Sjöberg
7f75b12bb2
Generate unique savepoint names for nested transactions (#7174)
* Generate unique savepoint names

* Add a test for deeply nested wrapped transactions
2024-09-14 20:58:29 +08:00
abhijeet45
0daaf1747c
fix: AfterQuery using safer right trim while clearing from clause's join added as part of https://github.com/go-gorm/gorm/pull/7027 (#7153)
Co-authored-by: Abhijeet Bhowmik <abhijeet.bhowmik@cambiumnetworks.com>
2024-08-22 19:03:42 +08:00
ivila
0dbfda5d7e
fix memory leaks in PrepareStatementDB (#7142)
* fix memory leaks in PrepareStatementDB

* Fix CR:
1) Fix potential Segmentation Fault in Reset function
2) Setting db.Stmts to nil map when Close to avoid further using

* Add Test:
1) TestPreparedStmtConcurrentReset
2) TestPreparedStmtConcurrentClose

* Fix test, create new connection to keep away from other tests

---------

Co-authored-by: Zehui Chen <zehui@ssc-hn.com>
2024-08-22 19:02:05 +08:00
enomotodev
4a50b36f63
ci: Add PostgreSQL 14 and 15 to GitHub Actions matrix (#7081)
* ci: Add PostgreSQL 14 and 15 to GitHub Actions matrix

* ci: Remove older PostgreSQL versions from test matrix
2024-06-25 10:36:06 +08:00
molon
11c4331058
feat: add MapColumns method (#6901)
* add MapColumns method

* fix MapColumns desc

* add TestMapColumns
2024-06-24 17:42:59 +08:00
Jinzhu
8a0af58cc5 fix map fields with clickhouse driver 2024-06-20 20:19:31 +08:00
Jinzhu
4f6291154b Allow to support other field types 2024-06-20 16:45:38 +08:00
Sérgio Prata Almeida
109f239fae
add DB level propagation for the Unscoped flag (#7007)
* adds PropagateUnscoped to db Config

* adds PropagateUnscoped test

* adds PropagateUnscoped to Session and sets it accordingly
2024-06-17 11:59:06 +08:00
Jinzhu
79bf7f92ed fix CI for sqlserver 2024-06-17 11:58:13 +08:00
Jinzhu
3d09f7947f only listen local port 2024-06-13 18:45:02 +08:00
Waleed Masoom
73a988ceb2
fix(scan): update Scan function to reset structs to zero values for each scan (#7061)
Co-authored-by: waleed.masoom <waleed.masoom@wheniwork.com>
2024-06-12 18:57:36 +08:00
Emilien
05167fd591
fix: use reflect.Append when preloading nested associations (#7014)
Co-authored-by: Emilien Kofman <emilien.kofman@miimosa.com>
2024-06-12 18:52:33 +08:00
Sergei Sadov
78c6dfd712
Fix association replace non-addressable panic (#7012)
* Fix association replace non-addressable panic

* Fix tests

* Add has one panic test

---------

Co-authored-by: sgsv <->
2024-06-12 18:49:45 +08:00
Nico Schäfer
3fe7fcf356
fix: unsupported data on nested joins with preloads (#6957)
* fix: `unsupported data` on nested joins with preloads

* Add test case for pointer join with nested prelaods

* Fix tests
2024-06-12 18:00:47 +08:00
贾一饼
9c4070ed19
fix: AfterQuery should clear FROM Clause's Joins rather than the Statement (#7027) 2024-06-12 17:51:44 +08:00
supergem3000
49d524aaea
feat: chainable order support clause.OrderBy (#7054)
* feat: chainable order support clause.OrderBy

* indent
2024-06-12 17:47:34 +08:00
Jinzhu
49d94c173c upgrade github action tests template 2024-06-12 17:24:34 +08:00
Ryuji Kokubu
0f105ec163
fix: strings.Title -> cases.Title bcs strings.Title library is deprecated (#6999)
* add: add cases library

* fix: strings.Title -> cases.Title

* run goimports to solve the error
2024-06-12 11:46:59 +08:00
hakusai22
5e599a07ec
fix: typo (#7003)
* fix: typo

* fix: covered
2024-05-08 12:07:58 +08:00
Wenli Looi
9d370bcb3e
Fix handling of unknown column types (#6540) 2024-04-26 17:53:11 +08:00
PiexlMax(奇淼
78920199f0
Fix panic bug in migrator due to lack of nil check for stmt.Schema (#6932) 2024-04-26 15:15:49 +08:00
Anıl Şenay
ac59252327
Add new error for "Violation Check Constraint" (#6992) 2024-04-26 10:53:17 +08:00
Cr
207f1ac68f
fix: not clause with or condition (#6984) 2024-04-25 20:22:53 +08:00
Cr
85299bfca7
perf: merge nested preload query when using join (#6990)
* pref: merge nest preload query

* fix: preload test
2024-04-25 20:21:03 +08:00
Jinzhu
5553ff3dcb downgrade mssql driver 2024-04-25 20:12:15 +08:00
kkocdko
bc49365de2
Faster utils.FileWithLineNum (#6981)
* faster FileWithLineNum

* tweak caller skip count
2024-04-22 14:43:02 +08:00
Ivan Chavez
d0b4ceb726
Added comment describing Unscoped() method (#6969) 2024-04-17 11:38:55 +08:00
yetone
9a61ef2af8
fix: duplicated preload (#6948) 2024-04-15 11:20:20 +08:00
Jinzhu
1e13fd7543 Fix duplicated columns in INSERT SQL for some fields with default value 2024-04-08 11:29:55 +08:00
hjwblog.com
1b48aa072d
feat: prepare_stmt support ping (#6924)
* feat: prepare_stmt support ping

* feat: prepare_stmt tx support ping
2024-03-28 16:47:39 +08:00
snackmgmg
26195e6d16
fix: remove callback from callbacks if Remove() called (#6916)
* fix: remove callback from callbacks if Remove() called

* reduce number of loops

* remove unnecessary blank line
2024-03-26 11:33:36 +08:00
givemeafish
956f7ce843
fix: 'type XXXX int' will print wrong sql to terminal (#6917)
Co-authored-by: 王泽平 <zeping.wang@yo-star.com>
2024-03-21 16:00:02 +08:00
Jinzhu
0d6c5345f3 Don't close prepared stmt for normal db error 2024-03-21 15:55:43 +08:00
Jinzhu
57603882ea Only close bad conn prepared stmt 2024-03-20 19:47:20 +08:00
Jinzhu
81536f823c Fix insert id into map results, fix #6812 2024-03-19 11:50:28 +08:00
Jinzhu
1b0aa802df Fix AutoMigrate for bool fields with default value 2024-03-18 19:24:16 +08:00
Jinzhu
e0c3be03fb Fix tests in local 2024-03-18 16:28:46 +08:00
jessetang
303de6e7c8
chore: optimize regEnLetterAndMidline regular (#6908)
* chore: optimize regular

* fix
2024-03-18 15:33:54 +08:00
Jinghao Lu
f7ebf049da
fix(create): fix insert column order (#6855)
* fix(create): fix insert column order

* chore: add ConvertToCreateValues ut for Slice case

* fix: remvoe testify dependency

---------

Co-authored-by: lujinghao <lujinghao@bytedance.com>
2024-03-18 13:48:42 +08:00
jessetang
ab89d54d87
chore: UnixNano convert to UnixMilli (#6907) 2024-03-18 13:44:55 +08:00
Jinzhu
281f3e369a Fix constraint name regexp 2024-03-18 11:32:30 +08:00
jessetang
7b1fb0bd73
fix(scan): array element is set to a zero value (#6890)
* fix(scan): array element is set to a zero value

* add test

* fix test

* optimization
2024-03-15 14:14:48 +08:00
black-06
e4e23d26d2
fix: nested preload with join panic when find (#6877) 2024-03-09 21:27:19 +08:00
jessetang
c4c9aa45e3
fix(scan.go): reflect.MakeSlice passes in the reflect.Array type (#6880) 2024-03-09 17:39:01 +08:00
Cr
9efae659cb
test: namer identifier lenght (#6872) 2024-03-09 17:31:28 +08:00
hishope
f17a75242e Signed-off-by: hishope <csqiye@126.com>
fix some typos in tests

Signed-off-by: hishope <csqiye@126.com>
2024-03-07 16:19:17 +08:00
tsuba3
3e2c4fc446
Fix regression in db.Not introduced in v1.25.6. (#6844)
* Fix regression in db.Not introduced in 940358e.

* Fix
2024-03-05 10:23:51 +08:00
Chef
f118e55db5
Add unittest test helper function ConvertSliceOfMapToValuesForCreate (#6854) 2024-03-05 10:22:57 +08:00
Chef
52404cddbb
CHORE add unittest test function ConvertMapToValueForCreate (#6846)
* CHORE add unittest test function  ConvertMapToValueForCreate

* CHORE move the test cases located in the files convert_map_test.go and visit_map_test.go into the file helper_test.go.
2024-02-27 10:48:04 +08:00
M Dmitry
d81ae6f701
Fixed: panic on nullable value with multiple foreign key usage (#6839)
See: https://github.com/go-gorm/playground/pull/537
2024-02-19 11:42:25 +08:00
black-06
8fb9a31775
refactor: part 2 of distinguish between Unique and UniqueIndex (#6822) 2024-02-06 19:48:40 +08:00
jasonchuan
9514d5f9e6
let limit and offset use bind parameter (#6806)
* let limit and offset use bind parameter

* format

* format limt_test

* try again

* fix test case fro connpool

* adding driverName for postgres  ,if not to do so, the stmt vars will be added  a wrong  one called pgx.QueryExecModeSimpleProtocol  ,  causing the SQL with limit  problem  need 1 parameter ,but given two.

* delete trunk files

* restore the test_test.go

* restore test_test.go

* driver/postgres->v1.5.5

* change postgres version rollback to 1.5.4

---------

Co-authored-by: chenchuan <chenchuan@360.cn>
Co-authored-by: jason_chuan <jason_chuan@126.com>
2024-02-06 10:54:40 +08:00
black-06
46816ad31d
refactor: distinguish between Unique and UniqueIndex (#6386)
* refactor: distinguish between UniqueIndex and Index

* add test

* add ParseIndex test

* modify unique to constraint

* modify unique to constraint

* fix MigrateColumnUnique

* fix test

* fix unit test

* update test mod

* add MigrateColumnUnique to Migrator interface

* fix format lint

* add comment

* go mod tidy

* revert: revert MigrateColumn

* resolve conflicts
2024-02-04 15:49:19 +08:00
black-06
418ee3fc19
fix: preload shouldn't overwrite the value of join (#6771)
* fix: preload shouldn't overwrite the value of join

* fix lint

* fix: join may automatically add nested query
2024-01-29 11:34:57 +08:00
dependabot[bot]
e043924fe7
chore(deps): bump actions/cache from 3 to 4 (#6802)
Bumps [actions/cache](https://github.com/actions/cache) from 3 to 4.
- [Release notes](https://github.com/actions/cache/releases)
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md)
- [Commits](https://github.com/actions/cache/compare/v3...v4)

---
updated-dependencies:
- dependency-name: actions/cache
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-01-29 10:34:20 +08:00
Jacky
0123dd4509
fix: ignore .gen.go suffix in logger to get the real caller when using gen #6697 (#6785) 2024-01-12 17:09:22 +08:00
Jinzhu
940358e0dd Fix tests doesn't follow https://gorm.io/docs/method_chaining.html convention 2024-01-12 16:42:21 +08:00
iTanken
87decced23
fix: ExplainSQL using consecutive pairs of escaper in SQL string represents an escaper (#6766)
Preventing it from being interpreted as the string terminator. This is a widely used escape mechanism in SQL standards and is applicable in most relational databases.
2023-12-28 19:53:36 +08:00
Stephano George
436cca753c
fix: join and select mytable.* not working (#6761)
* fix: select mytable.* not working

* fix: select mytable.*: will not match `mytable."*"`.
feat: increase readability of code matching table name column name
2023-12-23 21:19:41 +08:00
Alexis Viscogliosi
a2cac75218
feature: bring custom type and id column name to polymorphism (#6716)
* feature: bring custom type and id column name to polymorphism

* relationship: better returns for hasPolymorphicRelation

* fix: tests
2023-12-15 16:36:08 +08:00
Maciej Laskowski
b9ebdb13c7
Making locking parameters more intuitive (#6719)
* Making locking parameters more intuitive

* remove dedicated type
2023-12-15 16:32:56 +08:00
BugKillerPro
2fb4928aa8
refactor: Resolve implicit memory aliasing in for loop (#6730) 2023-12-15 16:31:23 +08:00
Franco Liberali
f0af94cd16 add test to show that update from works 2023-12-04 11:50:58 +08:00
FangSqing
3207ad6033
map insert support return increment id (#6662) 2023-11-15 21:32:56 +08:00
Jinzhu
c1e911f6ed Update tests/go.mod 2023-11-09 18:46:39 +08:00
Kijima Daigo
40f4afe8c2
docs: fix broken link (#6673) 2023-11-07 10:20:06 +08:00
Flc゛
d2fb7a942b
chore(logger): optimize (#6675)
* chore(logger): optimize

* chore(logger): optimize
2023-11-07 10:19:41 +08:00
black-06
9fea15ae75
feat: add MigrateColumnUnique (#6640)
* feat: add MigrateColumnUnique

* feat: define new methods

* delete debug in test
2023-10-30 17:15:49 +08:00
Cr
5adc0ce5f6
test: fix TestEmbeddedRelations (#6639) 2023-10-26 11:58:13 +08:00
gleb
78e905919f
tests/sqilte: enable FOREIGN_KEYS inside OpenTestConnection (#6641) 2023-10-26 11:54:15 +08:00
Franco Liberali
6bef318891
add support for returning in sqlserver (#6585) 2023-10-10 15:03:34 +08:00
dependabot[bot]
1b24081010
chore(deps): bump actions/checkout from 3 to 4 (#6586)
Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v3...v4)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-10 14:50:45 +08:00
Jeremy Quirke
8c18714462
Don't call MethodByName with a variable arg (#6602)
Go 1.22 goes somewhat toward addressing the issue using reflect
MethodByName disabling linker deadcode elimination (DCE) and the
resultant large increase in binary size because the linker cannot
prune unused code because it might be reached via reflection.

Go Issue golang/go#62257 reduces the number of incidences of this
problem by leveraging a compiler assist to avoid marking functions
containing calls to MethodByName as ReflectMethods as long as the
arguments are constants.

An analysis of Uber Technologies code base however shows that a number
of transitive imports still contain calls to MethodByName with a
variable argument, including GORM.

In the case of GORM, the solution we are proposing is because the
number of possible methods is finite, we will "unroll" this. This
demonstrably shows that GORM is not longer a problem for DCE.

Before
```
% go version
go version devel go1.22-2f3458a8ce Sat Sep 16 16:26:48 2023 -0700 darwin/arm64
% go  test ./... -ldflags=-dumpdep   2>  >(grep -i -e  '->.*<reflectmethod>')
gorm.io/gorm.(*Statement).BuildCondition -> gorm.io/gorm/schema.ParseWithSpecialTableName <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).Method <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).MethodByName <ReflectMethod>
ok  	gorm.io/gorm	(cached)
ok  	gorm.io/gorm/callbacks	(cached)
gorm.io/gorm/clause_test.BenchmarkComplexSelect -> gorm.io/gorm/schema.ParseWithSpecialTableName <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).Method <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).MethodByName <ReflectMethod>
?   	gorm.io/gorm/migrator	[no test files]
ok  	gorm.io/gorm/clause	(cached)
ok  	gorm.io/gorm/logger	(cached)
gorm.io/gorm/schema_test.TestAdvancedDataTypeValuerAndSetter -> gorm.io/gorm/schema.ParseWithSpecialTableName <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).Method <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).MethodByName <ReflectMethod>
?   	gorm.io/gorm/utils/tests	[no test files]
ok  	gorm.io/gorm/schema	(cached)
ok  	gorm.io/gorm/utils	(cached)
```

After

```
%go version
go version devel go1.22-2f3458a8ce Sat Sep 16 16:26:48 2023 -0700 darwin/arm64
%go  test ./... -ldflags=-dumpdep   2>  >(grep -i -e  '->.*<reflectmethod>')
ok  	gorm.io/gorm	(cached)
ok  	gorm.io/gorm/callbacks	(cached)
?   	gorm.io/gorm/migrator	[no test files]
?   	gorm.io/gorm/utils/tests	[no test files]
ok  	gorm.io/gorm/clause	(cached)
ok  	gorm.io/gorm/logger	(cached)
ok  	gorm.io/gorm/schema	(cached)
ok  	gorm.io/gorm/utils	(cached)
```
2023-10-10 14:50:29 +08:00
Mathias Zeller
12ba285a52
*datatypes.JSON in model causes panic on tx.Statement.Changed (#6611)
* do not panic on nil

* more explanation in comments

* get things compact
2023-10-10 14:46:32 +08:00
hjwblog.com
9d8a5bb208
feat: reuse name (#6626) 2023-10-10 14:45:48 +08:00
Samuel N Cui
2095d42b4c
fix: sqlite dialector cannot apply PRIMARY KEY AUTOINCREMENT type (#6624)
* fix: sqlite dialector cannot apply `PRIMARY KEY AUTOINCREMENT` type

fix #4760

* feat: add auto increment test

* feat: update sqlite

* feat: update tests deps sqlite to v1.5.4
2023-10-09 17:26:27 +08:00
Jinzhu
e57e5d8884 Update go.mod 2023-08-27 15:40:54 +08:00
Jinzhu
653732e1c3 Update go testing versions 2023-08-24 20:19:38 +08:00
Rataj
ac07543962
Fixed error message when dialector fails to initialize (#6509)
Let's say we have a problem with DSN which leads to dialector initialize error. However DB connection is not created and for some reason line 184 error provides <nil> even though "db" doesn't exist.

Previously, this code leads to:
panic: runtime error: invalid memory address or nil pointer dereference

This fix now doesn't attempt to close non-existant database connection and instead continues, so the proper error is shown. In my case:
[error] failed to initialize database, got error default addr for network 'localhost' unknown
2023-08-20 19:46:56 +08:00
龚一涛
7e44f73ad3
fix schema GetIdentityFieldValuesMap interface or ptr (#6417)
Co-authored-by: uptutu <yitao.gong@vzenith.com>
2023-08-19 21:35:14 +08:00
Heliner
2c2089760c
add float32 test case (#6530) 2023-08-19 21:33:57 +08:00
qqxhb
fef42941ba
feat: rm GetDBConnWithContext method (#6535)
* feat: rm contextconnpool method

* feat: nil
2023-08-19 21:33:31 +08:00
weih
bae684b363
fix(clause): when the value of clause.Eq is an empty array, the SQL should be IN (NULL) (#6503) 2023-08-10 13:34:33 +08:00
Jinzhu
15162afaf2 Support GetDBConnWithContext PreparedStmtDB 2023-08-10 13:30:57 +08:00
fayvori
3c34bc2f59
refactor: Regex description (#6507)
* Mirror cleanup

* Regex description

---------

Co-authored-by: Ignat Belousov <ignatbelousov@Ignats-MacBook-Pro.local>
2023-08-07 16:35:19 +08:00
Aayush Acharya
f473761813
fix: added SkipHooks in db getInstance() (#6484) 2023-08-04 10:35:59 +08:00
San Ye
193c454cf4
keep float precision in ExplainSQL (#6495) 2023-08-04 10:31:18 +08:00
Saeid
1fb26ac90e
test: coverage for tabletype added (#6496)
* test: coverage for tabletype added

* test: tidb exclueded

---------

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-08-04 10:30:07 +08:00
Jinzhu
a7f01bd1b2 Test Pluck with customized type 2023-07-25 10:47:19 +08:00
Saeid
c10f807d3c
test: coverage for foreign key violation err (#6403)
* test: coverage for foreign key violation err

* test: enabled foreign keys constraint for sqlite

* test: enabled mysql& mssql for ErrForeignKeyViolate

* test: disabled mysql & updated sqlserver driver version

* test: skipped tidb

---------

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-07-12 21:21:22 +08:00
Saeid
2066138684
ci: fix mariadb mysqladmin (#6401)
Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-06-11 07:42:18 +08:00
Saeid
c2d571cbc8
test: coverage for duplicated key err (#6389)
* test: ErrDuplicatedKey coverage added

* test: updated sqlserver version

* test: removed sqlserver

* test: support added for sqlserver

---------

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-06-10 21:05:19 +08:00
Johannes Riecken
7dd702d379
Fix incorrect documentation comment (has many -> has one) (#6382) 2023-06-07 15:02:30 +08:00
Nuno Cruces
7157b7e375
fix: database/sql.Scanner should not retain references (#6380) 2023-06-07 15:02:07 +08:00
Lev Zakharov
661781a3d7
feat: add *sql.DB connector that uses database context (#6366)
* feat: add SQLConnector

* rename
2023-06-05 16:25:05 +08:00
KantaHasegawa
5eaccaa624
reafactor: add nil detection when sqldb return (#6373)
* reafactor: add null detection when sqldb return

* refactor: Detecting nil in dbConnector.GetDBConn()

* refactor: Revert partial code from c1ea73036715018a1bb55cdb8690441044e13a76

* fix: fix if statement
2023-06-05 16:24:00 +08:00
Lev Zakharov
7a76c042e6
refactor: remove unnecessary prepared statement allocation (#6374) 2023-06-05 16:23:17 +08:00
black
c1ea730367 fix: avoid panic when open fails 2023-06-01 15:22:21 +08:00
东方上人
740f2be453
fix: begin transaction fail, rollback panic (#6365) 2023-05-31 19:21:51 +08:00
mohammad ali
26663ab9bf
max identifier length changed to 63 (#6337)
* max identifier length changed to 63

* default maxIdentifierLength is 64

* renamed License to LICENSE (#6336)

* Added support of "Violates Foreign Key Constraint" (#6329)

* Added support of "Violates Foreign Key Constraint"

Updated the translator and added the support of "foreign key constraint violation". For this, this error type is needed here.

* changed the description of ErrForeignKeyViolated

* refactor: error translator test (#6350)

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>

* fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements (#6220)

* test: add nested transaction and prepareStmt coexist test case

note: please test in the MySQL environment

Change-Id: I0db32adc5f74b0d443e98943d3b182236583b959
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

* fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements

1. SavetPoint SQL Statement not support in Prepared Statements
 e.g. see mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html

Change-Id: I082012db9b140e8ec69764c633724665cc802692
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

* revert(transaction_api): remove savepoint name pool,meaningless

Change-Id: I84aa9924fc54612005a81c83d66fdf8968ee56ad
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

---------

Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>
Co-authored-by: 王柳洋 <wangliuyang.520@bytedance.com>

* fix: save with hook (#6285) (#6294)

---------

Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>
Co-authored-by: Avinaba Bhattacharjee <avinababhattacharjee2002@gmail.com>
Co-authored-by: Muhammad Amir Ejaz <37077032+codingamir@users.noreply.github.com>
Co-authored-by: Saeid <sk.saeidee@yahoo.com>
Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
Co-authored-by: wangliuyang <54885906+wangliuyang520@users.noreply.github.com>
Co-authored-by: 王柳洋 <wangliuyang.520@bytedance.com>
Co-authored-by: black-06 <hello.bug@foxmail.com>
2023-05-30 10:00:48 +08:00
black-06
11fdf46a9f
fix: save with hook (#6285) (#6294) 2023-05-26 10:28:02 +08:00
wangliuyang
812bb20c34
fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements (#6220)
* test: add nested transaction and prepareStmt coexist test case

note: please test in the MySQL environment

Change-Id: I0db32adc5f74b0d443e98943d3b182236583b959
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

* fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements

1. SavetPoint SQL Statement not support in Prepared Statements
 e.g. see mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html

Change-Id: I082012db9b140e8ec69764c633724665cc802692
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

* revert(transaction_api): remove savepoint name pool,meaningless

Change-Id: I84aa9924fc54612005a81c83d66fdf8968ee56ad
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

---------

Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>
Co-authored-by: 王柳洋 <wangliuyang.520@bytedance.com>
2023-05-26 10:24:28 +08:00
Saeid
8197c00def
refactor: error translator test (#6350)
Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-05-25 11:10:00 +08:00
Muhammad Amir Ejaz
001738be49
Added support of "Violates Foreign Key Constraint" (#6329)
* Added support of "Violates Foreign Key Constraint"

Updated the translator and added the support of "foreign key constraint violation". For this, this error type is needed here.

* changed the description of ErrForeignKeyViolated
2023-05-21 21:27:22 +08:00
Avinaba Bhattacharjee
6698ba709e
renamed License to LICENSE (#6336) 2023-05-21 21:24:00 +08:00
201430098137
f5837deef3
fix:clickhouse error not capture(#6277) (#6321)
Co-authored-by: zhuangg <zhuangg@mingyuanyun.com>
2023-05-17 10:15:41 +08:00
Jinzhu
c3d7d08b9a Clear SET clause after build SQL 2023-05-15 15:43:44 +08:00
aclich
63534145fd
fix: 🐛 embedded struct test failed with custom datatypes (#6311)
* fix: 🐛 embedded struct test failed with custom datatypes

Fix the pointer embedded struct within custom datatypes and *time.time
should be nil issue.

* fix: 🐛 change test case to avoid mssql driver issue

change test cases from bytes to string to avoid mssql driver issue
2023-05-15 09:59:26 +08:00
John Mai
e61b98d696
feat: migrator support table comment (#6225)
* feat: migrator support table comment

* feat: migrator support tableType.It like ColumnTypes

* Avoid updating the go.mod file.

* Update tests_all.sh

* Update migrator.go

* remove Catalog() & Engine() methods.

* remove CatalogValue & EngineValue.

---------

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-05-05 15:58:27 +08:00
black-06
32045fdd7d
feat: unscoped association (#5899) (#6246)
* feat: unscoped association (#5899)

* modify name because mysql character is latin1

* work only on has association

* format

* Unscoped on belongs_to association
2023-05-04 19:30:45 +08:00
hykuan
67642abfff
fix: 🐛 numeric types in pointer embedded struct test failed (#6293) 2023-05-04 19:29:31 +08:00
hanwn
aeb298635b
debug: use slice Stale sort (#6263)
Co-authored-by: hanwang <hanwang.7721@bytedance.com>
2023-04-26 22:19:46 +08:00
Cr
407bedae0a
fix: nested joins alias (#6265) 2023-04-26 22:19:32 +08:00
yikakia
1f763c81cb
fix typo chainable_api.go (#6266) 2023-04-26 22:19:06 +08:00
Zhiheng Lin
32fc201554
fix: avoid coroutine leaks when the dialecter initialization fails. (#6249)
Co-authored-by: Kevin Lin <kevin.lin@shopee.com>
2023-04-21 22:17:21 +08:00
black-06
ac20d9e222
fix: unit test (#6250)
* fix: unit test

* fix create test

https://github.com/go-gorm/gorm/pull/6127#discussion_r1171214125

* style: rename to adaptorSerializerModel
2023-04-21 22:09:38 +08:00
Jinzhu
e9637024d3 Update README 2023-04-11 13:16:25 +08:00
black-06
828e22b17f
feat: support embedded preload (#6137)
* feat: support embedded preload

* fix lint and test

* fix test...
2023-04-11 13:10:38 +08:00
black-06
4b0da0e97a
fix cond in scopes (#6152)
* fix cond in scopes

* replace quote

* fix execute scopes
2023-04-11 12:01:23 +08:00
bsmith-auth0
ccc3cb758a
fix: many2many association with duplicate belongs to elem (#6206) 2023-04-11 11:06:13 +08:00
jessetang
05bb9d6106
refactor(migrator): non-standard codes (#6180) 2023-04-11 10:32:46 +08:00
dependabot[bot]
1d9f4b0f55
chore(deps): bump actions/stale from 7 to 8 (#6190)
Bumps [actions/stale](https://github.com/actions/stale) from 7 to 8.
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v7...v8)

---
updated-dependencies:
- dependency-name: actions/stale
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-11 10:27:05 +08:00
hanwn
59ca46db3c
fix: limit(0).offset(0) return all data (#6191)
Co-authored-by: hanwang <hanwang.7721@bytedance.com>
2023-04-11 10:25:47 +08:00
Cr
f0360dccbf
fix: embedded should be nil if not exists (#6219) 2023-04-11 10:13:25 +08:00
Saeid Kanishka
b444011d09
refactor: translatorError flag added for backward compatibility (#6178)
Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-03-24 10:07:05 +08:00
cyhone
5d1cdfef2e
avoid starting a transaction when performing only one insert operation in CreateInBatches function (#6174) 2023-03-23 14:02:35 +08:00
black-06
1a7ea98ac5
fix: count with group (#6157) (#6160)
* fix: count with group (#6157)

* add an easy-to-understand ut
2023-03-23 11:19:53 +08:00
black-06
0c7e575f19
save should be idempotent #6139 (#6149) 2023-03-23 11:18:57 +08:00
dependabot[bot]
d2dd0ce4a7
chore(deps): bump actions/setup-go from 3 to 4 (#6165)
Bumps [actions/setup-go](https://github.com/actions/setup-go) from 3 to 4.
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](https://github.com/actions/setup-go/compare/v3...v4)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-23 11:18:02 +08:00
Jinzhu
cc2d46e5be reuse name for savepoints from nested transaction, close #6060 2023-03-10 17:42:38 +08:00
Cr
8bf1f269cf
feat: support nested join (#6067)
* feat: support nested join

* fix: empty rel value
2023-03-10 17:21:56 +08:00
Jeffry Luqman
654b5f2006
test: pgsql alter column from smallint or string to boolean (#6107)
* test: pgsql alter column from smallint to boolean

* test: pgsql alter column from string to boolean
2023-03-10 17:11:56 +08:00
Cr
b62192456f
fix: diff schema update assign value (#6096) 2023-03-10 17:04:54 +08:00
Saeid Kanishka
707d70a542
refactor: translate error only when it is not nil (#6133)
* refactor: translate error only when it is not nil

* refactor: fix the error flow

* refactor: update the error if checks

* Update gorm.go

---------

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-03-10 16:51:27 +08:00
Truong Nguyen
ed474152b1
Fix: Composite primary key with auto-increment value returns 0 after insert (#6127)
* Fix #4930 workaround for databases that support auto-increment in composite primary key.

* Add test for composite key with auto-increment.

* schema.go: use field.AutoIncrement instead of field.TagSettings["AUTOINCREMENT"], add test to check autoincrement:false

create_test.go: remove unused code: drop table CompositeKeyProduct

---------

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-03-10 16:50:03 +08:00
Jinzhu
1643a36260 Fix possible concurrency problem for serializer 2023-03-10 16:39:57 +08:00
Cr
e9f25c73ee
fix: on confilct with default null (#6129)
* fix: on confilct with default null

* Update create.go

---------

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-03-10 16:35:26 +08:00
Saeid Kanishka
85eaf9eeda
feat: Unique Constraint Violation error translator for different drivers (#6004)
* feat: duplicated key error translator for different drivers

* test: removed the dependency

* test: fixed broken tests

* refactor: added ErrorTransltor interface

* style: applied styler

---------

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-03-06 14:03:31 +08:00
Jinzhu
f3874339ef Fix Save with stress tests 2023-03-02 17:22:51 +08:00
Jiepeng Cao
877cc9148f
Remove redundant code (#6087) 2023-02-27 15:44:35 +08:00
black-06
a80707de9e
Create and drop view (#6097)
* create view

* add comment

* fix test

* check param and add comment
2023-02-27 15:43:10 +08:00
Jiepeng Cao
391c961c7f
quotes on docker-compose.yml ports (#6089) 2023-02-27 15:39:02 +08:00
Cr
04cbd956eb
test: pgsql migrate unique index (#6028) 2023-02-18 09:21:07 +08:00
black-06
e66a059b82
fix: update panic if model is not ptr (#6037)
* fix: update panic if model is not ptr

* fix: update panic if model is not ptr

* fix: update panic if model is not ptr

* fix: raise an error if the value is not addressable

* fix: return
2023-02-18 09:20:29 +08:00
black-06
42fc75cb2c
fix: association concurrently appending (#6044)
* fix: association concurrently appending

* fix: fix unit test

* fix: fix gofumpt
2023-02-18 09:19:24 +08:00
Cr
aa89736db2
fix: miss join type (#6056) 2023-02-18 09:13:36 +08:00
Michael Anstis
532e9cf4cc
Issue 6054: Unscoped not working with PreLoad on Joins (#6058)
* Issue 6054: Unscoped not working with PreLoad on Joins

* Formatting

---------

Co-authored-by: Michael Anstis <manstis@redhat.com>
2023-02-18 09:06:43 +08:00
Cheese
02b7e26f6b
feat: add tidb integration test cases (#6014)
* feat: support tidb integration test

* feat: update the mysql driver version to test
2023-02-08 16:29:09 +08:00
Cr
878ac51e98
fix:throw model value required error (#6031)
* fix:throw model value required error

* chore:ingore typecheck

* chore:ingore errcheck

* refactor: use other error

* chore: gofumpt style
2023-02-08 13:40:41 +08:00
chyroc
e1f46eb802
fix: ignore nil query (#6021) 2023-02-02 17:54:51 +08:00
Jinzhu
4d6b70ec88 Allow modify statement from dest 2023-02-02 17:15:08 +08:00
qiankunli
cfbcedbf03
fix: support zeroValue tag on DeletedAt (#6011)
* fix: support zeroValue tag on DeletedAt

Signed-off-by: qiankunli <qiankun.li@qq.com>

* Update soft_delete_test.go

* Update tests_test.go

* Update soft_delete.go

---------

Signed-off-by: qiankunli <qiankun.li@qq.com>
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-02-01 14:40:55 +08:00
Jinzhu
d834dd60b7 Remove unnecessary code 2023-01-19 15:22:13 +08:00
Jinzhu
3d35ddba55 Fix use table.* as select/omit columns 2023-01-12 16:52:56 +08:00
Haibo
baf1afa1fc
fix(schema): field is only unique when there is one unique index (#5974) 2023-01-11 14:05:39 +08:00
Jinzhu
2bc913787b support implicit table alias, close #5840 #5940 2023-01-02 21:46:27 +08:00
Jinzhu
3d91802b1d Fix unexpected alter table in auto migration, close #5942, #5943 2023-01-02 21:06:04 +08:00
Jinzhu
b0e13d95b4 update github tests action 2023-01-01 22:27:49 +08:00
Jinzhu
4b768c8aff Upgrade tests deps 2023-01-01 22:22:08 +08:00
Haibo
16a272209a
fix(migrator): Tag default:'null' always causes field migration #5953 (#5954)
* fix(migrator): Tag default:'null' always causes field migration #5953

* Update migrate_test.go

* Update migrate_test.go

* Update migrate_test.go

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-01-01 22:14:28 +08:00
Haibo
da2b2861de
fix(migrator): ignore relationships when migrating #5913 (#5946) 2023-01-01 19:54:28 +08:00
dependabot[bot]
7da24d1d52
chore(deps): bump actions/stale from 6 to 7 (#5945)
Bumps [actions/stale](https://github.com/actions/stale) from 6 to 7.
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v6...v7)

---
updated-dependencies:
- dependency-name: actions/stale
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-12-27 08:47:17 +08:00
Jinzhu
ddd3cc2502 Add ParameterizedQueries option support for logger, close #5288 2022-12-25 11:37:23 +08:00
Cr
794edad60e
test(MigrateColumn): mock alter column to improve field compare (#5499)
* test(MigrateColumn): mock alter column to improve field compare

* Update migrate_test.go

* Update migrate_test.go

* Update migrate_test.go

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-12-24 17:42:16 +08:00
Cr
1935eb0adb
feat: support inner join (#5583)
* feat: support inner join

* test: mixed inner join and left join

* chore: code comment

* Update statement.go

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-12-24 12:27:38 +08:00
Defoo Li
775fa70af5
DryRun for migrator (#5689)
* DryRun for migrator

* Update migrator.go

* Update migrator.go

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-12-24 12:14:23 +08:00
Ning
bbd2bbe521
fix:Issue migrating field with CURRENT_TIMESTAMP (#5906)
Co-authored-by: ningfei <accelerator314@outlook.com>
2022-12-24 11:02:11 +08:00
Nate Armstrong
f3c6fc2533
Update func comments in chainable_api and FirstOr_ (#5935)
Add comments to functions in chainable_api. Depending on the method,
these comments add some additional context or details that are relevant
when reading the function, link to the actual docs at gorm.io/docs, or
provide examples of use. These comments should make GORM much more
pleasant to use with an IDE that provides hoverable comments, and are
minimal examples.

Also add in-code documentation to FirstOrInit and FirstOrCreate.

Almost all examples are directly pulled from the docs, with short
comments explaining the code. Most examples omit the `db.Model(&User{})`
for brevity, and would not actually work.

Co-authored-by: Nate Armstrong <nate.armstrong@eluv.io>
2022-12-23 16:51:01 +08:00
Edward McFarlane
4ec73c9bf4
Add test case for embedded value selects (#5901)
* Add test case for embedded value selects

* Revert recycle struct optimisation to avoid pointer overwrites
2022-12-19 11:49:05 +08:00
Cr
d9525d4da4
fix: skip append relation field to default db value (#5885)
* fix: relation field returning

* chore: gofumpt style
2022-12-01 20:26:59 +08:00
wjw1758548031
f931def33d
clear code syntax (#5889)
* clear code syntax

* clear code syntax
2022-12-01 20:25:53 +08:00
Jinzhu
f91313436a Fix group by with count logic 2022-11-21 11:10:56 +08:00
Cr
342310fba4
fix(FindInBatches): throw err if pk not exists (#5868) 2022-11-21 10:49:27 +08:00
kvii
b6836c2d3e
fix bug in windows (#5844)
* fix bug in windows

* fix file name bug

* test in unix like platform
2022-11-21 10:48:13 +08:00
jessetang
cef3de694d
cleanup(prepare_stmt.go): unnecessary map delete (#5849) 2022-11-13 11:12:09 +08:00
jessetang
1b9cd56c53
doc(README.md): add contributors (#5847) 2022-11-10 16:30:32 +08:00
kvii
871f1de6b9
fix logger path bug (#5836) 2022-11-05 11:52:08 +08:00
jessetang
fb640cf7da
test(utils): add utils unit test (#5834) 2022-11-05 08:38:14 +08:00
jessetang
5c8ecc3a2a
feat: golangci add goimports and whitespace (#5835) 2022-11-05 08:37:37 +08:00
jessetang
f82e9cfdbe
test(clause/joins): add join unit test (#5832) 2022-11-03 21:03:13 +08:00
Cr
b2f42528a4
fix(Joins): args with select and omit (#5790)
* fix(Joins): args with select and omit

* chore: gofumpt style
2022-11-02 10:28:00 +08:00
Cr
9d82aa5673
test: invalid cache plan with prepare stmt (#5778)
* test: invalid cache plan with prepare stmt

* test: more test cases

* test: drop and rename column
2022-10-20 14:10:47 +08:00
Cr
5dd2bb4827
feat(PreparedStmtDB): support reset (#5782)
* feat(PreparedStmtDB): support reset

* fix: close all stmt

* test: fix test

* fix: delete one by one
2022-10-19 14:46:59 +08:00
Jinzhu
3f20a543fa Support use clause.Interface as query params 2022-10-18 18:01:55 +08:00
viatoriche / Maxim Panfilov
62593cfad0 add test: TestAutoMigrateInt8PG: shouldn't execute ALTER COLUMN TYPE smallint, close #5762 2022-10-18 17:28:06 +08:00
Jinzhu
a0f4d3f7d2 Save as empty string for not nullable nil field serialized into json 2022-10-18 16:25:39 +08:00
Jinzhu
ab5f80a8d8 Save as NULL for nil object serialized into json 2022-10-18 15:44:56 +08:00
Cr
186e8a9e14
fix: association without pks (#5779) 2022-10-18 11:58:42 +08:00
Jinzhu
2a788fb20c Upgrade tests go.mod 2022-10-17 17:01:42 +08:00
Jinzhu
aa4312ee74 Don't display any GORM related package path as source 2022-10-17 17:01:42 +08:00
Jinzhu
08aa2f9888 Update README 2022-10-14 20:30:41 +08:00
Jinzhu
2c56954cb1 tests mariadb with returning support 2022-10-08 20:48:22 +08:00
Jinzhu
e93dc3426e Test postgres autoincrement check 2022-10-08 17:16:32 +08:00
Jinzhu
983e96f142 Add tests for alter column type 2022-10-08 16:04:57 +08:00
Jinzhu
34fbe84580 Add TableName with NamingStrategy support, close #5726 2022-10-07 21:18:37 +08:00
robhafner
e8f48b5c15
fix: limit=0 results (#5735) (#5736) 2022-10-07 20:14:14 +08:00
jesse.tang
4b22a55a75
fix: primaryFields are overwritten (#5721) 2022-10-07 18:29:28 +08:00
Wen Sun
9564b82975
Fix OnConstraint builder (#5738) 2022-10-07 13:46:20 +08:00
Cr
0b7113b618
fix: prepare deadlock (#5568)
* fix: prepare deadlock

* chore[ci skip]: code style

* chore[ci skip]: test remove unnecessary params

* fix: prepare deadlock

* fix: double check prepare

* test: more goroutines

* chore[ci skip]: improve code comments

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-09-30 18:13:36 +08:00
Stephano George
a3cc6c6088 Fix: wrong value when Find with Join with same column name, close #5723, #5711 2022-09-30 17:18:42 +08:00
jesse.tang
be440e7512
fix possible nil panic in tests (#5720)
* fix maybe nil panic

* reset code
2022-09-30 11:14:34 +08:00
dependabot[bot]
e1dd0dcbc4
chore(deps): bump actions/stale from 5 to 6 (#5717)
Bumps [actions/stale](https://github.com/actions/stale) from 5 to 6.
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/stale
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-09-30 11:13:01 +08:00
Nguyen Huu Tuan
328f301982
add some test case which related the logic (#5477) 2022-09-22 18:35:21 +08:00
kinggo
12237454ed fix: use preparestmt in trasaction will use new conn, close #5508 2022-09-22 16:47:43 +08:00
Cr
73bc53f061
feat: migrator support type aliases (#5627)
* feat: migrator support type aliases

* perf: check type
2022-09-22 15:56:32 +08:00
Cr
101a7c789f
fix: scan array (#5624)
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-09-22 15:51:47 +08:00
Jinzhu
3a72ba102e Allow shared foreign key for many2many jointable 2022-09-22 15:03:41 +08:00
jesse.tang
1f634c3937
support scan assign slice cap (#5634)
* support scan assign slice cap

* fix
2022-09-22 14:50:35 +08:00
Cr
5ed7b1a65e
fix: same embedded filed name (#5705) 2022-09-22 11:25:03 +08:00
qqxhb
490625981a
fix: update omit (#5699) 2022-09-16 15:02:44 +08:00
Googol Lee
edb00c10ad
AutoMigrate() should always migrate checks, even there is no relationship constraints. (#5644)
* fix: remove uuid autoincrement

* AutoMigrate() should always migrate checks, even there is no relationship constranits.

Co-authored-by: a631807682 <631807682@qq.com>
2022-09-14 10:26:51 +08:00
Bruce MacKenzie
f29afdd329
Rewrite of finisher_api Godocs (#5618) 2022-09-09 11:16:41 +08:00
Jiepeng Cao
b3eb1c8c51
simplified regexp (#5677) 2022-09-05 15:39:19 +08:00
jesse.tang
f78f635fae
Optimize: code logic db.scanIntoStruct() (#5633) 2022-09-05 15:34:33 +08:00
Cr
d71caef7d9
fix: remove uuid autoincrement (#5620) 2022-09-03 20:00:21 +08:00
Shunsuke Otani
8c3018b96a
Replace ioutil.Discard with io.Discard (#5603) 2022-08-15 10:50:06 +08:00
Shunsuke Otani
3f92b9b0df
Refactor: redundant type from composite literal (#5604) 2022-08-15 10:47:26 +08:00
Aoang
ba227e8939
Add Go 1.19 Support (#5608) 2022-08-15 10:46:57 +08:00
enwawerueli
573b9fa536 fix: correct grammar 2022-08-15 10:28:36 +08:00
Bruce MacKenzie
a35883590b
update Delete Godoc to describe soft delete behaviour (#5554) 2022-08-11 11:38:04 +08:00
Cr
f223279384
chore: fix gorm tag (#5577) 2022-08-10 11:03:42 +08:00
hjwblog.com
6e03b97e26
fix: empty serilizer err #5524 (#5525)
* fix: empty serilizer err #5524

* feat: fix UnixSecondSerializer return nil

* feat: split type case

Co-authored-by: huanjiawei <huanjiawei@bytedance.com>
2022-07-27 13:59:47 +08:00
MJrocker
3c6eb14c92 Fixed some typos in the code comment 2022-07-27 10:22:49 +08:00
Cr
06e174e24d
fix: embedded default value (#5540) 2022-07-25 14:10:30 +08:00
Xudong Zhang
bab3cd1724
fix bad logging performance of bulk create (#5520) (#5521) 2022-07-18 20:47:00 +08:00
Jinzhu
75720099b5 Create a new db in FindInBatches 2022-07-18 18:07:05 +08:00
Goxiaoy
2ba599e8b7
fix empty QueryClauses in association (#5502) (#5503)
* fix empty QueryClauses in association (#5502)

* test: empty QueryClauses in association (#5502)

* style: empty QueryClauses in association (#5502)

* style: empty QueryClauses in association (#5502)
2022-07-15 11:15:18 +08:00
alingse
099813bf11
Adjust ToStringKey use unpack params, fix pass []any as any in variadic function (#5500)
* fix pass []any as any in variadic function

* add .vscode to gitignore
2022-07-14 20:05:22 +08:00
Jinzhu
4d40e34734 Update select tests 2022-07-14 14:55:54 +08:00
Jinzhu
3262daf8d4 Fix select with association column 2022-07-13 18:26:35 +08:00
Jinzhu
cae30e9a50 Fix select with association column 2022-07-13 18:02:11 +08:00
Jinzhu
a7063848ef Fix select with uppercase column name 2022-07-13 17:44:14 +08:00
Jinzhu
08f6d06e47 Fix select with quoted column name 2022-07-13 17:21:19 +08:00
Jinzhu
62fdc2bb3b Fix serializer with empty string 2022-07-11 11:51:05 +08:00
Jinzhu
b13d1757fa Refactor Model with slice data 2022-07-07 15:39:29 +08:00
Jinzhu
9fd73ae4f1 Revert "use callback to handle transaction"
This reverts commit 93f28bc116526ba4decdd969a7b2b0b245ad70f1.
2022-07-07 15:06:48 +08:00
Jinzhu
fe01e1b9f4 Fix Model with slice data 2022-07-07 14:43:33 +08:00
Cr
46bce170ca
test: pg array type (#5480) 2022-07-04 16:42:27 +08:00
Jason Lee
5c4016d9a3
Merge pull request #5455 from longbridgeapp/feat-support-transaction-calllback 2022-07-01 23:34:26 +08:00
Cr
c74bc57add
fix: association many2many duplicate elem (#5473)
* fix: association many2many duplicate elem

* chore: gofumpt style
2022-07-01 15:12:15 +08:00
Joe
2cb4088456 ignore AddError return error 2022-07-01 14:37:38 +08:00
Cr
235c093bb9
fix(MigrateColumn):declared different type without length (#5465) 2022-06-29 10:07:42 +08:00
wws
3e6ab99043
fix:serializer contain field panic (#5461) 2022-06-25 16:32:47 +08:00
Joe
93f28bc116 use callback to handle transaction
- make transaction have before and after hooks, so plugin can have hack before
or after transaction
2022-06-24 10:33:39 +08:00
Jinzhu
a70af2a4c0 Fix Select with digits in column name 2022-06-20 15:35:40 +08:00
qqxhb
1305f637f8
feat: add method GetIndexes (#5436)
* feat: add method GetIndexes

* feat: add default impl for Index interface

* feat: fmt
2022-06-17 11:00:57 +08:00
Cr
8d45714628
fix: reset null value in slice (#5417)
* fix: reset null value in slice

* fix: can not set field in-place in join
2022-06-14 13:48:50 +08:00
Bexanderthebex
d01de7232b
enhancement: Avoid calling reflect.New() when passing in slice of values to Scan() (#5388)
* fix: reduce allocations when slice of values

* chore[test]: Add benchmark for scan

* chore[test]: add bench for scan slice

* chore[test]: add bench for slice pointer and improve tests

* chore[test]: make sure database is empty when doing slice tests

* fix[test]: correct sql delete statement

* enhancement: skip new if rows affected = 0
2022-06-01 11:50:57 +08:00
dependabot[bot]
f4e9904b02
chore(deps): bump gorm.io/driver/mysql from 1.3.3 to 1.3.4 in /tests (#5385)
Bumps [gorm.io/driver/mysql](https://github.com/go-gorm/mysql) from 1.3.3 to 1.3.4.
- [Release notes](https://github.com/go-gorm/mysql/releases)
- [Commits](https://github.com/go-gorm/mysql/compare/v1.3.3...v1.3.4)

---
updated-dependencies:
- dependency-name: gorm.io/driver/mysql
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-06-01 10:26:09 +08:00
Cr
93986de8e4
fix: migrate column default value (#5359)
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-05-28 23:09:13 +08:00
t-inagaki@hum_op
dc1ae394f3
fixed FirstOrCreate not handled error when table is not exists (#5367)
* fixed FirstOrCreate not handled error when table is not exists

* delete useless part
2022-05-28 22:18:43 +08:00
Cr
7e13b03bd4
fix: duplicate column scan (#5369)
* fix: duplicate column scan

* fix: dup filed in inconsistent schema and database

* chore[ci skip]: gofumpt style

* chore[ci skip]: fix typo
2022-05-28 22:18:07 +08:00
Cr
7d1a92d60e
test: test for skip prepared when auto migrate (#5350) 2022-05-22 16:12:28 +08:00
Clark McCauley
540fb49bcb
Fixed #5355 - Named variables don't work when followed by Windows CRLF line endings (#5356)
* Fixed #5355.

* Fixed unit test to test both CRLF and CR line endings
2022-05-22 15:16:01 +08:00
Cr
7496c3a56e
fix: trx in hooks clone stmt (#5338)
* fix: trx in hooks

* chore: format by gofumpt
2022-05-17 14:13:41 +08:00
black-06
f5e77aab2f
fix: quote index when creating table (#5331) 2022-05-17 10:59:53 +08:00
Cr
373bcf7aca
fix: many2many auto migrate (#5322)
* fix: many2many auto migrate

* fix: uuid ossp
2022-05-09 10:07:18 +08:00
Cr
19b8d37ae8
fix: preload with skip hooks (#5310) 2022-05-04 18:57:53 +08:00
Cr
b0104943ed
fix: callbcak sort when using multiple plugin (#5304) 2022-04-30 09:57:16 +08:00
Heliner
d3488ae6bc
fix: add judge result of auto_migrate (#5306)
Co-authored-by: fredhan <fredhan@futunn.com>
2022-04-30 09:50:53 +08:00
Cr
bd7e42ec65
fix: AutoMigrate with special table name (#5301)
* fix: AutoMigrate with special table name

* test: migrate with special table name
2022-04-27 21:13:48 +08:00
Jinzhu
6a6dfdae72 Refactor FirstOrCreate, FirstOrInit 2022-04-26 17:16:48 +08:00
Chiung-Ming Huang
0211ac91a2
index: add composite id (#5269)
* index: add composite id

* index: add test cases of composite id

* index: improve the comments for the test cases of composite id
2022-04-25 11:39:23 +08:00
Cr
a0cc631272
test: test for postgrs serial column (#5234)
* test: test for postgrs sercial column

* test: only for postgres

* chore: spelling mistake

* test: for drop sequence
2022-04-24 12:13:27 +08:00
aelmel
3643f856a3
check for pointer to pointer value (#5278)
* check for pointer to pointer value

* revert to Ptr

Co-authored-by: Alexei Melnic <alexei.melnic@meliora.xyz>
2022-04-24 09:10:36 +08:00
Cr
9b80fe9e96
fix: stmt.Changed zero value filed behavior (#5281)
* fix: stmt.Changed zero value filed behavior

* chore: rename var
2022-04-24 09:08:52 +08:00
glebarez
395606ac7c
fix missing error-check in AutoMigrate (#5283) 2022-04-22 11:19:33 +08:00
Jinzhu
88c26b62ee Support Scopes in group conditions 2022-04-20 17:21:38 +08:00
Cr
b49ae84780
fix: FindInBatches with offset limit (#5255)
* fix: FindInBatches with offset limit

* fix: break first

* fix: FindInBatches Limit zero
2022-04-17 09:58:33 +08:00
ZhangShenao
e0ed3ce400
fix spelling mistake (#5256)
Co-authored-by: Shenao Zhang <shenao.zhang@shopee.com>
2022-04-14 20:32:57 +08:00
Jinzhu
d421c67ef5 Remove ErrRecordNotFound error from log when using Save 2022-04-14 10:51:39 +08:00
dependabot[bot]
ce53ea53ee
chore(deps): bump actions/setup-go from 2 to 3 (#5243)
Bumps [actions/setup-go](https://github.com/actions/setup-go) from 2 to 3.
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](https://github.com/actions/setup-go/compare/v2...v3)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-04-13 15:53:12 +08:00
dependabot[bot]
771cbed755
chore(deps): bump actions/stale from 4 to 5 (#5244)
Bumps [actions/stale](https://github.com/actions/stale) from 4 to 5.
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v4...v5)

---
updated-dependencies:
- dependency-name: actions/stale
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-04-13 15:52:40 +08:00
Cr
a65912c588
fix: FirstOrCreate RowsAffected (#5250) 2022-04-13 15:52:07 +08:00
Filippo Del Moro
6aa6d37fc4
Fix scanIntoStruct (#5241)
* Reproduces error case

* Fix scanIntoStruct

Co-authored-by: Filippo Del Moro <filippo.delmoro@facile.it>
2022-04-13 15:47:04 +08:00
Jinzhu
74e07b049c Serializer unixtime support ptr of int 2022-04-11 22:07:56 +08:00
Jinzhu
41bef26f13 Remove shared sync pool for Scanner compatibility 2022-04-11 21:37:44 +08:00
Naveen
5c9ef9a843
Set permissions for GitHub actions (#5237)
Restrict the GitHub token permissions only to the required ones; this way, even if the attackers will succeed in compromising your workflow, they won’t be able to do much.

- Included permissions for the action. https://github.com/ossf/scorecard/blob/main/docs/checks.md#token-permissions

https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#permissions

https://docs.github.com/en/actions/using-jobs/assigning-permissions-to-jobs

[Keeping your GitHub Actions and workflows secure Part 1: Preventing pwn requests](https://securitylab.github.com/research/github-actions-preventing-pwn-requests/)

Signed-off-by: naveensrinivasan <172697+naveensrinivasan@users.noreply.github.com>
2022-04-10 09:38:43 +08:00
Jinzhu
0729261b62 Support double ptr for Save 2022-04-08 14:23:25 +08:00
Hasan
81c4024232
Offset issue resolved for scanning results back into struct (#5227) 2022-04-07 23:56:41 +08:00
huangcheng1
38a24606da fix: tables lost when joins exists in from clause, close #5218
commit 7f6a603afa26820e187489b5203f93adc513687c
Author: Jinzhu <wosmvp@gmail.com>
Date:   Sat Apr 2 17:26:48 2022 +0800

    Refactor #5218

commit 95d00e6ff2668233f3eca98aa4917291e3d869bd
Author: huangcheng1 <huangcheng1@sensetime.com>
Date:   Fri Apr 1 16:30:27 2022 +0800

    fix: tables lost when joins exists in from clause
2022-04-02 17:27:53 +08:00
Jinzhu
9144969c83 Allow to use tag to disable auto create/update time 2022-04-02 17:17:47 +08:00
ZhangShenao
f7b52bb649
unify db receiver name (#5215)
Co-authored-by: Shenao Zhang <shenao.zhang@shopee.com>
2022-04-01 08:35:16 +08:00
Goxiaoy
cd0315334b
fix: context missing in association (#5214) 2022-04-01 08:33:39 +08:00
ZhangShenao
8333844f71
fix variable shadowing (#5212)
Co-authored-by: Shenao Zhang <shenao.zhang@shopee.com>
2022-03-31 20:57:20 +08:00
Jinzhu
ea8509b777 Use defer to close rows to avoid scan panic leak rows 2022-03-29 18:48:32 +08:00
Jinzhu
9dd6ed9c65 Scan with Rows interface 2022-03-29 18:14:37 +08:00
dependabot[bot]
6c827ff2e3
chore(deps): bump actions/cache from 2 to 3 (#5196)
Bumps [actions/cache](https://github.com/actions/cache) from 2 to 3.
- [Release notes](https://github.com/actions/cache/releases)
- [Commits](https://github.com/actions/cache/compare/v2...v3)

---
updated-dependencies:
- dependency-name: actions/cache
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-03-28 19:55:05 +08:00
qqxhb
6d40a83432
Update README.md
add gorm gen
2022-03-24 16:30:14 +08:00
Cr
3d7019a7c2
fix: throw err if association model miss primary key (#5187) 2022-03-24 09:34:06 +08:00
Jin
9a4d10be64
style: fix coding typo (#5184) 2022-03-24 09:31:58 +08:00
Jinzhu
f92e6747cb Handle field set value error 2022-03-23 17:24:25 +08:00
Jinzhu
a7b3b5956f Fix hooks order, close https://github.com/go-gorm/gorm.io/pull/519 2022-03-22 22:42:36 +08:00
Jinzhu
d66f37ad32 Add Go 1.18 2022-03-21 10:50:14 +08:00
Jin
2d5cb997ed
style: fix linter check for NamingStrategy and onConflictOption (#5174) 2022-03-20 09:02:45 +08:00
Jinzhu
0097b39a77 Should ignore error when parsing default value for time, close #5176 2022-03-20 08:55:08 +08:00
Jinzhu
540b47571a Fix update select clause with before/after expressions, close #5164 2022-03-18 20:57:33 +08:00
Cr
d402765f69
test: fix utils.AssertEqual (#5172) 2022-03-18 20:11:23 +08:00
ag9920
3c00980e01 fix: serializer use default valueOf in assignInterfacesToValue, close #5168
commit 58e1b2bffbc216f2862d040fb545a8a486e473b6
Author: Jinzhu <wosmvp@gmail.com>
Date:   Fri Mar 18 17:06:43 2022 +0800

    Refactor #5168

commit fb9233011d209174e8223e970f0f732412852908
Author: ag9920 <alexgong7@outlook.com>
Date:   Thu Mar 17 21:23:28 2022 +0800

    fix: serializer use default valueOf in assignInterfacesToValue
2022-03-18 17:12:17 +08:00
Jinzhu
e6f7da0e0d Support Variable Relation 2022-03-18 14:30:30 +08:00
chenrui
5431da8caf fix: preload panic when model and dest different close #5130
commit e8307b5ef5273519a32cd8e4fd29250d1c277f6e
Author: Jinzhu <wosmvp@gmail.com>
Date:   Fri Mar 18 13:37:22 2022 +0800

    Refactor #5130

commit 40cbba49f374c9bae54f80daee16697ae45e905b
Author: chenrui <chenrui@jingdaka.com>
Date:   Sat Mar 5 17:36:56 2022 +0800

    test: fix test fail

commit 66d3f078291102a30532b6a9d97c757228a9b543
Author: chenrui <chenrui@jingdaka.com>
Date:   Sat Mar 5 17:29:09 2022 +0800

    test: drop table and auto migrate

commit 7cbf019a930019476a97ac7ac0f5fc186e8d5b42
Author: chenrui <chenrui@jingdaka.com>
Date:   Sat Mar 5 15:27:45 2022 +0800

    fix: preload panic when model and dest different
2022-03-18 13:38:46 +08:00
chenrui
c2e36ebe62 fix: soft delete for join, close #5132
commit a83023bdfc0dc6eaccc6704b64ff6436c2fe7725
Author: Jinzhu <wosmvp@gmail.com>
Date:   Fri Mar 18 01:05:25 2022 +0800

    Refactor #5132

commit 8559f51102c01be6c19913c0bc3a5771721ff1f5
Author: chenrui <chenrui@jingdaka.com>
Date:   Mon Mar 7 20:33:12 2022 +0800

    fix: should add deleted_at exprs for every joins

commit 2b7a1bdcf3eff9d23253173d21e73c1f056f9be4
Author: chenrui <chenrui@jingdaka.com>
Date:   Mon Mar 7 14:46:48 2022 +0800

    test: move debug flag

commit ce13a2a7bc50d2c23678806acf65dbd589827c77
Author: chenrui <chenrui@jingdaka.com>
Date:   Mon Mar 7 14:39:56 2022 +0800

    fix: soft delete for join.on
2022-03-18 01:09:20 +08:00
chenrui
9b9ae325bb fix: circular reference save, close #5140
commit 2ac099a37ac7bd74f0a98a6fdc42cc8527404144
Author: Jinzhu <wosmvp@gmail.com>
Date:   Thu Mar 17 23:49:21 2022 +0800

    Refactor #5140

commit 6e3ca2d1aa09943dcfb5d9a4b93bea28212f71be
Author: a631807682 <631807682@qq.com>
Date:   Sun Mar 13 12:52:08 2022 +0800

    test: add test for LoadOrStoreVisitMap

commit 9d5c68e41000fd15dea124797dd5f2656bf6b304
Author: chenrui <chenrui@jingdaka.com>
Date:   Thu Mar 10 20:33:47 2022 +0800

    chore: add more comment

commit bfffefb179c883389b72bef8f04469c0a8418043
Author: chenrui <chenrui@jingdaka.com>
Date:   Thu Mar 10 20:28:48 2022 +0800

    fix: should check values has been saved instead of rel.Name

commit e55cdfa4b3fbcf8b80baf009e8ddb2e40d471494
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:48:01 2022 +0800

    chore: go lint

commit fe4715c5bd4ac28950c97dded9848710d8becb88
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:27:24 2022 +0800

    chore: add test comment

commit 326862f3f8980482a09d7d1a7f4d1011bb8a7c59
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:22:33 2022 +0800

    fix: circular reference save
2022-03-17 23:53:31 +08:00
Mikhail Faraponov
2990790fbc
Use WriteByte for single byte operations (#5167)
Co-authored-by: Mikhail Faraponov <mikefaraponov@Mikhails-MacBook-Pro.local>
2022-03-17 22:54:30 +08:00
Hasan
f3e2da5ba3 Added offset when scanning the result back to struct, close #5143
commit 9a2058164d44c98d7b586b87bed1757f89d6fad7
Author: Jinzhu <wosmvp@gmail.com>
Date:   Thu Mar 17 22:34:19 2022 +0800

    Refactor #5143

commit c259de21768936428c9d89f7b31afb95b8acb36a
Author: Hasan <mr.k779@outlook.com>
Date:   Mon Mar 14 20:04:01 2022 +0545

    Update scan_test.go

commit 09f127b49151a52fbb8b354a03e6610d4f70262f
Author: Hasan <mr.k779@outlook.com>
Date:   Mon Mar 14 19:23:47 2022 +0545

    Added test for scanning embedded data into structs

commit aeaca493cf412def7813d36fd6a68acc832bf79f
Author: Hasan <mr.k779@outlook.com>
Date:   Tue Mar 8 04:08:16 2022 +0600

    Added offset when scanning the result back to struct
2022-03-17 22:52:40 +08:00
Jinzhu
63ac66b569 Support default tag for time.Time 2022-03-17 11:34:27 +08:00
Jinzhu
6befa0c947 Refactor preload error check 2022-03-17 11:22:25 +08:00
labulakalia
61b4c31236
fix when index name is "type", parseFieldIndexes will set index TYPE is "TYPE" (#5155)
* fix index name is type, parseFieldIndexes will set index TYPE is "TYPE"

* check TYPE empty
2022-03-14 21:47:59 +08:00
dependabot[bot]
f961bf1c14
chore(deps): bump actions/checkout from 2 to 3 (#5133)
Bumps [actions/checkout](https://github.com/actions/checkout) from 2 to 3.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v2...v3)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-03-12 22:28:18 +08:00
Jason Lee
b566ed7913
Merge pull request #5125 from CaoManhDat/master
ToSQL should enable SkipDefaultTransaction by default
2022-03-04 11:37:37 +08:00
Cao Manh Dat
29a8557384 ToSQL should enable SkipDefaultTransaction by default 2022-03-03 09:18:01 +07:00
Jinzhu
4e523499d1 Refactor Tx interface 2022-03-01 16:59:50 +08:00
lianghuan
996b96e812 Add TxConnPoolBeginner and Tx interface 2022-03-01 16:59:50 +08:00
Jinzhu
e2e802b837 Refactor Scan 2022-02-28 13:00:30 +08:00
Jinzhu
43a72b369e Refactor Scan 2022-02-28 00:16:57 +08:00
Jinzhu
530b0a12b4 Add fast path for ValueOf, ReflectValueOf 2022-02-27 22:16:31 +08:00
Jinzhu
4d14ac39ff Merge branch 'a631807682-i5091' 2022-02-27 09:12:35 +08:00
Jinzhu
68bb5379d9 Refactor scan into struct 2022-02-27 09:09:29 +08:00
Jinzhu
f2edda50e1 Refactor check missing where condition 2022-02-27 08:40:15 +08:00
chenrui
397b583b8e fix: query scanner in single column 2022-02-25 22:38:48 +08:00
Jinzhu
6a18a15c93 Refactor check missing where condition 2022-02-25 10:48:23 +08:00
jing1
3741f258d0
feat: support gob serialize (#5108) 2022-02-24 10:21:27 +08:00
Michael Nussbaum
45ef1da7e4
Fix naming longer then 64 chars with dots in table (#5045)
Ensures that foreign key relationships and indexes are given
syntactically valid names when their name length exceeds 64 characters
and they contained dot characters within the name. This is most often
relevant when a Postgres table name is fully qualified by including its schema
as part of its name
2022-02-24 10:10:20 +08:00
Jinzhu
b1201fce4e Fix update with customized time type, close #5101 2022-02-23 17:48:26 +08:00
Qt
7837fb6fa0
fix typo in TxCommitter interface comment & improve CheckTruth, chek val empty first (#5094)
* fix typo in TxCommitter interface comment

* improve CheckTruth, chek val empty first
2022-02-20 21:19:15 +08:00
codingxh
664c5fb767
strings.replace -> strings.replaceAll (#5095)
Co-authored-by: huquan<xxhh_quan_g@163.com>
2022-02-20 19:55:04 +08:00
Gilad Weiss
f3547e00cc
Inherit clone flag (NewDB) on transaction creation (#5012)
* Inherit clone flag (NewDB) on transaction creation

I find it very reassuring to know that after a finisher API, I get a clean db object for my next queries.
If you look at the example in https://gorm.io/docs i’d see many queries running one after the other.. but in reality they wouldn’t work as the they are portrayed and that’s because in default mode NewDB is false and will make all the clauses stay even after a finisher API.

My solution is just to have the value of the clone flag in the “parent” db object, be injected to its children transactions.

* Fix typo
2022-02-20 08:33:12 +08:00
sammyrnycreal
5edc78116f Fixed the use of "or" to be " OR ", to account for words that contain "or" or "and" (e.g., 'score', 'band') in a sql statement as the name of a field. 2022-02-20 08:22:21 +08:00
Jinzhu
48ced75d1d Improve support for AutoMigrate 2022-02-19 23:42:20 +08:00
Jinzhu
e0b4e0ec8f Update auto stale days 2022-02-19 17:11:23 +08:00
Jinzhu
0af95f509a Enhance migrator Columntype interface (#5088)
* Update Migrator ColumnType interface

* Update MigrateColumn Test

* Upgrade test drivers

* Fix typo
2022-02-19 17:02:53 +08:00
Jinzhu
39d84cba5f Add serializer support (#5078)
* Update context

* Update GormFieldValuer

* Add Serializer

* Add Serializer Interface

* Refactor gorm field

* Refactor setter, valuer

* Add sync.Pool

* Fix test

* Add pool manager

* Fix pool manager

* Add poolInitializer

* Add Serializer Scan support

* Add Serializer Value method

* Add serializer test

* Finish Serializer

* Fix JSONSerializer for postgres

* Fix JSONSerializer for sqlserver

* Test serializer tag

* Add unixtime serializer

* Update go.mod
2022-02-19 17:02:53 +08:00
li-jin-gou
19ac396a22
fix: isPrintable incorrect (#5076)
* fix: isPrintable incorrect

* fix: isPrintable incorrect

* style: use ReplaceAll instead of Replace
2022-02-15 20:32:03 +08:00
Jinzhu
a0aceeb33e Migrator AlterColumn with full data type 2022-02-10 10:40:48 +08:00
Jinzhu
df2365057b Remove uncessary switch case 2022-02-09 17:23:16 +08:00
Jinzhu
4eeb839cea Better support Stringer when explain SQL 2022-02-09 15:17:25 +08:00
li-jin-gou
d22215129e
fix: replace empty table name result in panic (#5048)
* fix: replace empty name result in panic

* fix: replace empty table name result in panic
2022-02-08 17:06:10 +08:00
Jinzhu
416c4d0653 Test query with Or and soft delete 2022-02-08 16:31:24 +08:00
Jason Lee
93b1a6f7ea
Merge pull request #5043 from Saurabh-Thakre/patch-2 2022-02-04 22:31:21 +08:00
Saurabh Thakre
581a879bf1
Added comments to existing methods
Added two comments to describe FirstOrInit and FirstOrCreate methods.
2022-01-31 17:26:28 +05:30
Jinzhu
f19b84d104 Fix github action 2022-01-30 22:46:41 +08:00
Jinzhu
8d293d44dd Fix docker-compose test env for Mac M1 2022-01-30 22:05:38 +08:00
Ning
8c3673286d
preoload not allowd before count (#5023)
Co-authored-by: ningfei <accelerator314@outlook.com>
2022-01-30 18:17:06 +08:00
li-jin-gou
c0bea447b9
fix: omit not work when use join (#5034) 2022-01-28 22:16:42 +08:00
Jinzhu
98c4b78e4d Add Session Initialized option 2022-01-28 19:26:10 +08:00
Jinzhu
cec0d32aec Support use clause.Expression as argument 2022-01-28 18:48:32 +08:00
dependabot[bot]
e5894ca449
chore(deps): bump gorm.io/driver/mysql from 1.2.1 to 1.2.3 in /tests (#4987)
Bumps [gorm.io/driver/mysql](https://github.com/go-gorm/mysql) from 1.2.1 to 1.2.3.
- [Release notes](https://github.com/go-gorm/mysql/releases)
- [Commits](https://github.com/go-gorm/mysql/compare/v1.2.1...v1.2.3)

---
updated-dependencies:
- dependency-name: gorm.io/driver/mysql
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-01-12 13:11:57 +08:00
piyongcai
a0d6ff1fea
time.Time, []byte type add alias support. (rebase master) (#4992)
* time.Time, []byte type add alias support

* reformat
2022-01-12 13:11:40 +08:00
Jinzhu
eae73624ad Fix return failed to begin transaction error when failed to start a transaction 2022-01-07 10:04:35 +08:00
kinggo
0df42e9afc
feat: add Connection to execute multiple commands in a single connection; (#4982) 2022-01-07 09:49:56 +08:00
halfcrazy
f757b8fdc9
fix: auto migration column order unpredictable (#4980) 2022-01-06 18:55:20 +08:00
kinggo
b47cf57f5e
ci: add gofumpt check in reviewdog (#4973) 2022-01-06 15:02:53 +08:00
kinggo
4dd2647967
Merge pull request #4964 from liweitingwt/f_test_error
improve the error handle in tests_test
2021-12-31 14:25:04 +08:00
kinggo
8dde09e0be
fix: generate sql incorrect when use soft_delete and only one OR (#4969)
* fix: generate sql incorrect when use soft_delete and only one OR
2021-12-30 11:47:14 +08:00
liweiting.wt
b9667cb747 fix: fix the error handle in tests_test 2021-12-28 18:22:17 +08:00
Emre Güllü
2c3fc2db28
Fix: Where clauses with named arguments may cause generation of unintended queries (#4937) 2021-12-21 19:50:00 +08:00
liweitingwt
24026bf1fe
modify unscoped judge (#4929)
* modify unscoped judge

* modify unscoped judge

Co-authored-by: liweiting <liweiting1995@gmail.com>
2021-12-16 10:41:34 +08:00
Jinzhu
adf8f70f06 Upgrade go.mod 2021-12-10 17:50:19 +08:00
piyongcai
380cc64ff5
fix type alias AutoMigrate bug(Add Test Case) (#4888)
* fix type alias AutoMigrate bug. eg

```go
package main

type IDer interface{ GetID() int64 }

// ID will add some method to implement some interface eg: GetID
type ID int64
func (z ID) GetID() int64 { return int64(z) }

type Test struct {
	ID
	Code string `gorm:"size:50"`
	Name string `gorm:"size:50"`
}

func main() {
	db, err := gorm.Open(postgres.New(postgres.Config{
		DSN: `dsn`,
		PreferSimpleProtocol: false,
	}), &gorm.Config{
		Logger:                 logger.Default.LogMode(logger.Info),
		SkipDefaultTransaction: true,
	})
	if err != nil {
		log.Fatal(err)
	}

	if err = db.AutoMigrate(&Test{}); err != nil {
		// invalid embedded struct for Test's field ID, should be struct, but got main.ID
		log.Fatal(err)
	}
}
```

* fix type alias AutoMigrate bug. eg

```go
package main

type IDer interface{ GetID() int64 }

// ID will add some method to implement some interface eg: GetID
type ID int64
func (z ID) GetID() int64 { return int64(z) }

type Test struct {
	ID
	Code string `gorm:"size:50"`
	Name string `gorm:"size:50"`
}

func main() {
	db, err := gorm.Open(postgres.New(postgres.Config{
		DSN:                  `dsn`,
		PreferSimpleProtocol: false,
	}), &gorm.Config{
		Logger:                 logger.Default.LogMode(logger.Info),
		SkipDefaultTransaction: true,
	})
	if err != nil {
		log.Fatal(err)
	}

	if err = db.AutoMigrate(&Test{}); err != nil {
		// invalid embedded struct for Test's field ID, should be struct, but got main.ID
		log.Fatal(err)
	}
}
```

* Add typealis test.

* try to fix golangci-lint
2021-12-10 17:45:36 +08:00
Matthieu MOREL
2a578d767f
Use Golangci configuration file (#4896) 2021-12-10 17:44:11 +08:00
kinggo
e5bdd610c3
fix: save not use soft_delete (#4897)
* fix: Save not use soft_delete

* fix: save not use soft_delete

* fix: save not use soft_delete

* fix: save not use soft_delete

Co-authored-by: kinggo <>
2021-12-08 13:58:06 +08:00
Jinzhu
300a23fc31 Check rows.Close error, close #4891 2021-12-02 10:39:24 +08:00
Jinzhu
8627634959 Fix create associations with zero primary key, close #4890 2021-12-02 10:20:16 +08:00
Jinzhu
3a3b82263a Fix auto migration always alert table, close #4198 2021-11-29 20:24:16 +08:00
kinggo
d8a710cba2
fix: count() when use group by and only find one record (#4885)
Co-authored-by: 李龙 <lilong.21@bytedance.com>
2021-11-29 20:14:23 +08:00
Jinzhu
27e2753c9d Fix create duplicated value when updating nested has many relationship, close #4796 2021-11-29 18:43:39 +08:00
Jinzhu
45e804dd3f Fix call valuer interface when using nil value 2021-11-29 16:19:11 +08:00
Jinzhu
92d5a959a0 Fix tests 2021-11-29 15:16:57 +08:00
Jinzhu
270e38c518 Fix duplicated error when Scan, close #4525 2021-11-29 14:23:10 +08:00
Jinzhu
e1b4c066a8 Fix FullSaveAssociations, close #4874 2021-11-29 11:02:44 +08:00
heige
9d5f315b6d
feat: go code style adjust and optimize code for callbacks package (#4861)
* feat: go code style adjust and optimize code for callbacks package

* Update scan.go
2021-11-29 09:33:20 +08:00
Jinzhu
b8f33a42a4
Add unused argument (#4871)
* Append unused argument to gorm statement
2021-11-23 17:11:52 +08:00
dependabot[bot]
cff7845e58
Bump gorm.io/driver/mysql from 1.1.3 to 1.2.0 in /tests (#4856)
Bumps [gorm.io/driver/mysql](https://github.com/go-gorm/mysql) from 1.1.3 to 1.2.0.
- [Release notes](https://github.com/go-gorm/mysql/releases)
- [Commits](https://github.com/go-gorm/mysql/compare/v1.1.3...v1.2.0)

---
updated-dependencies:
- dependency-name: gorm.io/driver/mysql
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-11-23 11:40:18 +08:00
dependabot[bot]
0f8e861597
Bump github.com/jinzhu/now from 1.1.2 to 1.1.3 in /tests (#4866)
Bumps [github.com/jinzhu/now](https://github.com/jinzhu/now) from 1.1.2 to 1.1.3.
- [Release notes](https://github.com/jinzhu/now/releases)
- [Commits](https://github.com/jinzhu/now/compare/v1.1.2...v1.1.3)

---
updated-dependencies:
- dependency-name: github.com/jinzhu/now
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-11-23 11:40:03 +08:00
dependabot[bot]
11d5c346ae
Bump github.com/jinzhu/now from 1.1.2 to 1.1.3 (#4865)
Bumps [github.com/jinzhu/now](https://github.com/jinzhu/now) from 1.1.2 to 1.1.3.
- [Release notes](https://github.com/jinzhu/now/releases)
- [Commits](https://github.com/jinzhu/now/compare/v1.1.2...v1.1.3)

---
updated-dependencies:
- dependency-name: github.com/jinzhu/now
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-11-23 11:39:42 +08:00
dino.ma
5e64ac7de9
feat(migrator,migrator/migrator.go,tests/migrate_test.go) : Get multiple data tables for migrator. (#4841)
* feat(migrator,migrator/migrator.go,tests/migrate_test.go) : Get multiple data tables for migrator.

* feat(migrator.go and migrator/migrator.go) : remove Table Struct replace with []string

* fix(migrator)  : Return all data tables

* Update migrator.go

* fix(migrator/migrator.go):remove var sql

* feat(migrate_test.go/go.mod):update sqlserver,sqlite,postgres,pq version and add getTables test

* fix(migrate_test.go):change GetTables Method Test,use intersection

Co-authored-by: dino.ma <mashengjie03@baidu.com>
2021-11-13 14:03:33 +08:00
riverchu
33bc56cbb5 feat(update): update when has SET clause 2021-11-09 19:55:47 +08:00
Jinzhu
5daa413f41 Stabilize schema.FieldsWithDefaultDBValue's order, close #4643 2021-11-08 20:20:55 +08:00
Jinzhu
ca7accdbf6 Fix preload all associations with inline conditions, close #4836 2021-11-08 19:47:10 +08:00
Jinzhu
b23c3b290e Don't query with primary key when using Save 2021-11-08 18:49:59 +08:00
Mayank Govilla
d9d5c4dce0
Fix self-referential belongs to constraint (#4801)
* create tests for self-ref has one migration

* add relation equality check to avoid skipping self-referential schemas

* remove drop table error check
2021-11-08 09:47:29 +08:00
heige
4c8810a848
Refactor if logic (#4683)
* adjust code for preload

* adjust code for Create
2021-11-04 13:45:44 +08:00
kinggo
c170af11e9
fix connections leak (#4826)
* fix connections leak

* fix connections leak

* fix connections leak

* fix connections leak

Co-authored-by: 李龙 <lilong.21@bytedance.com>
2021-11-03 13:39:52 +08:00
dependabot[bot]
7b927900e9
Bump gorm.io/driver/sqlserver from 1.1.2 to 1.2.0 in /tests (#4820)
Bumps [gorm.io/driver/sqlserver](https://github.com/go-gorm/sqlserver) from 1.1.2 to 1.2.0.
- [Release notes](https://github.com/go-gorm/sqlserver/releases)
- [Commits](https://github.com/go-gorm/sqlserver/compare/v1.1.2...v1.2.0)

---
updated-dependencies:
- dependency-name: gorm.io/driver/sqlserver
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-11-01 17:09:08 +08:00
Jason Lee
8de266b4a7
Add ToSQL support to generate SQL string. (#4787)
* Add db.ToSQL method for generate SQL string.

* Improve sql builder test for all dialects.

Improve assertEqualSQL test helper for ignore quotes in SQL.
2021-11-01 17:08:54 +08:00
Jinzhu
9635d25150 Fix query with uninitialized map 2021-11-01 13:00:52 +08:00
Jinzhu
9f533950a2 Add dest value if current size equal zero 2021-10-28 17:12:31 +08:00
Jinzhu
e953880d19 Add returning tests 2021-10-28 09:17:33 +08:00
Jinzhu
835d7bde59 Add returning support to delete 2021-10-28 07:56:55 +08:00
Jinzhu
af3fbdc2fc Improve returning support 2021-10-26 22:40:14 +08:00
Jason Lee
d3211908a0
Refactor ParseWithSchemaTable method and improve test. (#4789)
* Refactor ParseWithSchemaTable method and improve test.

* Fix schema.ParseWithSchemaTable method for only use schemaTable in migrator and improve test.

* Rename `schemaTable` to `specialTableName` for clearly argument.
2021-10-25 11:26:44 +08:00
Jason Lee
38e55f1117
Merge pull request #4773 from xwjdsh/master
fix: automigrate error caused by indexes while using dynamic table name
2021-10-19 10:11:11 +08:00
Wendell Sun
a3bd9c3ea2 fix: automigrate error caused by indexes while using dynamic table name 2021-10-19 09:59:57 +08:00
Jinzhu
9a5ba37604 Merge branch 'hashicorp-jimlambrt-null-without-ptrs' 2021-10-13 21:02:03 +08:00
Jinzhu
b27095e8a1 Refactor Convert SQL null values to zero values for model fields which are not pointers #4710 2021-10-13 21:01:36 +08:00
Jim
19cf645dbd feat: Convert SQL nulls to zero values (ConvertNullToZeroValues)
Makes it the default behavior to convert SQL null values to zero
values for model fields which are not pointers.
2021-10-13 08:11:22 -04:00
kinggo
696092e287
update tests' go.mod and tests_all.sh (#4774) 2021-10-13 14:41:33 +08:00
kinggo
ec58e3319f
fixed:panic when create value from nil struct pointer. (#4771)
* fixed:create nil pointer

* fixed:panic when create value from nil struct pointer.
2021-10-12 21:19:08 +08:00
kinggo
418c60c83c
fixed: clauseSelect.Columns missed when use Join And execute multiple query. (#4757) 2021-10-09 16:55:45 +08:00
Jinzhu
bfda75d099 Support specify select/omit columns with table 2021-10-09 10:42:41 +08:00
Jinzhu
6312d86c54 Support specify select/omit columns with table 2021-10-08 17:51:27 +08:00
Jinzhu
d4c838c1ce Upgrade sqlite driver 2021-10-08 17:31:58 +08:00
kinggo
b46e2afc4a
fix : update miss where's condition when primary key use "<-:create" tag (#4738)
* fix:update miss where condition

* fix:rename test case
2021-10-08 13:47:01 +08:00
heige
e3fc49a694
feat: ajust PreparedStmtDB unlock location and BuildCondition if logic (#4681) 2021-10-08 11:16:58 +08:00
heige
c13f3011f9
feat: adjust SetupJoinTable func if..else code (#4680) 2021-10-08 11:05:50 +08:00
Paras Waykole
5d91ddac8c
fixed belongs_to & has_one reversed if field same (proper fix) (#4694)
* fixed belongs_to & has_one reversed if field same

* hasmany same foreign key bug fixed and test added

* belongsToSameForeignKey fixed and reverted old fix
2021-10-08 10:59:55 +08:00
dependabot[bot]
57d927d046
Bump gorm.io/driver/postgres from 1.1.1 to 1.1.2 in /tests (#4740)
Bumps [gorm.io/driver/postgres](https://github.com/go-gorm/postgres) from 1.1.1 to 1.1.2.
- [Release notes](https://github.com/go-gorm/postgres/releases)
- [Commits](https://github.com/go-gorm/postgres/compare/v1.1.1...v1.1.2)

---
updated-dependencies:
- dependency-name: gorm.io/driver/postgres
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-10-08 10:54:50 +08:00
s-takehana
0b6bd33934
Update tests.yml (#4741) 2021-10-08 10:51:53 +08:00
River
851fea0221
fix: QuoteTo not fully support raw mode (#4735)
* fix: QuoteTo not fully support raw mode

* fix: table alias without AS

* test: clause.Column/Table quote test

* fix: revert table alias quote
2021-09-29 14:02:35 +08:00
Jinzhu
c4a2e891da Fix Join condition with DB 2021-09-28 22:37:15 +08:00
Jinzhu
002bf78ea7 Fix Join condition with DB, close #4719 2021-09-28 21:43:31 +08:00
kinggo
6864a24150
fix:remove the tableName judgment in pluck (#4731) 2021-09-27 22:11:29 +08:00
Jim
5202529ea1
fix (clause/expression): Allow sql stmt terminator (#4693)
Allow the sql stmt terminator ";" at the end of a named parameter.

Example: select * from table_name where name == @name;
2021-09-20 21:40:48 +08:00
dependabot[bot]
199c8529b6
Bump gorm.io/driver/postgres from 1.1.0 to 1.1.1 in /tests (#4699)
Bumps [gorm.io/driver/postgres](https://github.com/go-gorm/postgres) from 1.1.0 to 1.1.1.
- [Release notes](https://github.com/go-gorm/postgres/releases)
- [Commits](https://github.com/go-gorm/postgres/compare/v1.1.0...v1.1.1)

---
updated-dependencies:
- dependency-name: gorm.io/driver/postgres
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-09-20 21:33:38 +08:00
dependabot[bot]
d67120a155
Bump gorm.io/driver/sqlite from 1.1.4 to 1.1.5 in /tests (#4701)
Bumps [gorm.io/driver/sqlite](https://github.com/go-gorm/sqlite) from 1.1.4 to 1.1.5.
- [Release notes](https://github.com/go-gorm/sqlite/releases)
- [Commits](https://github.com/go-gorm/sqlite/compare/v1.1.4...v1.1.5)

---
updated-dependencies:
- dependency-name: gorm.io/driver/sqlite
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-09-20 21:25:29 +08:00
Jinzhu
ab355336cb Fix scan with interface 2021-09-17 18:35:14 +08:00
Jinzhu
da16a8aac6 Update updated_at when upserting with Create OnConflict 2021-09-17 15:29:49 +08:00
Jinzhu
12bbde89e6 Fix Scan with interface 2021-09-17 14:04:19 +08:00
Jinzhu
61b018cb94 Fix count with selected * 2021-09-16 11:17:54 +08:00
Jinzhu
d41fb3acdc Refactor dummy driver QuoteTo method 2021-09-11 16:22:35 +08:00
Jinzhu
04f049c1da Add tests for RowsAffected 2021-09-09 11:22:55 +08:00
Jinzhu
a16db07945 Refactor Join ON 2021-09-07 21:21:44 +08:00
Jinzhu
ba16b2368f
Refactor update record (#4679) 2021-09-07 20:04:54 +08:00
Jinzhu
6c94b07e98 try to fix fatal error: concurrent map read and map write 2021-09-07 15:30:14 +08:00
Jinzhu
3b6a7c8aec Update sqlserver driver 2021-09-07 12:01:19 +08:00
Adrien Carreira
d047f854e6 PR Comments 2021-09-06 20:13:20 +08:00
Adrien Carreira
c301aeb524 Refactor for readability 2021-09-06 20:13:20 +08:00
Adrien Carreira
52cc438d07 JoinsOn unit test + use all primary keys 2021-09-06 20:13:20 +08:00
Adrien Carreira
895c1178a0 Proposal, Add Specific on for Joins queries 2021-09-06 20:13:20 +08:00
riverchu
eaa63d15e7 feat: copy dest fields to model struct 2021-09-06 20:13:20 +08:00
riverchu
4581e8b590 test: update Save test 2021-09-06 20:13:20 +08:00
riverchu
c898622791 test: add testcase in TestSave 2021-09-06 20:13:20 +08:00
riverchu
1d9e563023 style: prepose error judgement 2021-09-06 20:13:20 +08:00
dependabot[bot]
a89d4d8fd5
Bump github.com/lib/pq from 1.10.2 to 1.10.3 in /tests (#4676)
Bumps [github.com/lib/pq](https://github.com/lib/pq) from 1.10.2 to 1.10.3.
- [Release notes](https://github.com/lib/pq/releases)
- [Commits](https://github.com/lib/pq/compare/v1.10.2...v1.10.3)

---
updated-dependencies:
- dependency-name: github.com/lib/pq
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-09-06 16:26:14 +08:00
dependabot[bot]
5f019f74bf
Bump gorm.io/gorm from 1.21.13 to 1.21.14 in /tests (#4655)
Bumps [gorm.io/gorm](https://github.com/go-gorm/gorm) from 1.21.13 to 1.21.14.
- [Release notes](https://github.com/go-gorm/gorm/releases)
- [Commits](https://github.com/go-gorm/gorm/compare/v1.21.13...v1.21.14)

---
updated-dependencies:
- dependency-name: gorm.io/gorm
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-09-03 17:47:50 +08:00
jxlwqq
15188cf409
Add Go 1.17 (#4666) 2021-09-03 17:47:32 +08:00
Jinzhu
3a8c250180 Refactor calc associations onConflictOption 2021-08-26 13:37:49 +08:00
zkqiang
74746211b8 Test update association with non-updatable 2021-08-26 13:37:49 +08:00
zkqiang
e81833fd11 Fix onConflict with non-updatable in associations 2021-08-26 13:37:49 +08:00
Jinzhu
f21e35f7c5 Fix table not supported error when using unexpected table name 2021-08-26 13:14:16 +08:00
dependabot[bot]
0934b10856
Bump gorm.io/driver/sqlserver from 1.0.7 to 1.0.8 in /tests (#4631)
Bumps [gorm.io/driver/sqlserver](https://github.com/go-gorm/sqlserver) from 1.0.7 to 1.0.8.
- [Release notes](https://github.com/go-gorm/sqlserver/releases)
- [Commits](https://github.com/go-gorm/sqlserver/compare/v1.0.7...v1.0.8)

---
updated-dependencies:
- dependency-name: gorm.io/driver/sqlserver
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-08-23 15:30:02 +08:00
Sec Cake
093694fbf2
Fix extra 'AND' when len(values) == 0 ON IN.NegationBuild() (#4618) 2021-08-20 18:06:48 +08:00
dependabot[bot]
7a53d8e46b
Bump gorm.io/driver/mysql from 1.1.1 to 1.1.2 in /tests (#4615)
Bumps [gorm.io/driver/mysql](https://github.com/go-gorm/mysql) from 1.1.1 to 1.1.2.
- [Release notes](https://github.com/go-gorm/mysql/releases)
- [Commits](https://github.com/go-gorm/mysql/compare/v1.1.1...v1.1.2)

---
updated-dependencies:
- dependency-name: gorm.io/driver/mysql
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-08-20 17:52:56 +08:00
dependabot[bot]
e076e9e0fb
Bump gorm.io/gorm from 1.21.12 to 1.21.13 in /tests (#4616)
Bumps [gorm.io/gorm](https://github.com/go-gorm/gorm) from 1.21.12 to 1.21.13.
- [Release notes](https://github.com/go-gorm/gorm/releases)
- [Commits](https://github.com/go-gorm/gorm/compare/v1.21.12...v1.21.13)

---
updated-dependencies:
- dependency-name: gorm.io/gorm
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-08-20 17:52:48 +08:00
River
1bb0d8732d
feat: count accpet db.table (#4626)
* feat: count accpet `db`.`table`

* fix: logic fix
2021-08-20 17:37:21 +08:00
River
25f561a742
feat: QuoteTo accept clause.Expr (#4621)
* feat: QuoteTo accept clause.Expr

* test: update Expr build test
2021-08-19 14:33:18 +08:00
Jinzhu
2b2f6e77af Add SchemaName to NamingStrategy 2021-08-11 16:20:29 +08:00
Sungyun Hur
a83d25e25e
chore(logger): explicitly set config of Default Logger (#4605) 2021-08-11 11:49:46 +08:00
张凯强
21e85b89d6
Fix create with ignore migration (#4571) 2021-08-09 13:23:44 +08:00
SmallTianTian
82fe815303
fix: table couln't be reentrant (#4556) 2021-08-09 13:20:22 +08:00
Matthieu MOREL
cbe72751ac
Update Dependencies (#4582)
* Create dependabot.yml

* Bump reviewdog/action-golangci-lint from 1 to 2 (#1)

Bumps [reviewdog/action-golangci-lint](https://github.com/reviewdog/action-golangci-lint) from 1 to 2.
- [Release notes](https://github.com/reviewdog/action-golangci-lint/releases)
- [Commits](https://github.com/reviewdog/action-golangci-lint/compare/v1...v2)

---
updated-dependencies:
- dependency-name: reviewdog/action-golangci-lint
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Bump actions/stale from 3.0.7 to 4 (#2)

Bumps [actions/stale](https://github.com/actions/stale) from 3.0.7 to 4.
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v3.0.7...v4)

---
updated-dependencies:
- dependency-name: actions/stale
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Bump gorm.io/gorm from 1.21.9 to 1.21.12 in /tests (#3)

Bumps [gorm.io/gorm](https://github.com/go-gorm/gorm) from 1.21.9 to 1.21.12.
- [Release notes](https://github.com/go-gorm/gorm/releases)
- [Commits](https://github.com/go-gorm/gorm/compare/v1.21.9...v1.21.12)

---
updated-dependencies:
- dependency-name: gorm.io/gorm
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Bump gorm.io/driver/mysql from 1.0.5 to 1.1.1 in /tests (#4)

Bumps [gorm.io/driver/mysql](https://github.com/go-gorm/mysql) from 1.0.5 to 1.1.1.
- [Release notes](https://github.com/go-gorm/mysql/releases)
- [Commits](https://github.com/go-gorm/mysql/compare/v1.0.5...v1.1.1)

---
updated-dependencies:
- dependency-name: gorm.io/driver/mysql
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Bump github.com/lib/pq from 1.6.0 to 1.10.2 in /tests (#5)

Bumps [github.com/lib/pq](https://github.com/lib/pq) from 1.6.0 to 1.10.2.
- [Release notes](https://github.com/lib/pq/releases)
- [Commits](https://github.com/lib/pq/compare/v1.6.0...v1.10.2)

---
updated-dependencies:
- dependency-name: github.com/lib/pq
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Bump github.com/google/uuid from 1.2.0 to 1.3.0 in /tests (#6)

Bumps [github.com/google/uuid](https://github.com/google/uuid) from 1.2.0 to 1.3.0.
- [Release notes](https://github.com/google/uuid/releases)
- [Commits](https://github.com/google/uuid/compare/v1.2.0...v1.3.0)

---
updated-dependencies:
- dependency-name: github.com/google/uuid
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-08-09 13:16:25 +08:00
Walter Scheper
a870486c4f
Do not emit ORDER BY for empty values (#4592)
This restores the behavior from gorm v1, where calling `DB.Order` with
an empty string, nil, or any unexpected type is a no-op.
2021-08-09 13:14:23 +08:00
heige
9e5a4e30b4
Fix migrator GuessConstraintAndTable method for return value for *schema.Check (#4527) 2021-08-03 11:40:57 +08:00
heige
413fe587c6
Optimize migrator.go MigrateColumn and ColumnTypes func. (#4532) 2021-08-02 18:44:10 +08:00
daheige
7a49629fd1 optimize Parse func for fieldValue.Interface 2021-07-28 19:00:34 +08:00
daheige
41ac73b6a1 update comment for ConvertSliceOfMapToValuesForCreate func 2021-07-28 18:56:39 +08:00
Jason Lee
3e15478534
Merge pull request #4530 from daheige/optimize-setupValuerAndSetter
optimize setupValuerAndSetter func
2021-07-28 18:51:09 +08:00
heige
5115813c50
Fix preload fmt.Errorf formatter (#4531) 2021-07-28 18:50:08 +08:00
s-takehana
2202e99cbf
Fix create index with comments in MySQL (#4521)
* Fix create index with comments in MySQL

* Fix tests
2021-07-18 11:47:44 +08:00
daheige
a70254609d optimize setupValuerAndSetter func 2021-07-14 22:03:17 +08:00
Jinzhu
74752018dc Fix hang when closing a prepared statement 2021-07-14 18:35:10 +08:00
River
ac97aec513
New Comma Expression (#4524)
* Add new comma expression

* Add comma expression unit test
2021-07-14 15:51:24 +08:00
Jinzhu
d4f3c109d6 Fix OnConflict with one column, close #4370 2021-07-13 21:29:31 +08:00
Jinzhu
83530ec659 Fix delete order by clause when counting, close #4478 2021-07-13 21:17:43 +08:00
Jinzhu
52b72d7ef2 Add error explanations when preloading assocations w/o foreign fields, close #4356 2021-07-13 21:00:13 +08:00
Jinzhu
b13732c450 Fix invalid preload SQL when no data found, close #4443 2021-07-13 20:23:05 +08:00
Jinzhu
c73fe96cfd Fix scan into decimal.Decimal, close #4457 2021-07-13 19:59:31 +08:00
Jinzhu
b616d810eb Fix scan single value to custom type, close #4501 2021-07-13 19:29:10 +08:00
Jinzhu
76cd73cb82 Fix wipes out MySQL global variables from the query, close #4515 2021-07-13 18:48:43 +08:00
Jinzhu
2ec7043818 Respect update permission for OnConflict Create 2021-07-13 18:04:42 +08:00
Burak Demirpolat
0329b800b0
slightly better callback warning (#4495) 2021-07-13 16:38:44 +08:00
wangyuehong
80497f27a6
title foreign schema for many2many to avoid panic (#4496)
Co-authored-by: yuehong.wang <yuehong.wang@dena.jp>
2021-07-13 16:36:22 +08:00
shiyu7
16579e00c6
fix: fix race issue in prepare method (#4487) 2021-07-01 06:27:12 +08:00
Jason Lee
6d64e31965
Merge pull request #4479 from wuwenchi/master
Fix document for Pluck's usage #4473
2021-06-27 11:29:20 +08:00
wuwenchi
8bd8d38fe9 Fix Pluck's usage #4473 2021-06-26 21:23:16 +08:00
Jinzhu
8e67a08774 Fix Scopes with Row, close #4465 2021-06-18 15:38:20 +08:00
Jinzhu
3226937f68 Fix calc gormSourceDir, close #4456 2021-06-13 10:32:03 +08:00
kalle (jag)
25b9f2e26a
Added return names to logger.Interface.Trace (#4450) 2021-06-11 21:51:40 +08:00
Tony
a0bddccfe1
Use count(*) instead of count(1) include NULL and non-NULL rows(SQL-92). (#4453) 2021-06-11 21:51:18 +08:00
Jinzhu
5b65b02805 Update tests go.mod 2021-06-11 16:00:26 +08:00
Jinzhu
e425ed6f6a Update tests go.mod 2021-06-10 20:26:21 +08:00
heige
50e85e14d4
Code optimize (#4415)
* optimize gormSourceDir replace

* fmt.Errorf adjust and Optimize for-break

* strings trim

* feat: avoid using the same name field and if..else optimization adjustment

* optimization callbacks/create.go Create func if...else logic

* fix: callbacks/create.go Create func

* fix FileWithLineNum func and add gormSourceDir unit test

* remove debug print and utils_filenum_test.go
2021-06-10 10:21:28 +08:00
liamrfell
00b252559f
Fix: FirstOrCreate slice out of bounds error when using 'Assigns' (#4436)
Co-authored-by: Liam Fell <liam@lot.to>
2021-06-07 10:39:24 +08:00
Vitaliy Shein
dd8bf88eb9
add Target where clause for on conflict (#4442)
Co-authored-by: Vitaliy Shein <vitaliy.shein@thebricks.com>
2021-06-07 10:39:00 +08:00
s-takehana
cf079b8b7d
Update version in tests.yml (#4432) 2021-06-02 09:58:22 +08:00
Jinzhu
810058cd55 Fix soft delete with Update 2021-06-01 18:34:38 +08:00
Jinzhu
9abac96546 Fix Eq, Neq support slice of data 2021-05-31 17:21:27 +08:00
Jinzhu
14e96080d8 Eq, Neq support slice of data 2021-05-31 15:25:38 +08:00
heyanfu
363f9b7863
golint standard (#4421) 2021-05-31 10:08:06 +08:00
Ikko Ashimine
bcf2b385a4
Fix typo in associations_test.go (#4407)
occured -> occurred
2021-05-27 17:40:28 +08:00
Brenda Wallace
ac722c16f9
Small grammar fix in error message (#4406) 2021-05-24 10:23:34 +08:00
Jinzhu
ea1bce3771 Only check struct value can address or not 2021-05-23 11:21:56 +08:00
Paras Waykole
79f427d862
fixed has_many stopped working if field names are identical (#4387)
* fixed belongs_to & has_one reversed if field same

* hasmany same foreign key bug fixed and test added
2021-05-19 16:05:29 +08:00
Atreya
cf93b16730
Fix ErrInvalidTransaction error message (#4380) 2021-05-17 15:53:48 +08:00
Jinzhu
92c3ba9dcc Fix create new db sessions in scopes 2021-05-17 15:36:07 +08:00
Chen Quan
a480bd8545
Update Optimize schema (#4364) 2021-05-10 09:51:50 +08:00
Jinzhu
6b7abc54a2 Fix tests 2021-05-06 13:06:31 +08:00
Jinzhu
2aca96d147 test ignore migration, close #4314, #4315 2021-05-05 08:26:48 +08:00
我的我的
3f359eab9b
slim trace if depth (#4346)
Co-authored-by: gogs <guzzsek@gmail.com>
2021-05-05 08:14:40 +08:00
Paras Waykole
8f7f3ad315
fixed belongs_to & has_one reversed if field same (#4343) 2021-05-05 07:57:54 +08:00
Jinzhu
70e93e73d8 Check data type if copyable before change reference field's type 2021-04-30 16:35:55 +08:00
Karolos Lykos
f0d0bbbc10
Added missing white space (#4330)
* Added missing white space

* Added missing white space

* Added missing white space
2021-04-29 07:15:37 +08:00
Jinzhu
6951be0284 Allow customize clauses 2021-04-28 17:19:30 +08:00
Jinzhu
82cb4ebfe2 Fix overwrite Statement in scopes 2021-04-22 13:12:20 +08:00
Sky34gl3
a855fe6402
Fixed naming longer than 64 characters (#4310)
Co-authored-by: Mickael MAUGER <mickael.mauger@almerys.com>
2021-04-22 13:11:19 +08:00
Jinzhu
d327926425 Check ReflectValue.CanAddr before set field value 2021-04-19 21:37:38 +08:00
Chris Faulkner
15a46bc042
Fix some typos (#4294) 2021-04-19 21:03:39 +08:00
Jinzhu
7701c88507 Assign transaction error to db 2021-04-16 19:27:23 +08:00
Jinzhu
d483ffa45c Fix Preload with nil pointer 2021-04-15 10:37:05 +08:00
heige
74e7a9ca07
Optimize reflect value length and method (#4280)
* Respect ignore migration when add column (#4276)

continue https://github.com/go-gorm/gorm/pull/4028

* feat: Optimal value type acquisition for v (#4278)

* feat: optimize relect value length and value

* feat: optimize ConvertSliceOfMapToValuesForCreate method

Co-authored-by: yrong1997 <yrong1997@gmail.com>
2021-04-14 13:00:54 +08:00
heige
5555b010dc
feat: Optimal value type acquisition for v (#4278) 2021-04-13 09:41:30 +08:00
yrong1997
d7911300f8
Respect ignore migration when add column (#4276)
continue https://github.com/go-gorm/gorm/pull/4028
2021-04-13 09:39:43 +08:00
Jinzhu
d278ca49ef sort GORM options before apply 2021-04-09 11:43:24 +08:00
Jinzhu
ad53074f1d Pass db error to new instance 2021-04-09 11:07:14 +08:00
Jinzhu
f3bdfa8261 Add IgnoreRecordNotFoundError option for logger 2021-04-09 10:21:01 +08:00
Jinzhu
673053f56a Fix context cancel error, close #4259, close #4260 2021-04-09 10:21:01 +08:00
gavwu
8cfa9d98f0
Update field.go (#4228)
seems like the `if-else` branch do the same thing, so remove it
2021-04-02 09:56:38 +08:00
Jinzhu
33601dc72f Support Having w/o Group 2021-03-30 18:28:09 +08:00
Jinzhu
73c6d3e64e Add AfterInitialize error 2021-03-29 18:36:01 +08:00
Jinzhu
0eba7a9ed1 Fix apply option 2021-03-26 14:20:42 +08:00
Jinzhu
a8b72546c1 Fix get database connection for prepared stmt, close #4214 2021-03-25 10:17:57 +08:00
Jinzhu
26e0c6fb69 skip test sqlserver due to it will raise data race for invalid sql 2021-03-24 17:12:30 +08:00
Jinzhu
88078e48d0 Remove sqlite_windows test case 2021-03-24 16:56:41 +08:00
Jinzhu
8204d0ada2 Update tests script 2021-03-24 16:44:51 +08:00
Jinzhu
704e53a774 Call scopes before parse model value, close #4209 2021-03-24 16:35:39 +08:00
Jinzhu
4d5cec8bdd Add golang 1.16 2021-03-24 14:22:36 +08:00
Genta Kamitani
26dd4c980a
Fix: FindInBatches ignores errors (#4203) 2021-03-22 14:11:07 +08:00
Jinzhu
8c92d9694a Fix to call Scopes with using Migrator 2021-03-19 16:34:51 +08:00
Jinzhu
a9fe025ef5 Add GetDBConnector interface 2021-03-19 15:55:38 +08:00
160 changed files with 17710 additions and 2562 deletions

15
.github/dependabot.yml vendored Normal file
View File

@ -0,0 +1,15 @@
---
version: 2
updates:
- package-ecosystem: gomod
directory: /
schedule:
interval: weekly
- package-ecosystem: github-actions
directory: /
schedule:
interval: weekly
- package-ecosystem: gomod
directory: /tests
schedule:
interval: weekly

20
.github/release-drafter.yml vendored Normal file
View File

@ -0,0 +1,20 @@
name-template: 'v Release $NEXT_PATCH_VERSION 🌈'
tag-template: 'v$NEXT_PATCH_VERSION'
categories:
- title: '🚀 Features'
labels:
- 'feature'
- 'enhancement'
- title: '🐛 Bug Fixes'
labels:
- 'fix'
- 'bugfix'
- 'bug'
- title: '🧰 Maintenance'
label: 'chore'
change-template: '- $TITLE @$AUTHOR (#$NUMBER)'
change-title-escapes: '\<*_&'
template: |
## Changes
$CHANGES

31
.github/workflows/create-release.yml vendored Normal file
View File

@ -0,0 +1,31 @@
name: Create Release
on:
push:
tags:
- 'v*.*.*'
permissions:
contents: write
pull-requests: read
jobs:
create_release:
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Generate Release Notes and Publish
id: generate_release_notes
uses: release-drafter/release-drafter@v6
with:
config-name: 'release-drafter.yml'
name: "Release ${{ github.ref_name }}"
tag: ${{ github.ref_name }}
publish: true
prerelease: false
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

26
.github/workflows/golangci-lint.yml vendored Normal file
View File

@ -0,0 +1,26 @@
name: golangci-lint
on:
push:
branches:
- main
- master
pull_request:
permissions:
contents: read
pull-requests: read
jobs:
golangci:
name: lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: stable
- name: golangci-lint
uses: golangci/golangci-lint-action@v7
with:
version: v2.0
only-new-issues: true

View File

@ -3,20 +3,26 @@ on:
schedule:
- cron: "*/10 * * * *"
permissions:
contents: read
jobs:
stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest
env:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v3.0.7
uses: actions/stale@v8
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale"
days-before-stale: 0
days-before-close: 2
days-before-close: 30
remove-stale-when-updated: true
only-labels: "type:invalid question"

View File

@ -11,7 +11,7 @@ jobs:
name: Label issues and pull requests
steps:
- name: check out
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: labeler
uses: jinzhu/super-labeler-action@develop

View File

@ -3,19 +3,25 @@ on:
schedule:
- cron: "*/10 * * * *"
permissions:
contents: read
jobs:
stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest
env:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v3.0.7
uses: actions/stale@v8
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale"
days-before-stale: 0
days-before-close: 2
days-before-close: 30
remove-stale-when-updated: true
only-labels: "type:missing reproduction steps"

View File

@ -1,11 +0,0 @@
name: reviewdog
on: [pull_request]
jobs:
golangci-lint:
name: runner / golangci-lint
runs-on: ubuntu-latest
steps:
- name: Check out code into the Go module directory
uses: actions/checkout@v1
- name: golangci-lint
uses: reviewdog/action-golangci-lint@v1

View File

@ -3,19 +3,25 @@ on:
schedule:
- cron: "0 2 * * *"
permissions:
contents: read
jobs:
stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest
env:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v3.0.7
uses: actions/stale@v8
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been automatically marked as stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 30 days"
days-before-stale: 60
days-before-close: 30
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"
days-before-stale: 360
days-before-close: 180
stale-issue-label: "status:stale"
exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request'
stale-pr-label: 'status:stale'

View File

@ -8,63 +8,41 @@ on:
branches-ignore:
- 'gh-pages'
permissions:
contents: read
jobs:
# Label of the container job
sqlite:
strategy:
matrix:
go: ['1.15', '1.14', '1.13']
platform: [ubuntu-latest, macos-latest] # can not run in windows OS
go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run in windows OS
runs-on: ${{ matrix.platform }}
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v2
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GORM_DIALECT=sqlite ./tests/tests_all.sh
sqlite_windows:
strategy:
matrix:
go: ['1.15', '1.14', '1.13']
platform: [windows-latest]
runs-on: ${{ matrix.platform }}
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- name: go mod package cache
uses: actions/cache@v2
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: cd tests && set GORM_DIALECT=sqlite && go test $race -count=1 -v ./... #run the line in widnows's CMD, default GORM_DIALECT is sqlite
run: GITHUB_ACTION=true GORM_DIALECT=sqlite ./tests/tests_all.sh
mysql:
strategy:
matrix:
dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest']
go: ['1.15', '1.14', '1.13']
dbversion: ['mysql:9', 'mysql:8', 'mysql:5.7']
go: ['1.23', '1.24']
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
@ -87,29 +65,71 @@ jobs:
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v2
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
mariadb:
strategy:
matrix:
dbversion: [ 'mariadb:latest' ]
go: ['1.23', '1.24']
platform: [ ubuntu-latest ]
runs-on: ${{ matrix.platform }}
services:
mysql:
image: ${{ matrix.dbversion }}
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
ports:
- 9910:3306
options: >-
--health-cmd "mariadb-admin ping -ugorm -pgorm"
--health-interval 10s
--health-start-period 10s
--health-timeout 5s
--health-retries 10
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
postgres:
strategy:
matrix:
dbversion: ['postgres:latest', 'postgres:11', 'postgres:10']
go: ['1.15', '1.14', '1.13']
platform: [ubuntu-latest] # can not run in macOS and widnowsOS
dbversion: ['postgres:latest', 'postgres:15', 'postgres:14', 'postgres:13']
go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }}
services:
@ -131,42 +151,40 @@ jobs:
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v2
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh
run: GITHUB_ACTION=true GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh
sqlserver:
strategy:
matrix:
go: ['1.15', '1.14', '1.13']
go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run test in macOS and windows
runs-on: ${{ matrix.platform }}
services:
mssql:
image: mcmoe/mssqldocker:latest
image: mcr.microsoft.com/mssql/server:2022-latest
env:
TZ: Asia/Shanghai
ACCEPT_EULA: Y
SA_PASSWORD: LoremIpsum86
MSSQL_DB: gorm
MSSQL_USER: gorm
MSSQL_PASSWORD: LoremIpsum86
MSSQL_SA_PASSWORD: LoremIpsum86
ports:
- 9930:1433
options: >-
--health-cmd="/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P LoremIpsum86 -l 30 -Q \"SELECT 1\" || exit 1"
--health-cmd="/opt/mssql-tools18/bin/sqlcmd -S localhost -U sa -P ${MSSQL_SA_PASSWORD} -N -C -l 30 -Q \"SELECT 1\" || exit 1"
--health-start-period 10s
--health-interval 10s
--health-timeout 5s
@ -174,18 +192,119 @@ jobs:
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v2
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://sa:LoremIpsum86@localhost:9930?database=master" ./tests/tests_all.sh
tidb:
strategy:
matrix:
dbversion: [ 'v6.5.0' ]
go: ['1.23', '1.24']
platform: [ ubuntu-latest ]
runs-on: ${{ matrix.platform }}
steps:
- name: Setup TiDB
uses: Icemap/tidb-action@main
with:
port: 9940
version: ${{matrix.dbversion}}
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh
gaussdb:
strategy:
matrix:
dbversion: ['opengauss/opengauss:7.0.0-RC1.B023']
go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }}
services:
gaussdb:
image: ${{ matrix.dbversion }}
env:
# GaussDB has password limitations
GS_PASSWORD: Gaussdb@123
TZ: Asia/Shanghai
ports:
- 9950:5432
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: Waiting for GaussDB to be ready
run: |
container_name=$(docker ps --filter "ancestor=opengauss/opengauss:7.0.0-RC1.B023" --format "{{.Names}}")
if [ -z "$container_name" ]; then
echo "Error: failed to find a container created from the 'opengauss/opengauss:7.0.0-RC1.B023' image."
exit 1
fi
max_retries=12
retry_count=0
if [ -t 0 ]; then
TTY_FLAG="-t"
else
TTY_FLAG=""
fi
while [ $retry_count -lt $max_retries ]; do
if docker exec -i "${container_name}" bash -c "su - omm -c 'gsql -U omm -c \"select 1;\"'"
then
echo "Creating database gorm..."
sql_file='/tmp/create_database.sql'
echo "CREATE DATABASE gorm DBCOMPATIBILITY 'PG';" > ${sql_file}
docker cp "${sql_file}" "${container_name}":"${sql_file}"
docker exec -i ${TTY_FLAG} "${container_name}" bash -c "su - omm -c 'gsql -U omm -f ${sql_file}'"
echo "Database initialization completed."
break
fi
echo "Waiting for database to be ready... (attempt $((retry_count + 1))/$max_retries)"
sleep 10
((++retry_count))
done
exit 0
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=gaussdb GORM_DSN="user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh

2
.gitignore vendored
View File

@ -3,3 +3,5 @@ documents
coverage.txt
_book
.idea
vendor
.vscode

19
.golangci.yml Normal file
View File

@ -0,0 +1,19 @@
version: "2"
linters:
default: standard
enable:
- cyclop
- gocritic
- gosec
- ineffassign
- misspell
- prealloc
- unconvert
- unparam
- whitespace
formatters:
enable:
- gofumpt
- goimports

128
CODE_OF_CONDUCT.md Normal file
View File

@ -0,0 +1,128 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to participate in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community includes:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period. This
includes avoiding interactions in community spaces and external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any interaction or public
communication with the community for a specified period. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

View File

@ -1,6 +1,6 @@
The MIT License (MIT)
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
Copyright (c) 2013-present Jinzhu <wosmvp@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@ -4,9 +4,6 @@ The fantastic ORM library for Golang, aims to be developer friendly.
[![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm)
[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions)
[![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm)
[![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm)
[![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT)
[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc)
@ -30,13 +27,18 @@ The fantastic ORM library for Golang, aims to be developer friendly.
## Getting Started
* GORM Guides [https://gorm.io](https://gorm.io)
* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html)
## Contributing
[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
## Contributors
[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework!
## License
© Jinzhu, 2013~time.Now
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License)
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE)

View File

@ -14,6 +14,7 @@ import (
type Association struct {
DB *DB
Relationship *schema.Relationship
Unscope bool
Error error
}
@ -26,7 +27,7 @@ func (db *DB) Association(column string) *Association {
association.Relationship = db.Statement.Schema.Relationships.Relations[column]
if association.Relationship == nil {
association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column)
association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column)
}
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
@ -40,6 +41,15 @@ func (db *DB) Association(column string) *Association {
return association
}
func (association *Association) Unscoped() *Association {
return &Association{
DB: association.DB,
Relationship: association.Relationship,
Error: association.Error,
Unscope: true,
}
}
func (association *Association) Find(out interface{}, conds ...interface{}) error {
if association.Error == nil {
association.Error = association.buildCondition().Find(out, conds...).Error
@ -64,14 +74,30 @@ func (association *Association) Append(values ...interface{}) error {
func (association *Association) Replace(values ...interface{}) error {
if association.Error == nil {
reflectValue := association.DB.Statement.ReflectValue
rel := association.Relationship
var oldBelongsToExpr clause.Expression
// we have to record the old BelongsTo value
if association.Unscope && rel.Type == schema.BelongsTo {
var foreignFields []*schema.Field
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
oldBelongsToExpr = clause.IN{Column: column, Values: values}
}
}
// save associations
if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
return association.Error
}
// set old associations's foreign key to null
reflectValue := association.DB.Statement.ReflectValue
rel := association.Relationship
switch rel.Type {
case schema.BelongsTo:
if len(values) == 0 {
@ -79,10 +105,10 @@ func (association *Association) Replace(values ...interface{}) error {
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
}
case reflect.Struct:
association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
}
for _, ref := range rel.References {
@ -91,17 +117,20 @@ func (association *Association) Replace(values ...interface{}) error {
association.Error = association.DB.UpdateColumns(updateMap).Error
}
if association.Unscope && oldBelongsToExpr != nil {
association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
}
case schema.HasOne, schema.HasMany:
var (
primaryFields []*schema.Field
foreignKeys []string
updateMap = map[string]interface{}{}
relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel})
relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel})
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
tx.Not(clause.IN{Column: column, Values: values})
}
@ -117,9 +146,13 @@ func (association *Association) Replace(values ...interface{}) error {
}
}
if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
if association.Unscope {
association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error
} else {
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
}
}
case schema.Many2Many:
var (
@ -143,14 +176,14 @@ func (association *Association) Replace(values ...interface{}) error {
}
}
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
tx.Where(clause.IN{Column: column, Values: values})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
}
@ -184,29 +217,53 @@ func (association *Association) Delete(values ...interface{}) error {
switch rel.Type {
case schema.BelongsTo:
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
associationDB := association.DB.Session(&Session{})
tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields)
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields)
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
if association.Unscope {
var foreignFields []*schema.Field
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
}
}
case schema.HasOne, schema.HasMany:
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
model := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := association.DB.Model(model)
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
if association.Unscope {
association.Error = tx.Clauses(conds...).Delete(model).Error
} else {
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
}
case schema.Many2Many:
var (
primaryFields, relPrimaryFields []*schema.Field
@ -228,11 +285,14 @@ func (association *Association) Delete(values ...interface{}) error {
}
}
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
@ -241,11 +301,11 @@ func (association *Association) Delete(values ...interface{}) error {
if association.Error == nil {
// clean up deleted values's foreign key
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
cleanUpDeletedRelations := func(data reflect.Value) {
if _, zero := rel.Field.ValueOf(data); !zero {
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data))
if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero {
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data))
primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
switch fieldValue.Kind() {
@ -253,7 +313,7 @@ func (association *Association) Delete(values ...interface{}) error {
validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
for i := 0; i < fieldValue.Len(); i++ {
for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i))
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i))
}
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
@ -261,23 +321,23 @@ func (association *Association) Delete(values ...interface{}) error {
}
}
association.Error = rel.Field.Set(data, validFieldValues.Interface())
association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface())
case reflect.Struct:
for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(fieldValue)
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue)
}
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
break
}
if rel.JoinTable == nil {
for _, ref := range rel.References {
if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
} else {
association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
}
}
}
@ -329,14 +389,18 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
switch rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() > 0 {
association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface())
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
}
}
case reflect.Struct:
association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface())
if !rv.CanAddr() {
association.Error = ErrInvalidValue
return
}
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
@ -344,9 +408,13 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
}
case schema.HasMany, schema.Many2Many:
elemType := association.Relationship.Field.IndirectFieldType.Elem()
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source))
oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
var fieldValue reflect.Value
if clear {
fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem()
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap())
} else {
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap())
reflect.Copy(fieldValue, oldFieldValue)
}
appendToFieldValues := func(ev reflect.Value) {
@ -355,7 +423,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} else if ev.Type().Elem().AssignableTo(elemType) {
fieldValue = reflect.Append(fieldValue, ev.Elem())
} else {
association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name)
association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name)
}
if elemType.Kind() == reflect.Struct {
@ -369,11 +437,15 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
}
case reflect.Struct:
if !rv.CanAddr() {
association.Error = ErrInvalidValue
return
}
appendToFieldValues(rv.Addr())
}
if association.Error == nil {
association.Error = association.Relationship.Field.Set(source, fieldValue.Interface())
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface())
}
}
}
@ -421,7 +493,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
// clear old data
if clear && len(values) == 0 {
for i := 0; i < reflectValue.Len(); i++ {
if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
association.Error = err
break
}
@ -429,7 +501,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
if association.Relationship.JoinTable == nil {
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
association.Error = err
break
}
@ -446,6 +518,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
for i := 0; i < reflectValue.Len(); i++ {
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
if association.Error != nil {
return
}
// TODO support save slice data, sql with case?
association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
@ -453,12 +528,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
case reflect.Struct:
// clear old data
if clear && len(values) == 0 {
association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
if association.Relationship.JoinTable == nil && association.Error == nil {
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
}
}
}
@ -467,6 +542,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
for idx, value := range values {
rv := reflect.Indirect(reflect.ValueOf(value))
appendToRelations(reflectValue, rv, clear && idx == 0)
if association.Error != nil {
return
}
}
if len(values) > 0 {
@ -475,7 +553,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
}
for _, assignBack := range assignBacks {
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source))
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source))
if assignBack.Index > 0 {
reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
} else {
@ -486,19 +564,21 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
func (association *Association) buildCondition() *DB {
var (
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue)
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause)
}
joinStmt.Build("WHERE")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
if len(joinStmt.SQL.String()) > 0 {
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
}
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{

View File

@ -32,6 +32,7 @@ type callbacks struct {
type processor struct {
db *DB
Clauses []string
fns []func(*DB)
callbacks []*callback
}
@ -71,21 +72,38 @@ func (cs *callbacks) Raw() *processor {
return cs.processors["raw"]
}
func (p *processor) Execute(db *DB) {
func (p *processor) Execute(db *DB) *DB {
// call scopes
for len(db.Statement.scopes) > 0 {
db = db.executeScopes()
}
var (
curTime = time.Now()
stmt = db.Statement
curTime = time.Now()
stmt = db.Statement
resetBuildClauses bool
)
if len(stmt.BuildClauses) == 0 {
stmt.BuildClauses = p.Clauses
resetBuildClauses = true
}
if optimizer, ok := db.Statement.Dest.(StatementModifier); ok {
optimizer.ModifyStatement(stmt)
}
// assign model values
if stmt.Model == nil {
stmt.Model = stmt.Dest
} else if stmt.Dest == nil {
stmt.Dest = stmt.Model
}
// parse model values
if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil {
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
} else {
db.AddError(err)
@ -93,12 +111,12 @@ func (p *processor) Execute(db *DB) {
}
}
// assign stmt.ReflectValue
if stmt.Dest != nil {
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
for stmt.ReflectValue.Kind() == reflect.Ptr {
if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() {
stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem()))
break
}
stmt.ReflectValue = stmt.ReflectValue.Elem()
@ -108,27 +126,30 @@ func (p *processor) Execute(db *DB) {
}
}
// call scopes
for len(stmt.scopes) > 0 {
scopes := stmt.scopes
stmt.scopes = nil
for _, scope := range scopes {
db = scope(db)
}
}
for _, f := range p.fns {
f(db)
}
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
}, db.Error)
if stmt.SQL.Len() > 0 {
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
sql, vars := stmt.SQL.String(), stmt.Vars
if filter, ok := db.Logger.(ParamsFilter); ok {
sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
}
return db.Dialector.Explain(sql, vars...), db.RowsAffected
}, db.Error)
}
if !stmt.DB.DryRun {
stmt.SQL.Reset()
stmt.Vars = nil
}
if resetBuildClauses {
stmt.BuildClauses = nil
}
return db
}
func (p *processor) Get(name string) func(*DB) {
@ -166,10 +187,18 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
func (p *processor) compile() (err error) {
var callbacks []*callback
removedMap := map[string]bool{}
for _, callback := range p.callbacks {
if callback.match == nil || callback.match(p.db) {
callbacks = append(callbacks, callback)
}
if callback.remove {
removedMap[callback.name] = true
}
}
if len(removedMap) > 0 {
callbacks = removeCallbacks(callbacks, removedMap)
}
p.callbacks = callbacks
@ -197,7 +226,7 @@ func (c *callback) Register(name string, fn func(*DB)) error {
}
func (c *callback) Remove(name string) error {
c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum())
c.name = name
c.remove = true
c.processor.callbacks = append(c.processor.callbacks, c)
@ -205,7 +234,7 @@ func (c *callback) Remove(name string) error {
}
func (c *callback) Replace(name string, fn func(*DB)) error {
c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum())
c.name = name
c.handler = fn
c.replace = true
@ -228,14 +257,20 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
names, sorted []string
sortCallback func(*callback) error
)
sort.Slice(cs, func(i, j int) bool {
return cs[j].before == "*" || cs[j].after == "*"
sort.SliceStable(cs, func(i, j int) bool {
if cs[j].before == "*" && cs[i].before != "*" {
return true
}
if cs[j].after == "*" && cs[i].after != "*" {
return true
}
return false
})
for _, c := range cs {
// show warning message the callback name already exists
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum())
}
names = append(names, c.name)
}
@ -251,7 +286,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
// if before callback already sorted, append current callback just after it
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
} else if curIdx > sortedIdx {
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before)
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before)
}
} else if idx := getRIndex(names, c.before); idx != -1 {
// if before callback exists
@ -269,7 +304,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
// if after callback sorted, append current callback to last
sorted = append(sorted, c.name)
} else if curIdx < sortedIdx {
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after)
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after)
}
} else if idx := getRIndex(names, c.after); idx != -1 {
// if after callback exists but haven't sorted
@ -312,3 +347,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
return
}
func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
callbacks := make([]*callback, 0, len(cs))
for _, callback := range cs {
if nameMap[callback.name] {
continue
}
callbacks = append(callbacks, callback)
}
return callbacks
}

View File

@ -7,6 +7,7 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
@ -23,8 +24,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
setupReferences := func(obj reflect.Value, elem reflect.Value) {
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(elem)
db.AddError(ref.ForeignKey.Set(obj, pv))
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv))
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
dest[ref.ForeignKey.DBName] = pv
@ -39,49 +40,64 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
objs = make([]reflect.Value, 0, db.Statement.ReflectValue.Len())
rValLen = db.Statement.ReflectValue.Len()
objs = make([]reflect.Value, 0, rValLen)
fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr
)
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
fieldType = reflect.PointerTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{}
for i := 0; i < rValLen; i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() != reflect.Struct {
break
}
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
if !isPtr {
rv = rv.Addr()
}
objs = append(objs, obj)
elems = reflect.Append(elems, rv)
if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
objs = append(objs, obj)
if isPtr {
elems = reflect.Append(elems, rv)
} else {
elems = reflect.Append(elems, rv.Addr())
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
} else {
break
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, rv)
}
}
}
if elems.Len() > 0 {
if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil {
if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil {
for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i))
}
}
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil {
if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
setupReferences(db.Statement.ReflectValue, rv)
}
}
@ -110,7 +126,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
)
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
fieldType = reflect.PointerTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
@ -119,18 +135,18 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(obj); !zero {
rv := rel.Field.ReflectValueOf(obj)
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj)
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj)
db.AddError(ref.ForeignKey.Set(rv, fv))
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv))
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue))
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue))
}
}
@ -145,11 +161,11 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
if f.Kind() != reflect.Ptr {
f = f.Addr()
}
@ -157,15 +173,15 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue)
ref.ForeignKey.Set(f, fv)
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv))
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(f, ref.PrimaryValue)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue))
}
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns)
saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
}
}
}
@ -179,28 +195,43 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
fieldType = reflect.PointerTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(v)
ref.ForeignKey.Set(elem, pv)
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv))
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(elem, ref.PrimaryValue)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue))
}
}
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
}
}
}
}
@ -224,7 +255,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
}
}
@ -237,41 +268,57 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
fieldType = reflect.PointerTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.JoinTable.ModelType)), 0, 10)
objs := []reflect.Value{}
appendToJoins := func(obj reflect.Value, elem reflect.Value) {
joinValue := reflect.New(rel.JoinTable.ModelType)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj)
ref.ForeignKey.Set(joinValue, fv)
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(joinValue, ref.PrimaryValue)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue))
} else {
fv, _ := ref.PrimaryKey.ValueOf(elem)
ref.ForeignKey.Set(joinValue, fv)
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
}
}
joins = reflect.Append(joins, joinValue)
}
identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
objs = append(objs, v)
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
if !isPtr {
elem = elem.Addr()
}
objs = append(objs, v)
elems = reflect.Append(elems, elem)
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, elem)
}
}
}
}
@ -288,12 +335,13 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
appendToElems(db.Statement.ReflectValue)
}
if elems.Len() > 0 {
// optimize elems of reflect value length
if elemLen := elems.Len(); elemLen > 0 {
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil)
saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil)
}
for i := 0; i < elems.Len(); i++ {
for i := 0; i < elemLen; i++ {
appendToJoins(objs[i], elems.Index(i))
}
}
@ -309,40 +357,35 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
}
}
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) clause.OnConflict {
if stmt.DB.FullSaveAssociations {
defaultUpdatingColumns = make([]string, 0, len(s.DBNames))
for _, dbName := range s.DBNames {
if v, ok := selectColumns[dbName]; (ok && !v) || (!ok && restricted) {
continue
}
if !s.LookUpField(dbName).PrimaryKey {
defaultUpdatingColumns = append(defaultUpdatingColumns, dbName)
}
}
}
if len(defaultUpdatingColumns) > 0 {
columns := make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
for _, dbName := range s.PrimaryFieldDBNames {
columns = append(columns, clause.Column{Name: dbName})
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName})
}
return clause.OnConflict{
Columns: columns,
DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns),
onConflict.UpdateAll = stmt.DB.FullSaveAssociations
if !onConflict.UpdateAll {
onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns)
}
} else {
onConflict.DoNothing = true
}
return clause.OnConflict{DoNothing: true}
return
}
func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
// stop save association loop
if checkAssociationsSaved(db, rValues) {
return nil
}
var (
selects, omits []string
onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns)
onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns)
refName = rel.Name + "."
values = rValues.Interface()
)
for name, ok := range selectColumns {
@ -372,7 +415,7 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{},
})
if tx.Statement.FullSaveAssociations {
tx = tx.InstanceSet("gorm:update_track_time", true)
tx = tx.Set("gorm:update_track_time", true)
}
if len(selects) > 0 {
@ -387,3 +430,24 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{},
return db.AddError(tx.Create(values).Error)
}
// check association values has been saved
// if values kind is Struct, check it has been saved
// if values kind is Slice/Array, check all items have been saved
var visitMapStoreKey = "gorm:saved_association_map"
func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
if visit, ok := db.Get(visitMapStoreKey); ok {
if v, ok := visit.(*visitMap); ok {
if loadOrStoreVisitMap(v, values) {
return true
}
}
} else {
vistMap := make(visitMap)
loadOrStoreVisitMap(&vistMap, values)
db.Set(visitMapStoreKey, &vistMap)
}
return false
}

View File

@ -4,9 +4,19 @@ import (
"gorm.io/gorm"
)
var (
createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
updateClauses = []string{"UPDATE", "SET", "WHERE"}
deleteClauses = []string{"DELETE", "FROM", "WHERE"}
)
type Config struct {
LastInsertIDReversed bool
WithReturning bool
CreateClauses []string
QueryClauses []string
UpdateClauses []string
DeleteClauses []string
}
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
@ -14,6 +24,19 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
return !db.SkipDefaultTransaction
}
if len(config.CreateClauses) == 0 {
config.CreateClauses = createClauses
}
if len(config.QueryClauses) == 0 {
config.QueryClauses = queryClauses
}
if len(config.DeleteClauses) == 0 {
config.DeleteClauses = deleteClauses
}
if len(config.UpdateClauses) == 0 {
config.UpdateClauses = updateClauses
}
createCallback := db.Callback().Create()
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
createCallback.Register("gorm:before_create", BeforeCreate)
@ -22,30 +45,39 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
createCallback.Clauses = config.CreateClauses
queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery)
queryCallback.Clauses = config.QueryClauses
deleteCallback := db.Callback().Delete()
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
deleteCallback.Register("gorm:before_delete", BeforeDelete)
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
deleteCallback.Register("gorm:delete", Delete)
deleteCallback.Register("gorm:delete", Delete(config))
deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
deleteCallback.Clauses = config.DeleteClauses
updateCallback := db.Callback().Update()
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
updateCallback.Register("gorm:before_update", BeforeUpdate)
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
updateCallback.Register("gorm:update", Update)
updateCallback.Register("gorm:update", Update(config))
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
updateCallback.Register("gorm:after_update", AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
updateCallback.Clauses = config.UpdateClauses
db.Callback().Row().Register("gorm:row", RowQuery)
db.Callback().Raw().Register("gorm:raw", RawExec)
rowCallback := db.Callback().Row()
rowCallback.Register("gorm:row", RowQuery)
rowCallback.Clauses = config.QueryClauses
rawCallback := db.Callback().Raw()
rawCallback.Register("gorm:raw", RawExec)
rawCallback.Clauses = config.QueryClauses
}

View File

@ -13,11 +13,20 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
case reflect.Slice, reflect.Array:
db.Statement.CurDestIndex = 0
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx)
if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() {
fc(value.Addr().Interface(), tx)
} else {
db.AddError(gorm.ErrInvalidValue)
return
}
db.Statement.CurDestIndex++
}
case reflect.Struct:
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
if db.Statement.ReflectValue.CanAddr() {
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
} else {
db.AddError(gorm.ErrInvalidValue)
}
}
}
}

View File

@ -3,12 +3,15 @@ package callbacks
import (
"fmt"
"reflect"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
// BeforeCreate before create hooks
func BeforeCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
@ -30,194 +33,206 @@ func BeforeCreate(db *gorm.DB) {
}
}
// Create create hook
func Create(config *Config) func(db *gorm.DB) {
if config.WithReturning {
return CreateWithReturning
} else {
return func(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
supportReturning := utils.Contains(config.CreateClauses, "RETURNING")
return func(db *gorm.DB) {
if db.Error != nil {
return
}
if db.Statement.Schema != nil {
if !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Insert{})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
}
if !db.DryRun && db.Error == nil {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
db.RowsAffected, _ = result.RowsAffected()
if db.RowsAffected > 0 {
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if config.LastInsertIDReversed {
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
if isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
} else {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
}
case reflect.Struct:
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
}
}
} else {
db.AddError(err)
}
}
if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
if field.Readable {
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
}
} else {
db.AddError(err)
}
if len(fromColumns) > 0 {
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
}
}
}
}
}
}
func CreateWithReturning(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Insert{})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
db.Statement.Build(db.Statement.BuildClauses...)
}
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
db.Statement.WriteString(" RETURNING ")
isDryRun := !db.DryRun && db.Error == nil
if !isDryRun {
return
}
var (
fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue))
values = make([]interface{}, len(sch.FieldsWithDefaultDBValue))
)
for idx, field := range sch.FieldsWithDefaultDBValue {
if idx > 0 {
db.Statement.WriteByte(',')
ok, mode := hasReturning(db, supportReturning)
if ok {
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing {
mode |= gorm.ScanOnConflictDoNothing
}
fields[idx] = field
db.Statement.WriteQuoted(field.DBName)
}
if !db.DryRun && db.Error == nil {
db.RowsAffected = 0
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
rows, err := db.Statement.ConnPool.QueryContext(
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
)
if db.AddError(err) == nil {
defer func() {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, mode)
if err == nil {
defer rows.Close()
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
c := db.Statement.Clauses["ON CONFLICT"]
onConflict, _ := c.Expression.(clause.OnConflict)
return
}
for rows.Next() {
BEGIN:
reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected))
if reflect.Indirect(reflectValue).Kind() != reflect.Struct {
break
}
result, err := db.Statement.ConnPool.ExecContext(
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
)
if err != nil {
db.AddError(err)
return
}
for idx, field := range fields {
fieldValue := field.ReflectValueOf(reflectValue)
db.RowsAffected, _ = result.RowsAffected()
if onConflict.DoNothing && !fieldValue.IsZero() {
db.RowsAffected++
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() {
return
}
if db.RowsAffected == 0 {
return
}
goto BEGIN
}
var (
pkField *schema.Field
pkFieldName = "@id"
)
values[idx] = fieldValue.Addr().Interface()
}
if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil ||
!db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue ||
!db.Statement.Schema.PrioritizedPrimaryField.Readable {
return
}
pkField = db.Statement.Schema.PrioritizedPrimaryField
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
}
db.RowsAffected++
if err := rows.Scan(values...); err != nil {
db.AddError(err)
}
}
case reflect.Struct:
for idx, field := range fields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
if !supportReturning {
db.AddError(err)
}
return
}
// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) {
case map[string]interface{}:
values[pkFieldName] = insertID
case *map[string]interface{}:
(*values)[pkFieldName] = insertID
case []map[string]interface{}, *[]map[string]interface{}:
mapValues, ok := values.([]map[string]interface{})
if !ok {
if v, ok := values.(*[]map[string]interface{}); ok {
if *v != nil {
mapValues = *v
}
}
}
if config.LastInsertIDReversed {
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
}
for _, mapValue := range mapValues {
if mapValue != nil {
mapValue[pkFieldName] = insertID
}
insertID += schema.DefaultAutoIncrementIncrement
}
default:
if pkField == nil {
return
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if config.LastInsertIDReversed {
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
if rows.Next() {
db.RowsAffected++
db.AddError(rows.Scan(values...))
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
if isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID -= pkField.AutoIncrementIncrement
}
}
} else {
db.AddError(err)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID += pkField.AutoIncrementIncrement
}
}
}
case reflect.Struct:
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero {
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
}
}
} else if !db.DryRun && db.Error == nil {
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
}
}
}
// AfterCreate after create hooks
func AfterCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
called = true
db.AddError(i.AfterSave(tx))
}
}
if db.Statement.Schema.AfterCreate {
if i, ok := value.(AfterCreateInterface); ok {
called = true
db.AddError(i.AfterCreate(tx))
}
}
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
called = true
db.AddError(i.AfterSave(tx))
}
}
return called
})
}
@ -225,6 +240,8 @@ func AfterCreate(db *gorm.DB) {
// ConvertToCreateValues convert to create values
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
curTime := stmt.DB.NowFunc()
switch value := stmt.Dest.(type) {
case map[string]interface{}:
values = ConvertMapToValuesForCreate(stmt, value)
@ -237,9 +254,11 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
default:
var (
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
curTime = stmt.DB.NowFunc()
_, updateTrackTime = stmt.Get("gorm:update_track_time")
isZero bool
)
stmt.Settings.Delete("gorm:update_track_time")
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
for _, db := range stmt.Schema.DBNames {
@ -252,15 +271,18 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
stmt.SQL.Grow(stmt.ReflectValue.Len() * 18)
values.Values = make([][]interface{}, stmt.ReflectValue.Len())
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
if stmt.ReflectValue.Len() == 0 {
rValLen := stmt.ReflectValue.Len()
if rValLen == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
for i := 0; i < stmt.ReflectValue.Len(); i++ {
stmt.SQL.Grow(rValLen * 18)
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
values.Values = make([][]interface{}, rValLen)
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
for i := 0; i < rValLen; i++ {
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
if !rv.IsValid() {
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
@ -270,41 +292,41 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
values.Values[i] = make([]interface{}, len(values.Columns))
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[i][idx], isZero = field.ValueOf(rv); isZero {
if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
if field.DefaultValueInterface != nil {
values.Values[i][idx] = field.DefaultValueInterface
field.Set(rv, field.DefaultValueInterface)
stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(rv, curTime)
values.Values[i][idx], _ = field.ValueOf(rv)
}
} else if field.AutoUpdateTime > 0 {
if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok {
field.Set(rv, curTime)
values.Values[i][idx], _ = field.ValueOf(rv)
stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
}
} else if field.AutoUpdateTime > 0 && updateTrackTime {
stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, isZero := field.ValueOf(rv); !isZero {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero {
if len(defaultValueFieldsHavingValue[field]) == 0 {
defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len())
defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
}
defaultValueFieldsHavingValue[field][i] = v
defaultValueFieldsHavingValue[field][i] = rvOfvalue
}
}
}
}
for field, vs := range defaultValueFieldsHavingValue {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
for idx := range values.Values {
if vs[idx] == nil {
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
} else {
values.Values[idx] = append(values.Values[idx], vs[idx])
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if vs, ok := defaultValueFieldsHavingValue[field]; ok {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
for idx := range values.Values {
if vs[idx] == nil {
values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
} else {
values.Values[idx] = append(values.Values[idx], vs[idx])
}
}
}
}
@ -312,27 +334,25 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero {
if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
if field.DefaultValueInterface != nil {
values.Values[0][idx] = field.DefaultValueInterface
field.Set(stmt.ReflectValue, field.DefaultValueInterface)
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(stmt.ReflectValue, curTime)
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
}
} else if field.AutoUpdateTime > 0 {
if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok {
field.Set(stmt.ReflectValue, curTime)
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
} else if field.AutoUpdateTime > 0 && updateTrackTime {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], v)
values.Values[0] = append(values.Values[0], rvOfvalue)
}
}
}
@ -343,17 +363,39 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
if stmt.Schema != nil && len(values.Columns) > 1 {
if stmt.Schema != nil && len(values.Columns) >= 1 {
selectColumns, restricted := stmt.SelectAndOmitColumns(true, true)
columns := make([]string, 0, len(values.Columns)-1)
for _, column := range values.Columns {
if field := stmt.Schema.LookUpField(column.Name); field != nil {
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
columns = append(columns, column.Name)
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil ||
strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 {
if field.AutoUpdateTime > 0 {
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
switch field.AutoUpdateTime {
case schema.UnixNanosecond:
assignment.Value = curTime.UnixNano()
case schema.UnixMillisecond:
assignment.Value = curTime.UnixMilli()
case schema.UnixSecond:
assignment.Value = curTime.Unix()
}
onConflict.DoUpdates = append(onConflict.DoUpdates, assignment)
} else {
columns = append(columns, column.Name)
}
}
}
}
}
onConflict.DoUpdates = clause.AssignmentColumns(columns)
onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...)
if len(onConflict.DoUpdates) == 0 {
onConflict.DoNothing = true
}
// use primary fields as default OnConflict columns
if len(onConflict.Columns) == 0 {

71
callbacks/create_test.go Normal file
View File

@ -0,0 +1,71 @@
package callbacks
import (
"reflect"
"sync"
"testing"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
var schemaCache = &sync.Map{}
func TestConvertToCreateValues_DestType_Slice(t *testing.T) {
type user struct {
ID int `gorm:"primaryKey"`
Name string
Email string `gorm:"default:(-)"`
Age int `gorm:"default:(-)"`
}
s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{})
if err != nil {
t.Errorf("parse schema error: %v, is not expected", err)
return
}
dest := []*user{
{
ID: 1,
Name: "alice",
Email: "email",
Age: 18,
},
{
ID: 2,
Name: "bob",
Email: "email",
Age: 19,
},
}
stmt := &gorm.Statement{
DB: &gorm.DB{
Config: &gorm.Config{
NowFunc: func() time.Time { return time.Time{} },
},
Statement: &gorm.Statement{
Settings: sync.Map{},
Schema: s,
},
},
ReflectValue: reflect.ValueOf(dest),
Dest: dest,
}
stmt.Schema = s
values := ConvertToCreateValues(stmt)
expected := clause.Values{
// column has value + defaultValue column has value (which should have a stable order)
Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}},
Values: [][]interface{}{
{"alice", "email", 18, 1},
{"bob", "email", 19, 2},
},
}
if !reflect.DeepEqual(expected, values) {
t.Errorf("expected: %v got %v", expected, values)
}
}

View File

@ -7,6 +7,7 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func BeforeDelete(db *gorm.DB) {
@ -25,99 +26,110 @@ func BeforeDelete(db *gorm.DB) {
func DeleteBeforeAssociations(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
if !restricted {
return
}
if restricted {
for column, v := range selectColumns {
if v {
if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok {
switch rel.Type {
case schema.HasOne, schema.HasMany:
queryConds := rel.ToQueryConditions(db.Statement.ReflectValue)
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
withoutConditions := false
if db.Statement.Unscoped {
tx = tx.Unscoped()
}
for column, v := range selectColumns {
if !v {
continue
}
if len(db.Statement.Selects) > 0 {
selects := make([]string, 0, len(db.Statement.Selects))
for _, s := range db.Statement.Selects {
if s == clause.Associations {
selects = append(selects, s)
} else if strings.HasPrefix(s, column+".") {
selects = append(selects, strings.TrimPrefix(s, column+"."))
}
}
rel, ok := db.Statement.Schema.Relationships.Relations[column]
if !ok {
continue
}
if len(selects) > 0 {
tx = tx.Select(selects)
}
}
switch rel.Type {
case schema.HasOne, schema.HasMany:
queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue)
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
withoutConditions := false
if db.Statement.Unscoped {
tx = tx.Unscoped()
}
for _, cond := range queryConds {
if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
withoutConditions = true
break
}
}
if !withoutConditions {
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
}
case schema.Many2Many:
var (
queryConds = make([]clause.Expression, 0, len(rel.References))
foreignFields = make([]*schema.Field, 0, len(rel.References))
relForeignKeys = make([]string, 0, len(rel.References))
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
table = rel.JoinTable.Table
tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table)
)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey)
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
} else if ref.PrimaryValue != "" {
queryConds = append(queryConds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
}
}
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields)
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
if len(db.Statement.Selects) > 0 {
selects := make([]string, 0, len(db.Statement.Selects))
for _, s := range db.Statement.Selects {
if s == clause.Associations {
selects = append(selects, s)
} else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) {
selects = append(selects, strings.TrimPrefix(s, columnPrefix))
}
}
if len(selects) > 0 {
tx = tx.Select(selects)
}
}
for _, cond := range queryConds {
if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
withoutConditions = true
break
}
}
if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
case schema.Many2Many:
var (
queryConds = make([]clause.Expression, 0, len(rel.References))
foreignFields = make([]*schema.Field, 0, len(rel.References))
relForeignKeys = make([]string, 0, len(rel.References))
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
table = rel.JoinTable.Table
tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table)
)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey)
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
} else if ref.PrimaryValue != "" {
queryConds = append(queryConds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
}
}
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields)
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
}
}
}
}
func Delete(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
func Delete(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.DeleteClauses, "RETURNING")
return func(db *gorm.DB) {
if db.Error != nil {
return
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.DeleteClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(100)
db.Statement.AddClauseIfNotExists(clause.Delete{})
if db.Statement.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
@ -125,7 +137,7 @@ func Delete(db *gorm.DB) {
}
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
_, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
@ -135,21 +147,36 @@ func Delete(db *gorm.DB) {
}
db.Statement.AddClauseIfNotExists(clause.From{})
db.Statement.Build("DELETE", "FROM", "WHERE")
db.Statement.Build(db.Statement.BuildClauses...)
}
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil {
db.AddError(gorm.ErrMissingWhereClause)
return
}
checkMissingWhereConditions(db)
if !db.DryRun && db.Error == nil {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
ok, mode := hasReturning(db, supportReturning)
if !ok {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
return
}
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
gorm.Scan(rows, db, mode)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
db.AddError(rows.Close())
}
}
}

View File

@ -1,6 +1,7 @@
package callbacks
import (
"reflect"
"sort"
"gorm.io/gorm"
@ -12,7 +13,7 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter
values.Columns = make([]clause.Column, 0, len(mapValue))
selectColumns, restricted := stmt.SelectAndOmitColumns(true, false)
var keys = make([]string, 0, len(mapValue))
keys := make([]string, 0, len(mapValue))
for k := range mapValue {
keys = append(keys, k)
}
@ -40,17 +41,20 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter
// ConvertSliceOfMapToValuesForCreate convert slice of map to values
func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) {
var (
columns = make([]string, 0, len(mapValues))
result = map[string][]interface{}{}
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
)
columns := make([]string, 0, len(mapValues))
// when the length of mapValues is zero,return directly here
// no need to call stmt.SelectAndOmitColumns method
if len(mapValues) == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
var (
result = make(map[string][]interface{}, len(mapValues))
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
)
for idx, mapValue := range mapValues {
for k, v := range mapValue {
if stmt.Schema != nil {
@ -88,3 +92,61 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
}
return
}
func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
if supportReturning {
if c, ok := tx.Statement.Clauses["RETURNING"]; ok {
returning, _ := c.Expression.(clause.Returning)
if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") {
return true, 0
}
return true, gorm.ScanUpdate
}
}
return false, 0
}
func checkMissingWhereConditions(db *gorm.DB) {
if !db.AllowGlobalUpdate && db.Error == nil {
where, withCondition := db.Statement.Clauses["WHERE"]
if withCondition {
if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
whereClause, _ := where.Expression.(clause.Where)
withCondition = len(whereClause.Exprs) > 1
}
}
if !withCondition {
db.AddError(gorm.ErrMissingWhereClause)
}
return
}
}
type visitMap = map[reflect.Value]bool
// Check if circular values, return true if loaded
func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Slice, reflect.Array:
loaded = true
for i := 0; i < v.Len(); i++ {
if !loadOrStoreVisitMap(visitMap, v.Index(i)) {
loaded = false
}
}
case reflect.Struct, reflect.Interface:
if v.CanAddr() {
p := v.Addr()
if _, ok := (*visitMap)[p]; ok {
return true
}
(*visitMap)[p] = true
}
}
return
}

157
callbacks/helper_test.go Normal file
View File

@ -0,0 +1,157 @@
package callbacks
import (
"reflect"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
func TestLoadOrStoreVisitMap(t *testing.T) {
var vm visitMap
var loaded bool
type testM struct {
Name string
}
t1 := testM{Name: "t1"}
t2 := testM{Name: "t2"}
t3 := testM{Name: "t3"}
vm = make(visitMap)
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
t.Fatalf("loaded should be true")
}
// t1 already exist but t2 not
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
t.Fatalf("loaded should be true")
}
}
func TestConvertMapToValuesForCreate(t *testing.T) {
testCase := []struct {
name string
input map[string]interface{}
expect clause.Values
}{
{
name: "Test convert string value",
input: map[string]interface{}{
"name": "my name",
},
expect: clause.Values{
Columns: []clause.Column{{Name: "name"}},
Values: [][]interface{}{{"my name"}},
},
},
{
name: "Test convert int value",
input: map[string]interface{}{
"age": 18,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "age"}},
Values: [][]interface{}{{18}},
},
},
{
name: "Test convert float value",
input: map[string]interface{}{
"score": 99.5,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "score"}},
Values: [][]interface{}{{99.5}},
},
},
{
name: "Test convert bool value",
input: map[string]interface{}{
"active": true,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "active"}},
Values: [][]interface{}{{true}},
},
},
}
for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
actual := ConvertMapToValuesForCreate(&gorm.Statement{}, tc.input)
if !reflect.DeepEqual(actual, tc.expect) {
t.Errorf("expect %v got %v", tc.expect, actual)
}
})
}
}
func TestConvertSliceOfMapToValuesForCreate(t *testing.T) {
testCase := []struct {
name string
input []map[string]interface{}
expect clause.Values
}{
{
name: "Test convert slice of string value",
input: []map[string]interface{}{
{"name": "my name"},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "name"}},
Values: [][]interface{}{{"my name"}},
},
},
{
name: "Test convert slice of int value",
input: []map[string]interface{}{
{"age": 18},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "age"}},
Values: [][]interface{}{{18}},
},
},
{
name: "Test convert slice of float value",
input: []map[string]interface{}{
{"score": 99.5},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "score"}},
Values: [][]interface{}{{99.5}},
},
},
{
name: "Test convert slice of bool value",
input: []map[string]interface{}{
{"active": true},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "active"}},
Values: [][]interface{}{{true}},
},
},
}
for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
actual := ConvertSliceOfMapToValuesForCreate(&gorm.Statement{}, tc.input)
if !reflect.DeepEqual(actual, tc.expect) {
t.Errorf("expected %v but got %v", tc.expect, actual)
}
})
}
}

View File

@ -1,7 +1,10 @@
package callbacks
import (
"fmt"
"reflect"
"sort"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
@ -9,10 +12,179 @@ import (
"gorm.io/gorm/utils"
)
func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) {
// parsePreloadMap extracts nested preloads. e.g.
//
// // schema has a "k0" relation and a "k7.k8" embedded relation
// parsePreloadMap(schema, map[string][]interface{}{
// clause.Associations: {"arg1"},
// "k1": {"arg2"},
// "k2.k3": {"arg3"},
// "k4.k5.k6": {"arg4"},
// })
// // preloadMap is
// map[string]map[string][]interface{}{
// "k0": {},
// "k7": {
// "k8": {},
// },
// "k1": {},
// "k2": {
// "k3": {"arg3"},
// },
// "k4": {
// "k5.k6": {"arg4"},
// },
// }
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
preloadMap := map[string]map[string][]interface{}{}
setPreloadMap := func(name, value string, args []interface{}) {
if _, ok := preloadMap[name]; !ok {
preloadMap[name] = map[string][]interface{}{}
}
if value != "" {
preloadMap[name][value] = args
}
}
for name, args := range preloads {
preloadFields := strings.Split(name, ".")
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
if preloadFields[0] == clause.Associations {
for _, relation := range s.Relationships.Relations {
if relation.Schema == s {
setPreloadMap(relation.Name, value, args)
}
}
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
for _, value := range embeddedValues(embeddedRelations) {
setPreloadMap(embedded, value, args)
}
}
} else {
setPreloadMap(preloadFields[0], value, args)
}
}
return preloadMap
}
func embeddedValues(embeddedRelations *schema.Relationships) []string {
if embeddedRelations == nil {
return nil
}
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
for _, relation := range embeddedRelations.Relations {
// skip first struct name
names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
}
for _, relations := range embeddedRelations.EmbeddedRelations {
names = append(names, embeddedValues(relations)...)
}
return names
}
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
// If the current relationship is embedded or joined, current query will be ignored.
//
//nolint:cyclop
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
// avoid random traversal of the map
preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap {
preloadNames = append(preloadNames, key)
}
sort.Strings(preloadNames)
isJoined := func(name string) (joined bool, nestedJoins []string) {
for _, join := range joins {
if _, ok := relationships.Relations[join]; ok && name == join {
joined = true
continue
}
join0, join1, cut := strings.Cut(join, ".")
if cut {
if _, ok := relationships.Relations[join0]; ok && name == join0 {
joined = true
nestedJoins = append(nestedJoins, join1)
}
}
}
return joined, nestedJoins
}
for _, name := range preloadNames {
if relations := relationships.EmbeddedRelations[name]; relations != nil {
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
return err
}
} else if rel := relationships.Relations[name]; rel != nil {
if joined, nestedJoins := isJoined(name); joined {
switch rv := db.Statement.ReflectValue; rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() > 0 {
reflectValue := rel.FieldSchema.MakeSlice().Elem()
for i := 0; i < rv.Len(); i++ {
frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
if frv.Kind() != reflect.Ptr {
reflectValue = reflect.Append(reflectValue, frv.Addr())
} else {
if frv.IsNil() {
continue
}
reflectValue = reflect.Append(reflectValue, frv)
}
}
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
}
case reflect.Struct, reflect.Pointer:
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
default:
return gorm.ErrInvalidData
}
} else {
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
tx.Statement.ReflectValue = db.Statement.ReflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
return err
}
}
} else {
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
}
}
return nil
}
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if err := tx.Statement.Parse(dest); err != nil {
tx.AddError(err)
return tx
}
tx.Statement.ReflectValue = reflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
return tx
}
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
var (
reflectValue = db.Statement.ReflectValue
tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})
reflectValue = tx.Statement.ReflectValue
relForeignKeys []string
relForeignFields []*schema.Field
foreignFields []*schema.Field
@ -21,13 +193,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
inlineConds []interface{}
)
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if rel.JoinTable != nil {
var (
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
joinRelForeignFields = make([]*schema.Field, 0, len(rel.References))
@ -48,25 +214,28 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
}
}
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
if len(joinForeignValues) == 0 {
return
return nil
}
joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error)
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil {
return err
}
// convert join identity map to relation identity map
fieldValues := make([]interface{}, len(joinForeignFields))
joinFieldValues := make([]interface{}, len(joinRelForeignFields))
for i := 0; i < joinResults.Len(); i++ {
joinIndexValue := joinResults.Index(i)
for idx, field := range joinForeignFields {
fieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
}
for idx, field := range joinRelForeignFields {
joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
}
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
@ -75,7 +244,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
}
}
_, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields)
_, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields)
} else {
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
@ -91,9 +260,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
}
}
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
if len(foreignValues) == 0 {
return
return nil
}
}
@ -105,16 +274,26 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx)
} else {
inlineConds = append(inlineConds, cond)
if len(values) != 0 {
tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})
for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx)
} else {
inlineConds = append(inlineConds, cond)
}
}
if len(inlineConds) > 0 {
tx = tx.Where(inlineConds[0], inlineConds[1:]...)
}
if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
return err
}
}
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error)
fieldValues := make([]interface{}, len(relForeignFields))
// clean up old values before preloading
@ -122,17 +301,17 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
case reflect.Struct:
switch rel.Type {
case schema.HasMany, schema.Many2Many:
rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default:
rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()))
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type {
case schema.HasMany, schema.Many2Many:
rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default:
rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()))
}
}
}
@ -140,11 +319,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
for i := 0; i < reflectResults.Len(); i++ {
elem := reflectResults.Index(i)
for idx, field := range relForeignFields {
fieldValues[idx], _ = field.ValueOf(elem)
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem)
}
for _, data := range identityMap[utils.ToStringKey(fieldValues...)] {
reflectFieldValue := rel.Field.ReflectValueOf(data)
datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
if !ok {
return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())
}
for _, data := range datas {
reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data)
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
}
@ -152,14 +336,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
reflectFieldValue = reflect.Indirect(reflectFieldValue)
switch reflectFieldValue.Kind() {
case reflect.Struct:
rel.Field.Set(data, reflectResults.Index(i).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface()))
case reflect.Slice, reflect.Array:
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()))
} else {
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()))
}
}
}
}
return tx.Error
}

View File

@ -3,11 +3,12 @@ package callbacks
import (
"fmt"
"reflect"
"sort"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func Query(db *gorm.DB) {
@ -20,28 +21,33 @@ func Query(db *gorm.DB) {
db.AddError(err)
return
}
defer rows.Close()
defer func() {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, 0)
gorm.Scan(rows, db, false)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
}
}
func BuildQuerySQL(db *gorm.DB) {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.QueryClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(100)
clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
var conds []clause.Expression
for _, primaryField := range db.Statement.Schema.PrimaryFields {
if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero {
if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero {
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
}
}
@ -95,127 +101,207 @@ func BuildQuerySQL(db *gorm.DB) {
}
// inline joins
if len(db.Statement.Joins) != 0 {
if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil {
fromClause := clause.From{}
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
fromClause = v
}
if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
for idx, dbName := range db.Statement.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
}
}
joins := []clause.Join{}
if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
joins = fromClause.Joins
}
specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
for _, join := range db.Statement.Joins {
if db.Statement.Schema == nil {
joins = append(joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
tableAliasName := relation.Name
for _, s := range relation.FieldSchema.DBNames {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName,
Name: s,
Alias: tableAliasName + "__" + s,
})
}
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
if db.Statement.Schema != nil {
var isRelations bool // is relations or raw sql
var relations []*schema.Relationship
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
if ok {
isRelations = true
relations = append(relations, relation)
} else {
// handle nested join like "Manager.Company"
nestedJoinNames := strings.Split(join.Name, ".")
if len(nestedJoinNames) > 1 {
isNestedJoin := true
guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
currentRelations := db.Statement.Schema.Relationships.Relations
for _, relname := range nestedJoinNames {
// incomplete match, only treated as raw sql
if relation, ok = currentRelations[relname]; ok {
guessNestedRelations = append(guessNestedRelations, relation)
currentRelations = relation.FieldSchema.Relationships.Relations
} else {
isNestedJoin = false
break
}
}
} else {
if ref.PrimaryValue == "" {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}
} else {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
}
if isNestedJoin {
isRelations = true
relations = guessNestedRelations
}
}
}
joins = append(joins, clause.Join{
Type: clause.LeftJoin,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
})
if isRelations {
genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
}
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
for _, s := range relation.FieldSchema.DBNames {
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName,
Name: s,
Alias: utils.NestedRelationName(tableAliasName, s),
})
}
}
if join.Expression != nil {
return clause.Join{
Type: join.JoinType,
Expression: join.Expression,
}
}
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
}
} else {
if ref.PrimaryValue == "" {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}
} else {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
}
}
}
}
{
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
for _, c := range relation.FieldSchema.QueryClauses {
onStmt.AddClause(c)
}
if join.On != nil {
onStmt.AddClause(join.On)
}
if cs, ok := onStmt.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
where.Build(&onStmt)
if onSQL := onStmt.SQL.String(); onSQL != "" {
vars := onStmt.Vars
for idx, v := range vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
}
}
}
return clause.Join{
Type: joinType,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
}
}
parentTableName := clause.CurrentTable
for idx, rel := range relations {
// joins table alias like "Manager, Company, Manager__Company"
curAliasName := rel.Name
if parentTableName != clause.CurrentTable {
curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
}
if _, ok := specifiedRelationsName[curAliasName]; !ok {
aliasName := curAliasName
if idx == len(relations)-1 && join.Alias != "" {
aliasName = join.Alias
}
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
specifiedRelationsName[curAliasName] = aliasName
}
parentTableName = curAliasName
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
} else {
joins = append(joins, clause.Join{
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
}
db.Statement.Joins = nil
db.Statement.AddClause(clause.From{Joins: joins})
db.Statement.AddClause(fromClause)
} else {
db.Statement.AddClauseIfNotExists(clause.From{})
}
db.Statement.AddClauseIfNotExists(clauseSelect)
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
db.Statement.Build(db.Statement.BuildClauses...)
}
}
func Preload(db *gorm.DB) {
if db.Error == nil && len(db.Statement.Preloads) > 0 {
preloadMap := map[string]map[string][]interface{}{}
for name := range db.Statement.Preloads {
preloadFields := strings.Split(name, ".")
if preloadFields[0] == clause.Associations {
for _, rel := range db.Statement.Schema.Relationships.Relations {
if rel.Schema == db.Statement.Schema {
if _, ok := preloadMap[rel.Name]; !ok {
preloadMap[rel.Name] = map[string][]interface{}{}
}
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
preloadMap[rel.Name][value] = db.Statement.Preloads[name]
}
}
}
} else {
if _, ok := preloadMap[preloadFields[0]]; !ok {
preloadMap[preloadFields[0]] = map[string][]interface{}{}
}
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
}
}
if db.Statement.Schema == nil {
db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired))
return
}
preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap {
preloadNames = append(preloadNames, key)
joins := make([]string, 0, len(db.Statement.Joins))
for _, join := range db.Statement.Joins {
joins = append(joins, join.Name)
}
sort.Strings(preloadNames)
for _, name := range preloadNames {
if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil {
preload(db, rel, db.Statement.Preloads[name], preloadMap[name])
} else {
db.AddError(fmt.Errorf("%v: %w for schema %v", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
}
tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
if tx.Error != nil {
return
}
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
}
}
func AfterQuery(db *gorm.DB) {
// clear the joins after query because preload need it
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
fromClause := db.Statement.Clauses["FROM"]
fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(db.Statement.Joins))} // keep the original From Joins
db.Statement.Clauses["FROM"] = fromClause
}
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(AfterFindInterface); ok {

View File

@ -9,8 +9,14 @@ func RawExec(db *gorm.DB) {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
db.AddError(err)
} else {
db.RowsAffected, _ = result.RowsAffected()
return
}
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
}

View File

@ -7,15 +7,17 @@ import (
func RowQuery(db *gorm.DB) {
if db.Error == nil {
BuildQuerySQL(db)
if !db.DryRun {
if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) {
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} else {
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
}
db.RowsAffected = -1
if db.DryRun || db.Error != nil {
return
}
if isRows, ok := db.Get("rows"); ok && isRows.(bool) {
db.Statement.Settings.Delete("rows")
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} else {
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
}
db.RowsAffected = -1
}
}

View File

@ -5,12 +5,14 @@ import (
)
func BeginTransaction(db *gorm.DB) {
if !db.Config.SkipDefaultTransaction {
if !db.Config.SkipDefaultTransaction && db.Error == nil {
if tx := db.Begin(); tx.Error == nil {
db.Statement.ConnPool = tx.Statement.ConnPool
db.InstanceSet("gorm:started_transaction", true)
} else if tx.Error == gorm.ErrInvalidTransaction {
tx.Error = nil
} else {
db.Error = tx.Error
}
}
}
@ -18,11 +20,12 @@ func BeginTransaction(db *gorm.DB) {
func CommitOrRollbackTransaction(db *gorm.DB) {
if !db.Config.SkipDefaultTransaction {
if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
if db.Error == nil {
db.Commit()
} else {
if db.Error != nil {
db.Rollback()
} else {
db.Commit()
}
db.Statement.ConnPool = db.ConnPool
}
}

View File

@ -7,6 +7,7 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func SetupUpdateReflectValue(db *gorm.DB) {
@ -20,7 +21,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if _, ok := dest[rel.Name]; ok {
rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name])
db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]))
}
}
}
@ -28,6 +29,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
}
}
// BeforeUpdate before update hooks
func BeforeUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
@ -50,45 +52,78 @@ func BeforeUpdate(db *gorm.DB) {
}
}
func Update(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
// Update update hook
func Update(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")
return func(db *gorm.DB) {
if db.Error != nil {
return
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{})
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
db.Statement.AddClause(set)
} else {
return
if _, ok := db.Statement.Clauses["SET"]; !ok {
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
defer delete(db.Statement.Clauses, "SET")
db.Statement.AddClause(set)
} else {
return
}
}
db.Statement.Build("UPDATE", "SET", "WHERE")
db.Statement.Build(db.Statement.BuildClauses...)
}
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
db.AddError(gorm.ErrMissingWhereClause)
return
}
checkMissingWhereConditions(db)
if !db.DryRun && db.Error == nil {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if ok, mode := hasReturning(db, supportReturning); ok {
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
dest := db.Statement.Dest
db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface()
gorm.Scan(rows, db, mode)
db.Statement.Dest = dest
db.AddError(rows.Close())
if err == nil {
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
} else {
db.AddError(err)
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
}
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
}
}
}
// AfterUpdate after update hooks
func AfterUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(AfterUpdateInterface); ok {
called = true
db.AddError(i.AfterUpdate(tx))
}
}
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
called = true
@ -96,12 +131,6 @@ func AfterUpdate(db *gorm.DB) {
}
}
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(AfterUpdateInterface); ok {
called = true
db.AddError(i.AfterUpdate(tx))
}
}
return called
})
}
@ -118,13 +147,15 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
case reflect.Slice, reflect.Array:
assignValue = func(field *schema.Field, value interface{}) {
for i := 0; i < stmt.ReflectValue.Len(); i++ {
field.Set(stmt.ReflectValue.Index(i), value)
if stmt.ReflectValue.CanAddr() {
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
}
}
}
case reflect.Struct:
assignValue = func(field *schema.Field, value interface{}) {
if stmt.ReflectValue.CanAddr() {
field.Set(stmt.ReflectValue, value)
field.Set(stmt.Context, stmt.ReflectValue, value)
}
}
default:
@ -140,23 +171,26 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var primaryKeyExprs []clause.Expression
for i := 0; i < stmt.ReflectValue.Len(); i++ {
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
var notZero bool
for idx, field := range stmt.Schema.PrimaryFields {
value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
notZero = notZero || !isZero
if size := stmt.ReflectValue.Len(); size > 0 {
var isZero bool
for i := 0; i < size; i++ {
for _, field := range stmt.Schema.PrimaryFields {
_, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
if !isZero {
break
}
}
}
if notZero {
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
if !isZero {
_, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
}
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields {
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
@ -209,46 +243,64 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if field.AutoUpdateTime == schema.UnixNanosecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
} else if field.AutoUpdateTime == schema.UnixMillisecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
} else if field.GORMDataType == schema.Time {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
} else {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()})
} else if field.AutoUpdateTime == schema.UnixSecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
} else {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
}
}
}
}
}
default:
updatingSchema := stmt.Schema
var isDiffSchema bool
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
// different schema
updatingStmt := &gorm.Statement{DB: stmt.DB}
if err := updatingStmt.Parse(stmt.Dest); err == nil {
updatingSchema = updatingStmt.Schema
isDiffSchema = true
}
}
switch updatingValue.Kind() {
case reflect.Struct:
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.LookUpField(dbName)
if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
value, isZero := field.ValueOf(updatingValue)
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano()
} else if field.AutoUpdateTime == schema.UnixMillisecond {
value = stmt.DB.NowFunc().UnixNano() / 1e6
} else if field.GORMDataType == schema.Time {
value = stmt.DB.NowFunc()
} else {
value = stmt.DB.NowFunc().Unix()
if field := updatingSchema.LookUpField(dbName); field != nil {
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
value, isZero := field.ValueOf(stmt.Context, updatingValue)
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano()
} else if field.AutoUpdateTime == schema.UnixMillisecond {
value = stmt.DB.NowFunc().UnixMilli()
} else if field.AutoUpdateTime == schema.UnixSecond {
value = stmt.DB.NowFunc().Unix()
} else {
value = stmt.DB.NowFunc()
}
isZero = false
}
isZero = false
}
if ok || !isZero {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
assignValue(field, value)
if (ok || !isZero) && field.Updatable {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
assignField := field
if isDiffSchema {
if originField := stmt.Schema.LookUpField(dbName); originField != nil {
assignField = originField
}
}
assignValue(assignField, value)
}
}
} else {
if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
} else {
if value, isZero := field.ValueOf(updatingValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
}

View File

@ -10,10 +10,11 @@ import (
)
// Model specify the model you would like to run db operations
// // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello")
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
// db.Model(&user).Update("name", "hello")
//
// // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello")
// // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello`
// db.Model(&user).Update("name", "hello")
func (db *DB) Model(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Model = value
@ -21,6 +22,19 @@ func (db *DB) Model(value interface{}) (tx *DB) {
}
// Clauses Add clauses
//
// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more
// advanced techniques like specifying lock strength and optimizer hints. See the
// [docs] for more depth.
//
// // add a simple limit clause
// db.Clauses(clause.Limit{Limit: 1}).Find(&User{})
// // tell the optimizer to use the `idx_user_name` index
// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{})
// // specify the lock strength to UPDATE
// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users)
//
// [docs]: https://gorm.io/docs/sql_builder.html#Clauses
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
tx = db.getInstance()
var whereConds []interface{}
@ -41,28 +55,42 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
return
}
var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`)
var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`)
// Table specify the table you would like to run db operations
//
// // Get a user
// db.Table("users").Take(&result)
func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 {
tx.Statement.Table = results[1]
return
if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 {
if results[1] != "" {
tx.Statement.Table = results[1]
} else {
tx.Statement.Table = results[2]
}
}
} else if tables := strings.Split(name, "."); len(tables) == 2 {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = tables[1]
return
} else if name != "" {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = name
} else {
tx.Statement.TableExpr = nil
tx.Statement.Table = ""
}
tx.Statement.Table = name
return
}
// Distinct specify distinct fields that you want querying
//
// // Select distinct names of users
// db.Distinct("name").Find(&results)
// // Select distinct name/age pairs from users
// db.Distinct("name", "age").Find(&results)
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Distinct = true
@ -73,6 +101,14 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) {
}
// Select specify fields that you want when querying, creating, updating
//
// Use Select when you only want a subset of the fields. By default, GORM will select all fields.
// Select accepts both string arguments and arrays.
//
// // Select name and age of user using multiple arguments
// db.Select("name", "age").Find(&users)
// // Select name and age of user using an array
// db.Select([]string{"name", "age"}).Find(&users)
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
@ -91,7 +127,11 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
return
}
}
delete(tx.Statement.Clauses, "SELECT")
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
clause.Expression = nil
tx.Statement.Clauses["SELECT"] = clause
}
case string:
if strings.Count(v, "?") >= len(args) && len(args) > 0 {
tx.Statement.AddClause(clause.Select{
@ -103,7 +143,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
Distinct: db.Statement.Distinct,
Expression: clause.NamedExpr{SQL: v, Vars: args},
})
} else {
} else {
tx.Statement.Selects = []string{v}
for _, arg := range args {
@ -121,7 +161,10 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
}
}
delete(tx.Statement.Clauses, "SELECT")
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
clause.Expression = nil
tx.Statement.Clauses["SELECT"] = clause
}
}
default:
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
@ -142,7 +185,25 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
return
}
// MapColumns modify the column names in the query results to facilitate align to the corresponding structural fields
func (db *DB) MapColumns(m map[string]string) (tx *DB) {
tx = db.getInstance()
tx.Statement.ColumnMapping = m
return
}
// Where add conditions
//
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
//
// // Find the first user with name jinzhu
// db.Where("name = ?", "jinzhu").First(&user)
// // Find the first user with name jinzhu and age 20
// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user)
// // Find the first user with name jinzhu and age not equal to 20
// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user)
//
// [docs]: https://gorm.io/docs/query.html#Conditions
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
@ -152,6 +213,11 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
}
// Not add NOT conditions
//
// Not works similarly to where, and has the same syntax.
//
// // Find the first user with name not equal to jinzhu
// db.Not("name = ?", "jinzhu").First(&user)
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
@ -161,6 +227,11 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
}
// Or add OR conditions
//
// Or is used to chain together queries with an OR.
//
// // Find the first user with name equal to jinzhu or john
// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user)
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
@ -170,15 +241,45 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
}
// Joins specify Joins conditions
// db.Joins("Account").Find(&user)
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
//
// db.Joins("Account").Find(&user)
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
return joins(db, clause.LeftJoin, query, args...)
}
// InnerJoins specify inner joins conditions
// db.InnerJoins("Account").Find(&user)
func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) {
return joins(db, clause.InnerJoin, query, args...)
}
func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
if len(args) == 1 {
if db, ok := args[0].(*DB); ok {
j := join{
Name: query, Conds: args, Selects: db.Statement.Selects,
Omits: db.Statement.Omits, JoinType: joinType,
}
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
j.On = &where
}
tx.Statement.Joins = append(tx.Statement.Joins, j)
return
}
}
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType})
return
}
// Group specify the group method on the find
//
// // Select the sum age of users with given names
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results)
func (db *DB) Group(name string) (tx *DB) {
tx = db.getInstance()
@ -190,6 +291,9 @@ func (db *DB) Group(name string) (tx *DB) {
}
// Having specify HAVING conditions for GROUP BY
//
// // Select the sum age of users with name jinzhu
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result)
func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.GroupBy{
@ -198,35 +302,58 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
return
}
// Order specify order when retrieve records from database
// db.Order("name DESC")
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
// Order specify order when retrieving records from database
//
// db.Order("name DESC")
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
// db.Order(clause.OrderBy{Columns: []clause.OrderByColumn{
// {Column: clause.Column{Name: "name"}, Desc: true},
// {Column: clause.Column{Name: "age"}, Desc: true},
// }})
func (db *DB) Order(value interface{}) (tx *DB) {
tx = db.getInstance()
switch v := value.(type) {
case clause.OrderBy:
tx.Statement.AddClause(v)
case clause.OrderByColumn:
tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{v},
})
default:
tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{{
Column: clause.Column{Name: fmt.Sprint(value), Raw: true},
}},
})
case string:
if v != "" {
tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{{
Column: clause.Column{Name: v, Raw: true},
}},
})
}
}
return
}
// Limit specify the number of records to be retrieved
//
// Limit conditions can be cancelled by using `Limit(-1)`.
//
// // retrieve 3 users
// db.Limit(3).Find(&users)
// // retrieve 3 users into users1, and all users into users2
// db.Limit(3).Find(&users1).Limit(-1).Find(&users2)
func (db *DB) Limit(limit int) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Limit: limit})
tx.Statement.AddClause(clause.Limit{Limit: &limit})
return
}
// Offset specify the number of records to skip before starting to return the records
//
// Offset conditions can be cancelled by using `Offset(-1)`.
//
// // select the third user
// db.Offset(2).First(&user)
// // select the first user by cancelling an earlier chained offset
// db.Offset(5).Offset(-1).First(&user)
func (db *DB) Offset(offset int) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Offset: offset})
@ -234,25 +361,37 @@ func (db *DB) Offset(offset int) (tx *DB) {
}
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000)
// }
//
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
// return func (db *gorm.DB) *gorm.DB {
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
// }
// }
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000)
// }
//
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
// return func (db *gorm.DB) *gorm.DB {
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
// }
// }
//
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
tx = db.getInstance()
tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
return tx
}
func (db *DB) executeScopes() (tx *DB) {
scopes := db.Statement.scopes
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
}
return db
}
// Preload preload associations with given conditions
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
//
// // get all users, and preload all non-cancelled orders
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if tx.Statement.Preloads == nil {
@ -262,18 +401,57 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
return
}
// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit]
//
// Attrs only adds attributes if the record is not found.
//
// // assign an email if the record is not found
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign an email if the record is not found, otherwise ignore provided email
// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20}
//
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.attrs = attrs
return
}
// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit]
//
// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that
// records will be updated even if they are found.
//
// // assign an email regardless of if the record is not found
// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign email regardless of if record is found
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
//
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.assigns = attrs
return
}
// Unscoped disables the global scope of soft deletion in a query.
// By default, GORM uses soft deletion, marking records as "deleted"
// by setting a timestamp on a specific field (e.g., `deleted_at`).
// Unscoped allows queries to include records marked as deleted,
// overriding the soft deletion behavior.
// Example:
//
// var users []User
// db.Unscoped().Find(&users)
// // Retrieves all users, including deleted ones.
func (db *DB) Unscoped() (tx *DB) {
tx = db.getInstance()
tx.Statement.Unscoped = true

View File

@ -29,10 +29,12 @@ func BenchmarkSelect(b *testing.B) {
func BenchmarkComplexSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
limit10 := 10
for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clauses := []clause.Interface{
clause.Select{}, clause.From{},
clause.Select{},
clause.From{},
clause.Where{Exprs: []clause.Expression{
clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.Gt{Column: "age", Value: 18},
@ -42,7 +44,7 @@ func BenchmarkComplexSelect(b *testing.B) {
clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}),
}},
clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}},
clause.Limit{Limit: 10, Offset: 20},
clause.Limit{Limit: &limit10, Offset: 20},
clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}},
}

View File

@ -20,6 +20,7 @@ type Builder interface {
Writer
WriteQuoted(field interface{})
AddVar(Writer, ...interface{})
AddError(error) error
}
// Clause

View File

@ -67,6 +67,12 @@ func (expr Expr) Build(builder Builder) {
builder.WriteByte(v)
}
}
if idx < len(expr.Vars) {
for _, v := range expr.Vars[idx:] {
builder.AddVar(builder, sql.NamedArg{Value: v})
}
}
}
// NamedExpr raw expression for named expr
@ -120,8 +126,8 @@ func (expr NamedExpr) Build(builder Builder) {
for _, v := range []byte(expr.SQL) {
if v == '@' && !inName {
inName = true
name = []byte{}
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' {
name = name[:0]
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' {
if inName {
if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv)
@ -173,7 +179,12 @@ func (expr NamedExpr) Build(builder Builder) {
}
if inName {
builder.AddVar(builder, namedMap[string(name)])
if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv)
} else {
builder.WriteByte('@')
builder.WriteString(string(name))
}
}
}
@ -205,11 +216,12 @@ func (in IN) Build(builder Builder) {
}
func (in IN) NegationBuild(builder Builder) {
builder.WriteQuoted(in.Column)
switch len(in.Values) {
case 0:
builder.WriteString(" IS NOT NULL")
case 1:
if _, ok := in.Values[0].([]interface{}); !ok {
builder.WriteQuoted(in.Column)
builder.WriteString(" <> ")
builder.AddVar(builder, in.Values[0])
break
@ -217,7 +229,6 @@ func (in IN) NegationBuild(builder Builder) {
fallthrough
default:
builder.WriteQuoted(in.Column)
builder.WriteString(" NOT IN (")
builder.AddVar(builder, in.Values...)
builder.WriteByte(')')
@ -233,11 +244,28 @@ type Eq struct {
func (eq Eq) Build(builder Builder) {
builder.WriteQuoted(eq.Column)
if eqNil(eq.Value) {
builder.WriteString(" IS NULL")
} else {
builder.WriteString(" = ")
builder.AddVar(builder, eq.Value)
switch eq.Value.(type) {
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
rv := reflect.ValueOf(eq.Value)
if rv.Len() == 0 {
builder.WriteString(" IN (NULL)")
} else {
builder.WriteString(" IN (")
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
builder.WriteByte(')')
}
default:
if eqNil(eq.Value) {
builder.WriteString(" IS NULL")
} else {
builder.WriteString(" = ")
builder.AddVar(builder, eq.Value)
}
}
}
@ -251,11 +279,24 @@ type Neq Eq
func (neq Neq) Build(builder Builder) {
builder.WriteQuoted(neq.Column)
if eqNil(neq.Value) {
builder.WriteString(" IS NOT NULL")
} else {
builder.WriteString(" <> ")
builder.AddVar(builder, neq.Value)
switch neq.Value.(type) {
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
builder.WriteString(" NOT IN (")
rv := reflect.ValueOf(neq.Value)
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
builder.WriteByte(')')
default:
if eqNil(neq.Value) {
builder.WriteString(" IS NOT NULL")
} else {
builder.WriteString(" <> ")
builder.AddVar(builder, neq.Value)
}
}
}
@ -331,7 +372,7 @@ func (like Like) NegationBuild(builder Builder) {
}
func eqNil(value interface{}) bool {
if valuer, ok := value.(driver.Valuer); ok {
if valuer, ok := value.(driver.Valuer); ok && !eqNilReflect(valuer) {
value, _ = valuer.Value()
}

View File

@ -60,6 +60,11 @@ func TestNamedExpr(t *testing.T) {
Vars: []interface{}{sql.Named("name", "jinzhu")},
Result: "name1 = ? AND name2 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "name1 = @name AND name2 = @@name",
Vars: []interface{}{map[string]interface{}{"name": "jinzhu"}},
Result: "name1 = ? AND name2 = @@name",
ExpectedVars: []interface{}{"jinzhu"},
}, {
SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1",
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
@ -73,17 +78,60 @@ func TestNamedExpr(t *testing.T) {
}, {
SQL: "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist",
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? @notexist",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
}, {
SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @Notexist",
SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @notexist",
Vars: []interface{}{NamedArgument{Name1: "jinzhu", Base: Base{Name2: "jinzhu2"}}},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? @notexist",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
}, {
SQL: "create table ? (? ?, ? ?)",
Vars: []interface{}{},
Result: "create table ? (? ?, ? ?)",
}, {
SQL: "name1 = @name AND name2 = @name;",
Vars: []interface{}{sql.Named("name", "jinzhu")},
Result: "name1 = ? AND name2 = ?;",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "name1 = @name1\r\n AND name2 = @name2",
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}},
Result: "name1 = ?\r\n AND name2 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "name1 = @name1\r AND name2 = @name2",
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}},
Result: "name1 = ?\r AND name2 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "?",
Vars: []interface{}{clause.Column{Table: "table", Name: "col"}},
Result: "`table`.`col`",
}, {
SQL: "?",
Vars: []interface{}{clause.Column{Table: "table", Name: "col", Raw: true}},
Result: "table.col",
}, {
SQL: "?",
Vars: []interface{}{clause.Column{Table: "table", Name: clause.PrimaryKey, Raw: true}},
Result: "table.id",
}, {
SQL: "?",
Vars: []interface{}{clause.Column{Table: "table", Name: "col", Alias: "alias"}},
Result: "`table`.`col` AS `alias`",
}, {
SQL: "?",
Vars: []interface{}{clause.Column{Table: "table", Name: "col", Alias: "alias", Raw: true}},
Result: "table.col AS alias",
}, {
SQL: "?",
Vars: []interface{}{clause.Table{Name: "table", Alias: "alias"}},
Result: "`table` `alias`",
}, {
SQL: "?",
Vars: []interface{}{clause.Table{Name: "table", Alias: "alias", Raw: true}},
Result: "table alias",
}}
for idx, result := range results {
@ -105,13 +153,15 @@ func TestNamedExpr(t *testing.T) {
func TestExpression(t *testing.T) {
column := "column-name"
results := []struct {
Expressions []clause.Expression
Result string
Expressions []clause.Expression
ExpectedVars []interface{}
Result string
}{{
Expressions: []clause.Expression{
clause.Eq{Column: column, Value: "column-value"},
},
Result: "`column-name` = ?",
ExpectedVars: []interface{}{"column-value"},
Result: "`column-name` = ?",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: column, Value: nil},
@ -126,7 +176,8 @@ func TestExpression(t *testing.T) {
Expressions: []clause.Expression{
clause.Neq{Column: column, Value: "column-value"},
},
Result: "`column-name` <> ?",
ExpectedVars: []interface{}{"column-value"},
Result: "`column-name` <> ?",
}, {
Expressions: []clause.Expression{
clause.Neq{Column: column, Value: nil},
@ -136,6 +187,35 @@ func TestExpression(t *testing.T) {
clause.Neq{Column: column, Value: (interface{})(nil)},
},
Result: "`column-name` IS NOT NULL",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: column, Value: []string{"a", "b"}},
},
ExpectedVars: []interface{}{"a", "b"},
Result: "`column-name` IN (?,?)",
}, {
Expressions: []clause.Expression{
clause.Neq{Column: column, Value: []string{"a", "b"}},
},
ExpectedVars: []interface{}{"a", "b"},
Result: "`column-name` NOT IN (?,?)",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: column, Value: []string{}},
},
Result: "`column-name` IN (NULL)",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100},
},
ExpectedVars: []interface{}{100},
Result: "SUM(`id`) = ?",
}, {
Expressions: []clause.Expression{
clause.Gte{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Table: "users", Name: "id"}}}, Value: 100},
},
ExpectedVars: []interface{}{100},
Result: "SUM(`users`.`id`) >= ?",
}}
for idx, result := range results {
@ -147,6 +227,10 @@ func TestExpression(t *testing.T) {
if stmt.SQL.String() != result.Result {
t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
}
if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) {
t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars)
}
})
}
}

View File

@ -39,4 +39,10 @@ func (groupBy GroupBy) MergeClause(clause *Clause) {
groupBy.Having = append(copiedHaving, groupBy.Having...)
}
clause.Expression = groupBy
if len(groupBy.Columns) == 0 {
clause.Name = ""
} else {
clause.Name = groupBy.Name()
}
}

View File

@ -18,7 +18,8 @@ func TestGroupBy(t *testing.T) {
Columns: []clause.Column{{Name: "role"}},
Having: []clause.Expression{clause.Eq{"role", "admin"}},
}},
"SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", []interface{}{"admin"},
"SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?",
[]interface{}{"admin"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{
@ -28,7 +29,8 @@ func TestGroupBy(t *testing.T) {
Columns: []clause.Column{{Name: "gender"}},
Having: []clause.Expression{clause.Neq{"gender", "U"}},
}},
"SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", []interface{}{"admin", "U"},
"SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?",
[]interface{}{"admin", "U"},
},
}

View File

@ -1,5 +1,7 @@
package clause
import "gorm.io/gorm/utils"
type JoinType string
const (
@ -9,7 +11,31 @@ const (
RightJoin JoinType = "RIGHT"
)
// Join join clause for from
type JoinTarget struct {
Type JoinType
Association string
Subquery Expression
Table string
}
func Has(name string) JoinTarget {
return JoinTarget{Type: InnerJoin, Association: name}
}
func (jt JoinType) Association(name string) JoinTarget {
return JoinTarget{Type: jt, Association: name}
}
func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
return JoinTarget{Type: jt, Association: name, Subquery: subquery}
}
func (jt JoinTarget) As(name string) JoinTarget {
jt.Table = name
return jt
}
// Join clause for from
type Join struct {
Type JoinType
Table Table
@ -18,6 +44,12 @@ type Join struct {
Expression Expression
}
func JoinTable(names ...string) Table {
return Table{
Name: utils.JoinNestedRelationNames(names),
}
}
func (join Join) Build(builder Builder) {
if join.Expression != nil {
join.Expression.Build(builder)

101
clause/joins_test.go Normal file
View File

@ -0,0 +1,101 @@
package clause_test
import (
"sync"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
func TestJoin(t *testing.T) {
results := []struct {
name string
join clause.Join
sql string
}{
{
name: "LEFT JOIN",
join: clause.Join{
Type: clause.LeftJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "RIGHT JOIN",
join: clause.Join{
Type: clause.RightJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "INNER JOIN",
join: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "CROSS JOIN",
join: clause.Join{
Type: clause.CrossJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "USING",
join: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
Using: []string{"id"},
},
sql: "INNER JOIN `user` USING (`id`)",
},
{
name: "Expression",
join: clause.Join{
// Invalid
Type: clause.LeftJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
// Valid
Expression: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
Using: []string{"id"},
},
},
sql: "INNER JOIN `user` USING (`id`)",
},
}
for _, result := range results {
t.Run(result.name, func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
result.join.Build(stmt)
if result.sql != stmt.SQL.String() {
t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String())
}
})
}
}

View File

@ -1,10 +1,8 @@
package clause
import "strconv"
// Limit limit clause
type Limit struct {
Limit int
Limit *int
Offset int
}
@ -15,16 +13,16 @@ func (limit Limit) Name() string {
// Build build where clause
func (limit Limit) Build(builder Builder) {
if limit.Limit > 0 {
if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteString("LIMIT ")
builder.WriteString(strconv.Itoa(limit.Limit))
builder.AddVar(builder, *limit.Limit)
}
if limit.Offset > 0 {
if limit.Limit > 0 {
builder.WriteString(" ")
if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteByte(' ')
}
builder.WriteString("OFFSET ")
builder.WriteString(strconv.Itoa(limit.Offset))
builder.AddVar(builder, limit.Offset)
}
}
@ -33,7 +31,7 @@ func (limit Limit) MergeClause(clause *Clause) {
clause.Name = ""
if v, ok := clause.Expression.(Limit); ok {
if limit.Limit == 0 && v.Limit != 0 {
if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil {
limit.Limit = v.Limit
}

View File

@ -8,6 +8,10 @@ import (
)
func TestLimit(t *testing.T) {
limit0 := 0
limit10 := 10
limit50 := 50
limitNeg10 := -10
results := []struct {
Clauses []clause.Interface
Result string
@ -15,38 +19,56 @@ func TestLimit(t *testing.T) {
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{
Limit: 10,
Limit: &limit10,
Offset: 20,
}},
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
"SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit10, 20},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
"SELECT * FROM `users` LIMIT ?",
[]interface{}{limit0},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}},
"SELECT * FROM `users` LIMIT ?",
[]interface{}{limit0},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
"SELECT * FROM `users` OFFSET 20", nil,
"SELECT * FROM `users` OFFSET ?",
[]interface{}{20},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` OFFSET 30", nil,
"SELECT * FROM `users` OFFSET ?",
[]interface{}{30},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}},
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}},
"SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit10, 20},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` LIMIT 10 OFFSET 30", nil,
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit10, 30},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
"SELECT * FROM `users` LIMIT 10", nil,
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
"SELECT * FROM `users` LIMIT ?",
[]interface{}{limit10},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}},
"SELECT * FROM `users` OFFSET 30", nil,
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}},
"SELECT * FROM `users` OFFSET ?",
[]interface{}{30},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}},
"SELECT * FROM `users` LIMIT 50 OFFSET 30", nil,
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}},
"SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit50, 30},
},
}

View File

@ -1,5 +1,12 @@
package clause
const (
LockingStrengthUpdate = "UPDATE"
LockingStrengthShare = "SHARE"
LockingOptionsSkipLocked = "SKIP LOCKED"
LockingOptionsNoWait = "NOWAIT"
)
type Locking struct {
Strength string
Table Table

View File

@ -14,17 +14,21 @@ func TestLocking(t *testing.T) {
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}},
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate}},
"SELECT * FROM `users` FOR UPDATE", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}},
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthShare, Table: clause.Table{Name: clause.CurrentTable}}},
"SELECT * FROM `users` FOR SHARE OF `users`", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}, clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}},
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsNoWait}},
"SELECT * FROM `users` FOR UPDATE NOWAIT", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsSkipLocked}},
"SELECT * FROM `users` FOR UPDATE SKIP LOCKED", nil,
},
}
for idx, result := range results {

View File

@ -3,6 +3,7 @@ package clause
type OnConflict struct {
Columns []Column
Where Where
TargetWhere Where
OnConstraint string
DoNothing bool
DoUpdates Set
@ -15,21 +16,27 @@ func (OnConflict) Name() string {
// Build build onConflict clause
func (onConflict OnConflict) Build(builder Builder) {
if len(onConflict.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range onConflict.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
builder.WriteString(`) `)
}
if onConflict.OnConstraint != "" {
builder.WriteString("ON CONSTRAINT ")
builder.WriteString(onConflict.OnConstraint)
builder.WriteByte(' ')
} else {
if len(onConflict.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range onConflict.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
builder.WriteString(`) `)
}
if len(onConflict.TargetWhere.Exprs) > 0 {
builder.WriteString(" WHERE ")
onConflict.TargetWhere.Build(builder)
builder.WriteByte(' ')
}
}
if onConflict.DoNothing {
@ -40,7 +47,7 @@ func (onConflict OnConflict) Build(builder Builder) {
}
if len(onConflict.Where.Exprs) > 0 {
builder.WriteString("WHERE ")
builder.WriteString(" WHERE ")
onConflict.Where.Build(builder)
builder.WriteByte(' ')
}

View File

@ -45,7 +45,8 @@ func TestOrderBy(t *testing.T) {
Expression: clause.Expr{SQL: "FIELD(id, ?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true},
},
},
"SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)", []interface{}{1, 2, 3},
"SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)",
[]interface{}{1, 2, 3},
},
}

View File

@ -11,20 +11,27 @@ func (returning Returning) Name() string {
// Build build where clause
func (returning Returning) Build(builder Builder) {
for idx, column := range returning.Columns {
if idx > 0 {
builder.WriteByte(',')
}
if len(returning.Columns) > 0 {
for idx, column := range returning.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
builder.WriteQuoted(column)
}
} else {
builder.WriteByte('*')
}
}
// MergeClause merge order by clauses
func (returning Returning) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Returning); ok {
returning.Columns = append(v.Columns, returning.Columns...)
if v, ok := clause.Expression.(Returning); ok && len(returning.Columns) > 0 {
if v.Columns != nil {
returning.Columns = append(v.Columns, returning.Columns...)
} else {
returning.Columns = nil
}
}
clause.Expression = returning
}

View File

@ -26,6 +26,22 @@ func TestReturning(t *testing.T) {
}},
"SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
[]clause.Column{clause.PrimaryColumn},
}, clause.Returning{}, clause.Returning{
[]clause.Column{{Name: "name"}, {Name: "age"}},
}},
"SELECT * FROM `users` RETURNING *", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
[]clause.Column{clause.PrimaryColumn},
}, clause.Returning{
[]clause.Column{{Name: "name"}, {Name: "age"}},
}, clause.Returning{}},
"SELECT * FROM `users` RETURNING *", nil,
},
}
for idx, result := range results {

View File

@ -43,3 +43,17 @@ func (s Select) MergeClause(clause *Clause) {
clause.Expression = s
}
}
// CommaExpression represents a group of expressions separated by commas.
type CommaExpression struct {
Exprs []Expression
}
func (comma CommaExpression) Build(builder Builder) {
for idx, expr := range comma.Exprs {
if idx > 0 {
_, _ = builder.WriteString(", ")
}
expr.Build(builder)
}
}

View File

@ -31,6 +31,37 @@ func TestSelect(t *testing.T) {
}, clause.From{}},
"SELECT `name` FROM `users`", nil,
},
{
[]clause.Interface{clause.Select{
Expression: clause.CommaExpression{
Exprs: []clause.Expression{
clause.NamedExpr{"?", []interface{}{clause.Column{Name: "id"}}},
clause.NamedExpr{"?", []interface{}{clause.Column{Name: "name"}}},
clause.NamedExpr{"LENGTH(?)", []interface{}{clause.Column{Name: "mobile"}}},
},
},
}, clause.From{}},
"SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil,
},
{
[]clause.Interface{clause.Select{
Expression: clause.CommaExpression{
Exprs: []clause.Expression{
clause.Expr{
SQL: "? as name",
Vars: []interface{}{
clause.Eq{
Column: clause.Column{Name: "age"},
Value: 18,
},
},
},
},
},
}, clause.From{}},
"SELECT `age` = ? as name FROM `users`",
[]interface{}{18},
},
}
for idx, result := range results {

View File

@ -24,9 +24,9 @@ func (set Set) Build(builder Builder) {
builder.AddVar(builder, assignment.Value)
}
} else {
builder.WriteQuoted(PrimaryColumn)
builder.WriteQuoted(Column{Name: PrimaryKey})
builder.WriteByte('=')
builder.WriteQuoted(PrimaryColumn)
builder.WriteQuoted(Column{Name: PrimaryKey})
}
}

View File

@ -20,7 +20,8 @@ func TestSet(t *testing.T) {
clause.Update{},
clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}),
},
"UPDATE `users` SET `users`.`id`=?", []interface{}{1},
"UPDATE `users` SET `users`.`id`=?",
[]interface{}{1},
},
{
[]clause.Interface{
@ -28,7 +29,8 @@ func TestSet(t *testing.T) {
clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}),
clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}),
},
"UPDATE `users` SET `name`=?", []interface{}{"jinzhu"},
"UPDATE `users` SET `name`=?",
[]interface{}{"jinzhu"},
},
}

View File

@ -21,7 +21,8 @@ func TestValues(t *testing.T) {
Values: [][]interface{}{{"jinzhu", 18}, {"josh", 1}},
},
},
"INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", []interface{}{"jinzhu", 18, "josh", 1},
"INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)",
[]interface{}{"jinzhu", 18, "josh", 1},
},
}

View File

@ -4,6 +4,11 @@ import (
"strings"
)
const (
AndWithSpace = " AND "
OrWithSpace = " OR "
)
// Where where clause
type Where struct {
Exprs []Expression
@ -16,6 +21,12 @@ func (where Where) Name() string {
// Build build where clause
func (where Where) Build(builder Builder) {
if len(where.Exprs) == 1 {
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
where.Exprs = andCondition.Exprs
}
}
// Switch position if the first query expression is a single Or condition
for idx, expr := range where.Exprs {
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
@ -26,7 +37,7 @@ func (where Where) Build(builder Builder) {
}
}
buildExprs(where.Exprs, builder, " AND ")
buildExprs(where.Exprs, builder, AndWithSpace)
}
func buildExprs(exprs []Expression, builder Builder, joinCond string) {
@ -35,7 +46,7 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) {
for idx, expr := range exprs {
if idx > 0 {
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
builder.WriteString(" OR ")
builder.WriteString(OrWithSpace)
} else {
builder.WriteString(joinCond)
}
@ -46,27 +57,30 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) {
case OrConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToLower(e.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
sql := strings.ToUpper(e.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
}
case AndConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToLower(e.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
sql := strings.ToUpper(e.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
}
case Expr:
sql := strings.ToLower(v.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
sql := strings.ToUpper(v.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
case NamedExpr:
sql := strings.ToUpper(v.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
}
if wrapInParentheses {
builder.WriteString(`(`)
builder.WriteByte('(')
expr.Build(builder)
builder.WriteString(`)`)
builder.WriteByte(')')
wrapInParentheses = false
} else {
expr.Build(builder)
@ -89,9 +103,14 @@ func (where Where) MergeClause(clause *Clause) {
func And(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
} else if len(exprs) == 1 {
return exprs[0]
}
if len(exprs) == 1 {
if _, ok := exprs[0].(OrConditions); !ok {
return exprs[0]
}
}
return AndConditions{Exprs: exprs}
}
@ -102,10 +121,10 @@ type AndConditions struct {
func (and AndConditions) Build(builder Builder) {
if len(and.Exprs) > 1 {
builder.WriteByte('(')
buildExprs(and.Exprs, builder, " AND ")
buildExprs(and.Exprs, builder, AndWithSpace)
builder.WriteByte(')')
} else {
buildExprs(and.Exprs, builder, " AND ")
buildExprs(and.Exprs, builder, AndWithSpace)
}
}
@ -123,10 +142,10 @@ type OrConditions struct {
func (or OrConditions) Build(builder Builder) {
if len(or.Exprs) > 1 {
builder.WriteByte('(')
buildExprs(or.Exprs, builder, " OR ")
buildExprs(or.Exprs, builder, OrWithSpace)
builder.WriteByte(')')
} else {
buildExprs(or.Exprs, builder, " OR ")
buildExprs(or.Exprs, builder, OrWithSpace)
}
}
@ -134,6 +153,11 @@ func Not(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
}
if len(exprs) == 1 {
if andCondition, ok := exprs[0].(AndConditions); ok {
exprs = andCondition.Exprs
}
}
return NotConditions{Exprs: exprs}
}
@ -142,23 +166,67 @@ type NotConditions struct {
}
func (not NotConditions) Build(builder Builder) {
if len(not.Exprs) > 1 {
builder.WriteByte('(')
anyNegationBuilder := false
for _, c := range not.Exprs {
if _, ok := c.(NegationExpressionBuilder); ok {
anyNegationBuilder = true
break
}
}
for idx, c := range not.Exprs {
if idx > 0 {
builder.WriteString(" AND ")
if anyNegationBuilder {
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder)
} else {
builder.WriteString("NOT ")
for idx, c := range not.Exprs {
if idx > 0 {
builder.WriteString(AndWithSpace)
}
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder)
} else {
builder.WriteString("NOT ")
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToUpper(e.SQL)
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
builder.WriteByte('(')
}
}
c.Build(builder)
if wrapInParentheses {
builder.WriteByte(')')
}
}
}
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
} else {
builder.WriteString("NOT ")
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}
for idx, c := range not.Exprs {
if idx > 0 {
switch c.(type) {
case OrConditions:
builder.WriteString(OrWithSpace)
default:
builder.WriteString(AndWithSpace)
}
}
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToLower(e.SQL)
if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses {
sql := strings.ToUpper(e.SQL)
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
builder.WriteByte('(')
}
}
@ -169,9 +237,9 @@ func (not NotConditions) Build(builder Builder) {
builder.WriteByte(')')
}
}
}
if len(not.Exprs) > 1 {
builder.WriteByte(')')
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
}
}

View File

@ -17,25 +17,29 @@ func TestWhere(t *testing.T) {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?", []interface{}{"1", 18, "jinzhu"},
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?",
[]interface{}{"1", 18, "jinzhu"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?",
[]interface{}{"1", "jinzhu", 18},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?",
[]interface{}{"1", "jinzhu", 18},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?", []interface{}{"1", "jinzhu"},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?",
[]interface{}{"1", "jinzhu"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
@ -43,7 +47,8 @@ func TestWhere(t *testing.T) {
}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"})},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"},
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)",
[]interface{}{"1", 18, "jinzhu", 100, "%linus%"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
@ -51,13 +56,78 @@ func TestWhere(t *testing.T) {
}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Not(clause.Gt{Column: "score", Value: 100}), clause.Like{Column: "name", Value: "%linus%"})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)",
[]interface{}{"1", 18, "jinzhu", 100, "%linus%"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))},
}},
"SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", []interface{}{18, "jinzhu"},
"SELECT * FROM `users` WHERE `age` = ? OR `name` <> ?",
[]interface{}{18, "jinzhu"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{
clause.And(clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}),
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})),
},
}},
"SELECT * FROM `users` WHERE `users`.`id` <> ? AND `score` <= ?",
[]interface{}{"1", 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}))},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)",
[]interface{}{"1", 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}},
clause.Expr{SQL: "`age` <= ?", Vars: []interface{}{60}})},
}},
"SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)",
[]interface{}{100, 60},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{
clause.Not(clause.AndConditions{
Exprs: []clause.Expression{
clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.Gt{Column: "age", Value: 18},
}}, clause.OrConditions{
Exprs: []clause.Expression{
clause.Lt{Column: "score", Value: 100},
},
}),
}}},
"SELECT * FROM `users` WHERE NOT ((`users`.`id` = ? AND `age` > ?) OR `score` < ?)",
[]interface{}{"1", 18, 100},
},
}

View File

@ -1,4 +1,3 @@
package clause
type With struct {
}
type With struct{}

View File

@ -2,13 +2,15 @@ package gorm
import (
"errors"
"gorm.io/gorm/logger"
)
var (
// ErrRecordNotFound record not found error
ErrRecordNotFound = errors.New("record not found")
ErrRecordNotFound = logger.ErrRecordNotFound
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
ErrInvalidTransaction = errors.New("no valid transaction")
ErrInvalidTransaction = errors.New("invalid transaction")
// ErrNotImplemented not implemented
ErrNotImplemented = errors.New("not implemented")
// ErrMissingWhereClause missing where clause
@ -19,6 +21,10 @@ var (
ErrPrimaryKeyRequired = errors.New("primary key required")
// ErrModelValueRequired model value required
ErrModelValueRequired = errors.New("model value required")
// ErrModelAccessibleFieldsRequired model accessible fields required
ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required")
// ErrSubQueryRequired sub query required
ErrSubQueryRequired = errors.New("sub query required")
// ErrInvalidData unsupported data
ErrInvalidData = errors.New("unsupported data")
// ErrUnsupportedDriver unsupported driver
@ -31,10 +37,18 @@ var (
ErrEmptySlice = errors.New("empty slice found")
// ErrDryRunModeUnsupported dry run mode unsupported
ErrDryRunModeUnsupported = errors.New("dry run mode unsupported")
// ErrInvaildDB invalid db
ErrInvaildDB = errors.New("invalid db")
// ErrInvalidDB invalid db
ErrInvalidDB = errors.New("invalid db")
// ErrInvalidValue invalid value
ErrInvalidValue = errors.New("invalid value")
ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice")
// ErrInvalidValueOfLength invalid values do not match length
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
// ErrPreloadNotAllowed preload is not allowed when count is used
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
// ErrDuplicatedKey occurs when there is a unique key constraint violation
ErrDuplicatedKey = errors.New("duplicated key not allowed")
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
// ErrCheckConstraintViolated occurs when there is a check constraint violation
ErrCheckConstraintViolated = errors.New("violates check constraint")
)

View File

@ -1,9 +1,11 @@
package gorm
import (
"context"
"database/sql"
"errors"
"fmt"
"hash/maphash"
"reflect"
"strings"
@ -13,7 +15,7 @@ import (
"gorm.io/gorm/utils"
)
// Create insert the value into database
// Create inserts value, returning the inserted data's primary key in value's id
func (db *DB) Create(value interface{}) (tx *DB) {
if db.CreateBatchSize > 0 {
return db.CreateInBatches(value, db.CreateBatchSize)
@ -21,11 +23,10 @@ func (db *DB) Create(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = value
tx.callbacks.Create().Execute(tx)
return
return tx.callbacks.Create().Execute(tx)
}
// CreateInBatches insert the value in batches into database
// CreateInBatches inserts value in batches of batchSize
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
@ -34,9 +35,10 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
var rowsAffected int64
tx = db.getInstance()
// the reflection length judgment of the optimized value
reflectLen := reflectValue.Len()
callFc := func(tx *DB) error {
// the reflection length judgment of the optimized value
reflectLen := reflectValue.Len()
for i := 0; i < reflectLen; i += batchSize {
ends := i + batchSize
if ends > reflectLen {
@ -54,7 +56,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
return nil
}
if tx.SkipDefaultTransaction {
if tx.SkipDefaultTransaction || reflectLen <= batchSize {
tx.AddError(callFc(tx.Session(&Session{})))
} else {
tx.AddError(tx.Transaction(callFc))
@ -64,29 +66,32 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
default:
tx = db.getInstance()
tx.Statement.Dest = value
tx.callbacks.Create().Execute(tx)
tx = tx.callbacks.Create().Execute(tx)
}
return
}
// Save update value in database, if the value doesn't have primary key, will insert it
// Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = value
reflectValue := reflect.Indirect(reflect.ValueOf(value))
for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface {
reflectValue = reflect.Indirect(reflectValue)
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
}
tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true))
tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true))
case reflect.Struct:
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
for _, pf := range tx.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(reflectValue); isZero {
tx.callbacks.Create().Execute(tx)
return
if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero {
return tx.callbacks.Create().Execute(tx)
}
}
}
@ -99,20 +104,19 @@ func (db *DB) Save(value interface{}) (tx *DB) {
tx.Statement.Selects = append(tx.Statement.Selects, "*")
}
tx.callbacks.Update().Execute(tx)
updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true}))
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
result := reflect.New(tx.Statement.Schema.ModelType).Interface()
if err := tx.Session(&Session{}).First(result).Error; errors.Is(err, ErrRecordNotFound) {
return tx.Create(value)
}
if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate {
return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
}
return updateTx
}
return
}
// First find first record that match given conditions, order by primary key
// First finds the first record ordered by primary key, matching given conditions conds
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
@ -124,11 +128,10 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
}
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx)
return
return tx.callbacks.Query().Execute(tx)
}
// Take return a record that match given conditions, the order will depend on the database implementation
// Take finds the first record returned by the database in no specified order, matching given conditions conds
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1)
if len(conds) > 0 {
@ -138,11 +141,10 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
}
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx)
return
return tx.callbacks.Query().Execute(tx)
}
// Last find last record that match given conditions, order by primary key
// Last finds the last record ordered by primary key, matching given conditions conds
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
@ -155,11 +157,10 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
}
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx)
return
return tx.callbacks.Query().Execute(tx)
}
// Find find records that match given conditions
// Find finds all records matching given conditions conds
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance()
if len(conds) > 0 {
@ -168,11 +169,10 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
}
}
tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx)
return
return tx.callbacks.Query().Execute(tx)
}
// FindInBatches find records in batches
// FindInBatches finds all records in batches of batchSize
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
var (
tx = db.Order(clause.OrderByColumn{
@ -183,34 +183,69 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
batch int
)
// user specified offset or limit
var totalSize int
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
if limit, ok := c.Expression.(clause.Limit); ok {
if limit.Limit != nil {
totalSize = *limit.Limit
}
if totalSize > 0 && batchSize > totalSize {
batchSize = totalSize
}
// reset to offset to 0 in next batch
tx = tx.Offset(-1).Session(&Session{})
}
}
for {
result := queryDB.Limit(batchSize).Find(dest)
rowsAffected += result.RowsAffected
batch++
if result.Error == nil && result.RowsAffected != 0 {
tx.AddError(fc(result, batch))
fcTx := result.Session(&Session{NewDB: true})
fcTx.RowsAffected = result.RowsAffected
tx.AddError(fc(fcTx, batch))
} else if result.Error != nil {
tx.AddError(result.Error)
}
if tx.Error != nil || int(result.RowsAffected) < batchSize {
break
} else {
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
if result.Statement.Schema.PrioritizedPrimaryField == nil {
tx.AddError(ErrPrimaryKeyRequired)
}
if totalSize > 0 {
if totalSize <= int(rowsAffected) {
break
} else {
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
}
if totalSize/batchSize == batch {
batchSize = totalSize % batchSize
}
}
// Optimize for-break
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
if result.Statement.Schema.PrioritizedPrimaryField == nil {
tx.AddError(ErrPrimaryKeyRequired)
break
}
primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
if zero {
tx.AddError(ErrPrimaryKeyRequired)
break
}
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
}
tx.RowsAffected = rowsAffected
return tx
}
func (tx *DB) assignInterfacesToValue(values ...interface{}) {
func (db *DB) assignInterfacesToValue(values ...interface{}) {
for _, value := range values {
switch v := value.(type) {
case []clause.Expression:
@ -218,40 +253,40 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
if field := tx.Statement.Schema.LookUpField(column); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
if field := db.Statement.Schema.LookUpField(column); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
}
case clause.Column:
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
if field := db.Statement.Schema.LookUpField(column.Name); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
}
}
} else if andCond, ok := expr.(clause.AndConditions); ok {
tx.assignInterfacesToValue(andCond.Exprs)
db.assignInterfacesToValue(andCond.Exprs)
}
}
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
if exprs := tx.Statement.BuildCondition(value); len(exprs) > 0 {
tx.assignInterfacesToValue(exprs)
if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 {
db.assignInterfacesToValue(exprs)
}
default:
if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Struct:
for _, f := range s.Fields {
if f.Readable {
if v, isZero := f.ValueOf(reflectValue); !isZero {
if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, v))
if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero {
if field := db.Statement.Schema.LookUpField(f.Name); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v))
}
}
}
}
}
} else if len(values) > 0 {
if exprs := tx.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
tx.assignInterfacesToValue(exprs)
if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
db.assignInterfacesToValue(exprs)
}
return
}
@ -259,12 +294,24 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
}
}
// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
// Each conds must be a struct or map.
//
// FirstOrInit never modifies the database. It is often used with Assign and Attrs.
//
// // assign an email if the record is not found
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign email regardless of if record is found
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 {
if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs)
@ -284,40 +331,64 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
return
}
// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
// Each conds must be a struct or map.
//
// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists.
//
// // assign an email if the record is not found
// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
// // result.RowsAffected -> 1
//
// // assign email regardless of if record is found
// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
// // result.RowsAffected -> 1
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{
tx = db.getInstance()
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
result := queryTx.Find(dest, conds...)
if result.Error != nil {
tx.Error = result.Error
return tx
}
if result.RowsAffected == 0 {
if c, ok := result.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs)
result.assignInterfacesToValue(where.Exprs)
}
}
// initialize with attrs, conds
if len(tx.Statement.attrs) > 0 {
tx.assignInterfacesToValue(tx.Statement.attrs...)
if len(db.Statement.attrs) > 0 {
result.assignInterfacesToValue(db.Statement.attrs...)
}
// initialize with attrs, conds
if len(tx.Statement.assigns) > 0 {
tx.assignInterfacesToValue(tx.Statement.assigns...)
if len(db.Statement.assigns) > 0 {
result.assignInterfacesToValue(db.Statement.assigns...)
}
return tx.Create(dest)
} else if len(db.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
assigns := map[string]interface{}{}
for _, expr := range exprs {
if eq, ok := expr.(clause.Eq); ok {
for i := 0; i < len(exprs); i++ {
expr := exprs[i]
if eq, ok := expr.(clause.AndConditions); ok {
exprs = append(exprs, eq.Exprs...)
} else if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
assigns[column] = eq.Value
case clause.Column:
assigns[column.Name] = eq.Value
default:
}
}
}
@ -325,42 +396,40 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.Model(dest).Updates(assigns)
}
return db
return tx
}
// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Update(column string, value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value}
tx.callbacks.Update().Execute(tx)
return
return tx.callbacks.Update().Execute(tx)
}
// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Updates(values interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = values
tx.callbacks.Update().Execute(tx)
return
return tx.callbacks.Update().Execute(tx)
}
func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value}
tx.Statement.SkipHooks = true
tx.callbacks.Update().Execute(tx)
return
return tx.callbacks.Update().Execute(tx)
}
func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = values
tx.Statement.SkipHooks = true
tx.callbacks.Update().Execute(tx)
return
return tx.callbacks.Update().Execute(tx)
}
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If
// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
// time if null.
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance()
if len(conds) > 0 {
@ -369,8 +438,7 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
}
}
tx.Statement.Dest = value
tx.callbacks.Delete().Execute(tx)
return
return tx.callbacks.Delete().Execute(tx)
}
func (db *DB) Count(count *int64) (tx *DB) {
@ -384,21 +452,21 @@ func (db *DB) Count(count *int64) (tx *DB) {
if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
defer func() {
db.Statement.Clauses["SELECT"] = selectClause
tx.Statement.Clauses["SELECT"] = selectClause
}()
} else {
defer delete(tx.Statement.Clauses, "SELECT")
}
if len(tx.Statement.Selects) == 0 {
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}})
} else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") {
expr := clause.Expr{SQL: "count(1)"}
expr := clause.Expr{SQL: "count(*)"}
if len(tx.Statement.Selects) == 1 {
dbName := tx.Statement.Selects[0]
fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) {
if tx.Statement.Parse(tx.Statement.Model) == nil {
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
dbName = f.DBName
@ -407,7 +475,7 @@ func (db *DB) Count(count *int64) (tx *DB) {
if tx.Statement.Distinct {
expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
} else {
} else if dbName != "*" {
expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
}
}
@ -418,24 +486,26 @@ func (db *DB) Count(count *int64) (tx *DB) {
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
delete(db.Statement.Clauses, "ORDER BY")
delete(tx.Statement.Clauses, "ORDER BY")
defer func() {
db.Statement.Clauses["ORDER BY"] = orderByClause
tx.Statement.Clauses["ORDER BY"] = orderByClause
}()
}
}
tx.Statement.Dest = count
tx.callbacks.Query().Execute(tx)
if tx.RowsAffected != 1 {
tx = tx.callbacks.Query().Execute(tx)
if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 {
*count = tx.RowsAffected
}
return
}
func (db *DB) Row() *sql.Row {
tx := db.getInstance().InstanceSet("rows", false)
tx.callbacks.Row().Execute(tx)
tx := db.getInstance().Set("rows", false)
tx = tx.callbacks.Row().Execute(tx)
row, ok := tx.Statement.Dest.(*sql.Row)
if !ok && tx.DryRun {
db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
@ -444,8 +514,8 @@ func (db *DB) Row() *sql.Row {
}
func (db *DB) Rows() (*sql.Rows, error) {
tx := db.getInstance().InstanceSet("rows", true)
tx.callbacks.Row().Execute(tx)
tx := db.getInstance().Set("rows", true)
tx = tx.callbacks.Row().Execute(tx)
rows, ok := tx.Statement.Dest.(*sql.Rows)
if !ok && tx.DryRun && tx.Error == nil {
tx.Error = ErrDryRunModeUnsupported
@ -453,7 +523,7 @@ func (db *DB) Rows() (*sql.Rows, error) {
return rows, tx.Error
}
// Scan scan value to a struct
// Scan scans selected value to the struct dest
func (db *DB) Scan(dest interface{}) (tx *DB) {
config := *db.Config
currentLogger, newLogger := config.Logger, logger.Recorder.New()
@ -462,15 +532,14 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
tx = db.getInstance()
tx.Config = &config
if rows, err := tx.Rows(); err != nil {
tx.AddError(err)
} else {
defer rows.Close()
if rows, err := tx.Rows(); err == nil {
if rows.Next() {
tx.ScanRows(rows, dest)
} else {
tx.RowsAffected = 0
tx.AddError(rows.Err())
}
tx.AddError(rows.Close())
}
currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
@ -480,9 +549,10 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
return
}
// Pluck used to query single column from a model as a map
// var ages []int64
// db.Find(&users).Pluck("age", &ages)
// Pluck queries a single column from a model, returning in the slice dest. E.g.:
//
// var ages []int64
// db.Model(&users).Pluck("age", &ages)
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
tx = db.getInstance()
if tx.Statement.Model != nil {
@ -491,8 +561,6 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
column = f.DBName
}
}
} else if tx.Statement.Table == "" {
tx.AddError(ErrModelValueRequired)
}
if len(tx.Statement.Selects) != 1 {
@ -503,8 +571,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
})
}
tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx)
return
return tx.callbacks.Query().Execute(tx)
}
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
@ -515,33 +582,67 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
tx.Statement.Dest = dest
tx.Statement.ReflectValue = reflect.ValueOf(dest)
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem()
elem := tx.Statement.ReflectValue.Elem()
if !elem.IsValid() {
elem = reflect.New(tx.Statement.ReflectValue.Type().Elem())
tx.Statement.ReflectValue.Set(elem)
}
tx.Statement.ReflectValue = elem
}
Scan(rows, tx, true)
Scan(rows, tx, ScanInitialized)
return tx.Error
}
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is
// returned to the connection pool.
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
if db.Error != nil {
return db.Error
}
tx := db.getInstance()
sqlDB, err := tx.DB()
if err != nil {
return
}
conn, err := sqlDB.Conn(tx.Statement.Context)
if err != nil {
return
}
defer conn.Close()
tx.Statement.ConnPool = conn
return fc(tx)
}
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
// they are rolled back.
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// nested transaction
if !db.DisableNestedTransaction {
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
spID := new(maphash.Hash).Sum64()
err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error
if err != nil {
return
}
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
db.RollbackTo(fmt.Sprintf("sp%p", fc))
db.RollbackTo(fmt.Sprintf("sp%d", spID))
}
}()
}
if err == nil {
err = fc(db.Session(&Session{}))
}
err = fc(db.Session(&Session{NewDB: db.clone == 1}))
} else {
tx := db.Begin(opts...)
if tx.Error != nil {
return tx.Error
}
defer func() {
// Make sure to rollback when panic, Block error or Commit error
@ -550,12 +651,9 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
}
}()
if err = tx.Error; err == nil {
err = fc(tx)
}
if err == nil {
err = tx.Commit().Error
if err = fc(tx); err == nil {
panicked = false
return tx.Commit().Error
}
}
@ -563,11 +661,11 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
return
}
// Begin begins a transaction
// Begin begins a transaction with any transaction options opts
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
var (
// clone statement
tx = db.getInstance().Session(&Session{Context: db.Statement.Context})
tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1})
opt *sql.TxOptions
err error
)
@ -576,11 +674,19 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
opt = opts[0]
}
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else {
ctx := tx.Statement.Context
if _, ok := ctx.Deadline(); !ok {
if db.Config.DefaultTransactionTimeout > 0 {
ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout)
}
}
switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
default:
err = ErrInvalidTransaction
}
@ -591,7 +697,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
return tx
}
// Commit commit a transaction
// Commit commits the changes in a transaction
func (db *DB) Commit() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Commit())
@ -601,7 +707,7 @@ func (db *DB) Commit() *DB {
return db
}
// Rollback rollback a transaction
// Rollback rollbacks the changes in a transaction
func (db *DB) Rollback() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
if !reflect.ValueOf(committer).IsNil() {
@ -615,7 +721,21 @@ func (db *DB) Rollback() *DB {
func (db *DB) SavePoint(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because SavePoint not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.SavePoint(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else {
db.AddError(ErrUnsupportedDriver)
}
@ -624,14 +744,28 @@ func (db *DB) SavePoint(name string) *DB {
func (db *DB) RollbackTo(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because RollbackTo not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.RollbackTo(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else {
db.AddError(ErrUnsupportedDriver)
}
return db
}
// Exec execute raw sql
// Exec executes raw sql
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.SQL = strings.Builder{}
@ -642,6 +776,5 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
}
tx.callbacks.Raw().Execute(tx)
return
return tx.callbacks.Raw().Execute(tx)
}

605
generics.go Normal file
View File

@ -0,0 +1,605 @@
package gorm
import (
"context"
"database/sql"
"fmt"
"sort"
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
)
type result struct {
Result sql.Result
RowsAffected int64
}
func (info *result) ModifyStatement(stmt *Statement) {
stmt.Result = info
}
// Build implements clause.Expression interface
func (result) Build(clause.Builder) {
}
func WithResult() *result {
return &result{}
}
type Interface[T any] interface {
Raw(sql string, values ...interface{}) ExecInterface[T]
Exec(ctx context.Context, sql string, values ...interface{}) error
CreateInterface[T]
}
type CreateInterface[T any] interface {
ChainInterface[T]
Table(name string, args ...interface{}) CreateInterface[T]
Create(ctx context.Context, r *T) error
CreateInBatches(ctx context.Context, r *[]T, batchSize int) error
}
type ChainInterface[T any] interface {
ExecInterface[T]
Scopes(scopes ...func(db *Statement)) ChainInterface[T]
Where(query interface{}, args ...interface{}) ChainInterface[T]
Not(query interface{}, args ...interface{}) ChainInterface[T]
Or(query interface{}, args ...interface{}) ChainInterface[T]
Limit(offset int) ChainInterface[T]
Offset(offset int) ChainInterface[T]
Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T]
Select(query string, args ...interface{}) ChainInterface[T]
Omit(columns ...string) ChainInterface[T]
MapColumns(m map[string]string) ChainInterface[T]
Distinct(args ...interface{}) ChainInterface[T]
Group(name string) ChainInterface[T]
Having(query interface{}, args ...interface{}) ChainInterface[T]
Order(value interface{}) ChainInterface[T]
Build(builder clause.Builder)
Delete(ctx context.Context) (rowsAffected int, err error)
Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
Updates(ctx context.Context, t T) (rowsAffected int, err error)
Count(ctx context.Context, column string) (result int64, err error)
}
type ExecInterface[T any] interface {
Scan(ctx context.Context, r interface{}) error
First(context.Context) (T, error)
Last(ctx context.Context) (T, error)
Take(context.Context) (T, error)
Find(ctx context.Context) ([]T, error)
FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error
Row(ctx context.Context) *sql.Row
Rows(ctx context.Context) (*sql.Rows, error)
}
type JoinBuilder interface {
Select(...string) JoinBuilder
Omit(...string) JoinBuilder
Where(query interface{}, args ...interface{}) JoinBuilder
Not(query interface{}, args ...interface{}) JoinBuilder
Or(query interface{}, args ...interface{}) JoinBuilder
}
type PreloadBuilder interface {
Select(...string) PreloadBuilder
Omit(...string) PreloadBuilder
Where(query interface{}, args ...interface{}) PreloadBuilder
Not(query interface{}, args ...interface{}) PreloadBuilder
Or(query interface{}, args ...interface{}) PreloadBuilder
Limit(offset int) PreloadBuilder
Offset(offset int) PreloadBuilder
Order(value interface{}) PreloadBuilder
LimitPerRecord(num int) PreloadBuilder
}
type op func(*DB) *DB
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
v := &g[T]{
db: db,
ops: make([]op, 0, 5),
}
if len(opts) > 0 {
v.ops = append(v.ops, func(db *DB) *DB {
return db.Clauses(opts...)
})
}
v.createG = &createG[T]{
chainG: chainG[T]{
execG: execG[T]{g: v},
},
}
return v
}
type g[T any] struct {
*createG[T]
db *DB
ops []op
}
func (g *g[T]) apply(ctx context.Context) *DB {
db := g.db
if !db.DryRun {
db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
}
for _, op := range g.ops {
db = op(db)
}
return db
}
func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] {
return execG[T]{g: &g[T]{
db: c.db,
ops: append(c.ops, func(db *DB) *DB {
return db.Raw(sql, values...)
}),
}}
}
func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error {
return c.apply(ctx).Exec(sql, values...).Error
}
type createG[T any] struct {
chainG[T]
}
func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] {
return createG[T]{c.with(func(db *DB) *DB {
return db.Table(name, args...)
})}
}
func (c createG[T]) Create(ctx context.Context, r *T) error {
return c.g.apply(ctx).Create(r).Error
}
func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error {
return c.g.apply(ctx).CreateInBatches(r, batchSize).Error
}
type chainG[T any] struct {
execG[T]
}
func (c chainG[T]) getInstance() *DB {
var r T
return c.g.apply(context.Background()).Model(r).getInstance()
}
func (c chainG[T]) with(v op) chainG[T] {
return chainG[T]{
execG: execG[T]{g: &g[T]{
db: c.g.db,
ops: append(append([]op(nil), c.g.ops...), v),
}},
}
}
func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] {
return c.with(func(db *DB) *DB {
for _, fc := range scopes {
fc(db.Statement)
}
return db
})
}
func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Table(name, args...)
})
}
func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Where(query, args...)
})
}
func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Not(query, args...)
})
}
func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Or(query, args...)
})
}
func (c chainG[T]) Limit(offset int) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Limit(offset)
})
}
func (c chainG[T]) Offset(offset int) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Offset(offset)
})
}
type joinBuilder struct {
db *DB
}
func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Select(columns ...string) JoinBuilder {
q.db.Select(columns)
return q
}
func (q *joinBuilder) Omit(columns ...string) JoinBuilder {
q.db.Omit(columns...)
return q
}
type preloadBuilder struct {
limitPerRecord int
db *DB
}
func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Select(columns ...string) PreloadBuilder {
q.db.Select(columns)
return q
}
func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder {
q.db.Omit(columns...)
return q
}
func (q *preloadBuilder) Limit(limit int) PreloadBuilder {
q.db.Limit(limit)
return q
}
func (q *preloadBuilder) Offset(offset int) PreloadBuilder {
q.db.Offset(offset)
return q
}
func (q *preloadBuilder) Order(value interface{}) PreloadBuilder {
q.db.Order(value)
return q
}
func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder {
q.limitPerRecord = num
return q
}
func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] {
return c.with(func(db *DB) *DB {
if jt.Table == "" {
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
}
q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)}
if on != nil {
if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
db.AddError(err)
}
}
j := join{
Name: jt.Association,
Alias: jt.Table,
Selects: q.db.Statement.Selects,
Omits: q.db.Statement.Omits,
JoinType: jt.Type,
}
if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
j.On = &where
}
if jt.Subquery != nil {
joinType := j.JoinType
if joinType == "" {
joinType = clause.LeftJoin
}
if db, ok := jt.Subquery.(interface{ getInstance() *DB }); ok {
stmt := db.getInstance().Statement
if len(j.Selects) == 0 {
j.Selects = stmt.Selects
}
if len(j.Omits) == 0 {
j.Omits = stmt.Omits
}
}
expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}}
if j.On != nil {
expr.SQL += " ON ?"
expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs})
}
j.Expression = expr
}
db.Statement.Joins = append(db.Statement.Joins, j)
sort.Slice(db.Statement.Joins, func(i, j int) bool {
return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name
})
return db
})
}
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Select(query, args...)
})
}
func (c chainG[T]) Omit(columns ...string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Omit(columns...)
})
}
func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.MapColumns(m)
})
}
func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Distinct(args...)
})
}
func (c chainG[T]) Group(name string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Group(name)
})
}
func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Having(query, args...)
})
}
func (c chainG[T]) Order(value interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Order(value)
})
}
func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Preload(association, func(tx *DB) *DB {
q := preloadBuilder{db: tx.getInstance()}
if query != nil {
if err := query(&q); err != nil {
db.AddError(err)
}
}
relation, ok := db.Statement.Schema.Relationships.Relations[association]
if !ok {
if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 {
relationships := db.Statement.Schema.Relationships
for _, field := range preloadFields {
var ok bool
relation, ok = relationships.Relations[field]
if ok {
relationships = relation.FieldSchema.Relationships
} else {
db.AddError(fmt.Errorf("relation %s not found", association))
return nil
}
}
} else {
db.AddError(fmt.Errorf("relation %s not found", association))
return nil
}
}
if q.limitPerRecord > 0 {
if relation.JoinTable != nil {
tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association))
return tx
}
refColumns := []clause.Column{}
for _, rel := range relation.References {
if rel.OwnPrimaryKey {
refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName})
}
}
if len(refColumns) != 0 {
selectExpr := clause.CommaExpression{}
for _, column := range q.db.Statement.Selects {
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}})
}
if len(selectExpr.Exprs) == 0 {
selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}}
}
partitionBy := clause.CommaExpression{}
for _, column := range refColumns {
partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}})
}
rnnColumn := clause.Column{Name: "gorm_preload_rnn"}
sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)"
vars := []interface{}{partitionBy}
if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok {
vars = append(vars, orderBy)
} else {
vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
}})
}
vars = append(vars, rnnColumn)
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars})
q.db.Clauses(clause.Select{Expression: selectExpr})
return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord)
}
}
return q.db
})
})
}
func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) {
r := new(T)
res := c.g.apply(ctx).Delete(r)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) {
var r T
res := c.g.apply(ctx).Model(r).Update(name, value)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) {
res := c.g.apply(ctx).Updates(t)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) {
var r T
err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error
return
}
func (c chainG[T]) Build(builder clause.Builder) {
subdb := c.getInstance()
subdb.Logger = logger.Discard
subdb.DryRun = true
if stmt, ok := builder.(*Statement); ok {
if subdb.Statement.SQL.Len() > 0 {
var (
vars = subdb.Statement.Vars
sql = subdb.Statement.SQL.String()
)
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{}
subdb.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1)
}
subdb.Statement.SQL.Reset()
subdb.Statement.Vars = stmt.Vars
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
} else {
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
}
} else {
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
subdb.callbacks.Query().Execute(subdb)
}
builder.WriteString(subdb.Statement.SQL.String())
stmt.Vars = subdb.Statement.Vars
}
}
type execG[T any] struct {
g *g[T]
}
func (g execG[T]) First(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).First(&r).Error
return r, err
}
func (g execG[T]) Scan(ctx context.Context, result interface{}) error {
var r T
err := g.g.apply(ctx).Model(r).Find(result).Error
return err
}
func (g execG[T]) Last(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).Last(&r).Error
return r, err
}
func (g execG[T]) Take(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).Take(&r).Error
return r, err
}
func (g execG[T]) Find(ctx context.Context) ([]T, error) {
var r []T
err := g.g.apply(ctx).Find(&r).Error
return r, err
}
func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error {
var data []T
return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error {
return fc(data, batch)
}).Error
}
func (g execG[T]) Row(ctx context.Context) *sql.Row {
return g.g.apply(ctx).Row()
}
func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) {
return g.g.apply(ctx).Rows()
}

5
go.mod
View File

@ -1,8 +1,9 @@
module gorm.io/gorm
go 1.14
go 1.18
require (
github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.1.1
github.com/jinzhu/now v1.1.5
golang.org/x/text v0.20.0
)

6
go.sum
View File

@ -1,4 +1,6 @@
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.1 h1:g39TucaRWyV3dwDO++eEc6qf8TVIQ/Da48WmqjZ3i7E=
github.com/jinzhu/now v1.1.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=

238
gorm.go
View File

@ -4,6 +4,8 @@ import (
"context"
"database/sql"
"fmt"
"reflect"
"sort"
"sync"
"time"
@ -19,7 +21,9 @@ const preparedStmtDBKey = "preparedStmt"
type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool
SkipDefaultTransaction bool
DefaultTransactionTimeout time.Duration
// NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer
// FullSaveAssociations full save associations
@ -32,10 +36,17 @@ type Config struct {
DryRun bool
// PrepareStmt executes the given query in cached statement
PrepareStmt bool
// PrepareStmt cache support LRU expired,
// default maxsize=int64 Max value and ttl=1h
PrepareStmtMaxSize int
PrepareStmtTTL time.Duration
// DisableAutomaticPing
DisableAutomaticPing bool
// DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool
// IgnoreRelationshipsWhenMigrating
IgnoreRelationshipsWhenMigrating bool
// DisableNestedTransaction disable nested transaction
DisableNestedTransaction bool
// AllowGlobalUpdate allow global update
@ -44,6 +55,10 @@ type Config struct {
QueryFields bool
// CreateBatchSize default create batch size
CreateBatchSize int
// TranslateError enabling error translation
TranslateError bool
// PropagateUnscoped propagate Unscoped to every other nested statement
PropagateUnscoped bool
// ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
@ -58,6 +73,7 @@ type Config struct {
cacheStore *sync.Map
}
// Apply update config to new config
func (c *Config) Apply(config *Config) error {
if config != c {
*config = *c
@ -65,6 +81,7 @@ func (c *Config) Apply(config *Config) error {
return nil
}
// AfterInitialize initialize plugins after db connected
func (c *Config) AfterInitialize(db *DB) error {
if db != nil {
for _, plugin := range c.Plugins {
@ -76,6 +93,7 @@ func (c *Config) AfterInitialize(db *DB) error {
return nil
}
// Option gorm option interface
type Option interface {
Apply(*Config) error
AfterInitialize(*DB) error
@ -95,11 +113,13 @@ type Session struct {
DryRun bool
PrepareStmt bool
NewDB bool
Initialized bool
SkipHooks bool
SkipDefaultTransaction bool
DisableNestedTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
PropagateUnscoped bool
QueryFields bool
Context context.Context
Logger logger.Interface
@ -111,14 +131,34 @@ type Session struct {
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
config := &Config{}
sort.Slice(opts, func(i, j int) bool {
_, isConfig := opts[i].(*Config)
_, isConfig2 := opts[j].(*Config)
return isConfig && !isConfig2
})
if len(opts) > 0 {
if c, ok := opts[0].(*Config); ok {
config = c
} else {
opts = append([]Option{config}, opts...)
}
}
var skipAfterInitialize bool
for _, opt := range opts {
if opt != nil {
if err := opt.Apply(config); err != nil {
return nil, err
if applyErr := opt.Apply(config); applyErr != nil {
return nil, applyErr
}
defer func() {
opt.AfterInitialize(db)
}()
defer func(opt Option) {
if skipAfterInitialize {
return
}
if errr := opt.AfterInitialize(db); errr != nil {
err = errr
}
}(opt)
}
}
@ -129,7 +169,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
}
if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{}
config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
}
if config.Logger == nil {
@ -162,17 +202,26 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
if config.Dialector != nil {
err = config.Dialector.Initialize(db)
}
if err != nil {
if db, _ := db.DB(); db != nil {
_ = db.Close()
}
preparedStmt := &PreparedStmtDB{
ConnPool: db.ConnPool,
Stmts: map[string]Stmt{},
Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100),
// DB is not initialized, so we skip AfterInitialize
skipAfterInitialize = true
return
}
if config.TranslateError {
if _, ok := db.Dialector.(ErrorTranslator); !ok {
config.Logger.Warn(context.Background(), "The TranslateError option is enabled, but the Dialector %s does not implement ErrorTranslator.", db.Dialector.Name())
}
}
}
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
if config.PrepareStmt {
preparedStmt := NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
db.ConnPool = preparedStmt
}
@ -223,6 +272,10 @@ func (db *DB) Session(config *Session) *DB {
txConfig.FullSaveAssociations = true
}
if config.PropagateUnscoped {
txConfig.PropagateUnscoped = true
}
if config.Context != nil || config.PrepareStmt || config.SkipHooks {
tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx
@ -233,16 +286,30 @@ func (db *DB) Session(config *Session) *DB {
}
if config.PrepareStmt {
var preparedStmt *PreparedStmtDB
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
preparedStmt := v.(*PreparedStmtDB)
preparedStmt = v.(*PreparedStmtDB)
} else {
preparedStmt = NewPreparedStmtDB(db.ConnPool, db.PrepareStmtMaxSize, db.PrepareStmtTTL)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
}
switch t := tx.Statement.ConnPool.(type) {
case Tx:
tx.Statement.ConnPool = &PreparedStmtTX{
Tx: t,
PreparedStmtDB: preparedStmt,
}
default:
tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool,
Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts,
}
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
}
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
}
if config.SkipHooks {
@ -273,6 +340,10 @@ func (db *DB) Session(config *Session) *DB {
tx.Config.NowFunc = config.NowFunc
}
if config.Initialized {
tx = tx.getInstance()
}
return tx
}
@ -283,7 +354,8 @@ func (db *DB) WithContext(ctx context.Context) *DB {
// Debug start debug mode
func (db *DB) Debug() (tx *DB) {
return db.Session(&Session{
tx = db.getInstance()
return tx.Session(&Session{
Logger: db.Logger.LogMode(logger.Info),
})
}
@ -319,10 +391,18 @@ func (db *DB) Callback() *callbacks {
// AddError add error to db
func (db *DB) AddError(err error) error {
if db.Error == nil {
db.Error = err
} else if err != nil {
db.Error = fmt.Errorf("%v; %w", db.Error, err)
if err != nil {
if db.Config.TranslateError {
if errTranslator, ok := db.Dialector.(ErrorTranslator); ok {
err = errTranslator.Translate(err)
}
}
if db.Error == nil {
db.Error = err
} else {
db.Error = fmt.Errorf("%v; %w", db.Error, err)
}
}
return db.Error
}
@ -330,30 +410,42 @@ func (db *DB) AddError(err error) error {
// DB returns `*sql.DB`
func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool
if stmtDB, ok := connPool.(*PreparedStmtDB); ok {
connPool = stmtDB.ConnPool
if db.Statement != nil && db.Statement.ConnPool != nil {
connPool = db.Statement.ConnPool
}
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
}
if sqldb, ok := connPool.(*sql.DB); ok {
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil {
return sqldb, err
}
}
if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil {
return sqldb, nil
}
return nil, ErrInvaildDB
return nil, ErrInvalidDB
}
func (db *DB) getInstance() *DB {
if db.clone > 0 {
tx := &DB{Config: db.Config}
tx := &DB{Config: db.Config, Error: db.Error}
if db.clone == 1 {
// clone with new statement
tx.Statement = &Statement{
DB: tx,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
Vars: make([]interface{}, 0, 8),
DB: tx,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
Vars: make([]interface{}, 0, 8),
SkipHooks: db.Statement.SkipHooks,
}
if db.Config.PropagateUnscoped {
tx.Statement.Unscoped = db.Statement.Unscoped
}
} else {
// with clone statement
@ -367,10 +459,12 @@ func (db *DB) getInstance() *DB {
return db
}
// Expr returns clause.Expr, which can be used to pass SQL expression as params
func Expr(expr string, args ...interface{}) clause.Expr {
return clause.Expr{SQL: expr, Vars: args}
}
// SetupJoinTable setup join table schema
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
var (
tx = db.getInstance()
@ -378,47 +472,50 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
modelSchema, joinSchema *schema.Schema
)
if err := stmt.Parse(model); err == nil {
modelSchema = stmt.Schema
} else {
err := stmt.Parse(model)
if err != nil {
return err
}
modelSchema = stmt.Schema
if err := stmt.Parse(joinTable); err == nil {
joinSchema = stmt.Schema
} else {
err = stmt.Parse(joinTable)
if err != nil {
return err
}
joinSchema = stmt.Schema
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
for _, ref := range relation.References {
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
f.DataType = ref.ForeignKey.DataType
f.GORMDataType = ref.ForeignKey.GORMDataType
if f.Size == 0 {
f.Size = ref.ForeignKey.Size
}
ref.ForeignKey = f
} else {
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
}
}
for name, rel := range relation.JoinTable.Relationships.Relations {
if _, ok := joinSchema.Relationships.Relations[name]; !ok {
rel.Schema = joinSchema
joinSchema.Relationships.Relations[name] = rel
}
}
relation.JoinTable = joinSchema
} else {
return fmt.Errorf("failed to found relation: %v", field)
relation, ok := modelSchema.Relationships.Relations[field]
isRelation := ok && relation.JoinTable != nil
if !isRelation {
return fmt.Errorf("failed to find relation: %s", field)
}
for _, ref := range relation.References {
f := joinSchema.LookUpField(ref.ForeignKey.DBName)
if f == nil {
return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName)
}
f.DataType = ref.ForeignKey.DataType
f.GORMDataType = ref.ForeignKey.GORMDataType
if f.Size == 0 {
f.Size = ref.ForeignKey.Size
}
ref.ForeignKey = f
}
for name, rel := range relation.JoinTable.Relationships.Relations {
if _, ok := joinSchema.Relationships.Relations[name]; !ok {
rel.Schema = joinSchema
joinSchema.Relationships.Relations[name] = rel
}
}
relation.JoinTable = joinSchema
return nil
}
// Use use plugin
func (db *DB) Use(plugin Plugin) error {
name := plugin.Name()
if _, ok := db.Plugins[name]; ok {
@ -430,3 +527,18 @@ func (db *DB) Use(plugin Plugin) error {
db.Plugins[name] = plugin
return nil
}
// ToSQL for generate SQL string.
//
// db.ToSQL(func(tx *gorm.DB) *gorm.DB {
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
// .Limit(10).Offset(5)
// .Order("name ASC")
// .First(&User{})
// })
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance())
stmt := tx.Statement
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
}

View File

@ -26,6 +26,10 @@ type Plugin interface {
Initialize(*DB) error
}
type ParamsFilter interface {
ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{})
}
// ConnPool db conns pool interface
type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
@ -40,20 +44,49 @@ type SavePointerDialectorInterface interface {
RollbackTo(tx *DB, name string) error
}
// TxBeginner tx beginner
type TxBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
// ConnPoolBeginner conn pool beginner
type ConnPoolBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
}
// TxCommitter tx committer
type TxCommitter interface {
Commit() error
Rollback() error
}
// Tx sql.Tx interface
type Tx interface {
ConnPool
TxCommitter
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
}
// Valuer gorm valuer interface
type Valuer interface {
GormValue(context.Context, *DB) clause.Expr
}
// GetDBConnector SQL db connector
type GetDBConnector interface {
GetDBConn() (*sql.DB, error)
}
// Rows rows interface
type Rows interface {
Columns() ([]string, error)
ColumnTypes() ([]*sql.ColumnType, error)
Next() bool
Scan(dest ...interface{}) error
Err() error
Close() error
}
type ErrorTranslator interface {
Translate(err error) error
}

493
internal/lru/lru.go Normal file
View File

@ -0,0 +1,493 @@
package lru
// golang -lru
// https://github.com/hashicorp/golang-lru
import (
"sync"
"time"
)
// EvictCallback is used to get a callback when a cache entry is evicted
type EvictCallback[K comparable, V any] func(key K, value V)
// LRU implements a thread-safe LRU with expirable entries.
type LRU[K comparable, V any] struct {
size int
evictList *LruList[K, V]
items map[K]*Entry[K, V]
onEvict EvictCallback[K, V]
// expirable options
mu sync.Mutex
ttl time.Duration
done chan struct{}
// buckets for expiration
buckets []bucket[K, V]
// uint8 because it's number between 0 and numBuckets
nextCleanupBucket uint8
}
// bucket is a container for holding entries to be expired
type bucket[K comparable, V any] struct {
entries map[K]*Entry[K, V]
newestEntry time.Time
}
// noEvictionTTL - very long ttl to prevent eviction
const noEvictionTTL = time.Hour * 24 * 365 * 10
// because of uint8 usage for nextCleanupBucket, should not exceed 256.
// casting it as uint8 explicitly requires type conversions in multiple places
const numBuckets = 100
// NewLRU returns a new thread-safe cache with expirable entries.
//
// Size parameter set to 0 makes cache of unlimited size, e.g. turns LRU mechanism off.
//
// Providing 0 TTL turns expiring off.
//
// Delete expired entries every 1/100th of ttl value. Goroutine which deletes expired entries runs indefinitely.
func NewLRU[K comparable, V any](size int, onEvict EvictCallback[K, V], ttl time.Duration) *LRU[K, V] {
if size < 0 {
size = 0
}
if ttl <= 0 {
ttl = noEvictionTTL
}
res := LRU[K, V]{
ttl: ttl,
size: size,
evictList: NewList[K, V](),
items: make(map[K]*Entry[K, V]),
onEvict: onEvict,
done: make(chan struct{}),
}
// initialize the buckets
res.buckets = make([]bucket[K, V], numBuckets)
for i := 0; i < numBuckets; i++ {
res.buckets[i] = bucket[K, V]{entries: make(map[K]*Entry[K, V])}
}
// enable deleteExpired() running in separate goroutine for cache with non-zero TTL
//
// Important: done channel is never closed, so deleteExpired() goroutine will never exit,
// it's decided to add functionality to close it in the version later than v2.
if res.ttl != noEvictionTTL {
go func(done <-chan struct{}) {
ticker := time.NewTicker(res.ttl / numBuckets)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
res.deleteExpired()
}
}
}(res.done)
}
return &res
}
// Purge clears the cache completely.
// onEvict is called for each evicted key.
func (c *LRU[K, V]) Purge() {
c.mu.Lock()
defer c.mu.Unlock()
for k, v := range c.items {
if c.onEvict != nil {
c.onEvict(k, v.Value)
}
delete(c.items, k)
}
for _, b := range c.buckets {
for _, ent := range b.entries {
delete(b.entries, ent.Key)
}
}
c.evictList.Init()
}
// Add adds a value to the cache. Returns true if an eviction occurred.
// Returns false if there was no eviction: the item was already in the cache,
// or the size was not exceeded.
func (c *LRU[K, V]) Add(key K, value V) (evicted bool) {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
// Check for existing item
if ent, ok := c.items[key]; ok {
c.evictList.MoveToFront(ent)
c.removeFromBucket(ent) // remove the entry from its current bucket as expiresAt is renewed
ent.Value = value
ent.ExpiresAt = now.Add(c.ttl)
c.addToBucket(ent)
return false
}
// Add new item
ent := c.evictList.PushFrontExpirable(key, value, now.Add(c.ttl))
c.items[key] = ent
c.addToBucket(ent) // adds the entry to the appropriate bucket and sets entry.expireBucket
evict := c.size > 0 && c.evictList.Length() > c.size
// Verify size not exceeded
if evict {
c.removeOldest()
}
return evict
}
// Get looks up a key's value from the cache.
func (c *LRU[K, V]) Get(key K) (value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
var ent *Entry[K, V]
if ent, ok = c.items[key]; ok {
// Expired item check
if time.Now().After(ent.ExpiresAt) {
return value, false
}
c.evictList.MoveToFront(ent)
return ent.Value, true
}
return
}
// Contains checks if a key is in the cache, without updating the recent-ness
// or deleting it for being stale.
func (c *LRU[K, V]) Contains(key K) (ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
_, ok = c.items[key]
return ok
}
// Peek returns the key value (or undefined if not found) without updating
// the "recently used"-ness of the key.
func (c *LRU[K, V]) Peek(key K) (value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
var ent *Entry[K, V]
if ent, ok = c.items[key]; ok {
// Expired item check
if time.Now().After(ent.ExpiresAt) {
return value, false
}
return ent.Value, true
}
return
}
// Remove removes the provided key from the cache, returning if the
// key was contained.
func (c *LRU[K, V]) Remove(key K) bool {
c.mu.Lock()
defer c.mu.Unlock()
if ent, ok := c.items[key]; ok {
c.removeElement(ent)
return true
}
return false
}
// RemoveOldest removes the oldest item from the cache.
func (c *LRU[K, V]) RemoveOldest() (key K, value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.evictList.Back(); ent != nil {
c.removeElement(ent)
return ent.Key, ent.Value, true
}
return
}
// GetOldest returns the oldest entry
func (c *LRU[K, V]) GetOldest() (key K, value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.evictList.Back(); ent != nil {
return ent.Key, ent.Value, true
}
return
}
func (c *LRU[K, V]) KeyValues() map[K]V {
c.mu.Lock()
defer c.mu.Unlock()
maps := make(map[K]V)
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
maps[ent.Key] = ent.Value
// keys = append(keys, ent.Key)
}
return maps
}
// Keys returns a slice of the keys in the cache, from oldest to newest.
// Expired entries are filtered out.
func (c *LRU[K, V]) Keys() []K {
c.mu.Lock()
defer c.mu.Unlock()
keys := make([]K, 0, len(c.items))
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
keys = append(keys, ent.Key)
}
return keys
}
// Values returns a slice of the values in the cache, from oldest to newest.
// Expired entries are filtered out.
func (c *LRU[K, V]) Values() []V {
c.mu.Lock()
defer c.mu.Unlock()
values := make([]V, 0, len(c.items))
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
values = append(values, ent.Value)
}
return values
}
// Len returns the number of items in the cache.
func (c *LRU[K, V]) Len() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.evictList.Length()
}
// Resize changes the cache size. Size of 0 means unlimited.
func (c *LRU[K, V]) Resize(size int) (evicted int) {
c.mu.Lock()
defer c.mu.Unlock()
if size <= 0 {
c.size = 0
return 0
}
diff := c.evictList.Length() - size
if diff < 0 {
diff = 0
}
for i := 0; i < diff; i++ {
c.removeOldest()
}
c.size = size
return diff
}
// Close destroys cleanup goroutine. To clean up the cache, run Purge() before Close().
// func (c *LRU[K, V]) Close() {
// c.mu.Lock()
// defer c.mu.Unlock()
// select {
// case <-c.done:
// return
// default:
// }
// close(c.done)
// }
// removeOldest removes the oldest item from the cache. Has to be called with lock!
func (c *LRU[K, V]) removeOldest() {
if ent := c.evictList.Back(); ent != nil {
c.removeElement(ent)
}
}
// removeElement is used to remove a given list element from the cache. Has to be called with lock!
func (c *LRU[K, V]) removeElement(e *Entry[K, V]) {
c.evictList.Remove(e)
delete(c.items, e.Key)
c.removeFromBucket(e)
if c.onEvict != nil {
c.onEvict(e.Key, e.Value)
}
}
// deleteExpired deletes expired records from the oldest bucket, waiting for the newest entry
// in it to expire first.
func (c *LRU[K, V]) deleteExpired() {
c.mu.Lock()
bucketIdx := c.nextCleanupBucket
timeToExpire := time.Until(c.buckets[bucketIdx].newestEntry)
// wait for newest entry to expire before cleanup without holding lock
if timeToExpire > 0 {
c.mu.Unlock()
time.Sleep(timeToExpire)
c.mu.Lock()
}
for _, ent := range c.buckets[bucketIdx].entries {
c.removeElement(ent)
}
c.nextCleanupBucket = (c.nextCleanupBucket + 1) % numBuckets
c.mu.Unlock()
}
// addToBucket adds entry to expire bucket so that it will be cleaned up when the time comes. Has to be called with lock!
func (c *LRU[K, V]) addToBucket(e *Entry[K, V]) {
bucketID := (numBuckets + c.nextCleanupBucket - 1) % numBuckets
e.ExpireBucket = bucketID
c.buckets[bucketID].entries[e.Key] = e
if c.buckets[bucketID].newestEntry.Before(e.ExpiresAt) {
c.buckets[bucketID].newestEntry = e.ExpiresAt
}
}
// removeFromBucket removes the entry from its corresponding bucket. Has to be called with lock!
func (c *LRU[K, V]) removeFromBucket(e *Entry[K, V]) {
delete(c.buckets[e.ExpireBucket].entries, e.Key)
}
// Cap returns the capacity of the cache
func (c *LRU[K, V]) Cap() int {
return c.size
}
// Entry is an LRU Entry
type Entry[K comparable, V any] struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *Entry[K, V]
// The list to which this element belongs.
list *LruList[K, V]
// The LRU Key of this element.
Key K
// The Value stored with this element.
Value V
// The time this element would be cleaned up, optional
ExpiresAt time.Time
// The expiry bucket item was put in, optional
ExpireBucket uint8
}
// PrevEntry returns the previous list element or nil.
func (e *Entry[K, V]) PrevEntry() *Entry[K, V] {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// LruList represents a doubly linked list.
// The zero Value for LruList is an empty list ready to use.
type LruList[K comparable, V any] struct {
root Entry[K, V] // sentinel list element, only &root, root.prev, and root.next are used
len int // current list Length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *LruList[K, V]) Init() *LruList[K, V] {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// NewList returns an initialized list.
func NewList[K comparable, V any]() *LruList[K, V] { return new(LruList[K, V]).Init() }
// Length returns the number of elements of list l.
// The complexity is O(1).
func (l *LruList[K, V]) Length() int { return l.len }
// Back returns the last element of list l or nil if the list is empty.
func (l *LruList[K, V]) Back() *Entry[K, V] {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List Value.
func (l *LruList[K, V]) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *LruList[K, V]) insert(e, at *Entry[K, V]) *Entry[K, V] {
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Entry{Value: v, ExpiresAt: ExpiresAt}, at).
func (l *LruList[K, V]) insertValue(k K, v V, expiresAt time.Time, at *Entry[K, V]) *Entry[K, V] {
return l.insert(&Entry[K, V]{Value: v, Key: k, ExpiresAt: expiresAt}, at)
}
// Remove removes e from its list, decrements l.len
func (l *LruList[K, V]) Remove(e *Entry[K, V]) V {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
return e.Value
}
// move moves e to next to at.
func (l *LruList[K, V]) move(e, at *Entry[K, V]) {
if e == at {
return
}
e.prev.next = e.next
e.next.prev = e.prev
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *LruList[K, V]) PushFront(k K, v V) *Entry[K, V] {
l.lazyInit()
return l.insertValue(k, v, time.Time{}, &l.root)
}
// PushFrontExpirable inserts a new expirable element e with Value v at the front of list l and returns e.
func (l *LruList[K, V]) PushFrontExpirable(k K, v V, expiresAt time.Time) *Entry[K, V] {
l.lazyInit()
return l.insertValue(k, v, expiresAt, &l.root)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *LruList[K, V]) MoveToFront(e *Entry[K, V]) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.move(e, &l.root)
}

View File

@ -0,0 +1,183 @@
package stmt_store
import (
"context"
"database/sql"
"math"
"sync"
"time"
"gorm.io/gorm/internal/lru"
)
type Stmt struct {
*sql.Stmt
Transaction bool
prepared chan struct{}
prepareErr error
}
func (stmt *Stmt) Error() error {
return stmt.prepareErr
}
func (stmt *Stmt) Close() error {
<-stmt.prepared
if stmt.Stmt != nil {
return stmt.Stmt.Close()
}
return nil
}
// Store defines an interface for managing the caching operations of SQL statements (Stmt).
// This interface provides methods for creating new statements, retrieving all cache keys,
// getting cached statements, setting cached statements, and deleting cached statements.
type Store interface {
// New creates a new Stmt object and caches it.
// Parameters:
// ctx: The context for the request, which can carry deadlines, cancellation signals, etc.
// key: The key representing the SQL query, used for caching and preparing the statement.
// isTransaction: Indicates whether this operation is part of a transaction, which may affect the caching strategy.
// connPool: A connection pool that provides database connections.
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
// Returns:
// *Stmt: A newly created statement object for executing SQL operations.
// error: An error if the statement preparation fails.
New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error)
// Keys returns a slice of all cache keys in the store.
Keys() []string
// Get retrieves a Stmt object from the store based on the given key.
// Parameters:
// key: The key used to look up the Stmt object.
// Returns:
// *Stmt: The found Stmt object, or nil if not found.
// bool: Indicates whether the corresponding Stmt object was successfully found.
Get(key string) (*Stmt, bool)
// Set stores the given Stmt object in the store and associates it with the specified key.
// Parameters:
// key: The key used to associate the Stmt object.
// value: The Stmt object to be stored.
Set(key string, value *Stmt)
// Delete removes the Stmt object corresponding to the specified key from the store.
// Parameters:
// key: The key associated with the Stmt object to be deleted.
Delete(key string)
}
// defaultMaxSize defines the default maximum capacity of the cache.
// Its value is the maximum value of the int64 type, which means that when the cache size is not specified,
// the cache can theoretically store as many elements as possible.
// (1 << 63) - 1 is the maximum value that an int64 type can represent.
const (
defaultMaxSize = math.MaxInt
// defaultTTL defines the default time-to-live (TTL) for each cache entry.
// When the TTL for cache entries is not specified, each cache entry will expire after 24 hours.
defaultTTL = time.Hour * 24
)
// New creates and returns a new Store instance.
//
// Parameters:
// - size: The maximum capacity of the cache. If the provided size is less than or equal to 0,
// it defaults to defaultMaxSize.
// - ttl: The time-to-live duration for each cache entry. If the provided ttl is less than or equal to 0,
// it defaults to defaultTTL.
//
// This function defines an onEvicted callback that is invoked when a cache entry is evicted.
// The callback ensures that if the evicted value (v) is not nil, its Close method is called asynchronously
// to release associated resources.
//
// Returns:
// - A Store instance implemented by lruStore, which internally uses an LRU cache with the specified size,
// eviction callback, and TTL.
func New(size int, ttl time.Duration) Store {
if size <= 0 {
size = defaultMaxSize
}
if ttl <= 0 {
ttl = defaultTTL
}
onEvicted := func(k string, v *Stmt) {
if v != nil {
go v.Close()
}
}
return &lruStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
}
type lruStore struct {
lru *lru.LRU[string, *Stmt]
}
func (s *lruStore) Keys() []string {
return s.lru.Keys()
}
func (s *lruStore) Get(key string) (*Stmt, bool) {
stmt, ok := s.lru.Get(key)
if ok && stmt != nil {
<-stmt.prepared
}
return stmt, ok
}
func (s *lruStore) Set(key string, value *Stmt) {
s.lru.Add(key, value)
}
func (s *lruStore) Delete(key string) {
s.lru.Remove(key)
}
type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
}
// New creates a new Stmt object for executing SQL queries.
// It caches the Stmt object for future use and handles preparation and error states.
// Parameters:
//
// ctx: Context for the request, used to carry deadlines, cancellation signals, etc.
// key: The key representing the SQL query, used for caching and preparing the statement.
// isTransaction: Indicates whether this operation is part of a transaction, affecting cache strategy.
// conn: A connection pool that provides database connections.
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
//
// Returns:
//
// *Stmt: A newly created statement object for executing SQL operations.
// error: An error if the statement preparation fails.
func (s *lruStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) {
// Create a Stmt object and set its Transaction property.
// The prepared channel is used to synchronize the statement preparation state.
cacheStmt := &Stmt{
Transaction: isTransaction,
prepared: make(chan struct{}),
}
// Cache the Stmt object with the associated key.
s.Set(key, cacheStmt)
// Unlock after completing initialization to prevent deadlocks.
locker.Unlock()
// Ensure the prepared channel is closed after the function execution completes.
defer close(cacheStmt.prepared)
// Prepare the SQL statement using the provided connection.
cacheStmt.Stmt, err = conn.PrepareContext(ctx, key)
if err != nil {
// If statement preparation fails, record the error and remove the invalid Stmt object from the cache.
cacheStmt.prepareErr = err
s.Delete(key)
return &Stmt{}, err
}
// Return the successfully prepared Stmt object.
return cacheStmt, nil
}

View File

@ -2,8 +2,9 @@ package logger
import (
"context"
"errors"
"fmt"
"io/ioutil"
"io"
"log"
"os"
"time"
@ -11,6 +12,9 @@ import (
"gorm.io/gorm/utils"
)
// ErrRecordNotFound record not found error
var ErrRecordNotFound = errors.New("record not found")
// Colors
const (
Reset = "\033[0m"
@ -27,13 +31,17 @@ const (
YellowBold = "\033[33;1m"
)
// LogLevel
// LogLevel log level
type LogLevel int
const (
// Silent silent log level
Silent LogLevel = iota + 1
// Error error log level
Error
// Warn warn log level
Warn
// Info info log level
Info
)
@ -42,10 +50,13 @@ type Writer interface {
Printf(string, ...interface{})
}
// Config logger config
type Config struct {
SlowThreshold time.Duration
Colorful bool
LogLevel LogLevel
SlowThreshold time.Duration
Colorful bool
IgnoreRecordNotFoundError bool
ParameterizedQueries bool
LogLevel LogLevel
}
// Interface logger interface
@ -54,19 +65,29 @@ type Interface interface {
Info(context.Context, string, ...interface{})
Warn(context.Context, string, ...interface{})
Error(context.Context, string, ...interface{})
Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error)
Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
}
var (
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
// Discard logger will print any log to io.Discard
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
// Default Default logger
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: Warn,
Colorful: true,
SlowThreshold: 200 * time.Millisecond,
LogLevel: Warn,
IgnoreRecordNotFoundError: false,
Colorful: true,
})
// Recorder logger records running SQL into a recorder instance
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
// RecorderParamsFilter defaults to no-op, allows to be run-over by a different implementation
RecorderParamsFilter = func(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
return sql, params
}
)
// New initialize logger
func New(writer Writer, config Config) Interface {
var (
infoStr = "%s\n[info] "
@ -113,57 +134,69 @@ func (l *logger) LogMode(level LogLevel) Interface {
}
// Info print info
func (l logger) Info(ctx context.Context, msg string, data ...interface{}) {
func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Info {
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}
// Warn print warn messages
func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) {
func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Warn {
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}
// Error print error messages
func (l logger) Error(ctx context.Context, msg string, data ...interface{}) {
func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Error {
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}
// Trace print sql message
func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel > Silent {
elapsed := time.Since(begin)
switch {
case err != nil && l.LogLevel >= Error:
sql, rows := fc()
if rows == -1 {
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn:
sql, rows := fc()
slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)
if rows == -1 {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case l.LogLevel == Info:
sql, rows := fc()
if rows == -1 {
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
//
//nolint:cyclop
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel <= Silent {
return
}
elapsed := time.Since(begin)
switch {
case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError):
sql, rows := fc()
if rows == -1 {
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn:
sql, rows := fc()
slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)
if rows == -1 {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case l.LogLevel == Info:
sql, rows := fc()
if rows == -1 {
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
}
}
// ParamsFilter filter params
func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
if l.Config.ParameterizedQueries {
return sql, nil
}
return sql, params
}
type traceRecorder struct {
Interface
BeginAt time.Time
@ -172,12 +205,21 @@ type traceRecorder struct {
Err error
}
func (l traceRecorder) New() *traceRecorder {
// New trace recorder
func (l *traceRecorder) New() *traceRecorder {
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
}
// Trace implement logger interface
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
l.BeginAt = begin
l.SQL, l.RowsAffected = fc()
l.Err = err
}
func (l *traceRecorder) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
if RecorderParamsFilter == nil {
return sql, params
}
return RecorderParamsFilter(ctx, sql, params...)
}

View File

@ -19,20 +19,40 @@ const (
nullStr = "NULL"
)
func isPrintable(s []byte) bool {
func isPrintable(s string) bool {
for _, r := range s {
if !unicode.IsPrint(rune(r)) {
if !unicode.IsPrint(r) {
return false
}
}
return true
}
var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
// A list of Go types that should be converted to SQL primitives
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
// RegEx matches only numeric values
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
func isNumeric(k reflect.Kind) bool {
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
case reflect.Float32, reflect.Float64:
return true
default:
return false
}
}
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var convertParams func(interface{}, int)
var vars = make([]string, len(avars))
var (
convertParams func(interface{}, int)
vars = make([]string, len(avars))
)
convertParams = func(v interface{}, idx int) {
switch v := v.(type) {
@ -64,23 +84,36 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
}
case fmt.Stringer:
reflectValue := reflect.ValueOf(v)
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
} else {
vars[idx] = nullStr
switch reflectValue.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
vars[idx] = fmt.Sprintf("%d", reflectValue.Interface())
case reflect.Float32, reflect.Float64:
vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface())
case reflect.Bool:
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
case reflect.String:
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
default:
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
} else {
vars[idx] = nullStr
}
}
case []byte:
if isPrintable(v) {
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
if s := string(v); isPrintable(s) {
vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
} else {
vars[idx] = escaper + "<binary>" + escaper
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
vars[idx] = utils.ToString(v)
case float64, float32:
vars[idx] = fmt.Sprintf("%.6f", v)
case float32:
vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32)
case float64:
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
case string:
vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper
vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper
default:
rv := reflect.ValueOf(v)
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
@ -90,14 +123,20 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
convertParams(v, idx)
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
convertParams(reflect.Indirect(rv).Interface(), idx)
} else if isNumeric(rv.Kind()) {
if rv.CanInt() || rv.CanUint() {
vars[idx] = fmt.Sprintf("%d", rv.Interface())
} else {
vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
}
} else {
for _, t := range convertableTypes {
for _, t := range convertibleTypes {
if rv.Type().ConvertibleTo(t) {
convertParams(rv.Convert(t).Interface(), idx)
return
}
}
vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
}
}
}
@ -124,9 +163,18 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
sql = newSQL.String()
} else {
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
for idx, v := range vars {
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1)
}
sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string {
num := v[1 : len(v)-1]
n, _ := strconv.Atoi(num)
// position var start from 1 ($1, $2)
n -= 1
if n >= 0 && n <= len(vars)-1 {
return vars[n]
}
return v
})
}
return sql

View File

@ -31,20 +31,24 @@ func (s ExampleStruct) Value() (driver.Value, error) {
}
func format(v []byte, escaper string) string {
return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
return escaper + strings.ReplaceAll(string(v), escaper, escaper+escaper) + escaper
}
func TestExplainSQL(t *testing.T) {
type role string
type password []byte
type intType int
type floatType float64
var (
tt = now.MustParse("2020-02-23 11:10:10")
myrole = role("admin")
pwd = password([]byte("pass"))
jsVal = []byte(`{"Name":"test","Val":"test"}`)
js = JSON(jsVal)
esVal = []byte(`{"Name":"test","Val":"test"}`)
es = ExampleStruct{Name: "test", Val: "test"}
tt = now.MustParse("2020-02-23 11:10:10")
myrole = role("admin")
pwd = password("pass")
jsVal = []byte(`{"Name":"test","Val":"test"}`)
js = JSON(jsVal)
esVal = []byte(`{"Name":"test","Val":"test"}`)
es = ExampleStruct{Name: "test", Val: "test"}
intVal intType = 1
floatVal floatType = 1.23
)
results := []struct {
@ -57,43 +61,67 @@ func TestExplainSQL(t *testing.T) {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)",
NumericRegexp: regexp.MustCompile(`\$(\d+)`),
Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)",
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`,
},
}

View File

@ -1,22 +1,22 @@
package gorm
import (
"reflect"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
// Migrator returns migrator
func (db *DB) Migrator() Migrator {
tx := db.getInstance()
// apply scopes to migrator
for len(db.Statement.scopes) > 0 {
scopes := db.Statement.scopes
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
}
for len(tx.Statement.scopes) > 0 {
tx = tx.executeScopes()
}
return db.Dialector.Migrator(db.Session(&Session{}))
return tx.Dialector.Migrator(tx.Session(&Session{}))
}
// AutoMigrate run auto migration for given models
@ -26,19 +26,45 @@ func (db *DB) AutoMigrate(dst ...interface{}) error {
// ViewOption view option
type ViewOption struct {
Replace bool
CheckOption string
Query *DB
Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE`
CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION`
Query *DB // required subquery.
}
// ColumnType column type interface
type ColumnType interface {
Name() string
DatabaseTypeName() string
DatabaseTypeName() string // varchar
ColumnType() (columnType string, ok bool) // varchar(64)
PrimaryKey() (isPrimaryKey bool, ok bool)
AutoIncrement() (isAutoIncrement bool, ok bool)
Length() (length int64, ok bool)
DecimalSize() (precision int64, scale int64, ok bool)
Nullable() (nullable bool, ok bool)
Unique() (unique bool, ok bool)
ScanType() reflect.Type
Comment() (value string, ok bool)
DefaultValue() (value string, ok bool)
}
type Index interface {
Table() string
Name() string
Columns() []string
PrimaryKey() (isPrimaryKey bool, ok bool)
Unique() (unique bool, ok bool)
Option() string
}
// TableType table type interface
type TableType interface {
Schema() string
Name() string
Type() string
Comment() (comment string, ok bool)
}
// Migrator migrator interface
type Migrator interface {
// AutoMigrate
AutoMigrate(dst ...interface{}) error
@ -46,18 +72,23 @@ type Migrator interface {
// Database
CurrentDatabase() string
FullDataTypeOf(*schema.Field) clause.Expr
GetTypeAliases(databaseTypeName string) []string
// Tables
CreateTable(dst ...interface{}) error
DropTable(dst ...interface{}) error
HasTable(dst interface{}) bool
RenameTable(oldName, newName interface{}) error
GetTables() (tableList []string, err error)
TableType(dst interface{}) (TableType, error)
// Columns
AddColumn(dst interface{}, field string) error
DropColumn(dst interface{}, field string) error
AlterColumn(dst interface{}, field string) error
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error
HasColumn(dst interface{}, field string) bool
RenameColumn(dst interface{}, oldName, field string) error
ColumnTypes(dst interface{}) ([]ColumnType, error)
@ -76,4 +107,5 @@ type Migrator interface {
DropIndex(dst interface{}, name string) error
HasIndex(dst interface{}, name string) bool
RenameIndex(dst interface{}, oldName, newName string) error
GetIndexes(dst interface{}) ([]Index, error)
}

107
migrator/column_type.go Normal file
View File

@ -0,0 +1,107 @@
package migrator
import (
"database/sql"
"reflect"
)
// ColumnType column type implements ColumnType interface
type ColumnType struct {
SQLColumnType *sql.ColumnType
NameValue sql.NullString
DataTypeValue sql.NullString
ColumnTypeValue sql.NullString
PrimaryKeyValue sql.NullBool
UniqueValue sql.NullBool
AutoIncrementValue sql.NullBool
LengthValue sql.NullInt64
DecimalSizeValue sql.NullInt64
ScaleValue sql.NullInt64
NullableValue sql.NullBool
ScanTypeValue reflect.Type
CommentValue sql.NullString
DefaultValueValue sql.NullString
}
// Name returns the name or alias of the column.
func (ct ColumnType) Name() string {
if ct.NameValue.Valid {
return ct.NameValue.String
}
return ct.SQLColumnType.Name()
}
// DatabaseTypeName returns the database system name of the column type. If an empty
// string is returned, then the driver type name is not supported.
// Consult your driver documentation for a list of driver data types. Length specifiers
// are not included.
// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
// "INT", and "BIGINT".
func (ct ColumnType) DatabaseTypeName() string {
if ct.DataTypeValue.Valid {
return ct.DataTypeValue.String
}
return ct.SQLColumnType.DatabaseTypeName()
}
// ColumnType returns the database type of the column. like `varchar(16)`
func (ct ColumnType) ColumnType() (columnType string, ok bool) {
return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid
}
// PrimaryKey returns the column is primary key or not.
func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) {
return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid
}
// AutoIncrement returns the column is auto increment or not.
func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) {
return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid
}
// Length returns the column type length for variable length column types
func (ct ColumnType) Length() (length int64, ok bool) {
if ct.LengthValue.Valid {
return ct.LengthValue.Int64, true
}
return ct.SQLColumnType.Length()
}
// DecimalSize returns the scale and precision of a decimal type.
func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) {
if ct.DecimalSizeValue.Valid {
return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true
}
return ct.SQLColumnType.DecimalSize()
}
// Nullable reports whether the column may be null.
func (ct ColumnType) Nullable() (nullable bool, ok bool) {
if ct.NullableValue.Valid {
return ct.NullableValue.Bool, true
}
return ct.SQLColumnType.Nullable()
}
// Unique reports whether the column may be unique.
func (ct ColumnType) Unique() (unique bool, ok bool) {
return ct.UniqueValue.Bool, ct.UniqueValue.Valid
}
// ScanType returns a Go type suitable for scanning into using Rows.Scan.
func (ct ColumnType) ScanType() reflect.Type {
if ct.ScanTypeValue != nil {
return ct.ScanTypeValue
}
return ct.SQLColumnType.ScanType()
}
// Comment returns the comment of current column.
func (ct ColumnType) Comment() (value string, ok bool) {
return ct.CommentValue.String, ct.CommentValue.Valid
}
// DefaultValue returns the default value of current column.
func (ct ColumnType) DefaultValue() (value string, ok bool) {
return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid
}

43
migrator/index.go Normal file
View File

@ -0,0 +1,43 @@
package migrator
import "database/sql"
// Index implements gorm.Index interface
type Index struct {
TableName string
NameValue string
ColumnList []string
PrimaryKeyValue sql.NullBool
UniqueValue sql.NullBool
OptionValue string
}
// Table return the table name of the index.
func (idx Index) Table() string {
return idx.TableName
}
// Name return the name of the index.
func (idx Index) Name() string {
return idx.NameValue
}
// Columns return the columns of the index
func (idx Index) Columns() []string {
return idx.ColumnList
}
// PrimaryKey returns the index is primary key or not.
func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) {
return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid
}
// Unique returns whether the index is unique or not.
func (idx Index) Unique() (unique bool, ok bool) {
return idx.UniqueValue.Bool, idx.UniqueValue.Valid
}
// Option return the optional attribute of the index
func (idx Index) Option() string {
return idx.OptionValue
}

File diff suppressed because it is too large Load Diff

33
migrator/table_type.go Normal file
View File

@ -0,0 +1,33 @@
package migrator
import (
"database/sql"
)
// TableType table type implements TableType interface
type TableType struct {
SchemaValue string
NameValue string
TypeValue string
CommentValue sql.NullString
}
// Schema returns the schema of the table.
func (ct TableType) Schema() string {
return ct.SchemaValue
}
// Name returns the name of the table.
func (ct TableType) Name() string {
return ct.NameValue
}
// Type returns the type of the table.
func (ct TableType) Type() string {
return ct.TypeValue
}
// Comment returns the comment of current table.
func (ct TableType) Comment() (comment string, ok bool) {
return ct.CommentValue.String, ct.CommentValue.Valid
}

View File

@ -4,9 +4,10 @@ import "time"
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embedded into your model or you may build your own model without it
// type User struct {
// gorm.Model
// }
//
// type User struct {
// gorm.Model
// }
type Model struct {
ID uint `gorm:"primarykey"`
CreatedAt time.Time

View File

@ -3,58 +3,86 @@ package gorm
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"reflect"
"sync"
"time"
"gorm.io/gorm/internal/stmt_store"
)
type Stmt struct {
*sql.Stmt
Transaction bool
}
type PreparedStmtDB struct {
Stmts map[string]Stmt
PreparedSQL []string
Mux *sync.RWMutex
Stmts stmt_store.Store
Mux *sync.RWMutex
ConnPool
}
func (db *PreparedStmtDB) Close() {
db.Mux.Lock()
for _, query := range db.PreparedSQL {
if stmt, ok := db.Stmts[query]; ok {
delete(db.Stmts, query)
stmt.Close()
}
// NewPreparedStmtDB creates and initializes a new instance of PreparedStmtDB.
//
// Parameters:
// - connPool: A connection pool that implements the ConnPool interface, used for managing database connections.
// - maxSize: The maximum number of prepared statements that can be stored in the statement store.
// - ttl: The time-to-live duration for each prepared statement in the store. Statements older than this duration will be automatically removed.
//
// Returns:
// - A pointer to a PreparedStmtDB instance, which manages prepared statements using the provided connection pool and configuration.
func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool, // Assigns the provided connection pool to manage database connections.
Stmts: stmt_store.New(maxSize, ttl), // Initializes a new statement store with the specified maximum size and TTL.
Mux: &sync.RWMutex{}, // Sets up a read-write mutex for synchronizing access to the statement store.
}
db.Mux.Unlock()
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
// GetDBConn returns the underlying *sql.DB connection
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil
}
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}
return nil, ErrInvalidDB
}
// Close closes all prepared statements in the store
func (db *PreparedStmtDB) Close() {
db.Mux.Lock()
defer db.Mux.Unlock()
for _, key := range db.Stmts.Keys() {
db.Stmts.Delete(key)
}
}
// Reset Deprecated use Close instead
func (db *PreparedStmtDB) Reset() {
db.Close()
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) {
db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock()
return stmt, nil
if db.Stmts != nil {
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock()
return stmt, stmt.Error()
}
}
db.Mux.RUnlock()
// retry
db.Mux.Lock()
// double check
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock()
return stmt, nil
} else if ok {
stmt.Close()
if db.Stmts != nil {
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock()
return stmt, stmt.Error()
}
}
stmt, err := conn.PrepareContext(ctx, query)
if err == nil {
db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction}
db.PreparedSQL = append(db.PreparedSQL, query)
}
db.Mux.Unlock()
return db.Stmts[query], err
return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux)
}
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
@ -62,6 +90,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
}
beginner, ok := db.ConnPool.(ConnPoolBeginner)
if !ok {
return nil, ErrInvalidTransaction
}
connPool, err := beginner.BeginTx(ctx, opt)
if err != nil {
return nil, err
}
if tx, ok := connPool.(Tx); ok {
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
}
return nil, ErrInvalidTransaction
}
@ -69,11 +110,8 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
result, err = stmt.ExecContext(ctx, args...)
if err != nil {
db.Mux.Lock()
stmt.Close()
delete(db.Stmts, query)
db.Mux.Unlock()
if errors.Is(err, driver.ErrBadConn) {
db.Stmts.Delete(query)
}
}
return result, err
@ -83,11 +121,8 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
if err != nil {
db.Mux.Lock()
stmt.Close()
delete(db.Stmts, query)
db.Mux.Unlock()
if errors.Is(err, driver.ErrBadConn) {
db.Stmts.Delete(query)
}
}
return rows, err
@ -101,20 +136,32 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
return &sql.Row{}
}
func (db *PreparedStmtDB) Ping() error {
conn, err := db.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}
type PreparedStmtTX struct {
*sql.Tx
Tx
PreparedStmtDB *PreparedStmtDB
}
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
return db.PreparedStmtDB.GetDBConn()
}
func (tx *PreparedStmtTX) Commit() error {
if tx.Tx != nil {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Commit()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) Rollback() error {
if tx.Tx != nil {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Rollback()
}
return ErrInvalidTransaction
@ -124,11 +171,8 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if err != nil {
tx.PreparedStmtDB.Mux.Lock()
stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.Mux.Unlock()
if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Stmts.Delete(query)
}
}
return result, err
@ -137,12 +181,9 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
if err != nil {
tx.PreparedStmtDB.Mux.Lock()
stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.Mux.Unlock()
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Stmts.Delete(query)
}
}
return rows, err
@ -155,3 +196,11 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg
}
return &sql.Row{}
}
func (tx *PreparedStmtTX) Ping() error {
conn, err := tx.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}

380
scan.go
View File

@ -8,13 +8,15 @@ import (
"time"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
// prepareValues prepare values slice
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
if db.Statement.Schema != nil {
for idx, name := range columns {
if field := db.Statement.Schema.LookUpField(name); field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
values[idx] = reflect.New(reflect.PointerTo(field.FieldType)).Interface()
continue
}
values[idx] = new(interface{})
@ -22,7 +24,7 @@ func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType,
} else if len(columnTypes) > 0 {
for idx, columnType := range columnTypes {
if columnType.ScanType() != nil {
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
values[idx] = reflect.New(reflect.PointerTo(columnType.ScanType())).Interface()
} else {
values[idx] = new(interface{})
}
@ -49,9 +51,96 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
}
}
func Scan(rows *sql.Rows, db *DB, initialized bool) {
columns, _ := rows.Columns()
values := make([]interface{}, len(columns))
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) {
for idx, field := range fields {
if field != nil {
values[idx] = field.NewValuePool.Get()
} else if len(fields) == 1 {
if reflectValue.CanAddr() {
values[idx] = reflectValue.Addr().Interface()
} else {
values[idx] = reflectValue.Interface()
}
}
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
joinedNestedSchemaMap := make(map[string]interface{})
for idx, field := range fields {
if field == nil {
continue
}
if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
} else { // joinFields count is larger than 2 when using join
var isNilPtrValue bool
var relValue reflect.Value
// does not contain raw dbname
nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
// current reflect value
currentReflectValue := reflectValue
fullRels := make([]string, 0, len(nestedJoinSchemas))
for _, joinSchema := range nestedJoinSchemas {
fullRels = append(fullRels, joinSchema.Name)
relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
if relValue.Kind() == reflect.Ptr {
fullRelsName := utils.JoinNestedRelationNames(fullRels)
// same nested structure
if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
isNilPtrValue = true
break
}
relValue.Set(reflect.New(relValue.Type().Elem()))
joinedNestedSchemaMap[fullRelsName] = nil
}
}
currentReflectValue = relValue
}
if !isNilPtrValue { // ignore if value is nil
f := joinFields[idx][len(joinFields[idx])-1]
db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
}
}
// release data to pool
field.NewValuePool.Put(values[idx])
}
}
// ScanMode scan data mode
type ScanMode uint8
// scan modes
const (
ScanInitialized ScanMode = 1 << 0 // 1
ScanUpdate ScanMode = 1 << 1 // 2
ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
)
// Scan scan rows into db statement
func Scan(rows Rows, db *DB, mode ScanMode) {
var (
columns, _ = rows.Columns()
values = make([]interface{}, len(columns))
initialized = mode&ScanInitialized != 0
update = mode&ScanUpdate != 0
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
)
if len(db.Statement.ColumnMapping) > 0 {
for i, column := range columns {
v, ok := db.Statement.ColumnMapping[column]
if ok {
columns[i] = v
}
}
}
db.RowsAffected = 0
switch dest := db.Statement.Dest.(type) {
@ -66,6 +155,9 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
mapValue, ok := dest.(map[string]interface{})
if !ok {
if v, ok := dest.(*map[string]interface{}); ok {
if *v == nil {
*v = map[string]interface{}{}
}
mapValue = *v
}
}
@ -96,152 +188,182 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
db.AddError(rows.Scan(dest))
}
default:
Schema := db.Statement.Schema
var (
fields = make([]*schema.Field, len(columns))
joinFields [][]*schema.Field
sch = db.Statement.Schema
reflectValue = db.Statement.ReflectValue
)
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
reflectValueType = db.Statement.ReflectValue.Type().Elem()
isPtr = reflectValueType.Kind() == reflect.Ptr
fields = make([]*schema.Field, len(columns))
joinFields [][2]*schema.Field
)
if reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem()
}
if isPtr {
reflectValueType = reflectValueType.Elem()
reflectValueType := reflectValue.Type()
switch reflectValueType.Kind() {
case reflect.Array, reflect.Slice:
reflectValueType = reflectValueType.Elem()
}
isPtr := reflectValueType.Kind() == reflect.Ptr
if isPtr {
reflectValueType = reflectValueType.Elem()
}
if sch != nil {
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20))
if Schema != nil {
if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct {
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
if len(columns) == 1 {
// Is Pluck
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
sch = nil
}
}
// Not Pluck
if sch != nil {
matchedFieldCount := make(map[string]int, len(columns))
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
if field := sch.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
if count, ok := matchedFieldCount[column]; ok {
// handle duplicate fields
for _, selectField := range sch.Fields {
if selectField.DBName == column && selectField.Readable {
if count == 0 {
matchedFieldCount[column]++
fields[idx] = selectField
break
}
count--
}
}
} else {
matchedFieldCount[column] = 1
}
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1])
for _, join := range db.Statement.Joins {
if join.Alias == aliasName {
names = append(strings.Split(join.Name, "."), names[len(names)-1])
break
}
}
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
subNameCount := len(names)
// nested relation fields
relFields := make([]*schema.Field, 0, subNameCount-1)
relFields = append(relFields, rel.Field)
for _, name := range names[1 : subNameCount-1] {
rel = rel.FieldSchema.Relationships.Relations[name]
relFields = append(relFields, rel.Field)
}
// latest name is raw dbname
dbName := names[subNameCount-1]
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
joinFields = make([][]*schema.Field, len(columns))
}
joinFields[idx] = [2]*schema.Field{rel.Field, field}
relFields = append(relFields, field)
joinFields[idx] = relFields
continue
}
}
values[idx] = &sql.RawBytes{}
var val interface{}
values[idx] = &val
} else {
values[idx] = &sql.RawBytes{}
}
}
}
// pluck values into slice of data
isPluck := false
if len(fields) == 1 {
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
isPluck = true
}
}
for initialized || rows.Next() {
initialized = false
db.RowsAffected++
elem := reflect.New(reflectValueType)
if isPluck {
db.AddError(rows.Scan(elem.Interface()))
} else {
for idx, field := range fields {
if field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
}
}
db.AddError(rows.Scan(values...))
for idx, field := range fields {
if len(joinFields) != 0 && joinFields[idx][0] != nil {
value := reflect.ValueOf(values[idx]).Elem()
relValue := joinFields[idx][0].ReflectValueOf(elem)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
field.Set(relValue, values[idx])
} else if field != nil {
field.Set(elem, values[idx])
}
}
}
if isPtr {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem))
} else {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem()))
}
}
case reflect.Struct, reflect.Ptr:
if db.Statement.ReflectValue.Type() != Schema.ModelType {
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
if initialized || rows.Next() {
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
field.Set(db.Statement.ReflectValue, values[idx])
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
value := reflect.ValueOf(values[idx]).Elem()
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
field.Set(relValue, values[idx])
}
}
var val interface{}
values[idx] = &val
}
}
}
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
elem reflect.Value
isArrayKind = reflectValue.Kind() == reflect.Array
)
if !update || reflectValue.Len() == 0 {
update = false
if isArrayKind {
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
} else {
// if the slice cap is externally initialized, the externally initialized slice is directly used here
if reflectValue.Cap() == 0 {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
} else {
reflectValue.SetLen(0)
db.Statement.ReflectValue.Set(reflectValue)
}
}
}
for initialized || rows.Next() {
BEGIN:
initialized = false
if update {
if int(db.RowsAffected) >= reflectValue.Len() {
return
}
elem = reflectValue.Index(int(db.RowsAffected))
if onConflictDonothing {
for _, field := range fields {
if _, ok := field.ValueOf(db.Statement.Context, elem); !ok {
db.RowsAffected++
goto BEGIN
}
}
}
} else {
elem = reflect.New(reflectValueType)
}
db.scanIntoStruct(rows, elem, values, fields, joinFields)
if !update {
if !isPtr {
elem = elem.Elem()
}
if isArrayKind {
if reflectValue.Len() >= int(db.RowsAffected) {
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
}
} else {
reflectValue = reflect.Append(reflectValue, elem)
}
}
}
if !update {
db.Statement.ReflectValue.Set(reflectValue)
}
case reflect.Struct, reflect.Ptr:
if initialized || rows.Next() {
if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct {
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
}
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
}
default:
db.AddError(rows.Scan(dest))
}
}
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound {
if err := rows.Err(); err != nil && err != db.Error {
db.AddError(err)
}
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil {
db.AddError(ErrRecordNotFound)
}
}

View File

@ -9,8 +9,7 @@ import (
"gorm.io/gorm/schema"
)
type UserWithCallback struct {
}
type UserWithCallback struct{}
func (UserWithCallback) BeforeSave(*gorm.DB) error {
return nil

View File

@ -1,37 +0,0 @@
package schema
import (
"regexp"
"strings"
)
var (
// reg match english letters and midline
regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$")
)
type Check struct {
Name string
Constraint string // length(phone) >= 10
*Field
}
// ParseCheckConstraints parse schema check constraints
func (schema *Schema) ParseCheckConstraints() map[string]Check {
var checks = map[string]Check{}
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
if names[0] == "" {
chk = strings.Join(names[1:], ",")
}
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = Check{Name: name, Constraint: chk, Field: field}
}
}
}
return checks
}

66
schema/constraint.go Normal file
View File

@ -0,0 +1,66 @@
package schema
import (
"regexp"
"strings"
"gorm.io/gorm/clause"
)
// reg match english letters and midline
var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`)
type CheckConstraint struct {
Name string
Constraint string // length(phone) >= 10
*Field
}
func (chk *CheckConstraint) GetName() string { return chk.Name }
func (chk *CheckConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
}
// ParseCheckConstraints parse schema check constraints
func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint {
checks := map[string]CheckConstraint{}
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
if names[0] == "" {
chk = strings.Join(names[1:], ",")
}
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field}
}
}
}
return checks
}
type UniqueConstraint struct {
Name string
Field *Field
}
func (uni *UniqueConstraint) GetName() string { return uni.Name }
func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}}
}
// ParseUniqueConstraints parse schema unique constraints
func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint {
uniques := make(map[string]UniqueConstraint)
for _, field := range schema.Fields {
if field.Unique {
name := schema.namer.UniqueName(schema.Table, field.DBName)
uniques[name] = UniqueConstraint{Name: name, Field: field}
}
}
return uniques
}

View File

@ -6,6 +6,7 @@ import (
"testing"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
type UserCheck struct {
@ -20,7 +21,7 @@ func TestParseCheck(t *testing.T) {
t.Fatalf("failed to parse user check, got error %v", err)
}
results := map[string]schema.Check{
results := map[string]schema.CheckConstraint{
"name_checker": {
Name: "name_checker",
Constraint: "name <> 'jinzhu'",
@ -53,3 +54,31 @@ func TestParseCheck(t *testing.T) {
}
}
}
func TestParseUniqueConstraints(t *testing.T) {
type UserUnique struct {
Name1 string `gorm:"unique"`
Name2 string `gorm:"uniqueIndex"`
}
user, err := schema.Parse(&UserUnique{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user unique, got error %v", err)
}
constraints := user.ParseUniqueConstraints()
results := map[string]schema.UniqueConstraint{
"uni_user_uniques_name1": {
Name: "uni_user_uniques_name1",
Field: &schema.Field{Name: "Name1", Unique: true},
},
}
for k, result := range results {
v, ok := constraints[k]
if !ok {
t.Errorf("Failed to found unique constraint %v from parsed constraints %+v", k, constraints)
}
tests.AssertObjEqual(t, result, v, "Name")
tests.AssertObjEqual(t, result.Field, v.Field, "Name", "Unique", "UniqueIndex")
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
package schema_test
import (
"context"
"database/sql"
"reflect"
"sync"
@ -57,7 +58,7 @@ func TestFieldValuerAndSetter(t *testing.T) {
}
for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
@ -80,7 +81,7 @@ func TestFieldValuerAndSetter(t *testing.T) {
}
for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
@ -132,7 +133,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
}
for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
@ -151,7 +152,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
}
for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
@ -202,7 +203,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
}
for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
@ -219,7 +220,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
}
for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
@ -235,6 +236,7 @@ type UserWithPermissionControl struct {
Name5 string `gorm:"<-:update"`
Name6 string `gorm:"<-:create,update"`
Name7 string `gorm:"->:false;<-:create,update"`
Name8 string `gorm:"->;-:migration"`
}
func TestParseFieldWithPermission(t *testing.T) {
@ -243,7 +245,7 @@ func TestParseFieldWithPermission(t *testing.T) {
t.Fatalf("Failed to parse user with permission, got error %v", err)
}
fields := []schema.Field{
fields := []*schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true, AutoIncrement: true},
{Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false},
{Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true},
@ -252,9 +254,81 @@ func TestParseFieldWithPermission(t *testing.T) {
{Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: true},
{Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: true},
{Name: "Name7", DBName: "name7", BindNames: []string{"Name7"}, DataType: schema.String, Tag: `gorm:"->:false;<-:create,update"`, Creatable: true, Updatable: true, Readable: false},
{Name: "Name8", DBName: "name8", BindNames: []string{"Name8"}, DataType: schema.String, Tag: `gorm:"->;-:migration"`, Creatable: false, Updatable: false, Readable: true, IgnoreMigration: true},
}
for _, f := range fields {
checkSchemaField(t, user, &f, func(f *schema.Field) {})
checkSchemaField(t, user, f, func(f *schema.Field) {})
}
}
type (
ID int64
INT int
INT8 int8
INT16 int16
INT32 int32
INT64 int64
UINT uint
UINT8 uint8
UINT16 uint16
UINT32 uint32
UINT64 uint64
FLOAT32 float32
FLOAT64 float64
BOOL bool
STRING string
TIME time.Time
BYTES []byte
TypeAlias struct {
ID
INT `gorm:"column:fint"`
INT8 `gorm:"column:fint8"`
INT16 `gorm:"column:fint16"`
INT32 `gorm:"column:fint32"`
INT64 `gorm:"column:fint64"`
UINT `gorm:"column:fuint"`
UINT8 `gorm:"column:fuint8"`
UINT16 `gorm:"column:fuint16"`
UINT32 `gorm:"column:fuint32"`
UINT64 `gorm:"column:fuint64"`
FLOAT32 `gorm:"column:ffloat32"`
FLOAT64 `gorm:"column:ffloat64"`
BOOL `gorm:"column:fbool"`
STRING `gorm:"column:fstring"`
TIME `gorm:"column:ftime"`
BYTES `gorm:"column:fbytes"`
}
)
func TestTypeAliasField(t *testing.T) {
alias, err := schema.Parse(&TypeAlias{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse TypeAlias with permission, got error %v", err)
}
fields := []*schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, PrimaryKey: true, HasDefaultValue: true, AutoIncrement: true},
{Name: "INT", DBName: "fint", BindNames: []string{"INT"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint"`},
{Name: "INT8", DBName: "fint8", BindNames: []string{"INT8"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fint8"`},
{Name: "INT16", DBName: "fint16", BindNames: []string{"INT16"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fint16"`},
{Name: "INT32", DBName: "fint32", BindNames: []string{"INT32"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fint32"`},
{Name: "INT64", DBName: "fint64", BindNames: []string{"INT64"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint64"`},
{Name: "UINT", DBName: "fuint", BindNames: []string{"UINT"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint"`},
{Name: "UINT8", DBName: "fuint8", BindNames: []string{"UINT8"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fuint8"`},
{Name: "UINT16", DBName: "fuint16", BindNames: []string{"UINT16"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fuint16"`},
{Name: "UINT32", DBName: "fuint32", BindNames: []string{"UINT32"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fuint32"`},
{Name: "UINT64", DBName: "fuint64", BindNames: []string{"UINT64"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint64"`},
{Name: "FLOAT32", DBName: "ffloat32", BindNames: []string{"FLOAT32"}, DataType: schema.Float, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:ffloat32"`},
{Name: "FLOAT64", DBName: "ffloat64", BindNames: []string{"FLOAT64"}, DataType: schema.Float, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:ffloat64"`},
{Name: "BOOL", DBName: "fbool", BindNames: []string{"BOOL"}, DataType: schema.Bool, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbool"`},
{Name: "STRING", DBName: "fstring", BindNames: []string{"STRING"}, DataType: schema.String, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fstring"`},
{Name: "TIME", DBName: "ftime", BindNames: []string{"TIME"}, DataType: schema.Time, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:ftime"`},
{Name: "BYTES", DBName: "fbytes", BindNames: []string{"BYTES"}, DataType: schema.Bytes, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbytes"`},
}
for _, f := range fields {
checkSchemaField(t, alias, f, func(f *schema.Field) {})
}
}

View File

@ -1,6 +1,7 @@
package schema
import (
"fmt"
"sort"
"strconv"
"strings"
@ -12,8 +13,8 @@ type Index struct {
Type string // btree, hash, gist, spgist, gin, and brin
Where string
Comment string
Option string // WITH PARSER parser_name
Fields []IndexOption
Option string // WITH PARSER parser_name
Fields []IndexOption // Note: IndexOption's Field maybe the same
}
type IndexOption struct {
@ -22,17 +23,28 @@ type IndexOption struct {
Sort string // DESC, ASC
Collate string
Length int
priority int
Priority int
}
// ParseIndexes parse schema indexes
func (schema *Schema) ParseIndexes() map[string]Index {
var indexes = map[string]Index{}
func (schema *Schema) ParseIndexes() []*Index {
indexesByName := map[string]*Index{}
indexes := []*Index{}
for _, field := range schema.Fields {
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
for _, index := range parseFieldIndexes(field) {
idx := indexes[index.Name]
fieldIndexes, err := parseFieldIndexes(field)
if err != nil {
schema.err = err
break
}
for _, index := range fieldIndexes {
idx := indexesByName[index.Name]
if idx == nil {
idx = &Index{Name: index.Name}
indexesByName[index.Name] = idx
indexes = append(indexes, idx)
}
idx.Name = index.Name
if idx.Class == "" {
idx.Class = index.Class
@ -52,14 +64,16 @@ func (schema *Schema) ParseIndexes() map[string]Index {
idx.Fields = append(idx.Fields, index.Fields...)
sort.Slice(idx.Fields, func(i, j int) bool {
return idx.Fields[i].priority < idx.Fields[j].priority
return idx.Fields[i].Priority < idx.Fields[j].Priority
})
indexes[index.Name] = idx
}
}
}
for _, index := range indexes {
if index.Class == "UNIQUE" && len(index.Fields) == 1 {
index.Fields[0].Field.UniqueIndex = index.Name
}
}
return indexes
}
@ -68,12 +82,12 @@ func (schema *Schema) LookIndex(name string) *Index {
indexes := schema.ParseIndexes()
for _, index := range indexes {
if index.Name == name {
return &index
return index
}
for _, field := range index.Fields {
if field.Name == name {
return &index
return index
}
}
}
@ -82,30 +96,41 @@ func (schema *Schema) LookIndex(name string) *Index {
return nil
}
func parseFieldIndexes(field *Field) (indexes []Index) {
func parseFieldIndexes(field *Field) (indexes []Index, err error) {
for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
if value != "" {
v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0]))
if k == "INDEX" || k == "UNIQUEINDEX" {
var (
name string
tag = strings.Join(v[1:], ":")
idx = strings.Index(tag, ",")
settings = ParseTagSetting(tag, ",")
length, _ = strconv.Atoi(settings["LENGTH"])
name string
tag = strings.Join(v[1:], ":")
idx = strings.IndexByte(tag, ',')
tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",")
settings = ParseTagSetting(tagSetting, ",")
length, _ = strconv.Atoi(settings["LENGTH"])
)
if idx == -1 {
idx = len(tag)
}
if idx != -1 {
name = tag[0:idx]
}
name = tag[0:idx]
if name == "" {
name = field.Schema.namer.IndexName(field.Schema.Table, field.Name)
subName := field.Name
const key = "COMPOSITE"
if composite, found := settings[key]; found {
if len(composite) == 0 || composite == key {
err = fmt.Errorf(
"the composite tag of %s.%s cannot be empty",
field.Schema.Name,
field.Name)
return
}
subName = composite
}
name = field.Schema.namer.IndexName(
field.Schema.Table, subName)
}
if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" {
@ -130,12 +155,13 @@ func parseFieldIndexes(field *Field) (indexes []Index) {
Sort: settings["SORT"],
Collate: settings["COLLATE"],
Length: length,
priority: priority,
Priority: priority,
}},
})
}
}
}
err = nil
return
}

View File

@ -1,11 +1,11 @@
package schema_test
import (
"reflect"
"sync"
"testing"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
type UserIndex struct {
@ -18,6 +18,41 @@ type UserIndex struct {
Age int64 `gorm:"index:profile,expression:ABS(age),option:WITH PARSER parser_name"`
OID int64 `gorm:"index:idx_id;index:idx_oid,unique"`
MemberNumber string `gorm:"index:idx_id,priority:1"`
Name7 string `gorm:"index:type"`
Name8 string `gorm:"index:,length:10;index:,collate:utf8"`
CompName1 string `gorm:"index:,unique,composite:idx_compname_1,option:NULLS NOT DISTINCT;not null"`
CompName2 string `gorm:"index:,composite:idx_compname_1"`
// Composite Index: Flattened structure.
Data0A string `gorm:"index:,composite:comp_id0"`
Data0B string `gorm:"index:,composite:comp_id0"`
// Composite Index: Nested structure.
Data1A string `gorm:"index:,composite:comp_id1"`
CompIdxLevel1C
// Composite Index: Unique and priority.
Data2A string `gorm:"index:,unique,composite:comp_id2,priority:2"`
CompIdxLevel2C
}
type CompIdxLevel1C struct {
CompIdxLevel1B
Data1C string `gorm:"index:,composite:comp_id1"`
}
type CompIdxLevel1B struct {
Data1B string `gorm:"index:,composite:comp_id1"`
}
type CompIdxLevel2C struct {
CompIdxLevel2B
Data2C string `gorm:"index:,unique,composite:comp_id2,priority:1"`
}
type CompIdxLevel2B struct {
Data2B string `gorm:"index:,unique,composite:comp_id2,priority:3"`
}
func TestParseIndex(t *testing.T) {
@ -26,17 +61,17 @@ func TestParseIndex(t *testing.T) {
t.Fatalf("failed to parse user index, got error %v", err)
}
results := map[string]schema.Index{
"idx_user_indices_name": {
results := []*schema.Index{
{
Name: "idx_user_indices_name",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name"}}},
},
"idx_name": {
{
Name: "idx_name",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2"}}},
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", UniqueIndex: "idx_name"}}},
},
"idx_user_indices_name3": {
{
Name: "idx_user_indices_name3",
Type: "btree",
Where: "name3 != 'jinzhu'",
@ -47,19 +82,19 @@ func TestParseIndex(t *testing.T) {
Length: 10,
}},
},
"idx_user_indices_name4": {
{
Name: "idx_user_indices_name4",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4"}}},
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", UniqueIndex: "idx_user_indices_name4"}}},
},
"idx_user_indices_name5": {
{
Name: "idx_user_indices_name5",
Class: "FULLTEXT",
Comment: "hello , world",
Where: "age > 10",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name5"}}},
},
"profile": {
{
Name: "profile",
Comment: "hello , world",
Where: "age > 10",
@ -69,48 +104,172 @@ func TestParseIndex(t *testing.T) {
Expression: "ABS(age)",
}},
},
"idx_id": {
{
Name: "idx_id",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID"}}},
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
},
"idx_oid": {
{
Name: "idx_oid",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}},
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
},
{
Name: "type",
Type: "",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}},
},
{
Name: "idx_user_indices_name8",
Type: "",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "Name8"}, Length: 10},
// Note: Duplicate Columns
{Field: &schema.Field{Name: "Name8"}, Collate: "utf8"},
},
},
{
Class: "UNIQUE",
Name: "idx_user_indices_idx_compname_1",
Option: "NULLS NOT DISTINCT",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "CompName1", NotNull: true}},
{Field: &schema.Field{Name: "CompName2"}},
},
},
{
Name: "idx_user_indices_comp_id0",
Type: "",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data0A"},
}, {
Field: &schema.Field{Name: "Data0B"},
}},
},
{
Name: "idx_user_indices_comp_id1",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data1A"},
}, {
Field: &schema.Field{Name: "Data1B"},
}, {
Field: &schema.Field{Name: "Data1C"},
}},
},
{
Name: "idx_user_indices_comp_id2",
Class: "UNIQUE",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data2C"},
}, {
Field: &schema.Field{Name: "Data2A"},
}, {
Field: &schema.Field{Name: "Data2B"},
}},
},
}
indices := user.ParseIndexes()
CheckIndices(t, results, user.ParseIndexes())
}
for k, result := range results {
v, ok := indices[k]
if !ok {
t.Fatalf("Failed to found index %v from parsed indices %+v", k, indices)
}
func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) {
type IndexTest struct {
FieldA string `gorm:"unique;index"` // unique and index
FieldB string `gorm:"unique"` // unique
for _, name := range []string{"Name", "Class", "Type", "Where", "Comment", "Option"} {
if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() {
t.Errorf(
"index %v %v should equal, expects %v, got %v",
k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(),
)
FieldC string `gorm:"index:,unique"` // uniqueIndex
FieldD string `gorm:"uniqueIndex;index"` // uniqueIndex and index
FieldE1 string `gorm:"uniqueIndex:uniq_field_e1_e2"` // mul uniqueIndex
FieldE2 string `gorm:"uniqueIndex:uniq_field_e1_e2"`
FieldF1 string `gorm:"uniqueIndex:uniq_field_f1_f2;index"` // mul uniqueIndex and index
FieldF2 string `gorm:"uniqueIndex:uniq_field_f1_f2;"`
FieldG string `gorm:"unique;uniqueIndex"` // unique and uniqueIndex
FieldH1 string `gorm:"unique;uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
FieldH2 string `gorm:"uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
}
indexSchema, err := schema.Parse(&IndexTest{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user index, got error %v", err)
}
indices := indexSchema.ParseIndexes()
expectedIndices := []*schema.Index{
{
Name: "idx_index_tests_field_a",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldA", Unique: true}}},
},
{
Name: "idx_index_tests_field_c",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldC", UniqueIndex: "idx_index_tests_field_c"}}},
},
{
Name: "idx_index_tests_field_d",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldD"}},
// Note: Duplicate Columns
{Field: &schema.Field{Name: "FieldD"}},
},
},
{
Name: "uniq_field_e1_e2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldE1"}},
{Field: &schema.Field{Name: "FieldE2"}},
},
},
{
Name: "uniq_field_f1_f2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldF1"}},
{Field: &schema.Field{Name: "FieldF2"}},
},
},
{
Name: "idx_index_tests_field_f1",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldF1"}}},
},
{
Name: "idx_index_tests_field_g",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldG", Unique: true, UniqueIndex: "idx_index_tests_field_g"}}},
},
{
Name: "uniq_field_h1_h2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldH1", Unique: true}},
{Field: &schema.Field{Name: "FieldH2"}},
},
},
}
CheckIndices(t, expectedIndices, indices)
}
func CheckIndices(t *testing.T, expected, actual []*schema.Index) {
if len(expected) != len(actual) {
t.Errorf("expected %d indices, but got %d", len(expected), len(actual))
return
}
for i, ei := range expected {
t.Run(ei.Name, func(t *testing.T) {
ai := actual[i]
tests.AssertObjEqual(t, ai, ei, "Name", "Class", "Type", "Where", "Comment", "Option")
if len(ei.Fields) != len(ai.Fields) {
t.Errorf("expected index %q field length is %d but actual %d", ei.Name, len(ei.Fields), len(ai.Fields))
return
}
}
for idx, ef := range result.Fields {
rf := v.Fields[idx]
if rf.Field.Name != ef.Field.Name {
t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name)
for i, ef := range ei.Fields {
af := ai.Fields[i]
tests.AssertObjEqual(t, af, ef, "Name", "Unique", "UniqueIndex", "Expression", "Sort", "Collate", "Length", "NotNull")
}
for _, name := range []string{"Expression", "Sort", "Collate", "Length"} {
if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() {
t.Errorf(
"index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name,
reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface(),
)
}
}
}
})
}
}

View File

@ -4,22 +4,39 @@ import (
"gorm.io/gorm/clause"
)
// ConstraintInterface database constraint interface
type ConstraintInterface interface {
GetName() string
Build() (sql string, vars []interface{})
}
// GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface {
GormDataType() string
}
// FieldNewValuePool field new scan value pool
type FieldNewValuePool interface {
Get() interface{}
Put(interface{})
}
// CreateClausesInterface create clauses interface
type CreateClausesInterface interface {
CreateClauses(*Field) []clause.Interface
}
// QueryClausesInterface query clauses interface
type QueryClausesInterface interface {
QueryClauses(*Field) []clause.Interface
}
// UpdateClausesInterface update clauses interface
type UpdateClausesInterface interface {
UpdateClauses(*Field) []clause.Interface
}
// DeleteClausesInterface delete clauses interface
type DeleteClausesInterface interface {
DeleteClauses(*Field) []clause.Interface
}

View File

@ -26,9 +26,11 @@ type User struct {
Active *bool
}
type mytime time.Time
type myint int
type mybool = bool
type (
mytime time.Time
myint int
mybool = bool
)
type AdvancedDataTypeUser struct {
ID sql.NullInt64

View File

@ -2,21 +2,26 @@ package schema
import (
"crypto/sha1"
"fmt"
"encoding/hex"
"regexp"
"strings"
"unicode/utf8"
"github.com/jinzhu/inflection"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
// Namer namer interface
type Namer interface {
TableName(table string) string
SchemaName(table string) string
ColumnName(table, column string) string
JoinTableName(joinTable string) string
RelationshipFKName(Relationship) string
CheckerName(table, column string) string
IndexName(table, column string) string
UniqueName(table, column string) string
}
// Replacer replacer interface like strings.Replacer
@ -24,12 +29,15 @@ type Replacer interface {
Replace(name string) string
}
var _ Namer = (*NamingStrategy)(nil)
// NamingStrategy tables, columns naming strategy
type NamingStrategy struct {
TablePrefix string
SingularTable bool
NameReplacer Replacer
NoLowerCase bool
TablePrefix string
SingularTable bool
NameReplacer Replacer
NoLowerCase bool
IdentifierMaxLength int
}
// TableName convert string to table name
@ -40,6 +48,16 @@ func (ns NamingStrategy) TableName(str string) string {
return ns.TablePrefix + inflection.Plural(ns.toDBName(str))
}
// SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName
func (ns NamingStrategy) SchemaName(table string) string {
table = strings.TrimPrefix(table, ns.TablePrefix)
if ns.SingularTable {
return ns.toSchemaName(table)
}
return ns.toSchemaName(inflection.Singular(table))
}
// ColumnName convert string to column name
func (ns NamingStrategy) ColumnName(table, column string) string {
return ns.toDBName(column)
@ -72,17 +90,28 @@ func (ns NamingStrategy) IndexName(table, column string) string {
return ns.formatName("idx", table, ns.toDBName(column))
}
func (ns NamingStrategy) formatName(prefix, table, name string) string {
formatedName := strings.Replace(fmt.Sprintf("%v_%v_%v", prefix, table, name), ".", "_", -1)
// UniqueName generate unique constraint name
func (ns NamingStrategy) UniqueName(table, column string) string {
return ns.formatName("uni", table, ns.toDBName(column))
}
if utf8.RuneCountInString(formatedName) > 64 {
func (ns NamingStrategy) formatName(prefix, table, name string) string {
formattedName := strings.ReplaceAll(strings.Join([]string{
prefix, table, name,
}, "_"), ".", "_")
if ns.IdentifierMaxLength == 0 {
ns.IdentifierMaxLength = 64
}
if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength {
h := sha1.New()
h.Write([]byte(formatedName))
h.Write([]byte(formattedName))
bs := h.Sum(nil)
formatedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + string(bs)[:8]
formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8]
}
return formatedName
return formattedName
}
var (
@ -94,7 +123,7 @@ var (
func init() {
commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
for _, initialism := range commonInitialisms {
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, cases.Title(language.Und).String(initialism))
}
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
}
@ -105,7 +134,13 @@ func (ns NamingStrategy) toDBName(name string) string {
}
if ns.NameReplacer != nil {
name = ns.NameReplacer.Replace(name)
tmpName := ns.NameReplacer.Replace(name)
if tmpName == "" {
return name
}
name = tmpName
}
if ns.NoLowerCase {
@ -151,3 +186,11 @@ func (ns NamingStrategy) toDBName(name string) string {
ret := buf.String()
return ret
}
func (ns NamingStrategy) toSchemaName(name string) string {
result := strings.ReplaceAll(cases.Title(language.Und, cases.NoLower).String(strings.ReplaceAll(name, "_", " ")), " ", "")
for _, initialism := range commonInitialisms {
result = regexp.MustCompile(cases.Title(language.Und, cases.NoLower).String(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
}
return result
}

View File

@ -6,7 +6,7 @@ import (
)
func TestToDBName(t *testing.T) {
var maps = map[string]string{
maps := map[string]string{
"": "",
"x": "x",
"X": "x",
@ -33,10 +33,30 @@ func TestToDBName(t *testing.T) {
t.Errorf("%v toName should equal %v, but got %v", key, value, ns.toDBName(key))
}
}
maps = map[string]string{
"x": "X",
"user_restrictions": "UserRestriction",
"this_is_a_test": "ThisIsATest",
"abc_and_jkl": "AbcAndJkl",
"employee_id": "EmployeeID",
"field_x": "FieldX",
"http_and_smtp": "HTTPAndSMTP",
"http_server_handler_for_url_id": "HTTPServerHandlerForURLID",
"uuid": "UUID",
"http_url": "HTTPURL",
"sha256_hash": "Sha256Hash",
"this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id": "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIDCanBeUsedAtTheEndAsID",
}
for key, value := range maps {
if ns.SchemaName(key) != value {
t.Errorf("%v schema name should equal %v, but got %v", key, value, ns.SchemaName(key))
}
}
}
func TestNamingStrategy(t *testing.T) {
var ns = NamingStrategy{
ns := NamingStrategy{
TablePrefix: "public.",
SingularTable: true,
NameReplacer: strings.NewReplacer("CID", "Cid"),
@ -82,7 +102,7 @@ func (r CustomReplacer) Replace(name string) string {
}
func TestCustomReplacer(t *testing.T) {
var ns = NamingStrategy{
ns := NamingStrategy{
TablePrefix: "public.",
SingularTable: true,
NameReplacer: CustomReplacer{
@ -126,7 +146,7 @@ func TestCustomReplacer(t *testing.T) {
}
func TestCustomReplacerWithNoLowerCase(t *testing.T) {
var ns = NamingStrategy{
ns := NamingStrategy{
TablePrefix: "public.",
SingularTable: true,
NameReplacer: CustomReplacer{
@ -168,3 +188,32 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) {
t.Errorf("invalid column name generated, got %v", columdName)
}
}
func TestFormatNameWithStringLongerThan63Characters(t *testing.T) {
ns := NamingStrategy{IdentifierMaxLength: 63}
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVer180f2c67" {
t.Errorf("invalid formatted name generated, got %v", formattedName)
}
}
func TestFormatNameWithStringLongerThan64Characters(t *testing.T) {
ns := NamingStrategy{IdentifierMaxLength: 64}
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" {
t.Errorf("invalid formatted name generated, got %v", formattedName)
}
}
func TestReplaceEmptyTableName(t *testing.T) {
ns := NamingStrategy{
SingularTable: true,
NameReplacer: strings.NewReplacer("Model", ""),
}
tableName := ns.TableName("Model")
if tableName != "Model" {
t.Errorf("invalid table name generated, got %v", tableName)
}
}

19
schema/pool.go Normal file
View File

@ -0,0 +1,19 @@
package schema
import (
"reflect"
"sync"
)
// sync pools
var (
normalPool sync.Map
poolInitializer = func(reflectType reflect.Type) FieldNewValuePool {
v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{
New: func() interface{} {
return reflect.New(reflectType).Interface()
},
})
return v.(FieldNewValuePool)
}
)

View File

@ -1,11 +1,16 @@
package schema
import (
"context"
"fmt"
"reflect"
"strings"
"sync"
"github.com/jinzhu/inflection"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"gorm.io/gorm/clause"
)
@ -26,6 +31,10 @@ type Relationships struct {
HasMany []*Relationship
Many2Many []*Relationship
Relations map[string]*Relationship
EmbeddedRelations map[string]*Relationships
Mux sync.RWMutex
}
type Relationship struct {
@ -69,14 +78,16 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
cacheStore := schema.cacheStore
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
schema.err = err
schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err)
return nil
}
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
schema.buildPolymorphicRelation(relation, field, polymorphic)
if hasPolymorphicRelation(field.TagSettings) {
schema.buildPolymorphicRelation(relation, field)
} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
schema.buildMany2ManyRelation(relation, field, many2many)
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
schema.guessRelation(relation, field, guessBelongs)
} else {
switch field.IndirectFieldType.Kind() {
case reflect.Struct:
@ -84,14 +95,16 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
case reflect.Slice:
schema.guessRelation(relation, field, guessHas)
default:
schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema,
field.Name)
}
}
if relation.Type == has {
// don't add relations to embeded schema, which might be shared
if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil {
relation.FieldSchema.Relationships.Mux.Lock()
relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation
relation.FieldSchema.Relationships.Mux.Unlock()
}
switch field.IndirectFieldType.Kind() {
@ -103,7 +116,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
}
if schema.err == nil {
schema.Relationships.Relations[relation.Name] = relation
schema.setRelation(relation)
switch relation.Type {
case HasOne:
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
@ -119,34 +132,100 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
return relation
}
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
// type User struct {
// Toys []Toy `gorm:"polymorphic:Owner;"`
// }
// type Pet struct {
// Toy Toy `gorm:"polymorphic:Owner;"`
// }
// type Toy struct {
// OwnerID int
// OwnerType string
// }
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) {
relation.Polymorphic = &Polymorphic{
Value: schema.Table,
PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"],
// hasPolymorphicRelation check if has polymorphic relation
// 1. `POLYMORPHIC` tag
// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
func hasPolymorphicRelation(tagSettings map[string]string) bool {
if _, ok := tagSettings["POLYMORPHIC"]; ok {
return true
}
_, hasType := tagSettings["POLYMORPHICTYPE"]
_, hasId := tagSettings["POLYMORPHICID"]
return hasType && hasId
}
func (schema *Schema) setRelation(relation *Relationship) {
// set non-embedded relation
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
if len(rel.Field.BindNames) > 1 {
schema.Relationships.Relations[relation.Name] = relation
}
} else {
schema.Relationships.Relations[relation.Name] = relation
}
// set embedded relation
if len(relation.Field.EmbeddedBindNames) <= 1 {
return
}
relationships := &schema.Relationships
for i, name := range relation.Field.EmbeddedBindNames {
if i < len(relation.Field.EmbeddedBindNames)-1 {
if relationships.EmbeddedRelations == nil {
relationships.EmbeddedRelations = map[string]*Relationships{}
}
if r := relationships.EmbeddedRelations[name]; r == nil {
relationships.EmbeddedRelations[name] = &Relationships{}
}
relationships = relationships.EmbeddedRelations[name]
} else {
if relationships.Relations == nil {
relationships.Relations = map[string]*Relationship{}
}
relationships.Relations[relation.Name] = relation
}
}
}
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
//
// type User struct {
// Toys []Toy `gorm:"polymorphic:Owner;"`
// }
// type Pet struct {
// Toy Toy `gorm:"polymorphic:Owner;"`
// }
// type Toy struct {
// OwnerID int
// OwnerType string
// }
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
polymorphic := field.TagSettings["POLYMORPHIC"]
relation.Polymorphic = &Polymorphic{
Value: schema.Table,
}
var (
typeName = polymorphic + "Type"
typeId = polymorphic + "ID"
)
if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
typeName = strings.TrimSpace(value)
}
if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
typeId = strings.TrimSpace(value)
}
relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]
if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
relation.Polymorphic.Value = strings.TrimSpace(value)
}
if relation.Polymorphic.PolymorphicType == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"Type")
}
if relation.Polymorphic.PolymorphicID == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"ID")
}
if schema.err == nil {
@ -158,12 +237,21 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
primaryKeyField := schema.PrioritizedPrimaryField
if len(relation.foreignKeys) > 0 {
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name)
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys,
schema, field.Name)
}
}
if primaryKeyField == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
relation.FieldSchema, schema, field.Name)
return
}
// use same data type for foreign keys
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
if copyableDataType(primaryKeyField.DataType) {
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
}
relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType
if relation.Polymorphic.PolymorphicID.Size == 0 {
relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size
@ -186,7 +274,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
err error
joinTableFields []reflect.StructField
fieldsMap = map[string]*Field{}
ownFieldsMap = map[string]bool{} // fix self join many2many
ownFieldsMap = map[string]*Field{} // fix self join many2many
referFieldsMap = map[string]*Field{}
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
)
@ -200,7 +289,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
if field := schema.LookUpField(foreignKey); field != nil {
ownForeignFields = append(ownForeignFields, field)
} else {
schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey)
schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
return
}
}
@ -212,33 +301,31 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
if field := relation.FieldSchema.LookUpField(foreignKey); field != nil {
refForeignFields = append(refForeignFields, field)
} else {
schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey)
schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
return
}
}
}
for idx, ownField := range ownForeignFields {
joinFieldName := strings.Title(schema.Name) + ownField.Name
joinFieldName := cases.Title(language.Und, cases.NoLower).String(schema.Name) + ownField.Name
if len(joinForeignKeys) > idx {
joinFieldName = strings.Title(joinForeignKeys[idx])
joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinForeignKeys[idx])
}
ownFieldsMap[joinFieldName] = true
ownFieldsMap[joinFieldName] = ownField
fieldsMap[joinFieldName] = ownField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName,
PkgPath: ownField.StructField.PkgPath,
Type: ownField.StructField.Type,
Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
})
}
for idx, relField := range refForeignFields {
joinFieldName := relation.FieldSchema.Name + relField.Name
if len(joinReferences) > idx {
joinFieldName = strings.Title(joinReferences[idx])
}
joinFieldName := cases.Title(language.Und, cases.NoLower).String(relation.FieldSchema.Name) + relField.Name
if _, ok := ownFieldsMap[joinFieldName]; ok {
if field.Name != relation.FieldSchema.Name {
@ -248,22 +335,32 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
}
}
fieldsMap[joinFieldName] = relField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName,
PkgPath: relField.StructField.PkgPath,
Type: relField.StructField.Type,
Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
})
if len(joinReferences) > idx {
joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinReferences[idx])
}
referFieldsMap[joinFieldName] = relField
if _, ok := fieldsMap[joinFieldName]; !ok {
fieldsMap[joinFieldName] = relField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName,
PkgPath: relField.StructField.PkgPath,
Type: relField.StructField.Type,
Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
})
}
}
joinTableFields = append(joinTableFields, reflect.StructField{
Name: strings.Title(schema.Name) + field.Name,
Name: cases.Title(language.Und, cases.NoLower).String(schema.Name) + field.Name,
Type: schema.ModelType,
Tag: `gorm:"-"`,
})
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
schema.namer); err != nil {
schema.err = err
}
relation.JoinTable.Name = many2many
@ -302,37 +399,45 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
for _, f := range relation.JoinTable.Fields {
if f.Creatable || f.Readable || f.Updatable {
// use same data type for foreign keys
f.DataType = fieldsMap[f.Name].DataType
if copyableDataType(fieldsMap[f.Name].DataType) {
f.DataType = fieldsMap[f.Name].DataType
}
f.GORMDataType = fieldsMap[f.Name].GORMDataType
if f.Size == 0 {
f.Size = fieldsMap[f.Name].Size
}
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name]
if ownPriamryField {
if of, ok := ownFieldsMap[f.Name]; ok {
joinRel := relation.JoinTable.Relationships.Relations[relName]
joinRel.Field = relation.Field
joinRel.References = append(joinRel.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
PrimaryKey: of,
ForeignKey: f,
})
} else {
relation.References = append(relation.References, &Reference{
PrimaryKey: of,
ForeignKey: f,
OwnPrimaryKey: true,
})
}
if rf, ok := referFieldsMap[f.Name]; ok {
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
if joinRefRel.Field == nil {
joinRefRel.Field = relation.Field
}
joinRefRel.References = append(joinRefRel.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
PrimaryKey: rf,
ForeignKey: f,
})
relation.References = append(relation.References, &Reference{
PrimaryKey: rf,
ForeignKey: f,
})
}
relation.References = append(relation.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
ForeignKey: f,
OwnPrimaryKey: ownPriamryField,
})
}
}
}
@ -374,7 +479,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
schema.guessRelation(relation, field, guessEmbeddedHas)
// case guessEmbeddedHas:
default:
schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a valid foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name)
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface",
schema, field.Name)
}
}
@ -382,33 +488,34 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
case guessBelongs:
primarySchema, foreignSchema = relation.FieldSchema, schema
case guessEmbeddedBelongs:
if field.OwnerSchema != nil {
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
} else {
if field.OwnerSchema == nil {
reguessOrErr()
return
}
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
case guessHas:
case guessEmbeddedHas:
if field.OwnerSchema != nil {
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
} else {
if field.OwnerSchema == nil {
reguessOrErr()
return
}
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
}
if len(relation.foreignKeys) > 0 {
for _, foreignKey := range relation.foreignKeys {
if f := foreignSchema.LookUpField(foreignKey); f != nil {
foreignFields = append(foreignFields, f)
} else {
f := foreignSchema.LookUpField(foreignKey)
if f == nil {
reguessOrErr()
return
}
foreignFields = append(foreignFields, f)
}
} else {
var primaryFields []*Field
primarySchemaName := primarySchema.Name
if primarySchemaName == "" {
primarySchemaName = relation.FieldSchema.Name
}
if len(relation.primaryKeys) > 0 {
for _, primaryKey := range relation.primaryKeys {
@ -420,31 +527,42 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
primaryFields = primarySchema.PrimaryFields
}
primaryFieldLoop:
for _, primaryField := range primaryFields {
lookUpName := primarySchema.Name + primaryField.Name
lookUpName := primarySchemaName + primaryField.Name
if gl == guessBelongs {
lookUpName = field.Name + primaryField.Name
}
lookUpNames := []string{lookUpName}
if len(primaryFields) == 1 {
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID",
strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
}
for _, name := range lookUpNames {
if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil {
foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField)
continue primaryFieldLoop
}
}
for _, name := range lookUpNames {
if f := foreignSchema.LookUpField(name); f != nil {
foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField)
break
continue primaryFieldLoop
}
}
}
}
if len(foreignFields) == 0 {
switch {
case len(foreignFields) == 0:
reguessOrErr()
return
} else if len(relation.primaryKeys) > 0 {
case len(relation.primaryKeys) > 0:
for idx, primaryKey := range relation.primaryKeys {
if f := primarySchema.LookUpField(primaryKey); f != nil {
if len(primaryFields) < idx+1 {
@ -458,7 +576,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
return
}
}
} else if len(primaryFields) == 0 {
case len(primaryFields) == 0:
if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
} else if len(primarySchema.PrimaryFields) == len(foreignFields) {
@ -472,7 +590,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
// build references
for idx, foreignField := range foreignFields {
// use same data type for foreign keys
foreignField.DataType = primaryFields[idx].DataType
if copyableDataType(primaryFields[idx].DataType) {
foreignField.DataType = primaryFields[idx].DataType
}
foreignField.GORMDataType = primaryFields[idx].GORMDataType
if foreignField.Size == 0 {
foreignField.Size = primaryFields[idx].Size
@ -492,6 +612,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
}
}
// Constraint is ForeignKey Constraint
type Constraint struct {
Name string
Field *Field
@ -503,6 +624,31 @@ type Constraint struct {
OnUpdate string
}
func (constraint *Constraint) GetName() string { return constraint.Name }
func (constraint *Constraint) Build() (sql string, vars []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys))
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
references := make([]interface{}, 0, len(constraint.References))
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}
func (rel *Relationship) ParseConstraint() *Constraint {
str := rel.Field.TagSettings["CONSTRAINT"]
if str == "-" {
@ -511,12 +657,13 @@ func (rel *Relationship) ParseConstraint() *Constraint {
if rel.Type == BelongsTo {
for _, r := range rel.FieldSchema.Relationships.Relations {
if r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) {
if r != rel && r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) {
matched := true
for idx, ref := range r.References {
if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey &&
rel.References[idx].PrimaryValue == ref.PrimaryValue) {
matched = false
break
}
}
@ -529,7 +676,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
var (
name string
idx = strings.Index(str, ",")
idx = strings.IndexByte(str, ',')
settings = ParseTagSetting(str, ",")
)
@ -568,7 +715,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
return &constraint
}
func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) {
func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) {
table := rel.FieldSchema.Table
foreignFields := []*Field{}
relForeignKeys := []string{}
@ -608,9 +755,19 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []
}
}
_, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields)
_, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields)
column, values := ToQueryValues(table, relForeignKeys, foreignValues)
conds = append(conds, clause.IN{Column: column, Values: values})
return
}
func copyableDataType(str DataType) bool {
lowerStr := strings.ToLower(string(str))
for _, s := range []string{"auto_increment", "primary key"} {
if strings.Contains(lowerStr, s) {
return false
}
}
return true
}

View File

@ -10,7 +10,7 @@ import (
func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) {
if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil {
t.Errorf("Failed to parse schema")
t.Errorf("Failed to parse schema, got error %v", err)
} else {
for _, rel := range relations {
checkSchemaRelation(t, s, rel)
@ -93,6 +93,20 @@ func TestBelongsToWithOnlyReferences2(t *testing.T) {
})
}
func TestSelfReferentialBelongsTo(t *testing.T) {
type User struct {
ID int32 `gorm:"primaryKey"`
Name string
CreatorID *int32
Creator *User
}
checkStructRelation(t, &User{}, Relation{
Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User",
References: []Reference{{"ID", "User", "CreatorID", "User", "", false}},
})
}
func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) {
type User struct {
ID int32 `gorm:"primaryKey"`
@ -107,6 +121,29 @@ func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) {
})
}
func TestBelongsToWithMixin(t *testing.T) {
type Profile struct {
gorm.Model
Refer string
Name string
}
type ProfileMixin struct {
Profile Profile `gorm:"References:Refer"`
ProfileRefer int
}
type User struct {
gorm.Model
ProfileMixin
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}},
})
}
func TestHasOneOverrideForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
@ -144,6 +181,24 @@ func TestHasOneOverrideReferences(t *testing.T) {
})
}
func TestHasOneOverrideReferences2(t *testing.T) {
type Profile struct {
gorm.Model
Name string
}
type User struct {
gorm.Model
ProfileID uint `gorm:"column:profile_id"`
Profile *Profile `gorm:"foreignKey:ID;references:ProfileID"`
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"ProfileID", "User", "ID", "Profile", "", true}},
})
}
func TestHasOneWithOnlyReferences(t *testing.T) {
type Profile struct {
gorm.Model
@ -273,6 +328,33 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) {
})
}
func TestMany2ManySharedForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
Name string
Kind string
ProfileRefer uint
}
type User struct {
gorm.Model
Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"`
Kind string
Refer uint
}
checkStructRelation(t, &User{}, Relation{
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"},
References: []Reference{
{"Refer", "User", "UserRefer", "user_profiles", "", true},
{"Kind", "User", "Kind", "user_profiles", "", true},
{"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false},
{"Kind", "Profile", "Kind", "user_profiles", "", false},
},
})
}
func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
@ -459,6 +541,340 @@ func TestEmbeddedRelation(t *testing.T) {
}
}
func TestEmbeddedHas(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type User struct {
ID int
Cat struct {
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
Toys []Toy `gorm:"polymorphic:Owner;"`
} `gorm:"embedded;embeddedPrefix:cat_"`
Dog struct {
ID int
Name string
UserID int
Toy Toy `gorm:"polymorphic:Owner;"`
Toys []Toy `gorm:"polymorphic:Owner;"`
}
Toys []Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
}
func TestPolymorphic(t *testing.T) {
t.Run("has one", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has one with custom polymorphic type and id", func(t *testing.T) {
type Toy struct {
ID int
Name string
RefId int
Type string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type;polymorphicId:RefId"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has one with only polymorphic type", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
Type string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "owner_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has many", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type Cat struct {
ID int
Name string
Toys []Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has many with custom polymorphic type and id", func(t *testing.T) {
type Toy struct {
ID int
Name string
RefId int
Type string
}
type Cat struct {
ID int
Name string
Toys []Toy `gorm:"polymorphicType:Type;polymorphicId:RefId"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
}
func TestEmbeddedBelongsTo(t *testing.T) {
type Country struct {
ID int `gorm:"primaryKey"`
Name string
}
type Address struct {
CountryID int
Country Country
}
type NestedAddress struct {
Address
}
type CountryMixin struct {
CountryID int
Country Country
}
type Org struct {
ID int
PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"`
VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"`
AddressID int
Address struct {
ID int
Address
}
NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
CountryMixin
}
s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Errorf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"PostalAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
"VisitingAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
"NestedAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
})
}
func TestVariableRelation(t *testing.T) {
var result struct {
User
}
checkStructRelation(t, &result, Relation{
Name: "Account", Type: schema.HasOne, Schema: "", FieldSchema: "Account",
References: []Reference{
{"ID", "", "UserID", "Account", "", true},
},
})
checkStructRelation(t, &result, Relation{
Name: "Company", Type: schema.BelongsTo, Schema: "", FieldSchema: "Company",
References: []Reference{
{"ID", "Company", "CompanyID", "", "", false},
},
})
}
func TestSameForeignKey(t *testing.T) {
type UserAux struct {
gorm.Model
@ -482,3 +898,101 @@ func TestSameForeignKey(t *testing.T) {
},
)
}
func TestBelongsToSameForeignKey(t *testing.T) {
type User struct {
gorm.Model
Name string
UUID string
}
type UserAux struct {
gorm.Model
Aux string
UUID string
User User `gorm:"ForeignKey:UUID;references:UUID;belongsTo"`
}
checkStructRelation(t, &UserAux{},
Relation{
Name: "User", Type: schema.BelongsTo, Schema: "UserAux", FieldSchema: "User",
References: []Reference{
{"UUID", "User", "UUID", "UserAux", "", false},
},
},
)
}
func TestHasOneWithSameForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
Name string
ProfileRefer int // not used in relationship
}
type User struct {
gorm.Model
Profile Profile `gorm:"ForeignKey:ID;references:ProfileRefer"`
ProfileRefer int
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"ProfileRefer", "User", "ID", "Profile", "", true}},
})
}
func TestHasManySameForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserRefer uint
}
type User struct {
gorm.Model
UserRefer uint
Profile []Profile `gorm:"ForeignKey:UserRefer"`
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.HasMany, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}},
})
}
type Author struct {
gorm.Model
}
type Book struct {
gorm.Model
Author Author
AuthorID uint
}
func (Book) TableName() string {
return "my_schema.a_very_very_very_very_very_very_very_very_long_table_name"
}
func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) {
s, err := schema.Parse(
&Book{},
&sync.Map{},
schema.NamingStrategy{IdentifierMaxLength: 64},
)
if err != nil {
t.Fatalf("Failed to parse schema")
}
expectedConstraintName := "fk_my_schema_a_very_very_very_very_very_very_very_very_l4db13eec"
constraint := s.Relationships.Relations["Author"].ParseConstraint()
if constraint.Name != expectedConstraintName {
t.Fatalf(
"expected constraint name %s, got %s",
expectedConstraintName,
constraint.Name,
)
}
}

View File

@ -5,13 +5,29 @@ import (
"errors"
"fmt"
"go/ast"
"path"
"reflect"
"strings"
"sync"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
)
type callbackType string
const (
callbackTypeBeforeCreate callbackType = "BeforeCreate"
callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
callbackTypeAfterCreate callbackType = "AfterCreate"
callbackTypeAfterUpdate callbackType = "AfterUpdate"
callbackTypeBeforeSave callbackType = "BeforeSave"
callbackTypeAfterSave callbackType = "AfterSave"
callbackTypeBeforeDelete callbackType = "BeforeDelete"
callbackTypeAfterDelete callbackType = "AfterDelete"
callbackTypeAfterFind callbackType = "AfterFind"
)
// ErrUnsupportedDataType unsupported data type
var ErrUnsupportedDataType = errors.New("unsupported data type")
@ -25,6 +41,7 @@ type Schema struct {
PrimaryFieldDBNames []string
Fields []*Field
FieldsByName map[string]*Field
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
Relationships Relationships
@ -45,15 +62,16 @@ type Schema struct {
func (schema Schema) String() string {
if schema.ModelType.Name() == "" {
return fmt.Sprintf("%v(%v)", schema.Name, schema.Table)
return fmt.Sprintf("%s(%s)", schema.Name, schema.Table)
}
return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name())
return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name())
}
func (schema Schema) MakeSlice() reflect.Value {
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20)
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20)
results := reflect.New(slice.Type())
results.Elem().Set(slice)
return results
}
@ -67,17 +85,56 @@ func (schema Schema) LookUpField(name string) *Field {
return nil
}
// LookUpFieldByBindName looks for the closest field in the embedded struct.
//
// type Struct struct {
// Embedded struct {
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
// }
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
// }
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
if len(bindNames) == 0 {
return nil
}
for i := len(bindNames) - 1; i >= 0; i-- {
find := strings.Join(bindNames[:i], ".") + "." + name
if field, ok := schema.FieldsByBindName[find]; ok {
return field
}
}
return nil
}
type Tabler interface {
TableName() string
}
// get data type from dialector
type TablerWithNamer interface {
TableName(Namer) string
}
// Parse get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
}
// ParseWithSpecialTableName get data type from dialector with extra schema table
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
modelType := reflect.ValueOf(dest).Type()
value := reflect.ValueOf(dest)
if value.Kind() == reflect.Ptr && value.IsNil() {
value = reflect.New(value.Type().Elem())
}
modelType := reflect.Indirect(value).Type()
if modelType.Kind() == reflect.Interface {
modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type()
}
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
@ -86,11 +143,22 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
}
if v, ok := cacheStore.Load(modelType); ok {
// Cache the Schema for performance,
// Use the modelType or modelType + schemaTable (if it present) as cache key.
var schemaCacheKey interface{}
if specialTableName != "" {
schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName)
} else {
schemaCacheKey = modelType
}
// Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
@ -100,28 +168,38 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName()
}
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
tableName = tabler.TableName(namer)
}
if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table
}
schema := &Schema{
Name: modelType.Name(),
ModelType: modelType,
Table: tableName,
FieldsByName: map[string]*Field{},
FieldsByDBName: map[string]*Field{},
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
initialized: make(chan struct{}),
if specialTableName != "" && specialTableName != tableName {
tableName = specialTableName
}
defer func() {
if schema.err != nil {
logger.Default.Error(context.Background(), schema.err.Error())
cacheStore.Delete(modelType)
}
}()
schema := &Schema{
Name: modelType.Name(),
ModelType: modelType,
Table: tableName,
FieldsByName: map[string]*Field{},
FieldsByBindName: map[string]*Field{},
FieldsByDBName: map[string]*Field{},
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
initialized: make(chan struct{}),
}
// When the schema initialization is completed, the channel will be closed
defer close(schema.initialized)
// Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
@ -138,6 +216,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
field.DBName = namer.ColumnName(schema.Table, field.Name)
}
bindName := field.BindName()
if field.DBName != "" {
// nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
@ -146,6 +225,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
}
schema.FieldsByDBName[field.DBName] = field
schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[bindName] = field
if v != nil && v.PrimaryKey {
for idx, f := range schema.PrimaryFields {
@ -164,8 +244,11 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByName[field.Name] = field
}
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByBindName[bindName] = field
}
field.setupValuerAndSetter()
field.setupValuerAndSetter(modelType)
}
prioritizedPrimaryField := schema.LookUpField("id")
@ -183,16 +266,26 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
}
}
if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 {
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
if schema.PrioritizedPrimaryField == nil {
if len(schema.PrimaryFields) == 1 {
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
} else if len(schema.PrimaryFields) > 1 {
// If there are multiple primary keys, the AUTOINCREMENT field is prioritized
for _, field := range schema.PrimaryFields {
if field.AutoIncrement {
schema.PrioritizedPrimaryField = field
break
}
}
}
}
for _, field := range schema.PrimaryFields {
schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName)
}
for _, field := range schema.FieldsByDBName {
if field.HasDefaultValue && field.DefaultValueInterface == nil {
for _, field := range schema.Fields {
if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
}
}
@ -211,49 +304,71 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
}
}
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
for _, name := range callbacks {
if methodValue := modelValue.MethodByName(name); methodValue.IsValid() {
callbackTypes := []callbackType{
callbackTypeBeforeCreate, callbackTypeAfterCreate,
callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
callbackTypeBeforeSave, callbackTypeAfterSave,
callbackTypeBeforeDelete, callbackTypeAfterDelete,
callbackTypeAfterFind,
}
for _, cbName := range callbackTypes {
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
switch methodValue.Type().String() {
case "func(*gorm.DB) error": // TODO hack
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
case "func(*gorm.DB) error":
expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath())
if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath {
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
} else {
logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg)
// PASS
}
default:
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name)
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
}
}
}
if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded {
// Cache the schema
if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
defer close(schema.initialized)
defer func() {
if schema.err != nil {
logger.Default.Error(context.Background(), schema.err.Error())
cacheStore.Delete(modelType)
}
}()
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
for _, field := range schema.Fields {
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) {
if schema.parseRelation(field); schema.err != nil {
return schema, schema.err
} else {
schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[field.BindName()] = field
}
}
fieldValue := reflect.New(field.IndirectFieldType)
if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok {
fieldInterface := fieldValue.Interface()
if fc, ok := fieldInterface.(CreateClausesInterface); ok {
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
}
if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok {
if fc, ok := fieldInterface.(QueryClausesInterface); ok {
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
}
if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok {
if fc, ok := fieldInterface.(UpdateClausesInterface); ok {
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
}
if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok {
if fc, ok := fieldInterface.(DeleteClausesInterface); ok {
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
}
}
@ -262,6 +377,39 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
return schema, schema.err
}
// This unrolling is needed to show to the compiler the exact set of methods
// that can be used on the modelType.
// Prior to go1.22 any use of MethodByName would cause the linker to
// abandon dead code elimination for the entire binary.
// As of go1.22 the compiler supports one special case of a string constant
// being passed to MethodByName. For enterprise customers or those building
// large binaries, this gives a significant reduction in binary size.
// https://github.com/golang/go/issues/62257
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
switch cbType {
case callbackTypeBeforeCreate:
return modelType.MethodByName(string(callbackTypeBeforeCreate))
case callbackTypeAfterCreate:
return modelType.MethodByName(string(callbackTypeAfterCreate))
case callbackTypeBeforeUpdate:
return modelType.MethodByName(string(callbackTypeBeforeUpdate))
case callbackTypeAfterUpdate:
return modelType.MethodByName(string(callbackTypeAfterUpdate))
case callbackTypeBeforeSave:
return modelType.MethodByName(string(callbackTypeBeforeSave))
case callbackTypeAfterSave:
return modelType.MethodByName(string(callbackTypeAfterSave))
case callbackTypeBeforeDelete:
return modelType.MethodByName(string(callbackTypeBeforeDelete))
case callbackTypeAfterDelete:
return modelType.MethodByName(string(callbackTypeAfterDelete))
case callbackTypeAfterFind:
return modelType.MethodByName(string(callbackTypeAfterFind))
default:
return reflect.ValueOf(nil)
}
}
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
@ -272,7 +420,7 @@ func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, e
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
}
if v, ok := cacheStore.Load(modelType); ok {

View File

@ -1,6 +1,7 @@
package schema_test
import (
"context"
"fmt"
"reflect"
"strings"
@ -29,7 +30,7 @@ func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields
}
if !found {
t.Errorf("schema %v failed to found priamry key: %v", s, field)
t.Errorf("schema %v failed to found primary key: %v", s, field)
}
}
})
@ -162,8 +163,8 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table)
}
for _, f := range relation.JoinTable.Fields {
checkSchemaField(t, r.JoinTable, &f, nil)
for i := range relation.JoinTable.Fields {
checkSchemaField(t, r.JoinTable, &relation.JoinTable.Fields[i], nil)
}
}
@ -200,10 +201,41 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
})
}
type EmbeddedRelations struct {
Relations map[string]Relation
EmbeddedRelations map[string]EmbeddedRelations
}
func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationships, expected map[string]EmbeddedRelations) {
for name, relations := range actual {
rs := expected[name]
t.Run("CheckEmbeddedRelations/"+name, func(t *testing.T) {
if len(relations.Relations) != len(rs.Relations) {
t.Errorf("schema relations count don't match, expects %d, got %d", len(rs.Relations), len(relations.Relations))
}
if len(relations.EmbeddedRelations) != len(rs.EmbeddedRelations) {
t.Errorf("schema embedded relations count don't match, expects %d, got %d", len(rs.EmbeddedRelations), len(relations.EmbeddedRelations))
}
for n, rel := range relations.Relations {
if r, ok := rs.Relations[n]; !ok {
t.Errorf("failed to find relation by name %s", n)
} else {
checkSchemaRelation(t, &schema.Schema{
Relationships: schema.Relationships{
Relations: map[string]*schema.Relationship{n: rel},
},
}, r)
}
}
checkEmbeddedRelations(t, relations.EmbeddedRelations, rs.EmbeddedRelations)
})
}
}
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
for k, v := range values {
t.Run("CheckField/"+k, func(t *testing.T) {
fv, _ := s.FieldsByDBName[k].ValueOf(value)
fv, _ := s.FieldsByDBName[k].ValueOf(context.Background(), value)
tests.AssertEqual(t, v, fv)
})
}

View File

@ -19,6 +19,22 @@ func TestParseSchema(t *testing.T) {
checkUserSchema(t, user)
}
func TestParseSchemaWithMap(t *testing.T) {
type User struct {
tests.User
Attrs map[string]string `gorm:"type:Map(String,String);"`
}
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user with map, got error %v", err)
}
if field := user.FieldsByName["Attrs"]; field.DataType != "Map(String,String)" {
t.Errorf("failed to parse user field Attrs")
}
}
func TestParseSchemaWithPointerFields(t *testing.T) {
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
@ -46,8 +62,8 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
}
for _, f := range fields {
checkSchemaField(t, user, &f, func(f *schema.Field) {
for i := range fields {
checkSchemaField(t, user, &fields[i], func(f *schema.Field) {
f.Creatable = true
f.Updatable = true
f.Readable = true
@ -136,8 +152,8 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) {
{Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool},
}
for _, f := range fields {
checkSchemaField(t, user, &f, func(f *schema.Field) {
for i := range fields {
checkSchemaField(t, user, &fields[i], func(f *schema.Field) {
f.Creatable = true
f.Updatable = true
f.Readable = true
@ -145,8 +161,7 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) {
}
}
type CustomizeTable struct {
}
type CustomizeTable struct{}
func (CustomizeTable) TableName() string {
return "customize"
@ -165,7 +180,6 @@ func TestCustomizeTableName(t *testing.T) {
func TestNestedModel(t *testing.T) {
versionUser, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse nested user, got error %v", err)
}
@ -204,7 +218,6 @@ func TestEmbeddedStruct(t *testing.T) {
}
cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse embedded struct with primary key, got error %v", err)
}
@ -273,7 +286,6 @@ func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) {
}
cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, CustomizedNamingStrategy{schema.NamingStrategy{}})
if err != nil {
t.Fatalf("failed to parse embedded struct with primary key, got error %v", err)
}
@ -297,3 +309,44 @@ func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) {
})
}
}
func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) {
type Product struct {
ProductID uint `gorm:"primaryKey;autoIncrement"`
LanguageCode uint `gorm:"primaryKey"`
Code string
Name string
}
type ProductNonAutoIncrement struct {
ProductID uint `gorm:"primaryKey;autoIncrement:false"`
LanguageCode uint `gorm:"primaryKey"`
Code string
Name string
}
product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse product struct with composite primary key, got error %v", err)
}
prioritizedPrimaryField := schema.Field{
Name: "ProductID", DBName: "product_id", BindNames: []string{"ProductID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY", "AUTOINCREMENT": "AUTOINCREMENT"},
}
product.Fields = []*schema.Field{product.PrioritizedPrimaryField}
checkSchemaField(t, product, &prioritizedPrimaryField, func(f *schema.Field) {
f.Creatable = true
f.Updatable = true
f.Readable = true
})
productNonAutoIncrement, err := schema.Parse(&ProductNonAutoIncrement{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse productNonAutoIncrement struct with composite primary key, got error %v", err)
}
if productNonAutoIncrement.PrioritizedPrimaryField != nil {
t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil")
}
}

173
schema/serializer.go Normal file
View File

@ -0,0 +1,173 @@
package schema
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"encoding/gob"
"encoding/json"
"fmt"
"reflect"
"strings"
"sync"
"time"
)
var serializerMap = sync.Map{}
// RegisterSerializer register serializer
func RegisterSerializer(name string, serializer SerializerInterface) {
serializerMap.Store(strings.ToLower(name), serializer)
}
// GetSerializer get serializer
func GetSerializer(name string) (serializer SerializerInterface, ok bool) {
v, ok := serializerMap.Load(strings.ToLower(name))
if ok {
serializer, ok = v.(SerializerInterface)
}
return serializer, ok
}
func init() {
RegisterSerializer("json", JSONSerializer{})
RegisterSerializer("unixtime", UnixSecondSerializer{})
RegisterSerializer("gob", GobSerializer{})
}
// Serializer field value serializer
type serializer struct {
Field *Field
Serializer SerializerInterface
SerializeValuer SerializerValuerInterface
Destination reflect.Value
Context context.Context
value interface{}
fieldValue interface{}
}
// Scan implements sql.Scanner interface
func (s *serializer) Scan(value interface{}) error {
s.value = value
return nil
}
// Value implements driver.Valuer interface
func (s serializer) Value() (driver.Value, error) {
return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue)
}
// SerializerInterface serializer interface
type SerializerInterface interface {
Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error
SerializerValuerInterface
}
// SerializerValuerInterface serializer valuer interface
type SerializerValuerInterface interface {
Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error)
}
// JSONSerializer json serializer
type JSONSerializer struct{}
// Scan implements serializer interface
func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
fieldValue := reflect.New(field.FieldType)
if dbValue != nil {
var bytes []byte
switch v := dbValue.(type) {
case []byte:
bytes = v
case string:
bytes = []byte(v)
default:
bytes, err = json.Marshal(v)
if err != nil {
return err
}
}
if len(bytes) > 0 {
err = json.Unmarshal(bytes, fieldValue.Interface())
}
}
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
return
}
// Value implements serializer interface
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
result, err := json.Marshal(fieldValue)
if string(result) == "null" {
if field.TagSettings["NOT NULL"] != "" {
return "", nil
}
return nil, err
}
return string(result), err
}
// UnixSecondSerializer json serializer
type UnixSecondSerializer struct{}
// Scan implements serializer interface
func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
t := sql.NullTime{}
if err = t.Scan(dbValue); err == nil && t.Valid {
err = field.Set(ctx, dst, t.Time.Unix())
}
return
}
// Value implements serializer interface
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) {
rv := reflect.ValueOf(fieldValue)
switch v := fieldValue.(type) {
case int64, int, uint, uint64, int32, uint32, int16, uint16:
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
if rv.IsZero() {
return nil, nil
}
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
default:
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
}
return
}
// GobSerializer gob serializer
type GobSerializer struct{}
// Scan implements serializer interface
func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
fieldValue := reflect.New(field.FieldType)
if dbValue != nil {
var bytesValue []byte
switch v := dbValue.(type) {
case []byte:
bytesValue = v
default:
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
}
if len(bytesValue) > 0 {
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
err = decoder.Decode(fieldValue.Interface())
}
}
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
return
}
// Value implements serializer interface
func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
buf := new(bytes.Buffer)
err := gob.NewEncoder(buf).Encode(fieldValue)
return buf.Bytes(), err
}

View File

@ -1,6 +1,8 @@
package schema
import (
"context"
"fmt"
"reflect"
"regexp"
"strings"
@ -58,23 +60,31 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct
return tag
}
func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag {
t := tag.Get("gorm")
if strings.Contains(t, value) {
return tag
}
return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t))
}
// GetRelationsValues get relations's values from a reflect value
func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
for _, rel := range rels {
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1)
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.FieldSchema.ModelType)), 0, 1)
appendToResults := func(value reflect.Value) {
if _, isZero := rel.Field.ValueOf(value); !isZero {
result := reflect.Indirect(rel.Field.ReflectValueOf(value))
if _, isZero := rel.Field.ValueOf(ctx, value); !isZero {
result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value))
switch result.Kind() {
case reflect.Struct:
reflectResults = reflect.Append(reflectResults, result.Addr())
case reflect.Slice, reflect.Array:
for i := 0; i < result.Len(); i++ {
if result.Index(i).Kind() == reflect.Ptr {
reflectResults = reflect.Append(reflectResults, result.Index(i))
if elem := result.Index(i); elem.Kind() == reflect.Ptr {
reflectResults = reflect.Append(reflectResults, elem)
} else {
reflectResults = reflect.Append(reflectResults, result.Index(i).Addr())
reflectResults = reflect.Append(reflectResults, elem.Addr())
}
}
}
@ -97,7 +107,7 @@ func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (refle
}
// GetIdentityFieldValuesMap get identity map from fields
func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
var (
results = [][]interface{}{}
dataResults = map[string][]reflect.Value{}
@ -105,12 +115,17 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
notZero, zero bool
)
if reflectValue.Kind() == reflect.Ptr ||
reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem()
}
switch reflectValue.Kind() {
case reflect.Struct:
results = [][]interface{}{make([]interface{}, len(fields))}
for idx, field := range fields {
results[0][idx], zero = field.ValueOf(reflectValue)
results[0][idx], zero = field.ValueOf(ctx, reflectValue)
notZero = notZero || !zero
}
@ -123,7 +138,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
for i := 0; i < reflectValue.Len(); i++ {
elem := reflectValue.Index(i)
elemKey := elem.Interface()
if elem.Kind() != reflect.Ptr {
if elem.Kind() != reflect.Ptr && elem.CanAddr() {
elemKey = elem.Addr().Interface()
}
@ -135,7 +150,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
fieldValues := make([]interface{}, len(fields))
notZero = false
for idx, field := range fields {
fieldValues[idx], zero = field.ValueOf(elem)
fieldValues[idx], zero = field.ValueOf(ctx, elem)
notZero = notZero || !zero
}
@ -155,12 +170,12 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
}
// GetIdentityFieldValuesMapFromValues get identity map from fields
func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
resultsMap := map[string][]reflect.Value{}
results := [][]interface{}{}
for _, v := range values {
rm, rs := GetIdentityFieldValuesMap(reflect.Indirect(reflect.ValueOf(v)), fields)
rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields)
for k, v := range rm {
resultsMap[k] = append(resultsMap[k], v...)
}
@ -178,17 +193,18 @@ func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interfa
}
return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues
} else {
columns := make([]clause.Column, len(foreignKeys))
for idx, key := range foreignKeys {
columns[idx] = clause.Column{Table: table, Name: key}
}
for idx, r := range foreignValues {
queryValues[idx] = r
}
return columns, queryValues
}
columns := make([]clause.Column, len(foreignKeys))
for idx, key := range foreignKeys {
columns[idx] = clause.Column{Table: table, Name: key}
}
for idx, r := range foreignValues {
queryValues[idx] = r
}
return columns, queryValues
}
type embeddedNamer struct {

View File

@ -6,6 +6,7 @@ import (
"encoding/json"
"reflect"
"github.com/jinzhu/now"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
@ -45,11 +46,21 @@ func (n *DeletedAt) UnmarshalJSON(b []byte) error {
}
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteQueryClause{Field: f}}
return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
}
func parseZeroValueTag(f *schema.Field) sql.NullString {
if v, ok := f.TagSettings["ZEROVALUE"]; ok {
if _, err := now.Parse(v); err == nil {
return sql.NullString{String: v, Valid: true}
}
}
return sql.NullString{Valid: false}
}
type SoftDeleteQueryClause struct {
Field *schema.Field
ZeroValue sql.NullString
Field *schema.Field
}
func (sd SoftDeleteQueryClause) Name() string {
@ -63,9 +74,9 @@ func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) {
}
func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok {
if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok && !stmt.Statement.Unscoped {
if c, ok := stmt.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 {
if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) >= 1 {
for _, expr := range where.Exprs {
if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 {
where.Exprs = []clause.Expression{clause.And(where.Exprs...)}
@ -78,18 +89,44 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
}
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue},
}})
stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
}
}
func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
}
type SoftDeleteUpdateClause struct {
ZeroValue sql.NullString
Field *schema.Field
}
func (sd SoftDeleteUpdateClause) Name() string {
return ""
}
func (sd SoftDeleteUpdateClause) Build(clause.Builder) {
}
func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) {
}
func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
}
}
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteDeleteClause{Field: f}}
return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
}
type SoftDeleteDeleteClause struct {
Field *schema.Field
ZeroValue sql.NullString
Field *schema.Field
}
func (sd SoftDeleteDeleteClause) Name() string {
@ -103,13 +140,13 @@ func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) {
}
func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
if stmt.SQL.String() == "" {
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
curTime := stmt.DB.NowFunc()
stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}})
stmt.SetColumn(sd.Field.DBName, curTime, true)
if stmt.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
@ -117,7 +154,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
}
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
_, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
@ -126,13 +163,8 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
}
}
if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok {
stmt.DB.AddError(ErrMissingWhereClause)
} else {
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
}
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
stmt.AddClauseIfNotExists(clause.Update{})
stmt.Build("UPDATE", "SET", "WHERE")
stmt.Build(stmt.DB.Callback().Update().Clauses...)
}
}

View File

@ -6,6 +6,7 @@ import (
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
@ -27,9 +28,11 @@ type Statement struct {
Dest interface{}
ReflectValue reflect.Value
Clauses map[string]clause.Clause
BuildClauses []string
Distinct bool
Selects []string // selected columns
Omits []string // omit columns
Selects []string // selected columns
Omits []string // omit columns
ColumnMapping map[string]string // map columns
Joins []join
Preloads map[string][]interface{}
Settings sync.Map
@ -44,11 +47,18 @@ type Statement struct {
attrs []interface{}
assigns []interface{}
scopes []func(*DB) *DB
Result *result
}
type join struct {
Name string
Conds []interface{}
Name string
Alias string
Conds []interface{}
On *clause.Where
Selects []string
Omits []string
Expression clause.Expression
JoinType clause.JoinType
}
// StatementModifier statement modifier interface
@ -56,12 +66,12 @@ type StatementModifier interface {
ModifyStatement(*Statement)
}
// Write write string
// WriteString write string
func (stmt *Statement) WriteString(str string) (int, error) {
return stmt.SQL.WriteString(str)
}
// Write write string
// WriteByte write byte
func (stmt *Statement) WriteByte(c byte) error {
return stmt.SQL.WriteByte(c)
}
@ -73,30 +83,36 @@ func (stmt *Statement) WriteQuoted(value interface{}) {
// QuoteTo write quoted value to writer
func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
write := func(raw bool, str string) {
if raw {
writer.WriteString(str)
} else {
stmt.DB.Dialector.QuoteTo(writer, str)
}
}
switch v := field.(type) {
case clause.Table:
if v.Name == clause.CurrentTable {
if stmt.TableExpr != nil {
stmt.TableExpr.Build(stmt)
} else {
stmt.DB.Dialector.QuoteTo(writer, stmt.Table)
write(v.Raw, stmt.Table)
}
} else if v.Raw {
writer.WriteString(v.Name)
} else {
stmt.DB.Dialector.QuoteTo(writer, v.Name)
write(v.Raw, v.Name)
}
if v.Alias != "" {
writer.WriteByte(' ')
stmt.DB.Dialector.QuoteTo(writer, v.Alias)
write(v.Raw, v.Alias)
}
case clause.Column:
if v.Table != "" {
if v.Table == clause.CurrentTable {
stmt.DB.Dialector.QuoteTo(writer, stmt.Table)
write(v.Raw, stmt.Table)
} else {
stmt.DB.Dialector.QuoteTo(writer, v.Table)
write(v.Raw, v.Table)
}
writer.WriteByte('.')
}
@ -105,36 +121,38 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
if stmt.Schema == nil {
stmt.DB.AddError(ErrModelValueRequired)
} else if stmt.Schema.PrioritizedPrimaryField != nil {
stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName)
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
} else if len(stmt.Schema.DBNames) > 0 {
stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.DBNames[0])
write(v.Raw, stmt.Schema.DBNames[0])
} else {
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
}
} else if v.Raw {
writer.WriteString(v.Name)
} else {
stmt.DB.Dialector.QuoteTo(writer, v.Name)
write(v.Raw, v.Name)
}
if v.Alias != "" {
writer.WriteString(" AS ")
stmt.DB.Dialector.QuoteTo(writer, v.Alias)
write(v.Raw, v.Alias)
}
case []clause.Column:
writer.WriteByte('(')
for idx, d := range v {
if idx > 0 {
writer.WriteString(",")
writer.WriteByte(',')
}
stmt.QuoteTo(writer, d)
}
writer.WriteByte(')')
case clause.Expr:
v.Build(stmt)
case string:
stmt.DB.Dialector.QuoteTo(writer, v)
case []string:
writer.WriteByte('(')
for idx, d := range v {
if idx > 0 {
writer.WriteString(",")
writer.WriteByte(',')
}
stmt.DB.Dialector.QuoteTo(writer, d)
}
@ -151,7 +169,7 @@ func (stmt *Statement) Quote(field interface{}) string {
return builder.String()
}
// Write write string
// AddVar add var
func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
for idx, v := range vars {
if idx > 0 {
@ -164,10 +182,17 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
case clause.Column, clause.Table:
stmt.QuoteTo(writer, v)
case Valuer:
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
case clause.Expr:
v.Build(stmt)
case *clause.Expr:
reflectValue := reflect.ValueOf(v)
if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() {
stmt.AddVar(writer, nil)
} else {
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
}
case clause.Interface:
c := clause.Clause{Name: v.Name()}
v.MergeClause(&c)
c.Build(stmt)
case clause.Expression:
v.Build(stmt)
case driver.Valuer:
stmt.Vars = append(stmt.Vars, v)
@ -183,19 +208,21 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
} else {
writer.WriteString("(NULL)")
}
case *DB:
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
if v.Statement.SQL.Len() > 0 {
case interface{ getInstance() *DB }:
cv := v.getInstance()
subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
if cv.Statement.SQL.Len() > 0 {
var (
vars = subdb.Statement.Vars
sql = v.Statement.SQL.String()
sql = cv.Statement.SQL.String()
)
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{}
v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
cv.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1)
}
@ -218,6 +245,9 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
writer.WriteString("(NULL)")
} else if rv.Type().Elem() == reflect.TypeOf(uint8(0)) {
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
} else {
writer.WriteByte('(')
for i := 0; i < rv.Len(); i++ {
@ -263,13 +293,24 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
if _, err := strconv.Atoi(s); err != nil {
if s == "" && len(args) == 0 {
return nil
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
}
if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
// looks like a where condition
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
} else if len(args) > 0 && strings.Contains(s, "@") {
}
if len(args) > 0 && strings.Contains(s, "@") {
// looks like a named query
return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
} else if len(args) == 1 {
}
if strings.Contains(strings.TrimSpace(s), " ") {
// looks like a where condition
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
}
if len(args) == 1 {
return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
}
}
@ -278,19 +319,31 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
conds := make([]clause.Expression, 0, 4)
args = append([]interface{}{query}, args...)
for idx, arg := range args {
if arg == nil {
continue
}
if valuer, ok := arg.(driver.Valuer); ok {
arg, _ = valuer.Value()
}
curTable := stmt.Table
if curTable == "" {
curTable = clause.CurrentTable
}
switch v := arg.(type) {
case clause.Expression:
conds = append(conds, v)
case *DB:
v.executeScopes()
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
if len(where.Exprs) == 1 {
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
where.Exprs[0] = clause.AndConditions(orConds)
if len(orConds.Exprs) == 1 {
where.Exprs[0] = clause.AndConditions(orConds)
}
}
}
conds = append(conds, clause.And(where.Exprs...))
@ -303,17 +356,21 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
conds = append(conds, clause.Eq{Column: i, Value: j})
}
case map[string]string:
var keys = make([]string, 0, len(v))
keys := make([]string, 0, len(v))
for i := range v {
keys = append(keys, i)
}
sort.Strings(keys)
for _, key := range keys {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
column := clause.Column{Name: key, Table: curTable}
if strings.Contains(key, ".") {
column = clause.Column{Name: key}
}
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
}
case map[string]interface{}:
var keys = make([]string, 0, len(v))
keys := make([]string, 0, len(v))
for i := range v {
keys = append(keys, i)
}
@ -321,22 +378,28 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
for _, key := range keys {
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
column := clause.Column{Name: key, Table: curTable}
if strings.Contains(key, ".") {
column = clause.Column{Name: key}
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if _, ok := v[key].(driver.Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
} else if _, ok := v[key].(Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
} else {
values := make([]interface{}, reflectValue.Len())
for i := 0; i < reflectValue.Len(); i++ {
// optimize reflect value length
valueLen := reflectValue.Len()
values := make([]interface{}, valueLen)
for i := 0; i < valueLen; i++ {
values[i] = reflectValue.Index(i).Interface()
}
conds = append(conds, clause.IN{Column: key, Values: values})
conds = append(conds, clause.IN{Column: column, Values: values})
}
default:
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
}
}
default:
@ -361,11 +424,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(reflectValue); !isZero || selected {
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
}
}
}
@ -375,11 +438,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected {
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
}
}
}
@ -396,24 +459,30 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
if len(args) == 1 {
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
values := make([]interface{}, reflectValue.Len())
for i := 0; i < reflectValue.Len(); i++ {
// optimize reflect value length
valueLen := reflectValue.Len()
values := make([]interface{}, valueLen)
for i := 0; i < valueLen; i++ {
values[i] = reflectValue.Index(i).Interface()
}
if len(values) > 0 {
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values})
return []clause.Expression{clause.And(conds...)}
}
return conds
return nil
}
}
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args})
}
}
}
return conds
if len(conds) > 0 {
return []clause.Expression{clause.And(conds...)}
}
return nil
}
// Build build sql with clauses names
@ -437,7 +506,11 @@ func (stmt *Statement) Build(clauses ...string) {
}
func (stmt *Statement) Parse(value interface{}) (err error) {
if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
return stmt.ParseWithSpecialTableName(value, "")
}
func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
stmt.Table = tables[1]
@ -461,12 +534,14 @@ func (stmt *Statement) clone() *Statement {
Distinct: stmt.Distinct,
Selects: stmt.Selects,
Omits: stmt.Omits,
ColumnMapping: stmt.ColumnMapping,
Preloads: map[string][]interface{}{},
ConnPool: stmt.ConnPool,
Schema: stmt.Schema,
Context: stmt.Context,
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
SkipHooks: stmt.SkipHooks,
Result: stmt.Result,
}
if stmt.SQL.Len() > 0 {
@ -501,10 +576,10 @@ func (stmt *Statement) clone() *Statement {
return newStmt
}
// Helpers
// SetColumn set column's value
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
//
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
if v, ok := stmt.Dest.(map[string]interface{}); ok {
v[name] = value
@ -529,7 +604,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
switch destValue.Kind() {
case reflect.Struct:
field.Set(destValue, value)
stmt.AddError(field.Set(stmt.Context, destValue, value))
default:
stmt.AddError(ErrInvalidData)
}
@ -539,13 +614,18 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
case reflect.Slice, reflect.Array:
if len(fromCallbacks) > 0 {
for i := 0; i < stmt.ReflectValue.Len(); i++ {
field.Set(stmt.ReflectValue.Index(i), value)
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value))
}
} else {
field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value)
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value))
}
case reflect.Struct:
field.Set(stmt.ReflectValue, value)
if !stmt.ReflectValue.CanAddr() {
stmt.AddError(ErrInvalidValue)
return
}
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value))
}
} else {
stmt.AddError(ErrInvalidField)
@ -565,12 +645,12 @@ func (stmt *Statement) Changed(fields ...string) bool {
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
changed := func(field *schema.Field) bool {
fieldValue, _ := field.ValueOf(modelValue)
fieldValue, _ := field.ValueOf(stmt.Context, modelValue)
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, ok := stmt.Dest.(map[string]interface{}); ok {
if fv, ok := v[field.Name]; ok {
if mv, mok := stmt.Dest.(map[string]interface{}); mok {
if fv, ok := mv[field.Name]; ok {
return !utils.AssertEqual(fv, fieldValue)
} else if fv, ok := v[field.DBName]; ok {
} else if fv, ok := mv[field.DBName]; ok {
return !utils.AssertEqual(fv, fieldValue)
}
} else {
@ -579,7 +659,10 @@ func (stmt *Statement) Changed(fields ...string) bool {
destValue = destValue.Elem()
}
changedValue, zero := field.ValueOf(destValue)
changedValue, zero := field.ValueOf(stmt.Context, destValue)
if v {
return !utils.AssertEqual(changedValue, fieldValue)
}
return !zero && !utils.AssertEqual(changedValue, fieldValue)
}
}
@ -605,44 +688,62 @@ func (stmt *Statement) Changed(fields ...string) bool {
return false
}
var matchName = func() func(tableColumn string) (table, column string) {
nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
return func(tableColumn string) (table, column string) {
if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
table = matches[1]
star := matches[2]
columnName := matches[3]
if star != "" {
return table, star
}
return table, columnName
}
return "", ""
}
}()
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
results := map[string]bool{}
notRestricted := false
// select columns
for _, column := range stmt.Selects {
processColumn := func(column string, result bool) {
if stmt.Schema == nil {
results[column] = true
results[column] = result
} else if column == "*" {
notRestricted = true
notRestricted = result
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = true
results[dbName] = result
}
} else if column == clause.Associations {
for _, rel := range stmt.Schema.Relationships.Relations {
results[rel.Name] = true
results[rel.Name] = result
}
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = true
results[field.DBName] = result
} else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") {
if col == "*" {
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = result
}
} else {
results[col] = result
}
} else {
results[column] = true
results[column] = result
}
}
// select columns
for _, column := range stmt.Selects {
processColumn(column, true)
}
// omit columns
for _, omit := range stmt.Omits {
if stmt.Schema == nil {
results[omit] = false
} else if omit == clause.Associations {
for _, rel := range stmt.Schema.Relationships.Relations {
results[rel.Name] = false
}
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
results[field.DBName] = false
} else {
results[omit] = false
}
for _, column := range stmt.Omits {
processColumn(column, false)
}
if stmt.Schema != nil {

View File

@ -34,3 +34,37 @@ func TestWhereCloneCorruption(t *testing.T) {
})
}
}
func TestNilCondition(t *testing.T) {
s := new(Statement)
if len(s.BuildCondition(nil)) != 0 {
t.Errorf("Nil condition should be empty")
}
}
func TestNameMatcher(t *testing.T) {
for k, v := range map[string][]string{
"table.name": {"table", "name"},
"`table`.`name`": {"table", "name"},
"'table'.'name'": {"table", "name"},
"'table'.name": {"table", "name"},
"table1.name_23": {"table1", "name_23"},
"`table_1`.`name23`": {"table_1", "name23"},
"'table23'.'name_1'": {"table23", "name_1"},
"'table23'.name1": {"table23", "name1"},
"'name1'": {"", "name1"},
"`name_1`": {"", "name_1"},
"`Name_1`": {"", "Name_1"},
"`Table`.`nAme`": {"Table", "nAme"},
"my_table.*": {"my_table", "*"},
"`my_table`.*": {"my_table", "*"},
"User__Company.*": {"User__Company", "*"},
"`User__Company`.*": {"User__Company", "*"},
`"User__Company".*`: {"User__Company", "*"},
`"table"."*"`: {"", ""},
} {
if table, column := matchName(k); table != v[0] || column != v[1] {
t.Errorf("failed to match value: %v, got %v, expect: %v", k, []string{table, column}, v)
}
}
}

View File

@ -3,11 +3,12 @@ package tests_test
import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
func TestBelongsToAssociation(t *testing.T) {
var user = *GetUser("belongs-to", Config{Company: true, Manager: true})
user := *GetUser("belongs-to", Config{Company: true, Manager: true})
if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
@ -31,8 +32,8 @@ func TestBelongsToAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Manager", 1, "")
// Append
var company = Company{Name: "company-belongs-to-append"}
var manager = GetUser("manager-belongs-to-append", Config{})
company := Company{Name: "company-belongs-to-append"}
manager := GetUser("manager-belongs-to-append", Config{})
if err := DB.Model(&user2).Association("Company").Append(&company); err != nil {
t.Fatalf("Error happened when append Company, got %v", err)
@ -60,8 +61,8 @@ func TestBelongsToAssociation(t *testing.T) {
AssertAssociationCount(t, user2, "Manager", 1, "AfterAppend")
// Replace
var company2 = Company{Name: "company-belongs-to-replace"}
var manager2 = GetUser("manager-belongs-to-replace", Config{})
company2 := Company{Name: "company-belongs-to-replace"}
manager2 := GetUser("manager-belongs-to-replace", Config{})
if err := DB.Model(&user2).Association("Company").Replace(&company2); err != nil {
t.Fatalf("Error happened when replace Company, got %v", err)
@ -132,10 +133,18 @@ func TestBelongsToAssociation(t *testing.T) {
AssertAssociationCount(t, user2, "Company", 0, "after clear")
AssertAssociationCount(t, user2, "Manager", 0, "after clear")
// unexist company id
unexistCompanyID := company.ID + 9999999
user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID}
if err := DB.Create(&user).Error; err == nil {
tidbSkip(t, "not support the foreign key feature")
t.Errorf("should have gotten foreign key violation error")
}
}
func TestBelongsToAssociationForSlice(t *testing.T) {
var users = []User{
users := []User{
*GetUser("slice-belongs-to-1", Config{Company: true, Manager: true}),
*GetUser("slice-belongs-to-2", Config{Company: true, Manager: false}),
*GetUser("slice-belongs-to-3", Config{Company: true, Manager: true}),
@ -217,3 +226,81 @@ func TestBelongsToAssociationForSlice(t *testing.T) {
AssertAssociationCount(t, users[0], "Company", 0, "After Delete")
AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete")
}
func TestBelongsToDefaultValue(t *testing.T) {
type Org struct {
ID string
}
type BelongsToUser struct {
OrgID string
Org Org `gorm:"default:NULL"`
}
tx := DB.Session(&gorm.Session{})
tx.Config.DisableForeignKeyConstraintWhenMigrating = true
AssertEqual(t, DB.Config.DisableForeignKeyConstraintWhenMigrating, false)
tx.Migrator().DropTable(&BelongsToUser{}, &Org{})
tx.AutoMigrate(&BelongsToUser{}, &Org{})
user := &BelongsToUser{
Org: Org{
ID: "BelongsToUser_Org_1",
},
}
err := DB.Create(&user).Error
AssertEqual(t, err, nil)
}
func TestBelongsToAssociationUnscoped(t *testing.T) {
type ItemParent struct {
gorm.Model
Logo string `gorm:"not null;type:varchar(50)"`
}
type ItemChild struct {
gorm.Model
Name string `gorm:"type:varchar(50)"`
ItemParentID uint
ItemParent ItemParent
}
tx := DB.Session(&gorm.Session{})
tx.Migrator().DropTable(&ItemParent{}, &ItemChild{})
tx.AutoMigrate(&ItemParent{}, &ItemChild{})
item := ItemChild{
Name: "name",
ItemParent: ItemParent{
Logo: "logo",
},
}
if err := tx.Create(&item).Error; err != nil {
t.Fatalf("failed to create items, got error: %v", err)
}
// test replace
if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{
Logo: "updated logo",
}); err != nil {
t.Errorf("failed to replace item parent, got error: %v", err)
}
var parents []ItemParent
if err := tx.Find(&parents).Error; err != nil {
t.Errorf("failed to find item parent, got error: %v", err)
}
if len(parents) != 1 {
t.Errorf("expected %d parents, got %d", 1, len(parents))
}
// test delete
if err := tx.Model(&item).Association("ItemParent").Unscoped().Delete(&parents); err != nil {
t.Errorf("failed to delete item parent, got error: %v", err)
}
if err := tx.Find(&parents).Error; err != nil {
t.Errorf("failed to find item parent, got error: %v", err)
}
if len(parents) != 0 {
t.Errorf("expected %d parents, got %d", 0, len(parents))
}
}

Some files were not shown because too many files have changed in this diff Show More