Compare commits

...

326 Commits

Author SHA1 Message Date
贾一饼
49eaeacb89
optimize: field.ReflectValueOf (#7530)
Some checks failed
tests / mysql (mysql:5.7, 1.24, ubuntu-latest) (push) Failing after 59s
tests / mysql (mysql:8, 1.23, ubuntu-latest) (push) Failing after 59s
tests / mysql (mysql:8, 1.24, ubuntu-latest) (push) Failing after 5s
tests / sqlite (1.23, ubuntu-latest) (push) Failing after 1m26s
tests / mysql (mysql:9, 1.23, ubuntu-latest) (push) Failing after 36s
tests / mysql (mysql:9, 1.24, ubuntu-latest) (push) Failing after 32s
tests / mariadb (mariadb:latest, 1.24, ubuntu-latest) (push) Failing after 9s
tests / mariadb (mariadb:latest, 1.23, ubuntu-latest) (push) Failing after 22s
tests / postgres (postgres:13, 1.23, ubuntu-latest) (push) Failing after 24s
tests / mysql (mysql:5.7, 1.23, ubuntu-latest) (push) Failing after 2m1s
tests / postgres (postgres:14, 1.23, ubuntu-latest) (push) Failing after 20s
tests / postgres (postgres:14, 1.24, ubuntu-latest) (push) Failing after 6s
tests / postgres (postgres:15, 1.24, ubuntu-latest) (push) Failing after 14s
tests / postgres (postgres:15, 1.23, ubuntu-latest) (push) Failing after 29s
tests / postgres (postgres:latest, 1.23, ubuntu-latest) (push) Failing after 26s
tests / postgres (postgres:latest, 1.24, ubuntu-latest) (push) Failing after 18s
tests / postgres (postgres:13, 1.24, ubuntu-latest) (push) Failing after 2m55s
tests / sqlite (1.24, ubuntu-latest) (push) Failing after 4m53s
tests / tidb (v6.5.0, 1.24, ubuntu-latest) (push) Failing after 2m6s
golangci-lint / lint (push) Failing after 6m41s
tests / gaussdb (opengauss/opengauss:7.0.0-RC1.B023, 1.23, ubuntu-latest) (push) Failing after 3m31s
tests / gaussdb (opengauss/opengauss:7.0.0-RC1.B023, 1.24, ubuntu-latest) (push) Failing after 3m38s
tests / sqlserver (1.23, ubuntu-latest) (push) Failing after 7m36s
tests / tidb (v6.5.0, 1.23, ubuntu-latest) (push) Failing after 12m41s
tests / sqlserver (1.24, ubuntu-latest) (push) Failing after 14m12s
Stale / stale (push) Successful in 13s
Close Missing Playground issues / stale (push) Successful in 5s
Close invalid questions issues / stale (push) Successful in 6s
2025-07-23 13:02:49 +08:00
贾一饼
52b4744410
optimize: performance optimization (#7526) 2025-07-22 18:32:28 +08:00
Jinzhu
9af6d510b5 Fix query when map keys include table-qualified column names, close #7507 2025-07-22 14:21:04 +08:00
Jinzhu
c63374f5d1 Don't request LastInsertID from database if not necessary, close #7469 2025-07-21 17:55:20 +08:00
jc
b9c7e562b0
fix(schema): check the hook function parameter type (#7468)
* fix(schema): Check the callback function parameter type

* fix log

* fix
2025-07-21 17:06:16 +08:00
Riseif
985940f0d8
should check inner condition length (#7512) 2025-07-21 11:57:12 +08:00
moseszane168
991c2d4891
Add GaussDB Database Support (#7508)
* support gaussdb

* use github CI

* change function name

* use gorm.io/driver/gaussdb

---------

Co-authored-by: bing.ma <bing.ma@daocloud.io>
2025-07-21 10:46:58 +08:00
Jinzhu
751a6dde7a
Call after initialize for gorm.Config (#7518) 2025-07-15 12:05:03 +08:00
贾一饼
2f4925e017
A little optimization for filed.ValueOf (#7499)
Co-authored-by: 贾一饼 <Boyang.Liang@apulis.com>
2025-07-07 11:15:10 +08:00
Eshan-Jogwar
1e8baf5459
fixes #7486 (#7492)
* fixes #7486

* Added A test case for subset changes of model

* completed the test file for check_subset_model_change_test.go
2025-06-25 11:11:08 +08:00
Salent Olivick
842ee527eb
fix decimal migrate error.(#7450) (#7450)
Signed-off-by: Chise1 <chise123@live.com>
2025-06-06 10:35:01 +08:00
enomotodev
23c0d7cf05
test: update MySQL test matrix to use official images and add 9.0, 8.4 versions (#7476)
* test: update MySQL test matrix to use official images and add 9.0, 8.4 versions

* test: use major version tags for MySQL test matrix
2025-06-06 10:10:23 +08:00
Jinzhu
718eae4fdd fix tests for mysql 9.0 2025-06-05 19:34:13 +08:00
Jinzhu
49b01a3e93 Fix Generics Scan, close https://github.com/go-gorm/playground/pull/803 2025-05-29 14:23:57 +08:00
Jinzhu
c44405a25b
Implement Generics API (#7424)
* Implement Generics API

* Add more generics tests

* Add more tests and Take method

* use delayed‑ops pipeline for generics API

* fix generics tests for mysql

* Support SubQuery for Generics

* Add clause.JoinTable helper method

* Fix golangci-lint error

* Complete the design and implementation of generic version Join

* improve generics version Joins support

* allow configuring select/omit columns for joins via subqueries

* finish generic version Preload

* handle error of generics Joins/Preload

* fix tests

* Add LimitPerRecord for generic version Preload

* fix tests for mysql 5.7

* test for nested generic version Join/Preload

* Add WithResult support for generics API

* test reuse generics db conditions

* fix data race

* remove ExampleLRU test

* Add default transaction timeout support

* fix test
2025-05-25 15:40:40 +08:00
Name
751c1d6b45
perf(schema): avoid redundant strings.ToLower call (#7464)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2025-05-25 09:27:21 +08:00
codingplz
8e7ab46c1b
fix: return init dialector error (#7379)
* fix: return init dialector error

* mock defer

* fix: skip AfterInitialize

---------

Co-authored-by: wenyazhou.13 <wenyazhou.13@bytedance.com>
2025-05-22 10:53:47 +08:00
Name
e3037e4ef0
perf: break early on match failure in ParseConstraint (#7402)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2025-05-22 10:49:19 +08:00
pipipipipip
1204330419
feat: error message show field name (#7452)
* feat: error message show field name

* feat: failed to parse field

* feat: failed to parse field

---------

Co-authored-by: zyp <zyp>
2025-05-21 11:13:31 +08:00
Name
9703eb775f
perf: use strings.IndexByte to replace strings.Index (#7454)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2025-05-21 10:35:56 +08:00
Name
1c966e0d25
perf: use strings.Cut to replace strings.SplitN (#7455)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2025-05-21 10:35:23 +08:00
Jinzhu
e5b867e785 remove unnecessary session-level configuration for prepared statements 2025-05-07 14:56:49 +08:00
iTanken
8c4e8e2d2a
fix: int type variable defaultMaxSize overflows in 32-bit environment (#7439)
Refs: #7435
2025-04-27 14:05:16 +08:00
Zhaodong Xie
a827495be1
Preparestmt use LRU Map instead default map (#7435)
* 支持lru淘汰preparestmt cache

* 支持lru淘汰preparestmt cache

* 支持lru淘汰preparestmt cache

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* change const export

* Add stmt_store

* refact prepare stmt store

* Rename lru store

* change const export

* ADD UT

* format code and add session level prepare stmt config

* code format according to golinter ci

* ADD UT

---------

Co-authored-by: xiezhaodong <xiezhaodong@bytedance.com>
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2025-04-25 16:22:26 +08:00
Jinzhu
489a563293 only check new issues for golangci linter 2025-04-17 15:30:17 +08:00
Jinzhu
42bd4f603c
use golangci replace reviewdog (#7426)
* use golangci replace reviewdog

* Update golangci config
2025-04-17 11:55:13 +08:00
a631807682
a9d27293de test: mssql ci 2025-03-11 15:56:46 +08:00
a631807682
3876ffe4bb test: mssql ci 2025-03-11 15:56:46 +08:00
Jinzhu
ee3b549d7d Update tests script 2025-03-09 15:35:17 +08:00
Aman
9f273777f5
fix deprecated reflect.PtrTo reflect.PointerTo usage (#7366)
* fix deprecated reflect.PtrTo reflect.PointerTo usage

* replace all deprecated reflect.PtrTo reflect.PointerTo usage
2025-02-13 14:16:26 +08:00
Vladimir Avtsenov
9ca84b3dde
fix concurrent map writes (#7298) 2025-01-12 19:49:06 +08:00
Max Katz
86b1d22911
chore: update copyright year (#7332) 2025-01-12 18:32:39 +08:00
Evyatar Yaffe
fed49230cb
Enhance db.Scan with ParamsFilter - Issue 7336 - Suggestion (#7337) 2025-01-12 18:19:28 +08:00
aviyam181199
8503287ca4
Fixed Empty Returning Clause Merge Bug (#7339) 2025-01-12 18:18:04 +08:00
nowindexman
4ef3af10ed
feat:Capitalize the priority field of IndexOption so that other systems can access this field from outside the package. (#7342) 2025-01-12 18:16:48 +08:00
Bennett Amodio
f482f25c71
fix: deterministic index ordering when migrating (#7208)
Issue: We observed that, when creating a database based on the same
gORM schema multiple times, indexes could appear in different orders,
hurting determinism for use-cases like schema comparison.

In order to fix this, it's simple to switch the ParseIndexes function
to return a list of indices rather than a map, so the callers will
iterate in deterministic order.
2024-12-06 10:27:44 +08:00
Jinzhu
6bfccf8afa Refactor all tests script 2024-11-21 17:03:31 +08:00
Ivan Ryabov
49bbaa637f
use map look-up for indexes (#7242) 2024-11-14 17:41:43 +08:00
홍성욱
b0d70a26d1
[#6372] Fixed nullable constraint bug for columns during auto migration (#7269)
* [#6372] Fixed nullable constraint bug for columns during auto migration

* [#6372] fix comment

* [#6372] Add test code

* [#6372] Add test code

* [#6372] Fix failed test case

* [#6372] Fix failed test case

* [#6372] wip

* [#6372] wip

* [#6372] wip

* [#6372] wip
2024-11-14 17:40:18 +08:00
omid fth
deceebfab8
Create CODE_OF_CONDUCT.md (#7240) 2024-10-17 14:18:13 +08:00
Yidi Sprei
52e3b353eb
refactor(workflow): update release workflow to enhance automation (#7224)
Replaced the old release workflow with a new setup using Release
Drafter. This refactor allows for more detailed release notes by
categorizing changes and automatically generating release drafts.
The new workflow triggers on semantic version tags and improves
permissions management. This change enhances the release process
by providing better documentation and automation.
2024-10-09 19:31:04 +08:00
Hansu Park
8020e8c166
refactor: improve logging for unimplemented ErrorTranslator in TranslateError config (#7225) 2024-10-09 19:29:48 +08:00
Yidi Sprei
62bd0b9331
Add GitHub Actions workflow to automate release creation on tagged pushes (#7209)
* feat(workflows): add GitHub Action to create release on new tag

This workflow automates the release creation process whenever a new
tag is pushed to the repository. It checks if a release for the tag
already exists and creates one if it doesn't, enhancing the release
management and streamlining the deployment process.

* test

* fix(workflow): improve release existence check in release-on-tag.yml

Refactor the script to improve checking for existing releases by tag.
Return an object instead of using core.setOutput to streamline the
workflow logic. Also, set result-encoding to string for better
compatibility.

* fix(workflow): correct output handling in release-on-tag.yml

Corrected the output handling in the 'release-on-tag.yml' workflow file
by changing 'result-encoding' to 'outputs'. This ensures that the step
correctly checks if a release exists before attempting to create a new
one, thereby avoiding potential errors during the release process.
Added a blank line for better readability.

* fix(workflow): correct output setting in release-on-tag workflow

Refactored how the release_exists output is set to be compatible with
GitHub Actions syntax. This ensures the workflow reliably detects if
a release already exists, improving the robustness of the release
process.

* fix(release): correct output handling in release-on-tag workflow

Improved the way outputs are managed in the 'check_release' step by
returning values instead of direct assignments. This change ensures
better handling of release existence checks and improves code
readability. Added 'result-encoding' to specify string encoding for
results.

* fix(workflow): correct release existence check and add debug output

Refactored the release existence check to return a simple boolean
string ('true'/'false') rather than an object. Added a step to print
the release existence status for debugging purposes. This ensures
correct conditional evaluation and aids in troubleshooting workflow
issues.

* fix(workflow): simplify release process by removing redundant checks

The release-on-tag workflow has been streamlined by eliminating the
redundant steps checking for existing releases. This change reduces
complexity and speeds up the release process by directly creating a
release on every tag push without prior existence verification.
2024-09-30 14:18:13 +08:00
Omkar P
c6ac54812a
Use official SQL Server docker image for tests (#7205)
* Use official SQL Server docker image for tests

* Try with tag 2019-latest instead of latest

* Use platform ubuntu-20.04 for SQL Server

* Switch to 2019-CU18-ubuntu-20.04

* Check with 2022-latest image tag and ubuntu-latest platform

* Update health-cmd

* Try sqlcmd without -N -C

* Re-include -N -C, try with ubuntu-20.04

* Try ubuntu-20.04 without -N -C (last trial)

* Finalize working config

* Remove unused env variables
2024-09-30 11:21:19 +08:00
Jinzhu
68434b76eb fix test script with latest docker release 2024-09-18 20:03:35 +08:00
Kazuki Isogai
c2515ce260
feat: remove version top-level element and rename. (#7086) 2024-09-18 19:42:52 +08:00
Leo Sjöberg
7f75b12bb2
Generate unique savepoint names for nested transactions (#7174)
* Generate unique savepoint names

* Add a test for deeply nested wrapped transactions
2024-09-14 20:58:29 +08:00
abhijeet45
0daaf1747c
fix: AfterQuery using safer right trim while clearing from clause's join added as part of https://github.com/go-gorm/gorm/pull/7027 (#7153)
Co-authored-by: Abhijeet Bhowmik <abhijeet.bhowmik@cambiumnetworks.com>
2024-08-22 19:03:42 +08:00
ivila
0dbfda5d7e
fix memory leaks in PrepareStatementDB (#7142)
* fix memory leaks in PrepareStatementDB

* Fix CR:
1) Fix potential Segmentation Fault in Reset function
2) Setting db.Stmts to nil map when Close to avoid further using

* Add Test:
1) TestPreparedStmtConcurrentReset
2) TestPreparedStmtConcurrentClose

* Fix test, create new connection to keep away from other tests

---------

Co-authored-by: Zehui Chen <zehui@ssc-hn.com>
2024-08-22 19:02:05 +08:00
enomotodev
4a50b36f63
ci: Add PostgreSQL 14 and 15 to GitHub Actions matrix (#7081)
* ci: Add PostgreSQL 14 and 15 to GitHub Actions matrix

* ci: Remove older PostgreSQL versions from test matrix
2024-06-25 10:36:06 +08:00
molon
11c4331058
feat: add MapColumns method (#6901)
* add MapColumns method

* fix MapColumns desc

* add TestMapColumns
2024-06-24 17:42:59 +08:00
Jinzhu
8a0af58cc5 fix map fields with clickhouse driver 2024-06-20 20:19:31 +08:00
Jinzhu
4f6291154b Allow to support other field types 2024-06-20 16:45:38 +08:00
Sérgio Prata Almeida
109f239fae
add DB level propagation for the Unscoped flag (#7007)
* adds PropagateUnscoped to db Config

* adds PropagateUnscoped test

* adds PropagateUnscoped to Session and sets it accordingly
2024-06-17 11:59:06 +08:00
Jinzhu
79bf7f92ed fix CI for sqlserver 2024-06-17 11:58:13 +08:00
Jinzhu
3d09f7947f only listen local port 2024-06-13 18:45:02 +08:00
Waleed Masoom
73a988ceb2
fix(scan): update Scan function to reset structs to zero values for each scan (#7061)
Co-authored-by: waleed.masoom <waleed.masoom@wheniwork.com>
2024-06-12 18:57:36 +08:00
Emilien
05167fd591
fix: use reflect.Append when preloading nested associations (#7014)
Co-authored-by: Emilien Kofman <emilien.kofman@miimosa.com>
2024-06-12 18:52:33 +08:00
Sergei Sadov
78c6dfd712
Fix association replace non-addressable panic (#7012)
* Fix association replace non-addressable panic

* Fix tests

* Add has one panic test

---------

Co-authored-by: sgsv <->
2024-06-12 18:49:45 +08:00
Nico Schäfer
3fe7fcf356
fix: unsupported data on nested joins with preloads (#6957)
* fix: `unsupported data` on nested joins with preloads

* Add test case for pointer join with nested prelaods

* Fix tests
2024-06-12 18:00:47 +08:00
贾一饼
9c4070ed19
fix: AfterQuery should clear FROM Clause's Joins rather than the Statement (#7027) 2024-06-12 17:51:44 +08:00
supergem3000
49d524aaea
feat: chainable order support clause.OrderBy (#7054)
* feat: chainable order support clause.OrderBy

* indent
2024-06-12 17:47:34 +08:00
Jinzhu
49d94c173c upgrade github action tests template 2024-06-12 17:24:34 +08:00
Ryuji Kokubu
0f105ec163
fix: strings.Title -> cases.Title bcs strings.Title library is deprecated (#6999)
* add: add cases library

* fix: strings.Title -> cases.Title

* run goimports to solve the error
2024-06-12 11:46:59 +08:00
hakusai22
5e599a07ec
fix: typo (#7003)
* fix: typo

* fix: covered
2024-05-08 12:07:58 +08:00
Wenli Looi
9d370bcb3e
Fix handling of unknown column types (#6540) 2024-04-26 17:53:11 +08:00
PiexlMax(奇淼
78920199f0
Fix panic bug in migrator due to lack of nil check for stmt.Schema (#6932) 2024-04-26 15:15:49 +08:00
Anıl Şenay
ac59252327
Add new error for "Violation Check Constraint" (#6992) 2024-04-26 10:53:17 +08:00
Cr
207f1ac68f
fix: not clause with or condition (#6984) 2024-04-25 20:22:53 +08:00
Cr
85299bfca7
perf: merge nested preload query when using join (#6990)
* pref: merge nest preload query

* fix: preload test
2024-04-25 20:21:03 +08:00
Jinzhu
5553ff3dcb downgrade mssql driver 2024-04-25 20:12:15 +08:00
kkocdko
bc49365de2
Faster utils.FileWithLineNum (#6981)
* faster FileWithLineNum

* tweak caller skip count
2024-04-22 14:43:02 +08:00
Ivan Chavez
d0b4ceb726
Added comment describing Unscoped() method (#6969) 2024-04-17 11:38:55 +08:00
yetone
9a61ef2af8
fix: duplicated preload (#6948) 2024-04-15 11:20:20 +08:00
Jinzhu
1e13fd7543 Fix duplicated columns in INSERT SQL for some fields with default value 2024-04-08 11:29:55 +08:00
hjwblog.com
1b48aa072d
feat: prepare_stmt support ping (#6924)
* feat: prepare_stmt support ping

* feat: prepare_stmt tx support ping
2024-03-28 16:47:39 +08:00
snackmgmg
26195e6d16
fix: remove callback from callbacks if Remove() called (#6916)
* fix: remove callback from callbacks if Remove() called

* reduce number of loops

* remove unnecessary blank line
2024-03-26 11:33:36 +08:00
givemeafish
956f7ce843
fix: 'type XXXX int' will print wrong sql to terminal (#6917)
Co-authored-by: 王泽平 <zeping.wang@yo-star.com>
2024-03-21 16:00:02 +08:00
Jinzhu
0d6c5345f3 Don't close prepared stmt for normal db error 2024-03-21 15:55:43 +08:00
Jinzhu
57603882ea Only close bad conn prepared stmt 2024-03-20 19:47:20 +08:00
Jinzhu
81536f823c Fix insert id into map results, fix #6812 2024-03-19 11:50:28 +08:00
Jinzhu
1b0aa802df Fix AutoMigrate for bool fields with default value 2024-03-18 19:24:16 +08:00
Jinzhu
e0c3be03fb Fix tests in local 2024-03-18 16:28:46 +08:00
jessetang
303de6e7c8
chore: optimize regEnLetterAndMidline regular (#6908)
* chore: optimize regular

* fix
2024-03-18 15:33:54 +08:00
Jinghao Lu
f7ebf049da
fix(create): fix insert column order (#6855)
* fix(create): fix insert column order

* chore: add ConvertToCreateValues ut for Slice case

* fix: remvoe testify dependency

---------

Co-authored-by: lujinghao <lujinghao@bytedance.com>
2024-03-18 13:48:42 +08:00
jessetang
ab89d54d87
chore: UnixNano convert to UnixMilli (#6907) 2024-03-18 13:44:55 +08:00
Jinzhu
281f3e369a Fix constraint name regexp 2024-03-18 11:32:30 +08:00
jessetang
7b1fb0bd73
fix(scan): array element is set to a zero value (#6890)
* fix(scan): array element is set to a zero value

* add test

* fix test

* optimization
2024-03-15 14:14:48 +08:00
black-06
e4e23d26d2
fix: nested preload with join panic when find (#6877) 2024-03-09 21:27:19 +08:00
jessetang
c4c9aa45e3
fix(scan.go): reflect.MakeSlice passes in the reflect.Array type (#6880) 2024-03-09 17:39:01 +08:00
Cr
9efae659cb
test: namer identifier lenght (#6872) 2024-03-09 17:31:28 +08:00
hishope
f17a75242e Signed-off-by: hishope <csqiye@126.com>
fix some typos in tests

Signed-off-by: hishope <csqiye@126.com>
2024-03-07 16:19:17 +08:00
tsuba3
3e2c4fc446
Fix regression in db.Not introduced in v1.25.6. (#6844)
* Fix regression in db.Not introduced in 940358e.

* Fix
2024-03-05 10:23:51 +08:00
Chef
f118e55db5
Add unittest test helper function ConvertSliceOfMapToValuesForCreate (#6854) 2024-03-05 10:22:57 +08:00
Chef
52404cddbb
CHORE add unittest test function ConvertMapToValueForCreate (#6846)
* CHORE add unittest test function  ConvertMapToValueForCreate

* CHORE move the test cases located in the files convert_map_test.go and visit_map_test.go into the file helper_test.go.
2024-02-27 10:48:04 +08:00
M Dmitry
d81ae6f701
Fixed: panic on nullable value with multiple foreign key usage (#6839)
See: https://github.com/go-gorm/playground/pull/537
2024-02-19 11:42:25 +08:00
black-06
8fb9a31775
refactor: part 2 of distinguish between Unique and UniqueIndex (#6822) 2024-02-06 19:48:40 +08:00
jasonchuan
9514d5f9e6
let limit and offset use bind parameter (#6806)
* let limit and offset use bind parameter

* format

* format limt_test

* try again

* fix test case fro connpool

* adding driverName for postgres  ,if not to do so, the stmt vars will be added  a wrong  one called pgx.QueryExecModeSimpleProtocol  ,  causing the SQL with limit  problem  need 1 parameter ,but given two.

* delete trunk files

* restore the test_test.go

* restore test_test.go

* driver/postgres->v1.5.5

* change postgres version rollback to 1.5.4

---------

Co-authored-by: chenchuan <chenchuan@360.cn>
Co-authored-by: jason_chuan <jason_chuan@126.com>
2024-02-06 10:54:40 +08:00
black-06
46816ad31d
refactor: distinguish between Unique and UniqueIndex (#6386)
* refactor: distinguish between UniqueIndex and Index

* add test

* add ParseIndex test

* modify unique to constraint

* modify unique to constraint

* fix MigrateColumnUnique

* fix test

* fix unit test

* update test mod

* add MigrateColumnUnique to Migrator interface

* fix format lint

* add comment

* go mod tidy

* revert: revert MigrateColumn

* resolve conflicts
2024-02-04 15:49:19 +08:00
black-06
418ee3fc19
fix: preload shouldn't overwrite the value of join (#6771)
* fix: preload shouldn't overwrite the value of join

* fix lint

* fix: join may automatically add nested query
2024-01-29 11:34:57 +08:00
dependabot[bot]
e043924fe7
chore(deps): bump actions/cache from 3 to 4 (#6802)
Bumps [actions/cache](https://github.com/actions/cache) from 3 to 4.
- [Release notes](https://github.com/actions/cache/releases)
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md)
- [Commits](https://github.com/actions/cache/compare/v3...v4)

---
updated-dependencies:
- dependency-name: actions/cache
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-01-29 10:34:20 +08:00
Jacky
0123dd4509
fix: ignore .gen.go suffix in logger to get the real caller when using gen #6697 (#6785) 2024-01-12 17:09:22 +08:00
Jinzhu
940358e0dd Fix tests doesn't follow https://gorm.io/docs/method_chaining.html convention 2024-01-12 16:42:21 +08:00
iTanken
87decced23
fix: ExplainSQL using consecutive pairs of escaper in SQL string represents an escaper (#6766)
Preventing it from being interpreted as the string terminator. This is a widely used escape mechanism in SQL standards and is applicable in most relational databases.
2023-12-28 19:53:36 +08:00
Stephano George
436cca753c
fix: join and select mytable.* not working (#6761)
* fix: select mytable.* not working

* fix: select mytable.*: will not match `mytable."*"`.
feat: increase readability of code matching table name column name
2023-12-23 21:19:41 +08:00
Alexis Viscogliosi
a2cac75218
feature: bring custom type and id column name to polymorphism (#6716)
* feature: bring custom type and id column name to polymorphism

* relationship: better returns for hasPolymorphicRelation

* fix: tests
2023-12-15 16:36:08 +08:00
Maciej Laskowski
b9ebdb13c7
Making locking parameters more intuitive (#6719)
* Making locking parameters more intuitive

* remove dedicated type
2023-12-15 16:32:56 +08:00
BugKillerPro
2fb4928aa8
refactor: Resolve implicit memory aliasing in for loop (#6730) 2023-12-15 16:31:23 +08:00
Franco Liberali
f0af94cd16 add test to show that update from works 2023-12-04 11:50:58 +08:00
FangSqing
3207ad6033
map insert support return increment id (#6662) 2023-11-15 21:32:56 +08:00
Jinzhu
c1e911f6ed Update tests/go.mod 2023-11-09 18:46:39 +08:00
Kijima Daigo
40f4afe8c2
docs: fix broken link (#6673) 2023-11-07 10:20:06 +08:00
Flc゛
d2fb7a942b
chore(logger): optimize (#6675)
* chore(logger): optimize

* chore(logger): optimize
2023-11-07 10:19:41 +08:00
black-06
9fea15ae75
feat: add MigrateColumnUnique (#6640)
* feat: add MigrateColumnUnique

* feat: define new methods

* delete debug in test
2023-10-30 17:15:49 +08:00
Cr
5adc0ce5f6
test: fix TestEmbeddedRelations (#6639) 2023-10-26 11:58:13 +08:00
gleb
78e905919f
tests/sqilte: enable FOREIGN_KEYS inside OpenTestConnection (#6641) 2023-10-26 11:54:15 +08:00
Franco Liberali
6bef318891
add support for returning in sqlserver (#6585) 2023-10-10 15:03:34 +08:00
dependabot[bot]
1b24081010
chore(deps): bump actions/checkout from 3 to 4 (#6586)
Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v3...v4)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-10 14:50:45 +08:00
Jeremy Quirke
8c18714462
Don't call MethodByName with a variable arg (#6602)
Go 1.22 goes somewhat toward addressing the issue using reflect
MethodByName disabling linker deadcode elimination (DCE) and the
resultant large increase in binary size because the linker cannot
prune unused code because it might be reached via reflection.

Go Issue golang/go#62257 reduces the number of incidences of this
problem by leveraging a compiler assist to avoid marking functions
containing calls to MethodByName as ReflectMethods as long as the
arguments are constants.

An analysis of Uber Technologies code base however shows that a number
of transitive imports still contain calls to MethodByName with a
variable argument, including GORM.

In the case of GORM, the solution we are proposing is because the
number of possible methods is finite, we will "unroll" this. This
demonstrably shows that GORM is not longer a problem for DCE.

Before
```
% go version
go version devel go1.22-2f3458a8ce Sat Sep 16 16:26:48 2023 -0700 darwin/arm64
% go  test ./... -ldflags=-dumpdep   2>  >(grep -i -e  '->.*<reflectmethod>')
gorm.io/gorm.(*Statement).BuildCondition -> gorm.io/gorm/schema.ParseWithSpecialTableName <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).Method <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).MethodByName <ReflectMethod>
ok  	gorm.io/gorm	(cached)
ok  	gorm.io/gorm/callbacks	(cached)
gorm.io/gorm/clause_test.BenchmarkComplexSelect -> gorm.io/gorm/schema.ParseWithSpecialTableName <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).Method <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).MethodByName <ReflectMethod>
?   	gorm.io/gorm/migrator	[no test files]
ok  	gorm.io/gorm/clause	(cached)
ok  	gorm.io/gorm/logger	(cached)
gorm.io/gorm/schema_test.TestAdvancedDataTypeValuerAndSetter -> gorm.io/gorm/schema.ParseWithSpecialTableName <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).Method <ReflectMethod>
type:reflect.Value <UsedInIface> -> reflect.(*Value).MethodByName <ReflectMethod>
?   	gorm.io/gorm/utils/tests	[no test files]
ok  	gorm.io/gorm/schema	(cached)
ok  	gorm.io/gorm/utils	(cached)
```

After

```
%go version
go version devel go1.22-2f3458a8ce Sat Sep 16 16:26:48 2023 -0700 darwin/arm64
%go  test ./... -ldflags=-dumpdep   2>  >(grep -i -e  '->.*<reflectmethod>')
ok  	gorm.io/gorm	(cached)
ok  	gorm.io/gorm/callbacks	(cached)
?   	gorm.io/gorm/migrator	[no test files]
?   	gorm.io/gorm/utils/tests	[no test files]
ok  	gorm.io/gorm/clause	(cached)
ok  	gorm.io/gorm/logger	(cached)
ok  	gorm.io/gorm/schema	(cached)
ok  	gorm.io/gorm/utils	(cached)
```
2023-10-10 14:50:29 +08:00
Mathias Zeller
12ba285a52
*datatypes.JSON in model causes panic on tx.Statement.Changed (#6611)
* do not panic on nil

* more explanation in comments

* get things compact
2023-10-10 14:46:32 +08:00
hjwblog.com
9d8a5bb208
feat: reuse name (#6626) 2023-10-10 14:45:48 +08:00
Samuel N Cui
2095d42b4c
fix: sqlite dialector cannot apply PRIMARY KEY AUTOINCREMENT type (#6624)
* fix: sqlite dialector cannot apply `PRIMARY KEY AUTOINCREMENT` type

fix #4760

* feat: add auto increment test

* feat: update sqlite

* feat: update tests deps sqlite to v1.5.4
2023-10-09 17:26:27 +08:00
Jinzhu
e57e5d8884 Update go.mod 2023-08-27 15:40:54 +08:00
Jinzhu
653732e1c3 Update go testing versions 2023-08-24 20:19:38 +08:00
Rataj
ac07543962
Fixed error message when dialector fails to initialize (#6509)
Let's say we have a problem with DSN which leads to dialector initialize error. However DB connection is not created and for some reason line 184 error provides <nil> even though "db" doesn't exist.

Previously, this code leads to:
panic: runtime error: invalid memory address or nil pointer dereference

This fix now doesn't attempt to close non-existant database connection and instead continues, so the proper error is shown. In my case:
[error] failed to initialize database, got error default addr for network 'localhost' unknown
2023-08-20 19:46:56 +08:00
龚一涛
7e44f73ad3
fix schema GetIdentityFieldValuesMap interface or ptr (#6417)
Co-authored-by: uptutu <yitao.gong@vzenith.com>
2023-08-19 21:35:14 +08:00
Heliner
2c2089760c
add float32 test case (#6530) 2023-08-19 21:33:57 +08:00
qqxhb
fef42941ba
feat: rm GetDBConnWithContext method (#6535)
* feat: rm contextconnpool method

* feat: nil
2023-08-19 21:33:31 +08:00
weih
bae684b363
fix(clause): when the value of clause.Eq is an empty array, the SQL should be IN (NULL) (#6503) 2023-08-10 13:34:33 +08:00
Jinzhu
15162afaf2 Support GetDBConnWithContext PreparedStmtDB 2023-08-10 13:30:57 +08:00
fayvori
3c34bc2f59
refactor: Regex description (#6507)
* Mirror cleanup

* Regex description

---------

Co-authored-by: Ignat Belousov <ignatbelousov@Ignats-MacBook-Pro.local>
2023-08-07 16:35:19 +08:00
Aayush Acharya
f473761813
fix: added SkipHooks in db getInstance() (#6484) 2023-08-04 10:35:59 +08:00
San Ye
193c454cf4
keep float precision in ExplainSQL (#6495) 2023-08-04 10:31:18 +08:00
Saeid
1fb26ac90e
test: coverage for tabletype added (#6496)
* test: coverage for tabletype added

* test: tidb exclueded

---------

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-08-04 10:30:07 +08:00
Jinzhu
a7f01bd1b2 Test Pluck with customized type 2023-07-25 10:47:19 +08:00
Saeid
c10f807d3c
test: coverage for foreign key violation err (#6403)
* test: coverage for foreign key violation err

* test: enabled foreign keys constraint for sqlite

* test: enabled mysql& mssql for ErrForeignKeyViolate

* test: disabled mysql & updated sqlserver driver version

* test: skipped tidb

---------

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-07-12 21:21:22 +08:00
Saeid
2066138684
ci: fix mariadb mysqladmin (#6401)
Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-06-11 07:42:18 +08:00
Saeid
c2d571cbc8
test: coverage for duplicated key err (#6389)
* test: ErrDuplicatedKey coverage added

* test: updated sqlserver version

* test: removed sqlserver

* test: support added for sqlserver

---------

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-06-10 21:05:19 +08:00
Johannes Riecken
7dd702d379
Fix incorrect documentation comment (has many -> has one) (#6382) 2023-06-07 15:02:30 +08:00
Nuno Cruces
7157b7e375
fix: database/sql.Scanner should not retain references (#6380) 2023-06-07 15:02:07 +08:00
Lev Zakharov
661781a3d7
feat: add *sql.DB connector that uses database context (#6366)
* feat: add SQLConnector

* rename
2023-06-05 16:25:05 +08:00
KantaHasegawa
5eaccaa624
reafactor: add nil detection when sqldb return (#6373)
* reafactor: add null detection when sqldb return

* refactor: Detecting nil in dbConnector.GetDBConn()

* refactor: Revert partial code from c1ea73036715018a1bb55cdb8690441044e13a76

* fix: fix if statement
2023-06-05 16:24:00 +08:00
Lev Zakharov
7a76c042e6
refactor: remove unnecessary prepared statement allocation (#6374) 2023-06-05 16:23:17 +08:00
black
c1ea730367 fix: avoid panic when open fails 2023-06-01 15:22:21 +08:00
东方上人
740f2be453
fix: begin transaction fail, rollback panic (#6365) 2023-05-31 19:21:51 +08:00
mohammad ali
26663ab9bf
max identifier length changed to 63 (#6337)
* max identifier length changed to 63

* default maxIdentifierLength is 64

* renamed License to LICENSE (#6336)

* Added support of "Violates Foreign Key Constraint" (#6329)

* Added support of "Violates Foreign Key Constraint"

Updated the translator and added the support of "foreign key constraint violation". For this, this error type is needed here.

* changed the description of ErrForeignKeyViolated

* refactor: error translator test (#6350)

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>

* fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements (#6220)

* test: add nested transaction and prepareStmt coexist test case

note: please test in the MySQL environment

Change-Id: I0db32adc5f74b0d443e98943d3b182236583b959
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

* fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements

1. SavetPoint SQL Statement not support in Prepared Statements
 e.g. see mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html

Change-Id: I082012db9b140e8ec69764c633724665cc802692
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

* revert(transaction_api): remove savepoint name pool,meaningless

Change-Id: I84aa9924fc54612005a81c83d66fdf8968ee56ad
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

---------

Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>
Co-authored-by: 王柳洋 <wangliuyang.520@bytedance.com>

* fix: save with hook (#6285) (#6294)

---------

Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>
Co-authored-by: Avinaba Bhattacharjee <avinababhattacharjee2002@gmail.com>
Co-authored-by: Muhammad Amir Ejaz <37077032+codingamir@users.noreply.github.com>
Co-authored-by: Saeid <sk.saeidee@yahoo.com>
Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
Co-authored-by: wangliuyang <54885906+wangliuyang520@users.noreply.github.com>
Co-authored-by: 王柳洋 <wangliuyang.520@bytedance.com>
Co-authored-by: black-06 <hello.bug@foxmail.com>
2023-05-30 10:00:48 +08:00
black-06
11fdf46a9f
fix: save with hook (#6285) (#6294) 2023-05-26 10:28:02 +08:00
wangliuyang
812bb20c34
fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements (#6220)
* test: add nested transaction and prepareStmt coexist test case

note: please test in the MySQL environment

Change-Id: I0db32adc5f74b0d443e98943d3b182236583b959
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

* fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements

1. SavetPoint SQL Statement not support in Prepared Statements
 e.g. see mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html

Change-Id: I082012db9b140e8ec69764c633724665cc802692
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

* revert(transaction_api): remove savepoint name pool,meaningless

Change-Id: I84aa9924fc54612005a81c83d66fdf8968ee56ad
Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>

---------

Signed-off-by: 王柳洋 <wangliuyang.520@bytedance.com>
Co-authored-by: 王柳洋 <wangliuyang.520@bytedance.com>
2023-05-26 10:24:28 +08:00
Saeid
8197c00def
refactor: error translator test (#6350)
Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-05-25 11:10:00 +08:00
Muhammad Amir Ejaz
001738be49
Added support of "Violates Foreign Key Constraint" (#6329)
* Added support of "Violates Foreign Key Constraint"

Updated the translator and added the support of "foreign key constraint violation". For this, this error type is needed here.

* changed the description of ErrForeignKeyViolated
2023-05-21 21:27:22 +08:00
Avinaba Bhattacharjee
6698ba709e
renamed License to LICENSE (#6336) 2023-05-21 21:24:00 +08:00
201430098137
f5837deef3
fix:clickhouse error not capture(#6277) (#6321)
Co-authored-by: zhuangg <zhuangg@mingyuanyun.com>
2023-05-17 10:15:41 +08:00
Jinzhu
c3d7d08b9a Clear SET clause after build SQL 2023-05-15 15:43:44 +08:00
aclich
63534145fd
fix: 🐛 embedded struct test failed with custom datatypes (#6311)
* fix: 🐛 embedded struct test failed with custom datatypes

Fix the pointer embedded struct within custom datatypes and *time.time
should be nil issue.

* fix: 🐛 change test case to avoid mssql driver issue

change test cases from bytes to string to avoid mssql driver issue
2023-05-15 09:59:26 +08:00
John Mai
e61b98d696
feat: migrator support table comment (#6225)
* feat: migrator support table comment

* feat: migrator support tableType.It like ColumnTypes

* Avoid updating the go.mod file.

* Update tests_all.sh

* Update migrator.go

* remove Catalog() & Engine() methods.

* remove CatalogValue & EngineValue.

---------

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-05-05 15:58:27 +08:00
black-06
32045fdd7d
feat: unscoped association (#5899) (#6246)
* feat: unscoped association (#5899)

* modify name because mysql character is latin1

* work only on has association

* format

* Unscoped on belongs_to association
2023-05-04 19:30:45 +08:00
hykuan
67642abfff
fix: 🐛 numeric types in pointer embedded struct test failed (#6293) 2023-05-04 19:29:31 +08:00
hanwn
aeb298635b
debug: use slice Stale sort (#6263)
Co-authored-by: hanwang <hanwang.7721@bytedance.com>
2023-04-26 22:19:46 +08:00
Cr
407bedae0a
fix: nested joins alias (#6265) 2023-04-26 22:19:32 +08:00
yikakia
1f763c81cb
fix typo chainable_api.go (#6266) 2023-04-26 22:19:06 +08:00
Zhiheng Lin
32fc201554
fix: avoid coroutine leaks when the dialecter initialization fails. (#6249)
Co-authored-by: Kevin Lin <kevin.lin@shopee.com>
2023-04-21 22:17:21 +08:00
black-06
ac20d9e222
fix: unit test (#6250)
* fix: unit test

* fix create test

https://github.com/go-gorm/gorm/pull/6127#discussion_r1171214125

* style: rename to adaptorSerializerModel
2023-04-21 22:09:38 +08:00
Jinzhu
e9637024d3 Update README 2023-04-11 13:16:25 +08:00
black-06
828e22b17f
feat: support embedded preload (#6137)
* feat: support embedded preload

* fix lint and test

* fix test...
2023-04-11 13:10:38 +08:00
black-06
4b0da0e97a
fix cond in scopes (#6152)
* fix cond in scopes

* replace quote

* fix execute scopes
2023-04-11 12:01:23 +08:00
bsmith-auth0
ccc3cb758a
fix: many2many association with duplicate belongs to elem (#6206) 2023-04-11 11:06:13 +08:00
jessetang
05bb9d6106
refactor(migrator): non-standard codes (#6180) 2023-04-11 10:32:46 +08:00
dependabot[bot]
1d9f4b0f55
chore(deps): bump actions/stale from 7 to 8 (#6190)
Bumps [actions/stale](https://github.com/actions/stale) from 7 to 8.
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v7...v8)

---
updated-dependencies:
- dependency-name: actions/stale
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-11 10:27:05 +08:00
hanwn
59ca46db3c
fix: limit(0).offset(0) return all data (#6191)
Co-authored-by: hanwang <hanwang.7721@bytedance.com>
2023-04-11 10:25:47 +08:00
Cr
f0360dccbf
fix: embedded should be nil if not exists (#6219) 2023-04-11 10:13:25 +08:00
Saeid Kanishka
b444011d09
refactor: translatorError flag added for backward compatibility (#6178)
Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-03-24 10:07:05 +08:00
cyhone
5d1cdfef2e
avoid starting a transaction when performing only one insert operation in CreateInBatches function (#6174) 2023-03-23 14:02:35 +08:00
black-06
1a7ea98ac5
fix: count with group (#6157) (#6160)
* fix: count with group (#6157)

* add an easy-to-understand ut
2023-03-23 11:19:53 +08:00
black-06
0c7e575f19
save should be idempotent #6139 (#6149) 2023-03-23 11:18:57 +08:00
dependabot[bot]
d2dd0ce4a7
chore(deps): bump actions/setup-go from 3 to 4 (#6165)
Bumps [actions/setup-go](https://github.com/actions/setup-go) from 3 to 4.
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](https://github.com/actions/setup-go/compare/v3...v4)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-23 11:18:02 +08:00
Jinzhu
cc2d46e5be reuse name for savepoints from nested transaction, close #6060 2023-03-10 17:42:38 +08:00
Cr
8bf1f269cf
feat: support nested join (#6067)
* feat: support nested join

* fix: empty rel value
2023-03-10 17:21:56 +08:00
Jeffry Luqman
654b5f2006
test: pgsql alter column from smallint or string to boolean (#6107)
* test: pgsql alter column from smallint to boolean

* test: pgsql alter column from string to boolean
2023-03-10 17:11:56 +08:00
Cr
b62192456f
fix: diff schema update assign value (#6096) 2023-03-10 17:04:54 +08:00
Saeid Kanishka
707d70a542
refactor: translate error only when it is not nil (#6133)
* refactor: translate error only when it is not nil

* refactor: fix the error flow

* refactor: update the error if checks

* Update gorm.go

---------

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-03-10 16:51:27 +08:00
Truong Nguyen
ed474152b1
Fix: Composite primary key with auto-increment value returns 0 after insert (#6127)
* Fix #4930 workaround for databases that support auto-increment in composite primary key.

* Add test for composite key with auto-increment.

* schema.go: use field.AutoIncrement instead of field.TagSettings["AUTOINCREMENT"], add test to check autoincrement:false

create_test.go: remove unused code: drop table CompositeKeyProduct

---------

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-03-10 16:50:03 +08:00
Jinzhu
1643a36260 Fix possible concurrency problem for serializer 2023-03-10 16:39:57 +08:00
Cr
e9f25c73ee
fix: on confilct with default null (#6129)
* fix: on confilct with default null

* Update create.go

---------

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-03-10 16:35:26 +08:00
Saeid Kanishka
85eaf9eeda
feat: Unique Constraint Violation error translator for different drivers (#6004)
* feat: duplicated key error translator for different drivers

* test: removed the dependency

* test: fixed broken tests

* refactor: added ErrorTransltor interface

* style: applied styler

---------

Co-authored-by: Saeid Saeidee <s.saeidee@sensysgatso.com>
2023-03-06 14:03:31 +08:00
Jinzhu
f3874339ef Fix Save with stress tests 2023-03-02 17:22:51 +08:00
Jiepeng Cao
877cc9148f
Remove redundant code (#6087) 2023-02-27 15:44:35 +08:00
black-06
a80707de9e
Create and drop view (#6097)
* create view

* add comment

* fix test

* check param and add comment
2023-02-27 15:43:10 +08:00
Jiepeng Cao
391c961c7f
quotes on docker-compose.yml ports (#6089) 2023-02-27 15:39:02 +08:00
Cr
04cbd956eb
test: pgsql migrate unique index (#6028) 2023-02-18 09:21:07 +08:00
black-06
e66a059b82
fix: update panic if model is not ptr (#6037)
* fix: update panic if model is not ptr

* fix: update panic if model is not ptr

* fix: update panic if model is not ptr

* fix: raise an error if the value is not addressable

* fix: return
2023-02-18 09:20:29 +08:00
black-06
42fc75cb2c
fix: association concurrently appending (#6044)
* fix: association concurrently appending

* fix: fix unit test

* fix: fix gofumpt
2023-02-18 09:19:24 +08:00
Cr
aa89736db2
fix: miss join type (#6056) 2023-02-18 09:13:36 +08:00
Michael Anstis
532e9cf4cc
Issue 6054: Unscoped not working with PreLoad on Joins (#6058)
* Issue 6054: Unscoped not working with PreLoad on Joins

* Formatting

---------

Co-authored-by: Michael Anstis <manstis@redhat.com>
2023-02-18 09:06:43 +08:00
Cheese
02b7e26f6b
feat: add tidb integration test cases (#6014)
* feat: support tidb integration test

* feat: update the mysql driver version to test
2023-02-08 16:29:09 +08:00
Cr
878ac51e98
fix:throw model value required error (#6031)
* fix:throw model value required error

* chore:ingore typecheck

* chore:ingore errcheck

* refactor: use other error

* chore: gofumpt style
2023-02-08 13:40:41 +08:00
chyroc
e1f46eb802
fix: ignore nil query (#6021) 2023-02-02 17:54:51 +08:00
Jinzhu
4d6b70ec88 Allow modify statement from dest 2023-02-02 17:15:08 +08:00
qiankunli
cfbcedbf03
fix: support zeroValue tag on DeletedAt (#6011)
* fix: support zeroValue tag on DeletedAt

Signed-off-by: qiankunli <qiankun.li@qq.com>

* Update soft_delete_test.go

* Update tests_test.go

* Update soft_delete.go

---------

Signed-off-by: qiankunli <qiankun.li@qq.com>
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-02-01 14:40:55 +08:00
Jinzhu
d834dd60b7 Remove unnecessary code 2023-01-19 15:22:13 +08:00
Jinzhu
3d35ddba55 Fix use table.* as select/omit columns 2023-01-12 16:52:56 +08:00
Haibo
baf1afa1fc
fix(schema): field is only unique when there is one unique index (#5974) 2023-01-11 14:05:39 +08:00
Jinzhu
2bc913787b support implicit table alias, close #5840 #5940 2023-01-02 21:46:27 +08:00
Jinzhu
3d91802b1d Fix unexpected alter table in auto migration, close #5942, #5943 2023-01-02 21:06:04 +08:00
Jinzhu
b0e13d95b4 update github tests action 2023-01-01 22:27:49 +08:00
Jinzhu
4b768c8aff Upgrade tests deps 2023-01-01 22:22:08 +08:00
Haibo
16a272209a
fix(migrator): Tag default:'null' always causes field migration #5953 (#5954)
* fix(migrator): Tag default:'null' always causes field migration #5953

* Update migrate_test.go

* Update migrate_test.go

* Update migrate_test.go

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-01-01 22:14:28 +08:00
Haibo
da2b2861de
fix(migrator): ignore relationships when migrating #5913 (#5946) 2023-01-01 19:54:28 +08:00
dependabot[bot]
7da24d1d52
chore(deps): bump actions/stale from 6 to 7 (#5945)
Bumps [actions/stale](https://github.com/actions/stale) from 6 to 7.
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v6...v7)

---
updated-dependencies:
- dependency-name: actions/stale
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-12-27 08:47:17 +08:00
Jinzhu
ddd3cc2502 Add ParameterizedQueries option support for logger, close #5288 2022-12-25 11:37:23 +08:00
Cr
794edad60e
test(MigrateColumn): mock alter column to improve field compare (#5499)
* test(MigrateColumn): mock alter column to improve field compare

* Update migrate_test.go

* Update migrate_test.go

* Update migrate_test.go

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-12-24 17:42:16 +08:00
Cr
1935eb0adb
feat: support inner join (#5583)
* feat: support inner join

* test: mixed inner join and left join

* chore: code comment

* Update statement.go

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-12-24 12:27:38 +08:00
Defoo Li
775fa70af5
DryRun for migrator (#5689)
* DryRun for migrator

* Update migrator.go

* Update migrator.go

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-12-24 12:14:23 +08:00
Ning
bbd2bbe521
fix:Issue migrating field with CURRENT_TIMESTAMP (#5906)
Co-authored-by: ningfei <accelerator314@outlook.com>
2022-12-24 11:02:11 +08:00
Nate Armstrong
f3c6fc2533
Update func comments in chainable_api and FirstOr_ (#5935)
Add comments to functions in chainable_api. Depending on the method,
these comments add some additional context or details that are relevant
when reading the function, link to the actual docs at gorm.io/docs, or
provide examples of use. These comments should make GORM much more
pleasant to use with an IDE that provides hoverable comments, and are
minimal examples.

Also add in-code documentation to FirstOrInit and FirstOrCreate.

Almost all examples are directly pulled from the docs, with short
comments explaining the code. Most examples omit the `db.Model(&User{})`
for brevity, and would not actually work.

Co-authored-by: Nate Armstrong <nate.armstrong@eluv.io>
2022-12-23 16:51:01 +08:00
Edward McFarlane
4ec73c9bf4
Add test case for embedded value selects (#5901)
* Add test case for embedded value selects

* Revert recycle struct optimisation to avoid pointer overwrites
2022-12-19 11:49:05 +08:00
Cr
d9525d4da4
fix: skip append relation field to default db value (#5885)
* fix: relation field returning

* chore: gofumpt style
2022-12-01 20:26:59 +08:00
wjw1758548031
f931def33d
clear code syntax (#5889)
* clear code syntax

* clear code syntax
2022-12-01 20:25:53 +08:00
Jinzhu
f91313436a Fix group by with count logic 2022-11-21 11:10:56 +08:00
Cr
342310fba4
fix(FindInBatches): throw err if pk not exists (#5868) 2022-11-21 10:49:27 +08:00
kvii
b6836c2d3e
fix bug in windows (#5844)
* fix bug in windows

* fix file name bug

* test in unix like platform
2022-11-21 10:48:13 +08:00
jessetang
cef3de694d
cleanup(prepare_stmt.go): unnecessary map delete (#5849) 2022-11-13 11:12:09 +08:00
jessetang
1b9cd56c53
doc(README.md): add contributors (#5847) 2022-11-10 16:30:32 +08:00
kvii
871f1de6b9
fix logger path bug (#5836) 2022-11-05 11:52:08 +08:00
jessetang
fb640cf7da
test(utils): add utils unit test (#5834) 2022-11-05 08:38:14 +08:00
jessetang
5c8ecc3a2a
feat: golangci add goimports and whitespace (#5835) 2022-11-05 08:37:37 +08:00
jessetang
f82e9cfdbe
test(clause/joins): add join unit test (#5832) 2022-11-03 21:03:13 +08:00
Cr
b2f42528a4
fix(Joins): args with select and omit (#5790)
* fix(Joins): args with select and omit

* chore: gofumpt style
2022-11-02 10:28:00 +08:00
Cr
9d82aa5673
test: invalid cache plan with prepare stmt (#5778)
* test: invalid cache plan with prepare stmt

* test: more test cases

* test: drop and rename column
2022-10-20 14:10:47 +08:00
Cr
5dd2bb4827
feat(PreparedStmtDB): support reset (#5782)
* feat(PreparedStmtDB): support reset

* fix: close all stmt

* test: fix test

* fix: delete one by one
2022-10-19 14:46:59 +08:00
Jinzhu
3f20a543fa Support use clause.Interface as query params 2022-10-18 18:01:55 +08:00
viatoriche / Maxim Panfilov
62593cfad0 add test: TestAutoMigrateInt8PG: shouldn't execute ALTER COLUMN TYPE smallint, close #5762 2022-10-18 17:28:06 +08:00
Jinzhu
a0f4d3f7d2 Save as empty string for not nullable nil field serialized into json 2022-10-18 16:25:39 +08:00
Jinzhu
ab5f80a8d8 Save as NULL for nil object serialized into json 2022-10-18 15:44:56 +08:00
Cr
186e8a9e14
fix: association without pks (#5779) 2022-10-18 11:58:42 +08:00
Jinzhu
2a788fb20c Upgrade tests go.mod 2022-10-17 17:01:42 +08:00
Jinzhu
aa4312ee74 Don't display any GORM related package path as source 2022-10-17 17:01:42 +08:00
Jinzhu
08aa2f9888 Update README 2022-10-14 20:30:41 +08:00
Jinzhu
2c56954cb1 tests mariadb with returning support 2022-10-08 20:48:22 +08:00
Jinzhu
e93dc3426e Test postgres autoincrement check 2022-10-08 17:16:32 +08:00
Jinzhu
983e96f142 Add tests for alter column type 2022-10-08 16:04:57 +08:00
Jinzhu
34fbe84580 Add TableName with NamingStrategy support, close #5726 2022-10-07 21:18:37 +08:00
robhafner
e8f48b5c15
fix: limit=0 results (#5735) (#5736) 2022-10-07 20:14:14 +08:00
jesse.tang
4b22a55a75
fix: primaryFields are overwritten (#5721) 2022-10-07 18:29:28 +08:00
Wen Sun
9564b82975
Fix OnConstraint builder (#5738) 2022-10-07 13:46:20 +08:00
Cr
0b7113b618
fix: prepare deadlock (#5568)
* fix: prepare deadlock

* chore[ci skip]: code style

* chore[ci skip]: test remove unnecessary params

* fix: prepare deadlock

* fix: double check prepare

* test: more goroutines

* chore[ci skip]: improve code comments

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-09-30 18:13:36 +08:00
Stephano George
a3cc6c6088 Fix: wrong value when Find with Join with same column name, close #5723, #5711 2022-09-30 17:18:42 +08:00
jesse.tang
be440e7512
fix possible nil panic in tests (#5720)
* fix maybe nil panic

* reset code
2022-09-30 11:14:34 +08:00
dependabot[bot]
e1dd0dcbc4
chore(deps): bump actions/stale from 5 to 6 (#5717)
Bumps [actions/stale](https://github.com/actions/stale) from 5 to 6.
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/stale
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-09-30 11:13:01 +08:00
Nguyen Huu Tuan
328f301982
add some test case which related the logic (#5477) 2022-09-22 18:35:21 +08:00
kinggo
12237454ed fix: use preparestmt in trasaction will use new conn, close #5508 2022-09-22 16:47:43 +08:00
Cr
73bc53f061
feat: migrator support type aliases (#5627)
* feat: migrator support type aliases

* perf: check type
2022-09-22 15:56:32 +08:00
Cr
101a7c789f
fix: scan array (#5624)
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-09-22 15:51:47 +08:00
Jinzhu
3a72ba102e Allow shared foreign key for many2many jointable 2022-09-22 15:03:41 +08:00
jesse.tang
1f634c3937
support scan assign slice cap (#5634)
* support scan assign slice cap

* fix
2022-09-22 14:50:35 +08:00
Cr
5ed7b1a65e
fix: same embedded filed name (#5705) 2022-09-22 11:25:03 +08:00
qqxhb
490625981a
fix: update omit (#5699) 2022-09-16 15:02:44 +08:00
Googol Lee
edb00c10ad
AutoMigrate() should always migrate checks, even there is no relationship constraints. (#5644)
* fix: remove uuid autoincrement

* AutoMigrate() should always migrate checks, even there is no relationship constranits.

Co-authored-by: a631807682 <631807682@qq.com>
2022-09-14 10:26:51 +08:00
Bruce MacKenzie
f29afdd329
Rewrite of finisher_api Godocs (#5618) 2022-09-09 11:16:41 +08:00
Jiepeng Cao
b3eb1c8c51
simplified regexp (#5677) 2022-09-05 15:39:19 +08:00
jesse.tang
f78f635fae
Optimize: code logic db.scanIntoStruct() (#5633) 2022-09-05 15:34:33 +08:00
Cr
d71caef7d9
fix: remove uuid autoincrement (#5620) 2022-09-03 20:00:21 +08:00
Shunsuke Otani
8c3018b96a
Replace ioutil.Discard with io.Discard (#5603) 2022-08-15 10:50:06 +08:00
Shunsuke Otani
3f92b9b0df
Refactor: redundant type from composite literal (#5604) 2022-08-15 10:47:26 +08:00
Aoang
ba227e8939
Add Go 1.19 Support (#5608) 2022-08-15 10:46:57 +08:00
enwawerueli
573b9fa536 fix: correct grammar 2022-08-15 10:28:36 +08:00
Bruce MacKenzie
a35883590b
update Delete Godoc to describe soft delete behaviour (#5554) 2022-08-11 11:38:04 +08:00
Cr
f223279384
chore: fix gorm tag (#5577) 2022-08-10 11:03:42 +08:00
hjwblog.com
6e03b97e26
fix: empty serilizer err #5524 (#5525)
* fix: empty serilizer err #5524

* feat: fix UnixSecondSerializer return nil

* feat: split type case

Co-authored-by: huanjiawei <huanjiawei@bytedance.com>
2022-07-27 13:59:47 +08:00
MJrocker
3c6eb14c92 Fixed some typos in the code comment 2022-07-27 10:22:49 +08:00
Cr
06e174e24d
fix: embedded default value (#5540) 2022-07-25 14:10:30 +08:00
Xudong Zhang
bab3cd1724
fix bad logging performance of bulk create (#5520) (#5521) 2022-07-18 20:47:00 +08:00
Jinzhu
75720099b5 Create a new db in FindInBatches 2022-07-18 18:07:05 +08:00
Goxiaoy
2ba599e8b7
fix empty QueryClauses in association (#5502) (#5503)
* fix empty QueryClauses in association (#5502)

* test: empty QueryClauses in association (#5502)

* style: empty QueryClauses in association (#5502)

* style: empty QueryClauses in association (#5502)
2022-07-15 11:15:18 +08:00
alingse
099813bf11
Adjust ToStringKey use unpack params, fix pass []any as any in variadic function (#5500)
* fix pass []any as any in variadic function

* add .vscode to gitignore
2022-07-14 20:05:22 +08:00
Jinzhu
4d40e34734 Update select tests 2022-07-14 14:55:54 +08:00
Jinzhu
3262daf8d4 Fix select with association column 2022-07-13 18:26:35 +08:00
Jinzhu
cae30e9a50 Fix select with association column 2022-07-13 18:02:11 +08:00
Jinzhu
a7063848ef Fix select with uppercase column name 2022-07-13 17:44:14 +08:00
Jinzhu
08f6d06e47 Fix select with quoted column name 2022-07-13 17:21:19 +08:00
Jinzhu
62fdc2bb3b Fix serializer with empty string 2022-07-11 11:51:05 +08:00
Jinzhu
b13d1757fa Refactor Model with slice data 2022-07-07 15:39:29 +08:00
Jinzhu
9fd73ae4f1 Revert "use callback to handle transaction"
This reverts commit 93f28bc116526ba4decdd969a7b2b0b245ad70f1.
2022-07-07 15:06:48 +08:00
Jinzhu
fe01e1b9f4 Fix Model with slice data 2022-07-07 14:43:33 +08:00
Cr
46bce170ca
test: pg array type (#5480) 2022-07-04 16:42:27 +08:00
Jason Lee
5c4016d9a3
Merge pull request #5455 from longbridgeapp/feat-support-transaction-calllback 2022-07-01 23:34:26 +08:00
Cr
c74bc57add
fix: association many2many duplicate elem (#5473)
* fix: association many2many duplicate elem

* chore: gofumpt style
2022-07-01 15:12:15 +08:00
Joe
2cb4088456 ignore AddError return error 2022-07-01 14:37:38 +08:00
Cr
235c093bb9
fix(MigrateColumn):declared different type without length (#5465) 2022-06-29 10:07:42 +08:00
wws
3e6ab99043
fix:serializer contain field panic (#5461) 2022-06-25 16:32:47 +08:00
Joe
93f28bc116 use callback to handle transaction
- make transaction have before and after hooks, so plugin can have hack before
or after transaction
2022-06-24 10:33:39 +08:00
Jinzhu
a70af2a4c0 Fix Select with digits in column name 2022-06-20 15:35:40 +08:00
qqxhb
1305f637f8
feat: add method GetIndexes (#5436)
* feat: add method GetIndexes

* feat: add default impl for Index interface

* feat: fmt
2022-06-17 11:00:57 +08:00
Cr
8d45714628
fix: reset null value in slice (#5417)
* fix: reset null value in slice

* fix: can not set field in-place in join
2022-06-14 13:48:50 +08:00
Bexanderthebex
d01de7232b
enhancement: Avoid calling reflect.New() when passing in slice of values to Scan() (#5388)
* fix: reduce allocations when slice of values

* chore[test]: Add benchmark for scan

* chore[test]: add bench for scan slice

* chore[test]: add bench for slice pointer and improve tests

* chore[test]: make sure database is empty when doing slice tests

* fix[test]: correct sql delete statement

* enhancement: skip new if rows affected = 0
2022-06-01 11:50:57 +08:00
dependabot[bot]
f4e9904b02
chore(deps): bump gorm.io/driver/mysql from 1.3.3 to 1.3.4 in /tests (#5385)
Bumps [gorm.io/driver/mysql](https://github.com/go-gorm/mysql) from 1.3.3 to 1.3.4.
- [Release notes](https://github.com/go-gorm/mysql/releases)
- [Commits](https://github.com/go-gorm/mysql/compare/v1.3.3...v1.3.4)

---
updated-dependencies:
- dependency-name: gorm.io/driver/mysql
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-06-01 10:26:09 +08:00
Cr
93986de8e4
fix: migrate column default value (#5359)
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-05-28 23:09:13 +08:00
t-inagaki@hum_op
dc1ae394f3
fixed FirstOrCreate not handled error when table is not exists (#5367)
* fixed FirstOrCreate not handled error when table is not exists

* delete useless part
2022-05-28 22:18:43 +08:00
Cr
7e13b03bd4
fix: duplicate column scan (#5369)
* fix: duplicate column scan

* fix: dup filed in inconsistent schema and database

* chore[ci skip]: gofumpt style

* chore[ci skip]: fix typo
2022-05-28 22:18:07 +08:00
Cr
7d1a92d60e
test: test for skip prepared when auto migrate (#5350) 2022-05-22 16:12:28 +08:00
Clark McCauley
540fb49bcb
Fixed #5355 - Named variables don't work when followed by Windows CRLF line endings (#5356)
* Fixed #5355.

* Fixed unit test to test both CRLF and CR line endings
2022-05-22 15:16:01 +08:00
Cr
7496c3a56e
fix: trx in hooks clone stmt (#5338)
* fix: trx in hooks

* chore: format by gofumpt
2022-05-17 14:13:41 +08:00
black-06
f5e77aab2f
fix: quote index when creating table (#5331) 2022-05-17 10:59:53 +08:00
Cr
373bcf7aca
fix: many2many auto migrate (#5322)
* fix: many2many auto migrate

* fix: uuid ossp
2022-05-09 10:07:18 +08:00
Cr
19b8d37ae8
fix: preload with skip hooks (#5310) 2022-05-04 18:57:53 +08:00
Cr
b0104943ed
fix: callbcak sort when using multiple plugin (#5304) 2022-04-30 09:57:16 +08:00
Heliner
d3488ae6bc
fix: add judge result of auto_migrate (#5306)
Co-authored-by: fredhan <fredhan@futunn.com>
2022-04-30 09:50:53 +08:00
Cr
bd7e42ec65
fix: AutoMigrate with special table name (#5301)
* fix: AutoMigrate with special table name

* test: migrate with special table name
2022-04-27 21:13:48 +08:00
Jinzhu
6a6dfdae72 Refactor FirstOrCreate, FirstOrInit 2022-04-26 17:16:48 +08:00
Chiung-Ming Huang
0211ac91a2
index: add composite id (#5269)
* index: add composite id

* index: add test cases of composite id

* index: improve the comments for the test cases of composite id
2022-04-25 11:39:23 +08:00
Cr
a0cc631272
test: test for postgrs serial column (#5234)
* test: test for postgrs sercial column

* test: only for postgres

* chore: spelling mistake

* test: for drop sequence
2022-04-24 12:13:27 +08:00
aelmel
3643f856a3
check for pointer to pointer value (#5278)
* check for pointer to pointer value

* revert to Ptr

Co-authored-by: Alexei Melnic <alexei.melnic@meliora.xyz>
2022-04-24 09:10:36 +08:00
Cr
9b80fe9e96
fix: stmt.Changed zero value filed behavior (#5281)
* fix: stmt.Changed zero value filed behavior

* chore: rename var
2022-04-24 09:08:52 +08:00
glebarez
395606ac7c
fix missing error-check in AutoMigrate (#5283) 2022-04-22 11:19:33 +08:00
Jinzhu
88c26b62ee Support Scopes in group conditions 2022-04-20 17:21:38 +08:00
Cr
b49ae84780
fix: FindInBatches with offset limit (#5255)
* fix: FindInBatches with offset limit

* fix: break first

* fix: FindInBatches Limit zero
2022-04-17 09:58:33 +08:00
ZhangShenao
e0ed3ce400
fix spelling mistake (#5256)
Co-authored-by: Shenao Zhang <shenao.zhang@shopee.com>
2022-04-14 20:32:57 +08:00
Jinzhu
d421c67ef5 Remove ErrRecordNotFound error from log when using Save 2022-04-14 10:51:39 +08:00
dependabot[bot]
ce53ea53ee
chore(deps): bump actions/setup-go from 2 to 3 (#5243)
Bumps [actions/setup-go](https://github.com/actions/setup-go) from 2 to 3.
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](https://github.com/actions/setup-go/compare/v2...v3)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-04-13 15:53:12 +08:00
dependabot[bot]
771cbed755
chore(deps): bump actions/stale from 4 to 5 (#5244)
Bumps [actions/stale](https://github.com/actions/stale) from 4 to 5.
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v4...v5)

---
updated-dependencies:
- dependency-name: actions/stale
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-04-13 15:52:40 +08:00
Cr
a65912c588
fix: FirstOrCreate RowsAffected (#5250) 2022-04-13 15:52:07 +08:00
Filippo Del Moro
6aa6d37fc4
Fix scanIntoStruct (#5241)
* Reproduces error case

* Fix scanIntoStruct

Co-authored-by: Filippo Del Moro <filippo.delmoro@facile.it>
2022-04-13 15:47:04 +08:00
Jinzhu
74e07b049c Serializer unixtime support ptr of int 2022-04-11 22:07:56 +08:00
Jinzhu
41bef26f13 Remove shared sync pool for Scanner compatibility 2022-04-11 21:37:44 +08:00
Naveen
5c9ef9a843
Set permissions for GitHub actions (#5237)
Restrict the GitHub token permissions only to the required ones; this way, even if the attackers will succeed in compromising your workflow, they won’t be able to do much.

- Included permissions for the action. https://github.com/ossf/scorecard/blob/main/docs/checks.md#token-permissions

https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#permissions

https://docs.github.com/en/actions/using-jobs/assigning-permissions-to-jobs

[Keeping your GitHub Actions and workflows secure Part 1: Preventing pwn requests](https://securitylab.github.com/research/github-actions-preventing-pwn-requests/)

Signed-off-by: naveensrinivasan <172697+naveensrinivasan@users.noreply.github.com>
2022-04-10 09:38:43 +08:00
Jinzhu
0729261b62 Support double ptr for Save 2022-04-08 14:23:25 +08:00
Hasan
81c4024232
Offset issue resolved for scanning results back into struct (#5227) 2022-04-07 23:56:41 +08:00
139 changed files with 13096 additions and 1428 deletions

20
.github/release-drafter.yml vendored Normal file
View File

@ -0,0 +1,20 @@
name-template: 'v Release $NEXT_PATCH_VERSION 🌈'
tag-template: 'v$NEXT_PATCH_VERSION'
categories:
- title: '🚀 Features'
labels:
- 'feature'
- 'enhancement'
- title: '🐛 Bug Fixes'
labels:
- 'fix'
- 'bugfix'
- 'bug'
- title: '🧰 Maintenance'
label: 'chore'
change-template: '- $TITLE @$AUTHOR (#$NUMBER)'
change-title-escapes: '\<*_&'
template: |
## Changes
$CHANGES

31
.github/workflows/create-release.yml vendored Normal file
View File

@ -0,0 +1,31 @@
name: Create Release
on:
push:
tags:
- 'v*.*.*'
permissions:
contents: write
pull-requests: read
jobs:
create_release:
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Generate Release Notes and Publish
id: generate_release_notes
uses: release-drafter/release-drafter@v6
with:
config-name: 'release-drafter.yml'
name: "Release ${{ github.ref_name }}"
tag: ${{ github.ref_name }}
publish: true
prerelease: false
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

26
.github/workflows/golangci-lint.yml vendored Normal file
View File

@ -0,0 +1,26 @@
name: golangci-lint
on:
push:
branches:
- main
- master
pull_request:
permissions:
contents: read
pull-requests: read
jobs:
golangci:
name: lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: stable
- name: golangci-lint
uses: golangci/golangci-lint-action@v7
with:
version: v2.0
only-new-issues: true

View File

@ -3,14 +3,20 @@ on:
schedule: schedule:
- cron: "*/10 * * * *" - cron: "*/10 * * * *"
permissions:
contents: read
jobs: jobs:
stale: stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
ACTIONS_STEP_DEBUG: true ACTIONS_STEP_DEBUG: true
steps: steps:
- name: Close Stale Issues - name: Close Stale Issues
uses: actions/stale@v4 uses: actions/stale@v8
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"

View File

@ -11,7 +11,7 @@ jobs:
name: Label issues and pull requests name: Label issues and pull requests
steps: steps:
- name: check out - name: check out
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: labeler - name: labeler
uses: jinzhu/super-labeler-action@develop uses: jinzhu/super-labeler-action@develop

View File

@ -3,14 +3,20 @@ on:
schedule: schedule:
- cron: "*/10 * * * *" - cron: "*/10 * * * *"
permissions:
contents: read
jobs: jobs:
stale: stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
ACTIONS_STEP_DEBUG: true ACTIONS_STEP_DEBUG: true
steps: steps:
- name: Close Stale Issues - name: Close Stale Issues
uses: actions/stale@v4 uses: actions/stale@v8
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"

View File

@ -1,22 +0,0 @@
name: reviewdog
on: [pull_request]
jobs:
golangci-lint:
name: runner / golangci-lint
runs-on: ubuntu-latest
steps:
- name: Check out code into the Go module directory
uses: actions/checkout@v3
- name: golangci-lint
uses: reviewdog/action-golangci-lint@v2
- name: Setup reviewdog
uses: reviewdog/action-setup@v1
- name: gofumpt -s with reviewdog
env:
REVIEWDOG_GITHUB_API_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
go install mvdan.cc/gofumpt@v0.2.0
gofumpt -e -d . | \
reviewdog -name="gofumpt" -f=diff -f.diff.strip=0 -reporter=github-pr-review

View File

@ -3,14 +3,20 @@ on:
schedule: schedule:
- cron: "0 2 * * *" - cron: "0 2 * * *"
permissions:
contents: read
jobs: jobs:
stale: stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
ACTIONS_STEP_DEBUG: true ACTIONS_STEP_DEBUG: true
steps: steps:
- name: Close Stale Issues - name: Close Stale Issues
uses: actions/stale@v4 uses: actions/stale@v8
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"

View File

@ -8,26 +8,29 @@ on:
branches-ignore: branches-ignore:
- 'gh-pages' - 'gh-pages'
permissions:
contents: read
jobs: jobs:
# Label of the container job # Label of the container job
sqlite: sqlite:
strategy: strategy:
matrix: matrix:
go: ['1.18', '1.17', '1.16'] go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run in windows OS platform: [ubuntu-latest] # can not run in windows OS
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}
steps: steps:
- name: Set up Go 1.x - name: Set up Go 1.x
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: go mod package cache - name: go mod package cache
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
@ -38,8 +41,8 @@ jobs:
mysql: mysql:
strategy: strategy:
matrix: matrix:
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] dbversion: ['mysql:9', 'mysql:8', 'mysql:5.7']
go: ['1.18', '1.17', '1.16'] go: ['1.23', '1.24']
platform: [ubuntu-latest] platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}
@ -62,16 +65,15 @@ jobs:
steps: steps:
- name: Set up Go 1.x - name: Set up Go 1.x
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: go mod package cache - name: go mod package cache
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
@ -79,11 +81,54 @@ jobs:
- name: Tests - name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
mariadb:
strategy:
matrix:
dbversion: [ 'mariadb:latest' ]
go: ['1.23', '1.24']
platform: [ ubuntu-latest ]
runs-on: ${{ matrix.platform }}
services:
mysql:
image: ${{ matrix.dbversion }}
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
ports:
- 9910:3306
options: >-
--health-cmd "mariadb-admin ping -ugorm -pgorm"
--health-interval 10s
--health-start-period 10s
--health-timeout 5s
--health-retries 10
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
postgres: postgres:
strategy: strategy:
matrix: matrix:
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] dbversion: ['postgres:latest', 'postgres:15', 'postgres:14', 'postgres:13']
go: ['1.18', '1.17', '1.16'] go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run in macOS and Windows platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}
@ -106,15 +151,15 @@ jobs:
steps: steps:
- name: Set up Go 1.x - name: Set up Go 1.x
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: go mod package cache - name: go mod package cache
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
@ -125,23 +170,21 @@ jobs:
sqlserver: sqlserver:
strategy: strategy:
matrix: matrix:
go: ['1.18', '1.17', '1.16'] go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run test in macOS and windows platform: [ubuntu-latest] # can not run test in macOS and windows
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}
services: services:
mssql: mssql:
image: mcmoe/mssqldocker:latest image: mcr.microsoft.com/mssql/server:2022-latest
env: env:
TZ: Asia/Shanghai
ACCEPT_EULA: Y ACCEPT_EULA: Y
SA_PASSWORD: LoremIpsum86 MSSQL_SA_PASSWORD: LoremIpsum86
MSSQL_DB: gorm
MSSQL_USER: gorm
MSSQL_PASSWORD: LoremIpsum86
ports: ports:
- 9930:1433 - 9930:1433
options: >- options: >-
--health-cmd="/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P LoremIpsum86 -l 30 -Q \"SELECT 1\" || exit 1" --health-cmd="/opt/mssql-tools18/bin/sqlcmd -S localhost -U sa -P ${MSSQL_SA_PASSWORD} -N -C -l 30 -Q \"SELECT 1\" || exit 1"
--health-start-period 10s --health-start-period 10s
--health-interval 10s --health-interval 10s
--health-timeout 5s --health-timeout 5s
@ -149,18 +192,119 @@ jobs:
steps: steps:
- name: Set up Go 1.x - name: Set up Go 1.x
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: go mod package cache - name: go mod package cache
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests - name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://sa:LoremIpsum86@localhost:9930?database=master" ./tests/tests_all.sh
tidb:
strategy:
matrix:
dbversion: [ 'v6.5.0' ]
go: ['1.23', '1.24']
platform: [ ubuntu-latest ]
runs-on: ${{ matrix.platform }}
steps:
- name: Setup TiDB
uses: Icemap/tidb-action@main
with:
port: 9940
version: ${{matrix.dbversion}}
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh
gaussdb:
strategy:
matrix:
dbversion: ['opengauss/opengauss:7.0.0-RC1.B023']
go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }}
services:
gaussdb:
image: ${{ matrix.dbversion }}
env:
# GaussDB has password limitations
GS_PASSWORD: Gaussdb@123
TZ: Asia/Shanghai
ports:
- 9950:5432
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: Waiting for GaussDB to be ready
run: |
container_name=$(docker ps --filter "ancestor=opengauss/opengauss:7.0.0-RC1.B023" --format "{{.Names}}")
if [ -z "$container_name" ]; then
echo "Error: failed to find a container created from the 'opengauss/opengauss:7.0.0-RC1.B023' image."
exit 1
fi
max_retries=12
retry_count=0
if [ -t 0 ]; then
TTY_FLAG="-t"
else
TTY_FLAG=""
fi
while [ $retry_count -lt $max_retries ]; do
if docker exec -i "${container_name}" bash -c "su - omm -c 'gsql -U omm -c \"select 1;\"'"
then
echo "Creating database gorm..."
sql_file='/tmp/create_database.sql'
echo "CREATE DATABASE gorm DBCOMPATIBILITY 'PG';" > ${sql_file}
docker cp "${sql_file}" "${container_name}":"${sql_file}"
docker exec -i ${TTY_FLAG} "${container_name}" bash -c "su - omm -c 'gsql -U omm -f ${sql_file}'"
echo "Database initialization completed."
break
fi
echo "Waiting for database to be ready... (attempt $((retry_count + 1))/$max_retries)"
sleep 10
((++retry_count))
done
exit 0
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=gaussdb GORM_DSN="user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh

3
.gitignore vendored
View File

@ -3,4 +3,5 @@ documents
coverage.txt coverage.txt
_book _book
.idea .idea
vendor vendor
.vscode

View File

@ -1,7 +1,9 @@
version: "2"
linters: linters:
default: standard
enable: enable:
- cyclop - cyclop
- exportloopref
- gocritic - gocritic
- gosec - gosec
- ineffassign - ineffassign
@ -9,3 +11,9 @@ linters:
- prealloc - prealloc
- unconvert - unconvert
- unparam - unparam
- whitespace
formatters:
enable:
- gofumpt
- goimports

128
CODE_OF_CONDUCT.md Normal file
View File

@ -0,0 +1,128 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to participate in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community includes:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period. This
includes avoiding interactions in community spaces and external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any interaction or public
communication with the community for a specified period. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

View File

@ -1,6 +1,6 @@
The MIT License (MIT) The MIT License (MIT)
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com> Copyright (c) 2013-present Jinzhu <wosmvp@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal

View File

@ -4,9 +4,6 @@ The fantastic ORM library for Golang, aims to be developer friendly.
[![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm)
[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) [![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions)
[![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm)
[![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm)
[![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT)
[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc)
@ -30,14 +27,18 @@ The fantastic ORM library for Golang, aims to be developer friendly.
## Getting Started ## Getting Started
* GORM Guides [https://gorm.io](https://gorm.io) * GORM Guides [https://gorm.io](https://gorm.io)
* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen) * Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html)
## Contributing ## Contributing
[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html) [You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
## Contributors
[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework!
## License ## License
© Jinzhu, 2013~time.Now © Jinzhu, 2013~time.Now
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License) Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE)

View File

@ -14,6 +14,7 @@ import (
type Association struct { type Association struct {
DB *DB DB *DB
Relationship *schema.Relationship Relationship *schema.Relationship
Unscope bool
Error error Error error
} }
@ -40,6 +41,15 @@ func (db *DB) Association(column string) *Association {
return association return association
} }
func (association *Association) Unscoped() *Association {
return &Association{
DB: association.DB,
Relationship: association.Relationship,
Error: association.Error,
Unscope: true,
}
}
func (association *Association) Find(out interface{}, conds ...interface{}) error { func (association *Association) Find(out interface{}, conds ...interface{}) error {
if association.Error == nil { if association.Error == nil {
association.Error = association.buildCondition().Find(out, conds...).Error association.Error = association.buildCondition().Find(out, conds...).Error
@ -64,14 +74,30 @@ func (association *Association) Append(values ...interface{}) error {
func (association *Association) Replace(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error {
if association.Error == nil { if association.Error == nil {
reflectValue := association.DB.Statement.ReflectValue
rel := association.Relationship
var oldBelongsToExpr clause.Expression
// we have to record the old BelongsTo value
if association.Unscope && rel.Type == schema.BelongsTo {
var foreignFields []*schema.Field
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
oldBelongsToExpr = clause.IN{Column: column, Values: values}
}
}
// save associations // save associations
if association.saveAssociation( /*clear*/ true, values...); association.Error != nil { if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
return association.Error return association.Error
} }
// set old associations's foreign key to null // set old associations's foreign key to null
reflectValue := association.DB.Statement.ReflectValue
rel := association.Relationship
switch rel.Type { switch rel.Type {
case schema.BelongsTo: case schema.BelongsTo:
if len(values) == 0 { if len(values) == 0 {
@ -91,6 +117,9 @@ func (association *Association) Replace(values ...interface{}) error {
association.Error = association.DB.UpdateColumns(updateMap).Error association.Error = association.DB.UpdateColumns(updateMap).Error
} }
if association.Unscope && oldBelongsToExpr != nil {
association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
}
case schema.HasOne, schema.HasMany: case schema.HasOne, schema.HasMany:
var ( var (
primaryFields []*schema.Field primaryFields []*schema.Field
@ -119,7 +148,11 @@ func (association *Association) Replace(values ...interface{}) error {
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 { if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error if association.Unscope {
association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error
} else {
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
}
} }
case schema.Many2Many: case schema.Many2Many:
var ( var (
@ -184,7 +217,8 @@ func (association *Association) Delete(values ...interface{}) error {
switch rel.Type { switch rel.Type {
case schema.BelongsTo: case schema.BelongsTo:
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) associationDB := association.DB.Session(&Session{})
tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 { if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
@ -198,8 +232,21 @@ func (association *Association) Delete(values ...interface{}) error {
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
if association.Unscope {
var foreignFields []*schema.Field
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
}
}
case schema.HasOne, schema.HasMany: case schema.HasOne, schema.HasMany:
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) model := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := association.DB.Model(model)
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 { if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
@ -212,7 +259,11 @@ func (association *Association) Delete(values ...interface{}) error {
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error if association.Unscope {
association.Error = tx.Clauses(conds...).Delete(model).Error
} else {
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
}
case schema.Many2Many: case schema.Many2Many:
var ( var (
primaryFields, relPrimaryFields []*schema.Field primaryFields, relPrimaryFields []*schema.Field
@ -345,6 +396,10 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} }
} }
case reflect.Struct: case reflect.Struct:
if !rv.CanAddr() {
association.Error = ErrInvalidValue
return
}
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface()) association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct { if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
@ -353,9 +408,13 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} }
case schema.HasMany, schema.Many2Many: case schema.HasMany, schema.Many2Many:
elemType := association.Relationship.Field.IndirectFieldType.Elem() elemType := association.Relationship.Field.IndirectFieldType.Elem()
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
var fieldValue reflect.Value
if clear { if clear {
fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap())
} else {
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap())
reflect.Copy(fieldValue, oldFieldValue)
} }
appendToFieldValues := func(ev reflect.Value) { appendToFieldValues := func(ev reflect.Value) {
@ -378,6 +437,10 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
} }
case reflect.Struct: case reflect.Struct:
if !rv.CanAddr() {
association.Error = ErrInvalidValue
return
}
appendToFieldValues(rv.Addr()) appendToFieldValues(rv.Addr())
} }
@ -455,6 +518,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
if association.Error != nil {
return
}
// TODO support save slice data, sql with case? // TODO support save slice data, sql with case?
association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
@ -476,6 +542,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
for idx, value := range values { for idx, value := range values {
rv := reflect.Indirect(reflect.ValueOf(value)) rv := reflect.Indirect(reflect.ValueOf(value))
appendToRelations(reflectValue, rv, clear && idx == 0) appendToRelations(reflectValue, rv, clear && idx == 0)
if association.Error != nil {
return
}
} }
if len(values) > 0 { if len(values) > 0 {
@ -507,7 +576,9 @@ func (association *Association) buildCondition() *DB {
joinStmt.AddClause(queryClause) joinStmt.AddClause(queryClause)
} }
joinStmt.Build("WHERE") joinStmt.Build("WHERE")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) if len(joinStmt.SQL.String()) > 0 {
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
} }
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{ tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{

View File

@ -75,11 +75,7 @@ func (cs *callbacks) Raw() *processor {
func (p *processor) Execute(db *DB) *DB { func (p *processor) Execute(db *DB) *DB {
// call scopes // call scopes
for len(db.Statement.scopes) > 0 { for len(db.Statement.scopes) > 0 {
scopes := db.Statement.scopes db = db.executeScopes()
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
}
} }
var ( var (
@ -93,6 +89,10 @@ func (p *processor) Execute(db *DB) *DB {
resetBuildClauses = true resetBuildClauses = true
} }
if optimizer, ok := db.Statement.Dest.(StatementModifier); ok {
optimizer.ModifyStatement(stmt)
}
// assign model values // assign model values
if stmt.Model == nil { if stmt.Model == nil {
stmt.Model = stmt.Dest stmt.Model = stmt.Dest
@ -132,7 +132,11 @@ func (p *processor) Execute(db *DB) *DB {
if stmt.SQL.Len() > 0 { if stmt.SQL.Len() > 0 {
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected sql, vars := stmt.SQL.String(), stmt.Vars
if filter, ok := db.Logger.(ParamsFilter); ok {
sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
}
return db.Dialector.Explain(sql, vars...), db.RowsAffected
}, db.Error) }, db.Error)
} }
@ -183,10 +187,18 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
func (p *processor) compile() (err error) { func (p *processor) compile() (err error) {
var callbacks []*callback var callbacks []*callback
removedMap := map[string]bool{}
for _, callback := range p.callbacks { for _, callback := range p.callbacks {
if callback.match == nil || callback.match(p.db) { if callback.match == nil || callback.match(p.db) {
callbacks = append(callbacks, callback) callbacks = append(callbacks, callback)
} }
if callback.remove {
removedMap[callback.name] = true
}
}
if len(removedMap) > 0 {
callbacks = removeCallbacks(callbacks, removedMap)
} }
p.callbacks = callbacks p.callbacks = callbacks
@ -245,8 +257,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
names, sorted []string names, sorted []string
sortCallback func(*callback) error sortCallback func(*callback) error
) )
sort.Slice(cs, func(i, j int) bool { sort.SliceStable(cs, func(i, j int) bool {
return cs[j].before == "*" || cs[j].after == "*" if cs[j].before == "*" && cs[i].before != "*" {
return true
}
if cs[j].after == "*" && cs[i].after != "*" {
return true
}
return false
}) })
for _, c := range cs { for _, c := range cs {
@ -329,3 +347,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
return return
} }
func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
callbacks := make([]*callback, 0, len(cs))
for _, callback := range cs {
if nameMap[callback.name] {
continue
}
callbacks = append(callbacks, callback)
}
return callbacks
}

View File

@ -47,29 +47,44 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
) )
if !isPtr { if !isPtr {
fieldType = reflect.PtrTo(fieldType) fieldType = reflect.PointerTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{}
for i := 0; i < rValLen; i++ { for i := 0; i < rValLen; i++ {
obj := db.Statement.ReflectValue.Index(i) obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() != reflect.Struct { if reflect.Indirect(obj).Kind() != reflect.Struct {
break break
} }
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
if !isPtr {
rv = rv.Addr()
}
objs = append(objs, obj) objs = append(objs, obj)
if isPtr { elems = reflect.Append(elems, rv)
elems = reflect.Append(elems, rv)
} else { relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
elems = reflect.Append(elems, rv.Addr()) for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, rv)
} }
} }
} }
if elems.Len() > 0 { if elems.Len() > 0 {
if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil { if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil {
for i := 0; i < elems.Len(); i++ { for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i)) setupReferences(objs[i], elems.Index(i))
} }
@ -111,7 +126,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
) )
if !isPtr { if !isPtr {
fieldType = reflect.PtrTo(fieldType) fieldType = reflect.PointerTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
@ -180,7 +195,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
fieldType := rel.Field.IndirectFieldType.Elem() fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr { if !isPtr {
fieldType = reflect.PtrTo(fieldType) fieldType = reflect.PointerTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{} identityMap := map[string]bool{}
@ -206,9 +221,12 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
} }
} }
cacheKey := utils.ToStringKey(relPrimaryValues) cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
identityMap[cacheKey] = true if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
if isPtr { if isPtr {
elems = reflect.Append(elems, elem) elems = reflect.Append(elems, elem)
} else { } else {
@ -250,10 +268,11 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
fieldType := rel.Field.IndirectFieldType.Elem() fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr { if !isPtr {
fieldType = reflect.PtrTo(fieldType) fieldType = reflect.PointerTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.JoinTable.ModelType)), 0, 10)
objs := []reflect.Value{} objs := []reflect.Value{}
appendToJoins := func(obj reflect.Value, elem reflect.Value) { appendToJoins := func(obj reflect.Value, elem reflect.Value) {
@ -272,19 +291,34 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
joins = reflect.Append(joins, joinValue) joins = reflect.Append(joins, joinValue)
} }
identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) { appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ { for i := 0; i < f.Len(); i++ {
elem := f.Index(i) elem := f.Index(i)
if !isPtr {
objs = append(objs, v) elem = elem.Addr()
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
} }
objs = append(objs, v)
elems = reflect.Append(elems, elem)
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, elem)
}
} }
} }
} }
@ -304,7 +338,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
// optimize elems of reflect value length // optimize elems of reflect value length
if elemLen := elems.Len(); elemLen > 0 { if elemLen := elems.Len(); elemLen > 0 {
if v, ok := selectColumns[rel.Name+".*"]; !ok || v { if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
saveAssociations(db, rel, elems, selectColumns, restricted, nil) saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil)
} }
for i := 0; i < elemLen; i++ { for i := 0; i < elemLen; i++ {

View File

@ -13,11 +13,20 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
db.Statement.CurDestIndex = 0 db.Statement.CurDestIndex = 0
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx) if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() {
fc(value.Addr().Interface(), tx)
} else {
db.AddError(gorm.ErrInvalidValue)
return
}
db.Statement.CurDestIndex++ db.Statement.CurDestIndex++
} }
case reflect.Struct: case reflect.Struct:
fc(db.Statement.ReflectValue.Addr().Interface(), tx) if db.Statement.ReflectValue.CanAddr() {
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
} else {
db.AddError(gorm.ErrInvalidValue)
}
} }
} }
} }

View File

@ -3,6 +3,7 @@ package callbacks
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strings"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -52,9 +53,13 @@ func Create(config *Config) func(db *gorm.DB) {
if _, ok := db.Statement.Clauses["RETURNING"]; !ok { if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) if field.Readable {
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
}
}
if len(fromColumns) > 0 {
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
} }
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
} }
} }
} }
@ -88,6 +93,10 @@ func Create(config *Config) func(db *gorm.DB) {
db.AddError(rows.Close()) db.AddError(rows.Close())
}() }()
gorm.Scan(rows, db, mode) gorm.Scan(rows, db, mode)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
} }
return return
@ -102,13 +111,70 @@ func Create(config *Config) func(db *gorm.DB) {
} }
db.RowsAffected, _ = result.RowsAffected() db.RowsAffected, _ = result.RowsAffected()
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
db.Statement.Schema.PrioritizedPrimaryField != nil && if db.Statement.Result != nil {
db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { db.Statement.Result.Result = result
insertID, err := result.LastInsertId() db.Statement.Result.RowsAffected = db.RowsAffected
insertOk := err == nil && insertID > 0 }
if !insertOk {
if db.RowsAffected == 0 {
return
}
var (
pkField *schema.Field
pkFieldName = "@id"
)
if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil ||
!db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue ||
!db.Statement.Schema.PrioritizedPrimaryField.Readable {
return
}
pkField = db.Statement.Schema.PrioritizedPrimaryField
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
}
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
if !supportReturning {
db.AddError(err) db.AddError(err)
}
return
}
// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) {
case map[string]interface{}:
values[pkFieldName] = insertID
case *map[string]interface{}:
(*values)[pkFieldName] = insertID
case []map[string]interface{}, *[]map[string]interface{}:
mapValues, ok := values.([]map[string]interface{})
if !ok {
if v, ok := values.(*[]map[string]interface{}); ok {
if *v != nil {
mapValues = *v
}
}
}
if config.LastInsertIDReversed {
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
}
for _, mapValue := range mapValues {
if mapValue != nil {
mapValue[pkFieldName] = insertID
}
insertID += schema.DefaultAutoIncrementIncrement
}
default:
if pkField == nil {
return return
} }
@ -121,10 +187,10 @@ func Create(config *Config) func(db *gorm.DB) {
break break
} }
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) _, isZero := pkField.ValueOf(db.Statement.Context, rv)
if isZero { if isZero {
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement insertID -= pkField.AutoIncrementIncrement
} }
} }
} else { } else {
@ -134,16 +200,16 @@ func Create(config *Config) func(db *gorm.DB) {
break break
} }
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement insertID += pkField.AutoIncrementIncrement
} }
} }
} }
case reflect.Struct: case reflect.Struct:
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) _, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero { if isZero {
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
} }
} }
} }
@ -252,13 +318,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
} }
} }
for field, vs := range defaultValueFieldsHavingValue { for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) if vs, ok := defaultValueFieldsHavingValue[field]; ok {
for idx := range values.Values { values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
if vs[idx] == nil { for idx := range values.Values {
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) if vs[idx] == nil {
} else { values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
values.Values[idx] = append(values.Values[idx], vs[idx]) } else {
values.Values[idx] = append(values.Values[idx], vs[idx])
}
} }
} }
} }
@ -281,7 +349,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
} }
for _, field := range stmt.Schema.FieldsWithDefaultDBValue { for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], rvOfvalue) values.Values[0] = append(values.Values[0], rvOfvalue)
@ -302,14 +370,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
for _, column := range values.Columns { for _, column := range values.Columns {
if field := stmt.Schema.LookUpField(column.Name); field != nil { if field := stmt.Schema.LookUpField(column.Name); field != nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil ||
strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 {
if field.AutoUpdateTime > 0 { if field.AutoUpdateTime > 0 {
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime} assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
switch field.AutoUpdateTime { switch field.AutoUpdateTime {
case schema.UnixNanosecond: case schema.UnixNanosecond:
assignment.Value = curTime.UnixNano() assignment.Value = curTime.UnixNano()
case schema.UnixMillisecond: case schema.UnixMillisecond:
assignment.Value = curTime.UnixNano() / 1e6 assignment.Value = curTime.UnixMilli()
case schema.UnixSecond: case schema.UnixSecond:
assignment.Value = curTime.Unix() assignment.Value = curTime.Unix()
} }

71
callbacks/create_test.go Normal file
View File

@ -0,0 +1,71 @@
package callbacks
import (
"reflect"
"sync"
"testing"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
var schemaCache = &sync.Map{}
func TestConvertToCreateValues_DestType_Slice(t *testing.T) {
type user struct {
ID int `gorm:"primaryKey"`
Name string
Email string `gorm:"default:(-)"`
Age int `gorm:"default:(-)"`
}
s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{})
if err != nil {
t.Errorf("parse schema error: %v, is not expected", err)
return
}
dest := []*user{
{
ID: 1,
Name: "alice",
Email: "email",
Age: 18,
},
{
ID: 2,
Name: "bob",
Email: "email",
Age: 19,
},
}
stmt := &gorm.Statement{
DB: &gorm.DB{
Config: &gorm.Config{
NowFunc: func() time.Time { return time.Time{} },
},
Statement: &gorm.Statement{
Settings: sync.Map{},
Schema: s,
},
},
ReflectValue: reflect.ValueOf(dest),
Dest: dest,
}
stmt.Schema = s
values := ConvertToCreateValues(stmt)
expected := clause.Values{
// column has value + defaultValue column has value (which should have a stable order)
Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}},
Values: [][]interface{}{
{"alice", "email", 18, 1},
{"bob", "email", 19, 2},
},
}
if !reflect.DeepEqual(expected, values) {
t.Errorf("expected: %v got %v", expected, values)
}
}

View File

@ -157,8 +157,14 @@ func Delete(config *Config) func(db *gorm.DB) {
ok, mode := hasReturning(db, supportReturning) ok, mode := hasReturning(db, supportReturning)
if !ok { if !ok {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if db.AddError(err) == nil { if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected() db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
} }
return return
@ -166,6 +172,10 @@ func Delete(config *Config) func(db *gorm.DB) {
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
gorm.Scan(rows, db, mode) gorm.Scan(rows, db, mode)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
db.AddError(rows.Close()) db.AddError(rows.Close())
} }
} }

View File

@ -125,7 +125,7 @@ func checkMissingWhereConditions(db *gorm.DB) {
type visitMap = map[reflect.Value]bool type visitMap = map[reflect.Value]bool
// Check if circular values, return true if loaded // Check if circular values, return true if loaded
func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) { func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) {
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
v = v.Elem() v = v.Elem()
} }
@ -134,17 +134,17 @@ func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
loaded = true loaded = true
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
if !loadOrStoreVisitMap(vistMap, v.Index(i)) { if !loadOrStoreVisitMap(visitMap, v.Index(i)) {
loaded = false loaded = false
} }
} }
case reflect.Struct, reflect.Interface: case reflect.Struct, reflect.Interface:
if v.CanAddr() { if v.CanAddr() {
p := v.Addr() p := v.Addr()
if _, ok := (*vistMap)[p]; ok { if _, ok := (*visitMap)[p]; ok {
return true return true
} }
(*vistMap)[p] = true (*visitMap)[p] = true
} }
} }

157
callbacks/helper_test.go Normal file
View File

@ -0,0 +1,157 @@
package callbacks
import (
"reflect"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
func TestLoadOrStoreVisitMap(t *testing.T) {
var vm visitMap
var loaded bool
type testM struct {
Name string
}
t1 := testM{Name: "t1"}
t2 := testM{Name: "t2"}
t3 := testM{Name: "t3"}
vm = make(visitMap)
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
t.Fatalf("loaded should be true")
}
// t1 already exist but t2 not
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
t.Fatalf("loaded should be true")
}
}
func TestConvertMapToValuesForCreate(t *testing.T) {
testCase := []struct {
name string
input map[string]interface{}
expect clause.Values
}{
{
name: "Test convert string value",
input: map[string]interface{}{
"name": "my name",
},
expect: clause.Values{
Columns: []clause.Column{{Name: "name"}},
Values: [][]interface{}{{"my name"}},
},
},
{
name: "Test convert int value",
input: map[string]interface{}{
"age": 18,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "age"}},
Values: [][]interface{}{{18}},
},
},
{
name: "Test convert float value",
input: map[string]interface{}{
"score": 99.5,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "score"}},
Values: [][]interface{}{{99.5}},
},
},
{
name: "Test convert bool value",
input: map[string]interface{}{
"active": true,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "active"}},
Values: [][]interface{}{{true}},
},
},
}
for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
actual := ConvertMapToValuesForCreate(&gorm.Statement{}, tc.input)
if !reflect.DeepEqual(actual, tc.expect) {
t.Errorf("expect %v got %v", tc.expect, actual)
}
})
}
}
func TestConvertSliceOfMapToValuesForCreate(t *testing.T) {
testCase := []struct {
name string
input []map[string]interface{}
expect clause.Values
}{
{
name: "Test convert slice of string value",
input: []map[string]interface{}{
{"name": "my name"},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "name"}},
Values: [][]interface{}{{"my name"}},
},
},
{
name: "Test convert slice of int value",
input: []map[string]interface{}{
{"age": 18},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "age"}},
Values: [][]interface{}{{18}},
},
},
{
name: "Test convert slice of float value",
input: []map[string]interface{}{
{"score": 99.5},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "score"}},
Values: [][]interface{}{{99.5}},
},
},
{
name: "Test convert slice of bool value",
input: []map[string]interface{}{
{"active": true},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "active"}},
Values: [][]interface{}{{true}},
},
},
}
for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
actual := ConvertSliceOfMapToValuesForCreate(&gorm.Statement{}, tc.input)
if !reflect.DeepEqual(actual, tc.expect) {
t.Errorf("expected %v but got %v", tc.expect, actual)
}
})
}
}

View File

@ -3,6 +3,8 @@ package callbacks
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"sort"
"strings"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -10,6 +12,176 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
// parsePreloadMap extracts nested preloads. e.g.
//
// // schema has a "k0" relation and a "k7.k8" embedded relation
// parsePreloadMap(schema, map[string][]interface{}{
// clause.Associations: {"arg1"},
// "k1": {"arg2"},
// "k2.k3": {"arg3"},
// "k4.k5.k6": {"arg4"},
// })
// // preloadMap is
// map[string]map[string][]interface{}{
// "k0": {},
// "k7": {
// "k8": {},
// },
// "k1": {},
// "k2": {
// "k3": {"arg3"},
// },
// "k4": {
// "k5.k6": {"arg4"},
// },
// }
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
preloadMap := map[string]map[string][]interface{}{}
setPreloadMap := func(name, value string, args []interface{}) {
if _, ok := preloadMap[name]; !ok {
preloadMap[name] = map[string][]interface{}{}
}
if value != "" {
preloadMap[name][value] = args
}
}
for name, args := range preloads {
preloadFields := strings.Split(name, ".")
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
if preloadFields[0] == clause.Associations {
for _, relation := range s.Relationships.Relations {
if relation.Schema == s {
setPreloadMap(relation.Name, value, args)
}
}
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
for _, value := range embeddedValues(embeddedRelations) {
setPreloadMap(embedded, value, args)
}
}
} else {
setPreloadMap(preloadFields[0], value, args)
}
}
return preloadMap
}
func embeddedValues(embeddedRelations *schema.Relationships) []string {
if embeddedRelations == nil {
return nil
}
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
for _, relation := range embeddedRelations.Relations {
// skip first struct name
names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
}
for _, relations := range embeddedRelations.EmbeddedRelations {
names = append(names, embeddedValues(relations)...)
}
return names
}
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
// If the current relationship is embedded or joined, current query will be ignored.
//
//nolint:cyclop
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
// avoid random traversal of the map
preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap {
preloadNames = append(preloadNames, key)
}
sort.Strings(preloadNames)
isJoined := func(name string) (joined bool, nestedJoins []string) {
for _, join := range joins {
if _, ok := relationships.Relations[join]; ok && name == join {
joined = true
continue
}
join0, join1, cut := strings.Cut(join, ".")
if cut {
if _, ok := relationships.Relations[join0]; ok && name == join0 {
joined = true
nestedJoins = append(nestedJoins, join1)
}
}
}
return joined, nestedJoins
}
for _, name := range preloadNames {
if relations := relationships.EmbeddedRelations[name]; relations != nil {
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
return err
}
} else if rel := relationships.Relations[name]; rel != nil {
if joined, nestedJoins := isJoined(name); joined {
switch rv := db.Statement.ReflectValue; rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() > 0 {
reflectValue := rel.FieldSchema.MakeSlice().Elem()
for i := 0; i < rv.Len(); i++ {
frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
if frv.Kind() != reflect.Ptr {
reflectValue = reflect.Append(reflectValue, frv.Addr())
} else {
if frv.IsNil() {
continue
}
reflectValue = reflect.Append(reflectValue, frv)
}
}
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
}
case reflect.Struct, reflect.Pointer:
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
default:
return gorm.ErrInvalidData
}
} else {
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
tx.Statement.ReflectValue = db.Statement.ReflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
return err
}
}
} else {
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
}
}
return nil
}
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if err := tx.Statement.Parse(dest); err != nil {
tx.AddError(err)
return tx
}
tx.Statement.ReflectValue = reflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
return tx
}
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
var ( var (
reflectValue = tx.Statement.ReflectValue reflectValue = tx.Statement.ReflectValue
@ -103,6 +275,8 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
if len(values) != 0 { if len(values) != 0 {
tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})
for _, cond := range conds { for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx) tx = fc(tx)
@ -111,7 +285,11 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
} }
} }
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil { if len(inlineConds) > 0 {
tx = tx.Where(inlineConds[0], inlineConds[1:]...)
}
if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
return err return err
} }
} }

View File

@ -3,11 +3,12 @@ package callbacks
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"sort"
"strings" "strings"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
func Query(db *gorm.DB) { func Query(db *gorm.DB) {
@ -24,6 +25,10 @@ func Query(db *gorm.DB) {
db.AddError(rows.Close()) db.AddError(rows.Close())
}() }()
gorm.Scan(rows, db, 0) gorm.Scan(rows, db, 0)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
} }
} }
} }
@ -109,78 +114,148 @@ func BuildQuerySQL(db *gorm.DB) {
} }
} }
specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
for _, join := range db.Statement.Joins { for _, join := range db.Statement.Joins {
if db.Statement.Schema == nil { if db.Statement.Schema != nil {
fromClause.Joins = append(fromClause.Joins, clause.Join{ var isRelations bool // is relations or raw sql
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, var relations []*schema.Relationship
}) relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { if ok {
tableAliasName := relation.Name isRelations = true
relations = append(relations, relation)
} else {
// handle nested join like "Manager.Company"
nestedJoinNames := strings.Split(join.Name, ".")
if len(nestedJoinNames) > 1 {
isNestedJoin := true
guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
currentRelations := db.Statement.Schema.Relationships.Relations
for _, relname := range nestedJoinNames {
// incomplete match, only treated as raw sql
if relation, ok = currentRelations[relname]; ok {
guessNestedRelations = append(guessNestedRelations, relation)
currentRelations = relation.FieldSchema.Relationships.Relations
} else {
isNestedJoin = false
break
}
}
for _, s := range relation.FieldSchema.DBNames { if isNestedJoin {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ isRelations = true
Table: tableAliasName, relations = guessNestedRelations
Name: s, }
Alias: tableAliasName + "__" + s, }
}
if isRelations {
genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
}
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
for _, s := range relation.FieldSchema.DBNames {
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName,
Name: s,
Alias: utils.NestedRelationName(tableAliasName, s),
})
}
}
if join.Expression != nil {
return clause.Join{
Type: join.JoinType,
Expression: join.Expression,
}
}
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
}
} else {
if ref.PrimaryValue == "" {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}
} else {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
}
}
}
}
{
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
for _, c := range relation.FieldSchema.QueryClauses {
onStmt.AddClause(c)
}
if join.On != nil {
onStmt.AddClause(join.On)
}
if cs, ok := onStmt.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
where.Build(&onStmt)
if onSQL := onStmt.SQL.String(); onSQL != "" {
vars := onStmt.Vars
for idx, v := range vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
}
}
}
return clause.Join{
Type: joinType,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
}
}
parentTableName := clause.CurrentTable
for idx, rel := range relations {
// joins table alias like "Manager, Company, Manager__Company"
curAliasName := rel.Name
if parentTableName != clause.CurrentTable {
curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
}
if _, ok := specifiedRelationsName[curAliasName]; !ok {
aliasName := curAliasName
if idx == len(relations)-1 && join.Alias != "" {
aliasName = join.Alias
}
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
specifiedRelationsName[curAliasName] = aliasName
}
parentTableName = curAliasName
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
}) })
} }
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
}
} else {
if ref.PrimaryValue == "" {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}
} else {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
}
}
}
}
{
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
for _, c := range relation.FieldSchema.QueryClauses {
onStmt.AddClause(c)
}
if join.On != nil {
onStmt.AddClause(join.On)
}
if cs, ok := onStmt.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
where.Build(&onStmt)
if onSQL := onStmt.SQL.String(); onSQL != "" {
vars := onStmt.Vars
for idx, v := range vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
}
}
}
fromClause.Joins = append(fromClause.Joins, clause.Join{
Type: clause.LeftJoin,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
})
} else { } else {
fromClause.Joins = append(fromClause.Joins, clause.Join{ fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
@ -189,7 +264,6 @@ func BuildQuerySQL(db *gorm.DB) {
} }
db.Statement.AddClause(fromClause) db.Statement.AddClause(fromClause)
db.Statement.Joins = nil
} else { } else {
db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.AddClauseIfNotExists(clause.From{})
} }
@ -207,60 +281,27 @@ func Preload(db *gorm.DB) {
return return
} }
preloadMap := map[string]map[string][]interface{}{} joins := make([]string, 0, len(db.Statement.Joins))
for name := range db.Statement.Preloads { for _, join := range db.Statement.Joins {
preloadFields := strings.Split(name, ".") joins = append(joins, join.Name)
if preloadFields[0] == clause.Associations {
for _, rel := range db.Statement.Schema.Relationships.Relations {
if rel.Schema == db.Statement.Schema {
if _, ok := preloadMap[rel.Name]; !ok {
preloadMap[rel.Name] = map[string][]interface{}{}
}
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
preloadMap[rel.Name][value] = db.Statement.Preloads[name]
}
}
}
} else {
if _, ok := preloadMap[preloadFields[0]]; !ok {
preloadMap[preloadFields[0]] = map[string][]interface{}{}
}
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
}
}
} }
preloadNames := make([]string, 0, len(preloadMap)) tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
for key := range preloadMap { if tx.Error != nil {
preloadNames = append(preloadNames, key)
}
sort.Strings(preloadNames)
preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
db.Statement.Settings.Range(func(k, v interface{}) bool {
preloadDB.Statement.Settings.Store(k, v)
return true
})
if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil {
return return
} }
preloadDB.Statement.ReflectValue = db.Statement.ReflectValue
for _, name := range preloadNames { db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
} else {
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
}
}
} }
} }
func AfterQuery(db *gorm.DB) { func AfterQuery(db *gorm.DB) {
// clear the joins after query because preload need it
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
fromClause := db.Statement.Clauses["FROM"]
fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(db.Statement.Joins))} // keep the original From Joins
db.Statement.Clauses["FROM"] = fromClause
}
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
callMethod(db, func(value interface{}, tx *gorm.DB) bool { callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(AfterFindInterface); ok { if i, ok := value.(AfterFindInterface); ok {

View File

@ -13,5 +13,10 @@ func RawExec(db *gorm.DB) {
} }
db.RowsAffected, _ = result.RowsAffected() db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
} }
} }

View File

@ -7,7 +7,7 @@ import (
func RowQuery(db *gorm.DB) { func RowQuery(db *gorm.DB) {
if db.Error == nil { if db.Error == nil {
BuildQuerySQL(db) BuildQuerySQL(db)
if db.DryRun { if db.DryRun || db.Error != nil {
return return
} }

View File

@ -70,10 +70,13 @@ func Update(config *Config) func(db *gorm.DB) {
if db.Statement.SQL.Len() == 0 { if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180) db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{}) db.Statement.AddClauseIfNotExists(clause.Update{})
if set := ConvertToAssignments(db.Statement); len(set) != 0 { if _, ok := db.Statement.Clauses["SET"]; !ok {
db.Statement.AddClause(set) if set := ConvertToAssignments(db.Statement); len(set) != 0 {
} else if _, ok := db.Statement.Clauses["SET"]; !ok { defer delete(db.Statement.Clauses, "SET")
return db.Statement.AddClause(set)
} else {
return
}
} }
db.Statement.Build(db.Statement.BuildClauses...) db.Statement.Build(db.Statement.BuildClauses...)
@ -89,6 +92,10 @@ func Update(config *Config) func(db *gorm.DB) {
gorm.Scan(rows, db, mode) gorm.Scan(rows, db, mode)
db.Statement.Dest = dest db.Statement.Dest = dest
db.AddError(rows.Close()) db.AddError(rows.Close())
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
} }
} else { } else {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
@ -96,6 +103,11 @@ func Update(config *Config) func(db *gorm.DB) {
if db.AddError(err) == nil { if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected() db.RowsAffected, _ = result.RowsAffected()
} }
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
} }
} }
} }
@ -135,7 +147,9 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
assignValue = func(field *schema.Field, value interface{}) { assignValue = func(field *schema.Field, value interface{}) {
for i := 0; i < stmt.ReflectValue.Len(); i++ { for i := 0; i < stmt.ReflectValue.Len(); i++ {
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) if stmt.ReflectValue.CanAddr() {
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
}
} }
} }
case reflect.Struct: case reflect.Struct:
@ -158,21 +172,21 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
switch stmt.ReflectValue.Kind() { switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if size := stmt.ReflectValue.Len(); size > 0 { if size := stmt.ReflectValue.Len(); size > 0 {
var primaryKeyExprs []clause.Expression var isZero bool
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields)) for _, field := range stmt.Schema.PrimaryFields {
var notZero bool _, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
for idx, field := range stmt.Schema.PrimaryFields { if !isZero {
value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i)) break
exprs[idx] = clause.Eq{Column: field.DBName, Value: value} }
notZero = notZero || !isZero
}
if notZero {
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
} }
} }
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) if !isZero {
_, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
} }
case reflect.Struct: case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields { for _, field := range stmt.Schema.PrimaryFields {
@ -229,7 +243,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if field.AutoUpdateTime == schema.UnixNanosecond { if field.AutoUpdateTime == schema.UnixNanosecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
} else if field.AutoUpdateTime == schema.UnixMillisecond { } else if field.AutoUpdateTime == schema.UnixMillisecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()})
} else if field.AutoUpdateTime == schema.UnixSecond { } else if field.AutoUpdateTime == schema.UnixSecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
} else { } else {
@ -241,11 +255,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
} }
default: default:
updatingSchema := stmt.Schema updatingSchema := stmt.Schema
var isDiffSchema bool
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
// different schema // different schema
updatingStmt := &gorm.Statement{DB: stmt.DB} updatingStmt := &gorm.Statement{DB: stmt.DB}
if err := updatingStmt.Parse(stmt.Dest); err == nil { if err := updatingStmt.Parse(stmt.Dest); err == nil {
updatingSchema = updatingStmt.Schema updatingSchema = updatingStmt.Schema
isDiffSchema = true
} }
} }
@ -261,7 +277,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if field.AutoUpdateTime == schema.UnixNanosecond { if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano() value = stmt.DB.NowFunc().UnixNano()
} else if field.AutoUpdateTime == schema.UnixMillisecond { } else if field.AutoUpdateTime == schema.UnixMillisecond {
value = stmt.DB.NowFunc().UnixNano() / 1e6 value = stmt.DB.NowFunc().UnixMilli()
} else if field.AutoUpdateTime == schema.UnixSecond { } else if field.AutoUpdateTime == schema.UnixSecond {
value = stmt.DB.NowFunc().Unix() value = stmt.DB.NowFunc().Unix()
} else { } else {
@ -272,7 +288,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if (ok || !isZero) && field.Updatable { if (ok || !isZero) && field.Updatable {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
assignValue(field, value) assignField := field
if isDiffSchema {
if originField := stmt.Schema.LookUpField(dbName); originField != nil {
assignField = originField
}
}
assignValue(assignField, value)
} }
} }
} else { } else {

View File

@ -1,36 +0,0 @@
package callbacks
import (
"reflect"
"testing"
)
func TestLoadOrStoreVisitMap(t *testing.T) {
var vm visitMap
var loaded bool
type testM struct {
Name string
}
t1 := testM{Name: "t1"}
t2 := testM{Name: "t2"}
t3 := testM{Name: "t3"}
vm = make(visitMap)
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
t.Fatalf("loaded should be true")
}
// t1 already exist but t2 not
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
t.Fatalf("loaded should be true")
}
}

View File

@ -10,10 +10,11 @@ import (
) )
// Model specify the model you would like to run db operations // Model specify the model you would like to run db operations
// // update all users's name to `hello` //
// db.Model(&User{}).Update("name", "hello") // // update all users's name to `hello`
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` // db.Model(&User{}).Update("name", "hello")
// db.Model(&user).Update("name", "hello") // // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello`
// db.Model(&user).Update("name", "hello")
func (db *DB) Model(value interface{}) (tx *DB) { func (db *DB) Model(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Model = value tx.Statement.Model = value
@ -21,6 +22,19 @@ func (db *DB) Model(value interface{}) (tx *DB) {
} }
// Clauses Add clauses // Clauses Add clauses
//
// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more
// advanced techniques like specifying lock strength and optimizer hints. See the
// [docs] for more depth.
//
// // add a simple limit clause
// db.Clauses(clause.Limit{Limit: 1}).Find(&User{})
// // tell the optimizer to use the `idx_user_name` index
// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{})
// // specify the lock strength to UPDATE
// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users)
//
// [docs]: https://gorm.io/docs/sql_builder.html#Clauses
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
var whereConds []interface{} var whereConds []interface{}
@ -41,15 +55,22 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
return return
} }
var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`) var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`)
// Table specify the table you would like to run db operations // Table specify the table you would like to run db operations
//
// // Get a user
// db.Table("users").Take(&result)
func (db *DB) Table(name string, args ...interface{}) (tx *DB) { func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args} tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 { if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 {
tx.Statement.Table = results[1] if results[1] != "" {
tx.Statement.Table = results[1]
} else {
tx.Statement.Table = results[2]
}
} }
} else if tables := strings.Split(name, "."); len(tables) == 2 { } else if tables := strings.Split(name, "."); len(tables) == 2 {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
@ -65,6 +86,11 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
} }
// Distinct specify distinct fields that you want querying // Distinct specify distinct fields that you want querying
//
// // Select distinct names of users
// db.Distinct("name").Find(&results)
// // Select distinct name/age pairs from users
// db.Distinct("name", "age").Find(&results)
func (db *DB) Distinct(args ...interface{}) (tx *DB) { func (db *DB) Distinct(args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Distinct = true tx.Statement.Distinct = true
@ -75,6 +101,14 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) {
} }
// Select specify fields that you want when querying, creating, updating // Select specify fields that you want when querying, creating, updating
//
// Use Select when you only want a subset of the fields. By default, GORM will select all fields.
// Select accepts both string arguments and arrays.
//
// // Select name and age of user using multiple arguments
// db.Select("name", "age").Find(&users)
// // Select name and age of user using an array
// db.Select([]string{"name", "age"}).Find(&users)
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
@ -151,7 +185,25 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
return return
} }
// MapColumns modify the column names in the query results to facilitate align to the corresponding structural fields
func (db *DB) MapColumns(m map[string]string) (tx *DB) {
tx = db.getInstance()
tx.Statement.ColumnMapping = m
return
}
// Where add conditions // Where add conditions
//
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
//
// // Find the first user with name jinzhu
// db.Where("name = ?", "jinzhu").First(&user)
// // Find the first user with name jinzhu and age 20
// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user)
// // Find the first user with name jinzhu and age not equal to 20
// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user)
//
// [docs]: https://gorm.io/docs/query.html#Conditions
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
@ -161,6 +213,11 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
} }
// Not add NOT conditions // Not add NOT conditions
//
// Not works similarly to where, and has the same syntax.
//
// // Find the first user with name not equal to jinzhu
// db.Not("name = ?", "jinzhu").First(&user)
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
@ -170,6 +227,11 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
} }
// Or add OR conditions // Or add OR conditions
//
// Or is used to chain together queries with an OR.
//
// // Find the first user with name equal to jinzhu or john
// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user)
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
@ -179,26 +241,45 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
} }
// Joins specify Joins conditions // Joins specify Joins conditions
// db.Joins("Account").Find(&user) //
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) // db.Joins("Account").Find(&user)
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
return joins(db, clause.LeftJoin, query, args...)
}
// InnerJoins specify inner joins conditions
// db.InnerJoins("Account").Find(&user)
func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) {
return joins(db, clause.InnerJoin, query, args...)
}
func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(args) == 1 { if len(args) == 1 {
if db, ok := args[0].(*DB); ok { if db, ok := args[0].(*DB); ok {
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { j := join{
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where}) Name: query, Conds: args, Selects: db.Statement.Selects,
return Omits: db.Statement.Omits, JoinType: joinType,
} }
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
j.On = &where
}
tx.Statement.Joins = append(tx.Statement.Joins, j)
return
} }
} }
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType})
return return
} }
// Group specify the group method on the find // Group specify the group method on the find
//
// // Select the sum age of users with given names
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results)
func (db *DB) Group(name string) (tx *DB) { func (db *DB) Group(name string) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
@ -210,6 +291,9 @@ func (db *DB) Group(name string) (tx *DB) {
} }
// Having specify HAVING conditions for GROUP BY // Having specify HAVING conditions for GROUP BY
//
// // Select the sum age of users with name jinzhu
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result)
func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.GroupBy{ tx.Statement.AddClause(clause.GroupBy{
@ -218,13 +302,20 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
return return
} }
// Order specify order when retrieve records from database // Order specify order when retrieving records from database
// db.Order("name DESC") //
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) // db.Order("name DESC")
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
// db.Order(clause.OrderBy{Columns: []clause.OrderByColumn{
// {Column: clause.Column{Name: "name"}, Desc: true},
// {Column: clause.Column{Name: "age"}, Desc: true},
// }})
func (db *DB) Order(value interface{}) (tx *DB) { func (db *DB) Order(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
switch v := value.(type) { switch v := value.(type) {
case clause.OrderBy:
tx.Statement.AddClause(v)
case clause.OrderByColumn: case clause.OrderByColumn:
tx.Statement.AddClause(clause.OrderBy{ tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{v}, Columns: []clause.OrderByColumn{v},
@ -242,13 +333,27 @@ func (db *DB) Order(value interface{}) (tx *DB) {
} }
// Limit specify the number of records to be retrieved // Limit specify the number of records to be retrieved
//
// Limit conditions can be cancelled by using `Limit(-1)`.
//
// // retrieve 3 users
// db.Limit(3).Find(&users)
// // retrieve 3 users into users1, and all users into users2
// db.Limit(3).Find(&users1).Limit(-1).Find(&users2)
func (db *DB) Limit(limit int) (tx *DB) { func (db *DB) Limit(limit int) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Limit: limit}) tx.Statement.AddClause(clause.Limit{Limit: &limit})
return return
} }
// Offset specify the number of records to skip before starting to return the records // Offset specify the number of records to skip before starting to return the records
//
// Offset conditions can be cancelled by using `Offset(-1)`.
//
// // select the third user
// db.Offset(2).First(&user)
// // select the first user by cancelling an earlier chained offset
// db.Offset(5).Offset(-1).First(&user)
func (db *DB) Offset(offset int) (tx *DB) { func (db *DB) Offset(offset int) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Offset: offset}) tx.Statement.AddClause(clause.Limit{Offset: offset})
@ -256,25 +361,37 @@ func (db *DB) Offset(offset int) (tx *DB) {
} }
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically // Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000)
// }
// //
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { // func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return func (db *gorm.DB) *gorm.DB { // return db.Where("amount > ?", 1000)
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) // }
// }
// }
// //
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) // func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
// return func (db *gorm.DB) *gorm.DB {
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
// }
// }
//
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.scopes = append(tx.Statement.scopes, funcs...) tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
return tx return tx
} }
func (db *DB) executeScopes() (tx *DB) {
scopes := db.Statement.scopes
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
}
return db
}
// Preload preload associations with given conditions // Preload preload associations with given conditions
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) //
// // get all users, and preload all non-cancelled orders
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if tx.Statement.Preloads == nil { if tx.Statement.Preloads == nil {
@ -284,18 +401,57 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
return return
} }
// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit]
//
// Attrs only adds attributes if the record is not found.
//
// // assign an email if the record is not found
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign an email if the record is not found, otherwise ignore provided email
// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20}
//
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.attrs = attrs tx.Statement.attrs = attrs
return return
} }
// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit]
//
// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that
// records will be updated even if they are found.
//
// // assign an email regardless of if the record is not found
// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign email regardless of if record is found
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
//
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
func (db *DB) Assign(attrs ...interface{}) (tx *DB) { func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.assigns = attrs tx.Statement.assigns = attrs
return return
} }
// Unscoped disables the global scope of soft deletion in a query.
// By default, GORM uses soft deletion, marking records as "deleted"
// by setting a timestamp on a specific field (e.g., `deleted_at`).
// Unscoped allows queries to include records marked as deleted,
// overriding the soft deletion behavior.
// Example:
//
// var users []User
// db.Unscoped().Find(&users)
// // Retrieves all users, including deleted ones.
func (db *DB) Unscoped() (tx *DB) { func (db *DB) Unscoped() (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Unscoped = true tx.Statement.Unscoped = true

View File

@ -29,6 +29,7 @@ func BenchmarkSelect(b *testing.B) {
func BenchmarkComplexSelect(b *testing.B) { func BenchmarkComplexSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
limit10 := 10
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clauses := []clause.Interface{ clauses := []clause.Interface{
@ -43,7 +44,7 @@ func BenchmarkComplexSelect(b *testing.B) {
clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}), clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}),
}}, }},
clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}}, clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}},
clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Limit: &limit10, Offset: 20},
clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}}, clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}},
} }

View File

@ -20,6 +20,7 @@ type Builder interface {
Writer Writer
WriteQuoted(field interface{}) WriteQuoted(field interface{})
AddVar(Writer, ...interface{}) AddVar(Writer, ...interface{})
AddError(error) error
} }
// Clause // Clause

View File

@ -126,8 +126,8 @@ func (expr NamedExpr) Build(builder Builder) {
for _, v := range []byte(expr.SQL) { for _, v := range []byte(expr.SQL) {
if v == '@' && !inName { if v == '@' && !inName {
inName = true inName = true
name = []byte{} name = name[:0]
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' || v == ';' { } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' {
if inName { if inName {
if nv, ok := namedMap[string(name)]; ok { if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv) builder.AddVar(builder, nv)
@ -246,15 +246,19 @@ func (eq Eq) Build(builder Builder) {
switch eq.Value.(type) { switch eq.Value.(type) {
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
builder.WriteString(" IN (")
rv := reflect.ValueOf(eq.Value) rv := reflect.ValueOf(eq.Value)
for i := 0; i < rv.Len(); i++ { if rv.Len() == 0 {
if i > 0 { builder.WriteString(" IN (NULL)")
builder.WriteByte(',') } else {
builder.WriteString(" IN (")
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
} }
builder.AddVar(builder, rv.Index(i).Interface()) builder.WriteByte(')')
} }
builder.WriteByte(')')
default: default:
if eqNil(eq.Value) { if eqNil(eq.Value) {
builder.WriteString(" IS NULL") builder.WriteString(" IS NULL")

View File

@ -94,6 +94,16 @@ func TestNamedExpr(t *testing.T) {
Vars: []interface{}{sql.Named("name", "jinzhu")}, Vars: []interface{}{sql.Named("name", "jinzhu")},
Result: "name1 = ? AND name2 = ?;", Result: "name1 = ? AND name2 = ?;",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "name1 = @name1\r\n AND name2 = @name2",
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}},
Result: "name1 = ?\r\n AND name2 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "name1 = @name1\r AND name2 = @name2",
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}},
Result: "name1 = ?\r AND name2 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, { }, {
SQL: "?", SQL: "?",
Vars: []interface{}{clause.Column{Table: "table", Name: "col"}}, Vars: []interface{}{clause.Column{Table: "table", Name: "col"}},
@ -189,6 +199,11 @@ func TestExpression(t *testing.T) {
}, },
ExpectedVars: []interface{}{"a", "b"}, ExpectedVars: []interface{}{"a", "b"},
Result: "`column-name` NOT IN (?,?)", Result: "`column-name` NOT IN (?,?)",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: column, Value: []string{}},
},
Result: "`column-name` IN (NULL)",
}, { }, {
Expressions: []clause.Expression{ Expressions: []clause.Expression{
clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100}, clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100},

View File

@ -1,5 +1,7 @@
package clause package clause
import "gorm.io/gorm/utils"
type JoinType string type JoinType string
const ( const (
@ -9,7 +11,31 @@ const (
RightJoin JoinType = "RIGHT" RightJoin JoinType = "RIGHT"
) )
// Join join clause for from type JoinTarget struct {
Type JoinType
Association string
Subquery Expression
Table string
}
func Has(name string) JoinTarget {
return JoinTarget{Type: InnerJoin, Association: name}
}
func (jt JoinType) Association(name string) JoinTarget {
return JoinTarget{Type: jt, Association: name}
}
func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
return JoinTarget{Type: jt, Association: name, Subquery: subquery}
}
func (jt JoinTarget) As(name string) JoinTarget {
jt.Table = name
return jt
}
// Join clause for from
type Join struct { type Join struct {
Type JoinType Type JoinType
Table Table Table Table
@ -18,6 +44,12 @@ type Join struct {
Expression Expression Expression Expression
} }
func JoinTable(names ...string) Table {
return Table{
Name: utils.JoinNestedRelationNames(names),
}
}
func (join Join) Build(builder Builder) { func (join Join) Build(builder Builder) {
if join.Expression != nil { if join.Expression != nil {
join.Expression.Build(builder) join.Expression.Build(builder)

101
clause/joins_test.go Normal file
View File

@ -0,0 +1,101 @@
package clause_test
import (
"sync"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
func TestJoin(t *testing.T) {
results := []struct {
name string
join clause.Join
sql string
}{
{
name: "LEFT JOIN",
join: clause.Join{
Type: clause.LeftJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "RIGHT JOIN",
join: clause.Join{
Type: clause.RightJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "INNER JOIN",
join: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "CROSS JOIN",
join: clause.Join{
Type: clause.CrossJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "USING",
join: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
Using: []string{"id"},
},
sql: "INNER JOIN `user` USING (`id`)",
},
{
name: "Expression",
join: clause.Join{
// Invalid
Type: clause.LeftJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
// Valid
Expression: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
Using: []string{"id"},
},
},
sql: "INNER JOIN `user` USING (`id`)",
},
}
for _, result := range results {
t.Run(result.name, func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
result.join.Build(stmt)
if result.sql != stmt.SQL.String() {
t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String())
}
})
}
}

View File

@ -1,10 +1,8 @@
package clause package clause
import "strconv"
// Limit limit clause // Limit limit clause
type Limit struct { type Limit struct {
Limit int Limit *int
Offset int Offset int
} }
@ -15,16 +13,16 @@ func (limit Limit) Name() string {
// Build build where clause // Build build where clause
func (limit Limit) Build(builder Builder) { func (limit Limit) Build(builder Builder) {
if limit.Limit > 0 { if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteString("LIMIT ") builder.WriteString("LIMIT ")
builder.WriteString(strconv.Itoa(limit.Limit)) builder.AddVar(builder, *limit.Limit)
} }
if limit.Offset > 0 { if limit.Offset > 0 {
if limit.Limit > 0 { if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteByte(' ') builder.WriteByte(' ')
} }
builder.WriteString("OFFSET ") builder.WriteString("OFFSET ")
builder.WriteString(strconv.Itoa(limit.Offset)) builder.AddVar(builder, limit.Offset)
} }
} }
@ -33,7 +31,7 @@ func (limit Limit) MergeClause(clause *Clause) {
clause.Name = "" clause.Name = ""
if v, ok := clause.Expression.(Limit); ok { if v, ok := clause.Expression.(Limit); ok {
if limit.Limit == 0 && v.Limit != 0 { if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil {
limit.Limit = v.Limit limit.Limit = v.Limit
} }

View File

@ -8,6 +8,10 @@ import (
) )
func TestLimit(t *testing.T) { func TestLimit(t *testing.T) {
limit0 := 0
limit10 := 10
limit50 := 50
limitNeg10 := -10
results := []struct { results := []struct {
Clauses []clause.Interface Clauses []clause.Interface
Result string Result string
@ -15,38 +19,56 @@ func TestLimit(t *testing.T) {
}{ }{
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{ []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{
Limit: 10, Limit: &limit10,
Offset: 20, Offset: 20,
}}, }},
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, "SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit10, 20},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
"SELECT * FROM `users` LIMIT ?",
[]interface{}{limit0},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}},
"SELECT * FROM `users` LIMIT ?",
[]interface{}{limit0},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
"SELECT * FROM `users` OFFSET 20", nil, "SELECT * FROM `users` OFFSET ?",
[]interface{}{20},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` OFFSET 30", nil, "SELECT * FROM `users` OFFSET ?",
[]interface{}{30},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}},
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, "SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit10, 20},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` LIMIT 10 OFFSET 30", nil, "SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit10, 30},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
"SELECT * FROM `users` LIMIT 10", nil, "SELECT * FROM `users` LIMIT ?",
[]interface{}{limit10},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}},
"SELECT * FROM `users` OFFSET 30", nil, "SELECT * FROM `users` OFFSET ?",
[]interface{}{30},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}},
"SELECT * FROM `users` LIMIT 50 OFFSET 30", nil, "SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit50, 30},
}, },
} }

View File

@ -1,5 +1,12 @@
package clause package clause
const (
LockingStrengthUpdate = "UPDATE"
LockingStrengthShare = "SHARE"
LockingOptionsSkipLocked = "SKIP LOCKED"
LockingOptionsNoWait = "NOWAIT"
)
type Locking struct { type Locking struct {
Strength string Strength string
Table Table Table Table

View File

@ -14,17 +14,21 @@ func TestLocking(t *testing.T) {
Vars []interface{} Vars []interface{}
}{ }{
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate}},
"SELECT * FROM `users` FOR UPDATE", nil, "SELECT * FROM `users` FOR UPDATE", nil,
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthShare, Table: clause.Table{Name: clause.CurrentTable}}},
"SELECT * FROM `users` FOR SHARE OF `users`", nil, "SELECT * FROM `users` FOR SHARE OF `users`", nil,
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}, clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsNoWait}},
"SELECT * FROM `users` FOR UPDATE NOWAIT", nil, "SELECT * FROM `users` FOR UPDATE NOWAIT", nil,
}, },
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsSkipLocked}},
"SELECT * FROM `users` FOR UPDATE SKIP LOCKED", nil,
},
} }
for idx, result := range results { for idx, result := range results {

View File

@ -16,27 +16,27 @@ func (OnConflict) Name() string {
// Build build onConflict clause // Build build onConflict clause
func (onConflict OnConflict) Build(builder Builder) { func (onConflict OnConflict) Build(builder Builder) {
if len(onConflict.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range onConflict.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
builder.WriteString(`) `)
}
if len(onConflict.TargetWhere.Exprs) > 0 {
builder.WriteString(" WHERE ")
onConflict.TargetWhere.Build(builder)
builder.WriteByte(' ')
}
if onConflict.OnConstraint != "" { if onConflict.OnConstraint != "" {
builder.WriteString("ON CONSTRAINT ") builder.WriteString("ON CONSTRAINT ")
builder.WriteString(onConflict.OnConstraint) builder.WriteString(onConflict.OnConstraint)
builder.WriteByte(' ') builder.WriteByte(' ')
} else {
if len(onConflict.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range onConflict.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
builder.WriteString(`) `)
}
if len(onConflict.TargetWhere.Exprs) > 0 {
builder.WriteString(" WHERE ")
onConflict.TargetWhere.Build(builder)
builder.WriteByte(' ')
}
} }
if onConflict.DoNothing { if onConflict.DoNothing {

View File

@ -26,9 +26,12 @@ func (returning Returning) Build(builder Builder) {
// MergeClause merge order by clauses // MergeClause merge order by clauses
func (returning Returning) MergeClause(clause *Clause) { func (returning Returning) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Returning); ok { if v, ok := clause.Expression.(Returning); ok && len(returning.Columns) > 0 {
returning.Columns = append(v.Columns, returning.Columns...) if v.Columns != nil {
returning.Columns = append(v.Columns, returning.Columns...)
} else {
returning.Columns = nil
}
} }
clause.Expression = returning clause.Expression = returning
} }

View File

@ -26,6 +26,22 @@ func TestReturning(t *testing.T) {
}}, }},
"SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil, "SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil,
}, },
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
[]clause.Column{clause.PrimaryColumn},
}, clause.Returning{}, clause.Returning{
[]clause.Column{{Name: "name"}, {Name: "age"}},
}},
"SELECT * FROM `users` RETURNING *", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
[]clause.Column{clause.PrimaryColumn},
}, clause.Returning{
[]clause.Column{{Name: "name"}, {Name: "age"}},
}, clause.Returning{}},
"SELECT * FROM `users` RETURNING *", nil,
},
} }
for idx, result := range results { for idx, result := range results {

View File

@ -49,16 +49,18 @@ func TestSelect(t *testing.T) {
Exprs: []clause.Expression{ Exprs: []clause.Expression{
clause.Expr{ clause.Expr{
SQL: "? as name", SQL: "? as name",
Vars: []interface{}{clause.Eq{ Vars: []interface{}{
Column: clause.Column{Name: "age"}, clause.Eq{
Value: 18, Column: clause.Column{Name: "age"},
}, Value: 18,
},
}, },
}, },
}, },
}, },
}, clause.From{}}, }, clause.From{}},
"SELECT `age` = ? as name FROM `users`", []interface{}{18}, "SELECT `age` = ? as name FROM `users`",
[]interface{}{18},
}, },
} }

View File

@ -21,6 +21,12 @@ func (where Where) Name() string {
// Build build where clause // Build build where clause
func (where Where) Build(builder Builder) { func (where Where) Build(builder Builder) {
if len(where.Exprs) == 1 {
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
where.Exprs = andCondition.Exprs
}
}
// Switch position if the first query expression is a single Or condition // Switch position if the first query expression is a single Or condition
for idx, expr := range where.Exprs { for idx, expr := range where.Exprs {
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 { if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
@ -147,6 +153,11 @@ func Not(exprs ...Expression) Expression {
if len(exprs) == 0 { if len(exprs) == 0 {
return nil return nil
} }
if len(exprs) == 1 {
if andCondition, ok := exprs[0].(AndConditions); ok {
exprs = andCondition.Exprs
}
}
return NotConditions{Exprs: exprs} return NotConditions{Exprs: exprs}
} }
@ -155,19 +166,63 @@ type NotConditions struct {
} }
func (not NotConditions) Build(builder Builder) { func (not NotConditions) Build(builder Builder) {
if len(not.Exprs) > 1 { anyNegationBuilder := false
builder.WriteByte('(') for _, c := range not.Exprs {
if _, ok := c.(NegationExpressionBuilder); ok {
anyNegationBuilder = true
break
}
} }
for idx, c := range not.Exprs { if anyNegationBuilder {
if idx > 0 { if len(not.Exprs) > 1 {
builder.WriteString(AndWithSpace) builder.WriteByte('(')
} }
if negationBuilder, ok := c.(NegationExpressionBuilder); ok { for idx, c := range not.Exprs {
negationBuilder.NegationBuild(builder) if idx > 0 {
} else { builder.WriteString(AndWithSpace)
builder.WriteString("NOT ") }
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder)
} else {
builder.WriteString("NOT ")
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToUpper(e.SQL)
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
builder.WriteByte('(')
}
}
c.Build(builder)
if wrapInParentheses {
builder.WriteByte(')')
}
}
}
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
} else {
builder.WriteString("NOT ")
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}
for idx, c := range not.Exprs {
if idx > 0 {
switch c.(type) {
case OrConditions:
builder.WriteString(OrWithSpace)
default:
builder.WriteString(AndWithSpace)
}
}
e, wrapInParentheses := c.(Expr) e, wrapInParentheses := c.(Expr)
if wrapInParentheses { if wrapInParentheses {
sql := strings.ToUpper(e.SQL) sql := strings.ToUpper(e.SQL)
@ -182,9 +237,9 @@ func (not NotConditions) Build(builder Builder) {
builder.WriteByte(')') builder.WriteByte(')')
} }
} }
}
if len(not.Exprs) > 1 { if len(not.Exprs) > 1 {
builder.WriteByte(')') builder.WriteByte(')')
}
} }
} }

View File

@ -63,7 +63,7 @@ func TestWhere(t *testing.T) {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{ []clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))}, Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))},
}}, }},
"SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", "SELECT * FROM `users` WHERE `age` = ? OR `name` <> ?",
[]interface{}{18, "jinzhu"}, []interface{}{18, "jinzhu"},
}, },
{ {
@ -94,7 +94,7 @@ func TestWhere(t *testing.T) {
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})), clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})),
}, },
}}, }},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)", "SELECT * FROM `users` WHERE `users`.`id` <> ? AND `score` <= ?",
[]interface{}{"1", 100}, []interface{}{"1", 100},
}, },
{ {
@ -105,6 +105,30 @@ func TestWhere(t *testing.T) {
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)", "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)",
[]interface{}{"1", 100}, []interface{}{"1", 100},
}, },
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}},
clause.Expr{SQL: "`age` <= ?", Vars: []interface{}{60}})},
}},
"SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)",
[]interface{}{100, 60},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{
clause.Not(clause.AndConditions{
Exprs: []clause.Expression{
clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.Gt{Column: "age", Value: 18},
}}, clause.OrConditions{
Exprs: []clause.Expression{
clause.Lt{Column: "score", Value: 100},
},
}),
}}},
"SELECT * FROM `users` WHERE NOT ((`users`.`id` = ? AND `age` > ?) OR `score` < ?)",
[]interface{}{"1", 18, 100},
},
} }
for idx, result := range results { for idx, result := range results {

View File

@ -21,6 +21,10 @@ var (
ErrPrimaryKeyRequired = errors.New("primary key required") ErrPrimaryKeyRequired = errors.New("primary key required")
// ErrModelValueRequired model value required // ErrModelValueRequired model value required
ErrModelValueRequired = errors.New("model value required") ErrModelValueRequired = errors.New("model value required")
// ErrModelAccessibleFieldsRequired model accessible fields required
ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required")
// ErrSubQueryRequired sub query required
ErrSubQueryRequired = errors.New("sub query required")
// ErrInvalidData unsupported data // ErrInvalidData unsupported data
ErrInvalidData = errors.New("unsupported data") ErrInvalidData = errors.New("unsupported data")
// ErrUnsupportedDriver unsupported driver // ErrUnsupportedDriver unsupported driver
@ -41,4 +45,10 @@ var (
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
// ErrPreloadNotAllowed preload is not allowed when count is used // ErrPreloadNotAllowed preload is not allowed when count is used
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used") ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
// ErrDuplicatedKey occurs when there is a unique key constraint violation
ErrDuplicatedKey = errors.New("duplicated key not allowed")
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
// ErrCheckConstraintViolated occurs when there is a check constraint violation
ErrCheckConstraintViolated = errors.New("violates check constraint")
) )

View File

@ -1,9 +1,11 @@
package gorm package gorm
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"hash/maphash"
"reflect" "reflect"
"strings" "strings"
@ -13,7 +15,7 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
// Create insert the value into database // Create inserts value, returning the inserted data's primary key in value's id
func (db *DB) Create(value interface{}) (tx *DB) { func (db *DB) Create(value interface{}) (tx *DB) {
if db.CreateBatchSize > 0 { if db.CreateBatchSize > 0 {
return db.CreateInBatches(value, db.CreateBatchSize) return db.CreateInBatches(value, db.CreateBatchSize)
@ -24,7 +26,7 @@ func (db *DB) Create(value interface{}) (tx *DB) {
return tx.callbacks.Create().Execute(tx) return tx.callbacks.Create().Execute(tx)
} }
// CreateInBatches insert the value in batches into database // CreateInBatches inserts value in batches of batchSize
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
reflectValue := reflect.Indirect(reflect.ValueOf(value)) reflectValue := reflect.Indirect(reflect.ValueOf(value))
@ -33,9 +35,10 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
var rowsAffected int64 var rowsAffected int64
tx = db.getInstance() tx = db.getInstance()
// the reflection length judgment of the optimized value
reflectLen := reflectValue.Len()
callFc := func(tx *DB) error { callFc := func(tx *DB) error {
// the reflection length judgment of the optimized value
reflectLen := reflectValue.Len()
for i := 0; i < reflectLen; i += batchSize { for i := 0; i < reflectLen; i += batchSize {
ends := i + batchSize ends := i + batchSize
if ends > reflectLen { if ends > reflectLen {
@ -53,7 +56,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
return nil return nil
} }
if tx.SkipDefaultTransaction { if tx.SkipDefaultTransaction || reflectLen <= batchSize {
tx.AddError(callFc(tx.Session(&Session{}))) tx.AddError(callFc(tx.Session(&Session{})))
} else { } else {
tx.AddError(tx.Transaction(callFc)) tx.AddError(tx.Transaction(callFc))
@ -68,12 +71,16 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
return return
} }
// Save update value in database, if the value doesn't have primary key, will insert it // Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
func (db *DB) Save(value interface{}) (tx *DB) { func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = value tx.Statement.Dest = value
reflectValue := reflect.Indirect(reflect.ValueOf(value)) reflectValue := reflect.Indirect(reflect.ValueOf(value))
for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface {
reflectValue = reflect.Indirect(reflectValue)
}
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
@ -97,20 +104,19 @@ func (db *DB) Save(value interface{}) (tx *DB) {
tx.Statement.Selects = append(tx.Statement.Selects, "*") tx.Statement.Selects = append(tx.Statement.Selects, "*")
} }
tx = tx.callbacks.Update().Execute(tx) updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true}))
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate {
result := reflect.New(tx.Statement.Schema.ModelType).Interface() return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
if err := tx.Session(&Session{}).Take(result).Error; errors.Is(err, ErrRecordNotFound) {
return tx.Create(value)
}
} }
return updateTx
} }
return return
} }
// First find first record that match given conditions, order by primary key // First finds the first record ordered by primary key, matching given conditions conds
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{ tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
@ -125,7 +131,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
} }
// Take return a record that match given conditions, the order will depend on the database implementation // Take finds the first record returned by the database in no specified order, matching given conditions conds
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1) tx = db.Limit(1)
if len(conds) > 0 { if len(conds) > 0 {
@ -138,7 +144,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
} }
// Last find last record that match given conditions, order by primary key // Last finds the last record ordered by primary key, matching given conditions conds
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{ tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
@ -154,7 +160,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
} }
// Find find records that match given conditions // Find finds all records matching given conditions conds
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(conds) > 0 { if len(conds) > 0 {
@ -166,7 +172,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
} }
// FindInBatches find records in batches // FindInBatches finds all records in batches of batchSize
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
var ( var (
tx = db.Order(clause.OrderByColumn{ tx = db.Order(clause.OrderByColumn{
@ -177,13 +183,32 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
batch int batch int
) )
// user specified offset or limit
var totalSize int
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
if limit, ok := c.Expression.(clause.Limit); ok {
if limit.Limit != nil {
totalSize = *limit.Limit
}
if totalSize > 0 && batchSize > totalSize {
batchSize = totalSize
}
// reset to offset to 0 in next batch
tx = tx.Offset(-1).Session(&Session{})
}
}
for { for {
result := queryDB.Limit(batchSize).Find(dest) result := queryDB.Limit(batchSize).Find(dest)
rowsAffected += result.RowsAffected rowsAffected += result.RowsAffected
batch++ batch++
if result.Error == nil && result.RowsAffected != 0 { if result.Error == nil && result.RowsAffected != 0 {
tx.AddError(fc(result, batch)) fcTx := result.Session(&Session{NewDB: true})
fcTx.RowsAffected = result.RowsAffected
tx.AddError(fc(fcTx, batch))
} else if result.Error != nil { } else if result.Error != nil {
tx.AddError(result.Error) tx.AddError(result.Error)
} }
@ -192,6 +217,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
break break
} }
if totalSize > 0 {
if totalSize <= int(rowsAffected) {
break
}
if totalSize/batchSize == batch {
batchSize = totalSize % batchSize
}
}
// Optimize for-break // Optimize for-break
resultsValue := reflect.Indirect(reflect.ValueOf(dest)) resultsValue := reflect.Indirect(reflect.ValueOf(dest))
if result.Statement.Schema.PrioritizedPrimaryField == nil { if result.Statement.Schema.PrioritizedPrimaryField == nil {
@ -199,7 +233,11 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
break break
} }
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
if zero {
tx.AddError(ErrPrimaryKeyRequired)
break
}
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
} }
@ -256,13 +294,24 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
} }
} }
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) // FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
// Each conds must be a struct or map.
//
// FirstOrInit never modifies the database. It is often used with Assign and Attrs.
//
// // assign an email if the record is not found
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign email regardless of if record is found
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{ queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
}) })
if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
if c, ok := tx.Statement.Clauses["WHERE"]; ok { if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok { if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs) tx.assignInterfacesToValue(where.Exprs)
@ -282,59 +331,82 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
return return
} }
// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions) // FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
// Each conds must be a struct or map.
//
// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists.
//
// // assign an email if the record is not found
// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
// // result.RowsAffected -> 1
//
// // assign email regardless of if record is found
// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
// // result.RowsAffected -> 1
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{ tx = db.getInstance()
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
}) })
if tx = queryTx.Find(dest, conds...); tx.Error == nil {
if tx.RowsAffected == 0 {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs)
}
}
// initialize with attrs, conds result := queryTx.Find(dest, conds...)
if len(tx.Statement.attrs) > 0 { if result.Error != nil {
tx.assignInterfacesToValue(tx.Statement.attrs...) tx.Error = result.Error
} return tx
// initialize with attrs, conds
if len(tx.Statement.assigns) > 0 {
tx.assignInterfacesToValue(tx.Statement.assigns...)
}
return tx.Create(dest)
} else if len(db.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
assigns := map[string]interface{}{}
for _, expr := range exprs {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
assigns[column] = eq.Value
case clause.Column:
assigns[column.Name] = eq.Value
default:
}
}
}
return tx.Model(dest).Updates(assigns)
}
} }
if result.RowsAffected == 0 {
if c, ok := result.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
result.assignInterfacesToValue(where.Exprs)
}
}
// initialize with attrs, conds
if len(db.Statement.attrs) > 0 {
result.assignInterfacesToValue(db.Statement.attrs...)
}
// initialize with attrs, conds
if len(db.Statement.assigns) > 0 {
result.assignInterfacesToValue(db.Statement.assigns...)
}
return tx.Create(dest)
} else if len(db.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
assigns := map[string]interface{}{}
for i := 0; i < len(exprs); i++ {
expr := exprs[i]
if eq, ok := expr.(clause.AndConditions); ok {
exprs = append(exprs, eq.Exprs...)
} else if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
assigns[column] = eq.Value
case clause.Column:
assigns[column.Name] = eq.Value
}
}
}
return tx.Model(dest).Updates(assigns)
}
return tx return tx
} }
// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields // Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Update(column string, value interface{}) (tx *DB) { func (db *DB) Update(column string, value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value} tx.Statement.Dest = map[string]interface{}{column: value}
return tx.callbacks.Update().Execute(tx) return tx.callbacks.Update().Execute(tx)
} }
// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields // Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Updates(values interface{}) (tx *DB) { func (db *DB) Updates(values interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = values tx.Statement.Dest = values
@ -355,7 +427,9 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
return tx.callbacks.Update().Execute(tx) return tx.callbacks.Update().Execute(tx)
} }
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition // Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If
// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
// time if null.
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(conds) > 0 { if len(conds) > 0 {
@ -449,7 +523,7 @@ func (db *DB) Rows() (*sql.Rows, error) {
return rows, tx.Error return rows, tx.Error
} }
// Scan scan value to a struct // Scan scans selected value to the struct dest
func (db *DB) Scan(dest interface{}) (tx *DB) { func (db *DB) Scan(dest interface{}) (tx *DB) {
config := *db.Config config := *db.Config
currentLogger, newLogger := config.Logger, logger.Recorder.New() currentLogger, newLogger := config.Logger, logger.Recorder.New()
@ -463,6 +537,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
tx.ScanRows(rows, dest) tx.ScanRows(rows, dest)
} else { } else {
tx.RowsAffected = 0 tx.RowsAffected = 0
tx.AddError(rows.Err())
} }
tx.AddError(rows.Close()) tx.AddError(rows.Close())
} }
@ -474,9 +549,10 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
return return
} }
// Pluck used to query single column from a model as a map // Pluck queries a single column from a model, returning in the slice dest. E.g.:
// var ages []int64 //
// db.Model(&users).Pluck("age", &ages) // var ages []int64
// db.Model(&users).Pluck("age", &ages)
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if tx.Statement.Model != nil { if tx.Statement.Model != nil {
@ -517,7 +593,8 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
return tx.Error return tx.Error
} }
// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed. // Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is
// returned to the connection pool.
func (db *DB) Connection(fc func(tx *DB) error) (err error) { func (db *DB) Connection(fc func(tx *DB) error) (err error) {
if db.Error != nil { if db.Error != nil {
return db.Error return db.Error
@ -539,27 +616,28 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) {
return fc(tx) return fc(tx)
} }
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. // Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
// they are rolled back.
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true panicked := true
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// nested transaction // nested transaction
if !db.DisableNestedTransaction { if !db.DisableNestedTransaction {
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error spID := new(maphash.Hash).Sum64()
err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error
if err != nil { if err != nil {
return return
} }
defer func() { defer func() {
// Make sure to rollback when panic, Block error or Commit error // Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil { if panicked || err != nil {
db.RollbackTo(fmt.Sprintf("sp%p", fc)) db.RollbackTo(fmt.Sprintf("sp%d", spID))
} }
}() }()
} }
err = fc(db.Session(&Session{NewDB: db.clone == 1}))
err = fc(db.Session(&Session{}))
} else { } else {
tx := db.Begin(opts...) tx := db.Begin(opts...)
if tx.Error != nil { if tx.Error != nil {
@ -583,7 +661,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
return return
} }
// Begin begins a transaction // Begin begins a transaction with any transaction options opts
func (db *DB) Begin(opts ...*sql.TxOptions) *DB { func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
var ( var (
// clone statement // clone statement
@ -596,11 +674,18 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
opt = opts[0] opt = opts[0]
} }
ctx := tx.Statement.Context
if _, ok := ctx.Deadline(); !ok {
if db.Config.DefaultTransactionTimeout > 0 {
ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout)
}
}
switch beginner := tx.Statement.ConnPool.(type) { switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner: case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
case ConnPoolBeginner: case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
default: default:
err = ErrInvalidTransaction err = ErrInvalidTransaction
} }
@ -612,7 +697,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
return tx return tx
} }
// Commit commit a transaction // Commit commits the changes in a transaction
func (db *DB) Commit() *DB { func (db *DB) Commit() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Commit()) db.AddError(committer.Commit())
@ -622,7 +707,7 @@ func (db *DB) Commit() *DB {
return db return db
} }
// Rollback rollback a transaction // Rollback rollbacks the changes in a transaction
func (db *DB) Rollback() *DB { func (db *DB) Rollback() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
if !reflect.ValueOf(committer).IsNil() { if !reflect.ValueOf(committer).IsNil() {
@ -636,7 +721,21 @@ func (db *DB) Rollback() *DB {
func (db *DB) SavePoint(name string) *DB { func (db *DB) SavePoint(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because SavePoint not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.SavePoint(db, name)) db.AddError(savePointer.SavePoint(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else { } else {
db.AddError(ErrUnsupportedDriver) db.AddError(ErrUnsupportedDriver)
} }
@ -645,14 +744,28 @@ func (db *DB) SavePoint(name string) *DB {
func (db *DB) RollbackTo(name string) *DB { func (db *DB) RollbackTo(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because RollbackTo not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.RollbackTo(db, name)) db.AddError(savePointer.RollbackTo(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else { } else {
db.AddError(ErrUnsupportedDriver) db.AddError(ErrUnsupportedDriver)
} }
return db return db
} }
// Exec execute raw sql // Exec executes raw sql
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.SQL = strings.Builder{} tx.Statement.SQL = strings.Builder{}

605
generics.go Normal file
View File

@ -0,0 +1,605 @@
package gorm
import (
"context"
"database/sql"
"fmt"
"sort"
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
)
type result struct {
Result sql.Result
RowsAffected int64
}
func (info *result) ModifyStatement(stmt *Statement) {
stmt.Result = info
}
// Build implements clause.Expression interface
func (result) Build(clause.Builder) {
}
func WithResult() *result {
return &result{}
}
type Interface[T any] interface {
Raw(sql string, values ...interface{}) ExecInterface[T]
Exec(ctx context.Context, sql string, values ...interface{}) error
CreateInterface[T]
}
type CreateInterface[T any] interface {
ChainInterface[T]
Table(name string, args ...interface{}) CreateInterface[T]
Create(ctx context.Context, r *T) error
CreateInBatches(ctx context.Context, r *[]T, batchSize int) error
}
type ChainInterface[T any] interface {
ExecInterface[T]
Scopes(scopes ...func(db *Statement)) ChainInterface[T]
Where(query interface{}, args ...interface{}) ChainInterface[T]
Not(query interface{}, args ...interface{}) ChainInterface[T]
Or(query interface{}, args ...interface{}) ChainInterface[T]
Limit(offset int) ChainInterface[T]
Offset(offset int) ChainInterface[T]
Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T]
Select(query string, args ...interface{}) ChainInterface[T]
Omit(columns ...string) ChainInterface[T]
MapColumns(m map[string]string) ChainInterface[T]
Distinct(args ...interface{}) ChainInterface[T]
Group(name string) ChainInterface[T]
Having(query interface{}, args ...interface{}) ChainInterface[T]
Order(value interface{}) ChainInterface[T]
Build(builder clause.Builder)
Delete(ctx context.Context) (rowsAffected int, err error)
Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
Updates(ctx context.Context, t T) (rowsAffected int, err error)
Count(ctx context.Context, column string) (result int64, err error)
}
type ExecInterface[T any] interface {
Scan(ctx context.Context, r interface{}) error
First(context.Context) (T, error)
Last(ctx context.Context) (T, error)
Take(context.Context) (T, error)
Find(ctx context.Context) ([]T, error)
FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error
Row(ctx context.Context) *sql.Row
Rows(ctx context.Context) (*sql.Rows, error)
}
type JoinBuilder interface {
Select(...string) JoinBuilder
Omit(...string) JoinBuilder
Where(query interface{}, args ...interface{}) JoinBuilder
Not(query interface{}, args ...interface{}) JoinBuilder
Or(query interface{}, args ...interface{}) JoinBuilder
}
type PreloadBuilder interface {
Select(...string) PreloadBuilder
Omit(...string) PreloadBuilder
Where(query interface{}, args ...interface{}) PreloadBuilder
Not(query interface{}, args ...interface{}) PreloadBuilder
Or(query interface{}, args ...interface{}) PreloadBuilder
Limit(offset int) PreloadBuilder
Offset(offset int) PreloadBuilder
Order(value interface{}) PreloadBuilder
LimitPerRecord(num int) PreloadBuilder
}
type op func(*DB) *DB
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
v := &g[T]{
db: db,
ops: make([]op, 0, 5),
}
if len(opts) > 0 {
v.ops = append(v.ops, func(db *DB) *DB {
return db.Clauses(opts...)
})
}
v.createG = &createG[T]{
chainG: chainG[T]{
execG: execG[T]{g: v},
},
}
return v
}
type g[T any] struct {
*createG[T]
db *DB
ops []op
}
func (g *g[T]) apply(ctx context.Context) *DB {
db := g.db
if !db.DryRun {
db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
}
for _, op := range g.ops {
db = op(db)
}
return db
}
func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] {
return execG[T]{g: &g[T]{
db: c.db,
ops: append(c.ops, func(db *DB) *DB {
return db.Raw(sql, values...)
}),
}}
}
func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error {
return c.apply(ctx).Exec(sql, values...).Error
}
type createG[T any] struct {
chainG[T]
}
func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] {
return createG[T]{c.with(func(db *DB) *DB {
return db.Table(name, args...)
})}
}
func (c createG[T]) Create(ctx context.Context, r *T) error {
return c.g.apply(ctx).Create(r).Error
}
func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error {
return c.g.apply(ctx).CreateInBatches(r, batchSize).Error
}
type chainG[T any] struct {
execG[T]
}
func (c chainG[T]) getInstance() *DB {
var r T
return c.g.apply(context.Background()).Model(r).getInstance()
}
func (c chainG[T]) with(v op) chainG[T] {
return chainG[T]{
execG: execG[T]{g: &g[T]{
db: c.g.db,
ops: append(append([]op(nil), c.g.ops...), v),
}},
}
}
func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] {
return c.with(func(db *DB) *DB {
for _, fc := range scopes {
fc(db.Statement)
}
return db
})
}
func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Table(name, args...)
})
}
func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Where(query, args...)
})
}
func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Not(query, args...)
})
}
func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Or(query, args...)
})
}
func (c chainG[T]) Limit(offset int) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Limit(offset)
})
}
func (c chainG[T]) Offset(offset int) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Offset(offset)
})
}
type joinBuilder struct {
db *DB
}
func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Select(columns ...string) JoinBuilder {
q.db.Select(columns)
return q
}
func (q *joinBuilder) Omit(columns ...string) JoinBuilder {
q.db.Omit(columns...)
return q
}
type preloadBuilder struct {
limitPerRecord int
db *DB
}
func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Select(columns ...string) PreloadBuilder {
q.db.Select(columns)
return q
}
func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder {
q.db.Omit(columns...)
return q
}
func (q *preloadBuilder) Limit(limit int) PreloadBuilder {
q.db.Limit(limit)
return q
}
func (q *preloadBuilder) Offset(offset int) PreloadBuilder {
q.db.Offset(offset)
return q
}
func (q *preloadBuilder) Order(value interface{}) PreloadBuilder {
q.db.Order(value)
return q
}
func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder {
q.limitPerRecord = num
return q
}
func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] {
return c.with(func(db *DB) *DB {
if jt.Table == "" {
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
}
q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)}
if on != nil {
if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
db.AddError(err)
}
}
j := join{
Name: jt.Association,
Alias: jt.Table,
Selects: q.db.Statement.Selects,
Omits: q.db.Statement.Omits,
JoinType: jt.Type,
}
if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
j.On = &where
}
if jt.Subquery != nil {
joinType := j.JoinType
if joinType == "" {
joinType = clause.LeftJoin
}
if db, ok := jt.Subquery.(interface{ getInstance() *DB }); ok {
stmt := db.getInstance().Statement
if len(j.Selects) == 0 {
j.Selects = stmt.Selects
}
if len(j.Omits) == 0 {
j.Omits = stmt.Omits
}
}
expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}}
if j.On != nil {
expr.SQL += " ON ?"
expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs})
}
j.Expression = expr
}
db.Statement.Joins = append(db.Statement.Joins, j)
sort.Slice(db.Statement.Joins, func(i, j int) bool {
return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name
})
return db
})
}
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Select(query, args...)
})
}
func (c chainG[T]) Omit(columns ...string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Omit(columns...)
})
}
func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.MapColumns(m)
})
}
func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Distinct(args...)
})
}
func (c chainG[T]) Group(name string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Group(name)
})
}
func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Having(query, args...)
})
}
func (c chainG[T]) Order(value interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Order(value)
})
}
func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Preload(association, func(tx *DB) *DB {
q := preloadBuilder{db: tx.getInstance()}
if query != nil {
if err := query(&q); err != nil {
db.AddError(err)
}
}
relation, ok := db.Statement.Schema.Relationships.Relations[association]
if !ok {
if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 {
relationships := db.Statement.Schema.Relationships
for _, field := range preloadFields {
var ok bool
relation, ok = relationships.Relations[field]
if ok {
relationships = relation.FieldSchema.Relationships
} else {
db.AddError(fmt.Errorf("relation %s not found", association))
return nil
}
}
} else {
db.AddError(fmt.Errorf("relation %s not found", association))
return nil
}
}
if q.limitPerRecord > 0 {
if relation.JoinTable != nil {
tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association))
return tx
}
refColumns := []clause.Column{}
for _, rel := range relation.References {
if rel.OwnPrimaryKey {
refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName})
}
}
if len(refColumns) != 0 {
selectExpr := clause.CommaExpression{}
for _, column := range q.db.Statement.Selects {
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}})
}
if len(selectExpr.Exprs) == 0 {
selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}}
}
partitionBy := clause.CommaExpression{}
for _, column := range refColumns {
partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}})
}
rnnColumn := clause.Column{Name: "gorm_preload_rnn"}
sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)"
vars := []interface{}{partitionBy}
if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok {
vars = append(vars, orderBy)
} else {
vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
}})
}
vars = append(vars, rnnColumn)
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars})
q.db.Clauses(clause.Select{Expression: selectExpr})
return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord)
}
}
return q.db
})
})
}
func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) {
r := new(T)
res := c.g.apply(ctx).Delete(r)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) {
var r T
res := c.g.apply(ctx).Model(r).Update(name, value)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) {
res := c.g.apply(ctx).Updates(t)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) {
var r T
err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error
return
}
func (c chainG[T]) Build(builder clause.Builder) {
subdb := c.getInstance()
subdb.Logger = logger.Discard
subdb.DryRun = true
if stmt, ok := builder.(*Statement); ok {
if subdb.Statement.SQL.Len() > 0 {
var (
vars = subdb.Statement.Vars
sql = subdb.Statement.SQL.String()
)
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{}
subdb.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1)
}
subdb.Statement.SQL.Reset()
subdb.Statement.Vars = stmt.Vars
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
} else {
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
}
} else {
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
subdb.callbacks.Query().Execute(subdb)
}
builder.WriteString(subdb.Statement.SQL.String())
stmt.Vars = subdb.Statement.Vars
}
}
type execG[T any] struct {
g *g[T]
}
func (g execG[T]) First(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).First(&r).Error
return r, err
}
func (g execG[T]) Scan(ctx context.Context, result interface{}) error {
var r T
err := g.g.apply(ctx).Model(r).Find(result).Error
return err
}
func (g execG[T]) Last(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).Last(&r).Error
return r, err
}
func (g execG[T]) Take(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).Take(&r).Error
return r, err
}
func (g execG[T]) Find(ctx context.Context) ([]T, error) {
var r []T
err := g.g.apply(ctx).Find(&r).Error
return r, err
}
func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error {
var data []T
return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error {
return fc(data, batch)
}).Error
}
func (g execG[T]) Row(ctx context.Context) *sql.Row {
return g.g.apply(ctx).Row()
}
func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) {
return g.g.apply(ctx).Rows()
}

5
go.mod
View File

@ -1,8 +1,9 @@
module gorm.io/gorm module gorm.io/gorm
go 1.14 go 1.18
require ( require (
github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.1.4 github.com/jinzhu/now v1.1.5
golang.org/x/text v0.20.0
) )

6
go.sum
View File

@ -1,4 +1,6 @@
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=

143
gorm.go
View File

@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"reflect"
"sort" "sort"
"sync" "sync"
"time" "time"
@ -20,7 +21,9 @@ const preparedStmtDBKey = "preparedStmt"
type Config struct { type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can disable it by setting `SkipDefaultTransaction` to true // You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool SkipDefaultTransaction bool
DefaultTransactionTimeout time.Duration
// NamingStrategy tables, columns naming strategy // NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer NamingStrategy schema.Namer
// FullSaveAssociations full save associations // FullSaveAssociations full save associations
@ -33,10 +36,17 @@ type Config struct {
DryRun bool DryRun bool
// PrepareStmt executes the given query in cached statement // PrepareStmt executes the given query in cached statement
PrepareStmt bool PrepareStmt bool
// PrepareStmt cache support LRU expired,
// default maxsize=int64 Max value and ttl=1h
PrepareStmtMaxSize int
PrepareStmtTTL time.Duration
// DisableAutomaticPing // DisableAutomaticPing
DisableAutomaticPing bool DisableAutomaticPing bool
// DisableForeignKeyConstraintWhenMigrating // DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool DisableForeignKeyConstraintWhenMigrating bool
// IgnoreRelationshipsWhenMigrating
IgnoreRelationshipsWhenMigrating bool
// DisableNestedTransaction disable nested transaction // DisableNestedTransaction disable nested transaction
DisableNestedTransaction bool DisableNestedTransaction bool
// AllowGlobalUpdate allow global update // AllowGlobalUpdate allow global update
@ -45,6 +55,10 @@ type Config struct {
QueryFields bool QueryFields bool
// CreateBatchSize default create batch size // CreateBatchSize default create batch size
CreateBatchSize int CreateBatchSize int
// TranslateError enabling error translation
TranslateError bool
// PropagateUnscoped propagate Unscoped to every other nested statement
PropagateUnscoped bool
// ClauseBuilders clause builder // ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder ClauseBuilders map[string]clause.ClauseBuilder
@ -105,6 +119,7 @@ type Session struct {
DisableNestedTransaction bool DisableNestedTransaction bool
AllowGlobalUpdate bool AllowGlobalUpdate bool
FullSaveAssociations bool FullSaveAssociations bool
PropagateUnscoped bool
QueryFields bool QueryFields bool
Context context.Context Context context.Context
Logger logger.Interface Logger logger.Interface
@ -122,12 +137,24 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
return isConfig && !isConfig2 return isConfig && !isConfig2
}) })
if len(opts) > 0 {
if c, ok := opts[0].(*Config); ok {
config = c
} else {
opts = append([]Option{config}, opts...)
}
}
var skipAfterInitialize bool
for _, opt := range opts { for _, opt := range opts {
if opt != nil { if opt != nil {
if applyErr := opt.Apply(config); applyErr != nil { if applyErr := opt.Apply(config); applyErr != nil {
return nil, applyErr return nil, applyErr
} }
defer func(opt Option) { defer func(opt Option) {
if skipAfterInitialize {
return
}
if errr := opt.AfterInitialize(db); errr != nil { if errr := opt.AfterInitialize(db); errr != nil {
err = errr err = errr
} }
@ -142,7 +169,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
} }
if config.NamingStrategy == nil { if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{} config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
} }
if config.Logger == nil { if config.Logger == nil {
@ -175,17 +202,26 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
if config.Dialector != nil { if config.Dialector != nil {
err = config.Dialector.Initialize(db) err = config.Dialector.Initialize(db)
} if err != nil {
if db, _ := db.DB(); db != nil {
_ = db.Close()
}
preparedStmt := &PreparedStmtDB{ // DB is not initialized, so we skip AfterInitialize
ConnPool: db.ConnPool, skipAfterInitialize = true
Stmts: map[string]Stmt{}, return
Mux: &sync.RWMutex{}, }
PreparedSQL: make([]string, 0, 100),
if config.TranslateError {
if _, ok := db.Dialector.(ErrorTranslator); !ok {
config.Logger.Warn(context.Background(), "The TranslateError option is enabled, but the Dialector %s does not implement ErrorTranslator.", db.Dialector.Name())
}
}
} }
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
if config.PrepareStmt { if config.PrepareStmt {
preparedStmt := NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
db.ConnPool = preparedStmt db.ConnPool = preparedStmt
} }
@ -236,6 +272,10 @@ func (db *DB) Session(config *Session) *DB {
txConfig.FullSaveAssociations = true txConfig.FullSaveAssociations = true
} }
if config.PropagateUnscoped {
txConfig.PropagateUnscoped = true
}
if config.Context != nil || config.PrepareStmt || config.SkipHooks { if config.Context != nil || config.PrepareStmt || config.SkipHooks {
tx.Statement = tx.Statement.clone() tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx tx.Statement.DB = tx
@ -246,16 +286,30 @@ func (db *DB) Session(config *Session) *DB {
} }
if config.PrepareStmt { if config.PrepareStmt {
var preparedStmt *PreparedStmtDB
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
preparedStmt := v.(*PreparedStmtDB) preparedStmt = v.(*PreparedStmtDB)
} else {
preparedStmt = NewPreparedStmtDB(db.ConnPool, db.PrepareStmtMaxSize, db.PrepareStmtTTL)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
}
switch t := tx.Statement.ConnPool.(type) {
case Tx:
tx.Statement.ConnPool = &PreparedStmtTX{
Tx: t,
PreparedStmtDB: preparedStmt,
}
default:
tx.Statement.ConnPool = &PreparedStmtDB{ tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool, ConnPool: db.Config.ConnPool,
Mux: preparedStmt.Mux, Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts, Stmts: preparedStmt.Stmts,
} }
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
} }
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
} }
if config.SkipHooks { if config.SkipHooks {
@ -300,7 +354,8 @@ func (db *DB) WithContext(ctx context.Context) *DB {
// Debug start debug mode // Debug start debug mode
func (db *DB) Debug() (tx *DB) { func (db *DB) Debug() (tx *DB) {
return db.Session(&Session{ tx = db.getInstance()
return tx.Session(&Session{
Logger: db.Logger.LogMode(logger.Info), Logger: db.Logger.LogMode(logger.Info),
}) })
} }
@ -336,10 +391,18 @@ func (db *DB) Callback() *callbacks {
// AddError add error to db // AddError add error to db
func (db *DB) AddError(err error) error { func (db *DB) AddError(err error) error {
if db.Error == nil { if err != nil {
db.Error = err if db.Config.TranslateError {
} else if err != nil { if errTranslator, ok := db.Dialector.(ErrorTranslator); ok {
db.Error = fmt.Errorf("%v; %w", db.Error, err) err = errTranslator.Translate(err)
}
}
if db.Error == nil {
db.Error = err
} else {
db.Error = fmt.Errorf("%v; %w", db.Error, err)
}
} }
return db.Error return db.Error
} }
@ -347,12 +410,20 @@ func (db *DB) AddError(err error) error {
// DB returns `*sql.DB` // DB returns `*sql.DB`
func (db *DB) DB() (*sql.DB, error) { func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool connPool := db.ConnPool
if db.Statement != nil && db.Statement.ConnPool != nil {
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { connPool = db.Statement.ConnPool
return dbConnector.GetDBConn() }
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
} }
if sqldb, ok := connPool.(*sql.DB); ok { if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil {
return sqldb, err
}
}
if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil {
return sqldb, nil return sqldb, nil
} }
@ -366,11 +437,15 @@ func (db *DB) getInstance() *DB {
if db.clone == 1 { if db.clone == 1 {
// clone with new statement // clone with new statement
tx.Statement = &Statement{ tx.Statement = &Statement{
DB: tx, DB: tx,
ConnPool: db.Statement.ConnPool, ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context, Context: db.Statement.Context,
Clauses: map[string]clause.Clause{}, Clauses: map[string]clause.Clause{},
Vars: make([]interface{}, 0, 8), Vars: make([]interface{}, 0, 8),
SkipHooks: db.Statement.SkipHooks,
}
if db.Config.PropagateUnscoped {
tx.Statement.Unscoped = db.Statement.Unscoped
} }
} else { } else {
// with clone statement // with clone statement
@ -412,7 +487,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
relation, ok := modelSchema.Relationships.Relations[field] relation, ok := modelSchema.Relationships.Relations[field]
isRelation := ok && relation.JoinTable != nil isRelation := ok && relation.JoinTable != nil
if !isRelation { if !isRelation {
return fmt.Errorf("failed to found relation: %s", field) return fmt.Errorf("failed to find relation: %s", field)
} }
for _, ref := range relation.References { for _, ref := range relation.References {
@ -455,14 +530,14 @@ func (db *DB) Use(plugin Plugin) error {
// ToSQL for generate SQL string. // ToSQL for generate SQL string.
// //
// db.ToSQL(func(tx *gorm.DB) *gorm.DB { // db.ToSQL(func(tx *gorm.DB) *gorm.DB {
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}) // return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
// .Limit(10).Offset(5) // .Limit(10).Offset(5)
// .Order("name ASC") // .Order("name ASC")
// .First(&User{}) // .First(&User{})
// }) // })
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance())
stmt := tx.Statement stmt := tx.Statement
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)

View File

@ -26,6 +26,10 @@ type Plugin interface {
Initialize(*DB) error Initialize(*DB) error
} }
type ParamsFilter interface {
ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{})
}
// ConnPool db conns pool interface // ConnPool db conns pool interface
type ConnPool interface { type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
@ -82,3 +86,7 @@ type Rows interface {
Err() error Err() error
Close() error Close() error
} }
type ErrorTranslator interface {
Translate(err error) error
}

493
internal/lru/lru.go Normal file
View File

@ -0,0 +1,493 @@
package lru
// golang -lru
// https://github.com/hashicorp/golang-lru
import (
"sync"
"time"
)
// EvictCallback is used to get a callback when a cache entry is evicted
type EvictCallback[K comparable, V any] func(key K, value V)
// LRU implements a thread-safe LRU with expirable entries.
type LRU[K comparable, V any] struct {
size int
evictList *LruList[K, V]
items map[K]*Entry[K, V]
onEvict EvictCallback[K, V]
// expirable options
mu sync.Mutex
ttl time.Duration
done chan struct{}
// buckets for expiration
buckets []bucket[K, V]
// uint8 because it's number between 0 and numBuckets
nextCleanupBucket uint8
}
// bucket is a container for holding entries to be expired
type bucket[K comparable, V any] struct {
entries map[K]*Entry[K, V]
newestEntry time.Time
}
// noEvictionTTL - very long ttl to prevent eviction
const noEvictionTTL = time.Hour * 24 * 365 * 10
// because of uint8 usage for nextCleanupBucket, should not exceed 256.
// casting it as uint8 explicitly requires type conversions in multiple places
const numBuckets = 100
// NewLRU returns a new thread-safe cache with expirable entries.
//
// Size parameter set to 0 makes cache of unlimited size, e.g. turns LRU mechanism off.
//
// Providing 0 TTL turns expiring off.
//
// Delete expired entries every 1/100th of ttl value. Goroutine which deletes expired entries runs indefinitely.
func NewLRU[K comparable, V any](size int, onEvict EvictCallback[K, V], ttl time.Duration) *LRU[K, V] {
if size < 0 {
size = 0
}
if ttl <= 0 {
ttl = noEvictionTTL
}
res := LRU[K, V]{
ttl: ttl,
size: size,
evictList: NewList[K, V](),
items: make(map[K]*Entry[K, V]),
onEvict: onEvict,
done: make(chan struct{}),
}
// initialize the buckets
res.buckets = make([]bucket[K, V], numBuckets)
for i := 0; i < numBuckets; i++ {
res.buckets[i] = bucket[K, V]{entries: make(map[K]*Entry[K, V])}
}
// enable deleteExpired() running in separate goroutine for cache with non-zero TTL
//
// Important: done channel is never closed, so deleteExpired() goroutine will never exit,
// it's decided to add functionality to close it in the version later than v2.
if res.ttl != noEvictionTTL {
go func(done <-chan struct{}) {
ticker := time.NewTicker(res.ttl / numBuckets)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
res.deleteExpired()
}
}
}(res.done)
}
return &res
}
// Purge clears the cache completely.
// onEvict is called for each evicted key.
func (c *LRU[K, V]) Purge() {
c.mu.Lock()
defer c.mu.Unlock()
for k, v := range c.items {
if c.onEvict != nil {
c.onEvict(k, v.Value)
}
delete(c.items, k)
}
for _, b := range c.buckets {
for _, ent := range b.entries {
delete(b.entries, ent.Key)
}
}
c.evictList.Init()
}
// Add adds a value to the cache. Returns true if an eviction occurred.
// Returns false if there was no eviction: the item was already in the cache,
// or the size was not exceeded.
func (c *LRU[K, V]) Add(key K, value V) (evicted bool) {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
// Check for existing item
if ent, ok := c.items[key]; ok {
c.evictList.MoveToFront(ent)
c.removeFromBucket(ent) // remove the entry from its current bucket as expiresAt is renewed
ent.Value = value
ent.ExpiresAt = now.Add(c.ttl)
c.addToBucket(ent)
return false
}
// Add new item
ent := c.evictList.PushFrontExpirable(key, value, now.Add(c.ttl))
c.items[key] = ent
c.addToBucket(ent) // adds the entry to the appropriate bucket and sets entry.expireBucket
evict := c.size > 0 && c.evictList.Length() > c.size
// Verify size not exceeded
if evict {
c.removeOldest()
}
return evict
}
// Get looks up a key's value from the cache.
func (c *LRU[K, V]) Get(key K) (value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
var ent *Entry[K, V]
if ent, ok = c.items[key]; ok {
// Expired item check
if time.Now().After(ent.ExpiresAt) {
return value, false
}
c.evictList.MoveToFront(ent)
return ent.Value, true
}
return
}
// Contains checks if a key is in the cache, without updating the recent-ness
// or deleting it for being stale.
func (c *LRU[K, V]) Contains(key K) (ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
_, ok = c.items[key]
return ok
}
// Peek returns the key value (or undefined if not found) without updating
// the "recently used"-ness of the key.
func (c *LRU[K, V]) Peek(key K) (value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
var ent *Entry[K, V]
if ent, ok = c.items[key]; ok {
// Expired item check
if time.Now().After(ent.ExpiresAt) {
return value, false
}
return ent.Value, true
}
return
}
// Remove removes the provided key from the cache, returning if the
// key was contained.
func (c *LRU[K, V]) Remove(key K) bool {
c.mu.Lock()
defer c.mu.Unlock()
if ent, ok := c.items[key]; ok {
c.removeElement(ent)
return true
}
return false
}
// RemoveOldest removes the oldest item from the cache.
func (c *LRU[K, V]) RemoveOldest() (key K, value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.evictList.Back(); ent != nil {
c.removeElement(ent)
return ent.Key, ent.Value, true
}
return
}
// GetOldest returns the oldest entry
func (c *LRU[K, V]) GetOldest() (key K, value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.evictList.Back(); ent != nil {
return ent.Key, ent.Value, true
}
return
}
func (c *LRU[K, V]) KeyValues() map[K]V {
c.mu.Lock()
defer c.mu.Unlock()
maps := make(map[K]V)
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
maps[ent.Key] = ent.Value
// keys = append(keys, ent.Key)
}
return maps
}
// Keys returns a slice of the keys in the cache, from oldest to newest.
// Expired entries are filtered out.
func (c *LRU[K, V]) Keys() []K {
c.mu.Lock()
defer c.mu.Unlock()
keys := make([]K, 0, len(c.items))
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
keys = append(keys, ent.Key)
}
return keys
}
// Values returns a slice of the values in the cache, from oldest to newest.
// Expired entries are filtered out.
func (c *LRU[K, V]) Values() []V {
c.mu.Lock()
defer c.mu.Unlock()
values := make([]V, 0, len(c.items))
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
values = append(values, ent.Value)
}
return values
}
// Len returns the number of items in the cache.
func (c *LRU[K, V]) Len() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.evictList.Length()
}
// Resize changes the cache size. Size of 0 means unlimited.
func (c *LRU[K, V]) Resize(size int) (evicted int) {
c.mu.Lock()
defer c.mu.Unlock()
if size <= 0 {
c.size = 0
return 0
}
diff := c.evictList.Length() - size
if diff < 0 {
diff = 0
}
for i := 0; i < diff; i++ {
c.removeOldest()
}
c.size = size
return diff
}
// Close destroys cleanup goroutine. To clean up the cache, run Purge() before Close().
// func (c *LRU[K, V]) Close() {
// c.mu.Lock()
// defer c.mu.Unlock()
// select {
// case <-c.done:
// return
// default:
// }
// close(c.done)
// }
// removeOldest removes the oldest item from the cache. Has to be called with lock!
func (c *LRU[K, V]) removeOldest() {
if ent := c.evictList.Back(); ent != nil {
c.removeElement(ent)
}
}
// removeElement is used to remove a given list element from the cache. Has to be called with lock!
func (c *LRU[K, V]) removeElement(e *Entry[K, V]) {
c.evictList.Remove(e)
delete(c.items, e.Key)
c.removeFromBucket(e)
if c.onEvict != nil {
c.onEvict(e.Key, e.Value)
}
}
// deleteExpired deletes expired records from the oldest bucket, waiting for the newest entry
// in it to expire first.
func (c *LRU[K, V]) deleteExpired() {
c.mu.Lock()
bucketIdx := c.nextCleanupBucket
timeToExpire := time.Until(c.buckets[bucketIdx].newestEntry)
// wait for newest entry to expire before cleanup without holding lock
if timeToExpire > 0 {
c.mu.Unlock()
time.Sleep(timeToExpire)
c.mu.Lock()
}
for _, ent := range c.buckets[bucketIdx].entries {
c.removeElement(ent)
}
c.nextCleanupBucket = (c.nextCleanupBucket + 1) % numBuckets
c.mu.Unlock()
}
// addToBucket adds entry to expire bucket so that it will be cleaned up when the time comes. Has to be called with lock!
func (c *LRU[K, V]) addToBucket(e *Entry[K, V]) {
bucketID := (numBuckets + c.nextCleanupBucket - 1) % numBuckets
e.ExpireBucket = bucketID
c.buckets[bucketID].entries[e.Key] = e
if c.buckets[bucketID].newestEntry.Before(e.ExpiresAt) {
c.buckets[bucketID].newestEntry = e.ExpiresAt
}
}
// removeFromBucket removes the entry from its corresponding bucket. Has to be called with lock!
func (c *LRU[K, V]) removeFromBucket(e *Entry[K, V]) {
delete(c.buckets[e.ExpireBucket].entries, e.Key)
}
// Cap returns the capacity of the cache
func (c *LRU[K, V]) Cap() int {
return c.size
}
// Entry is an LRU Entry
type Entry[K comparable, V any] struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *Entry[K, V]
// The list to which this element belongs.
list *LruList[K, V]
// The LRU Key of this element.
Key K
// The Value stored with this element.
Value V
// The time this element would be cleaned up, optional
ExpiresAt time.Time
// The expiry bucket item was put in, optional
ExpireBucket uint8
}
// PrevEntry returns the previous list element or nil.
func (e *Entry[K, V]) PrevEntry() *Entry[K, V] {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// LruList represents a doubly linked list.
// The zero Value for LruList is an empty list ready to use.
type LruList[K comparable, V any] struct {
root Entry[K, V] // sentinel list element, only &root, root.prev, and root.next are used
len int // current list Length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *LruList[K, V]) Init() *LruList[K, V] {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// NewList returns an initialized list.
func NewList[K comparable, V any]() *LruList[K, V] { return new(LruList[K, V]).Init() }
// Length returns the number of elements of list l.
// The complexity is O(1).
func (l *LruList[K, V]) Length() int { return l.len }
// Back returns the last element of list l or nil if the list is empty.
func (l *LruList[K, V]) Back() *Entry[K, V] {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List Value.
func (l *LruList[K, V]) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *LruList[K, V]) insert(e, at *Entry[K, V]) *Entry[K, V] {
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Entry{Value: v, ExpiresAt: ExpiresAt}, at).
func (l *LruList[K, V]) insertValue(k K, v V, expiresAt time.Time, at *Entry[K, V]) *Entry[K, V] {
return l.insert(&Entry[K, V]{Value: v, Key: k, ExpiresAt: expiresAt}, at)
}
// Remove removes e from its list, decrements l.len
func (l *LruList[K, V]) Remove(e *Entry[K, V]) V {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
return e.Value
}
// move moves e to next to at.
func (l *LruList[K, V]) move(e, at *Entry[K, V]) {
if e == at {
return
}
e.prev.next = e.next
e.next.prev = e.prev
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *LruList[K, V]) PushFront(k K, v V) *Entry[K, V] {
l.lazyInit()
return l.insertValue(k, v, time.Time{}, &l.root)
}
// PushFrontExpirable inserts a new expirable element e with Value v at the front of list l and returns e.
func (l *LruList[K, V]) PushFrontExpirable(k K, v V, expiresAt time.Time) *Entry[K, V] {
l.lazyInit()
return l.insertValue(k, v, expiresAt, &l.root)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *LruList[K, V]) MoveToFront(e *Entry[K, V]) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.move(e, &l.root)
}

View File

@ -0,0 +1,183 @@
package stmt_store
import (
"context"
"database/sql"
"math"
"sync"
"time"
"gorm.io/gorm/internal/lru"
)
type Stmt struct {
*sql.Stmt
Transaction bool
prepared chan struct{}
prepareErr error
}
func (stmt *Stmt) Error() error {
return stmt.prepareErr
}
func (stmt *Stmt) Close() error {
<-stmt.prepared
if stmt.Stmt != nil {
return stmt.Stmt.Close()
}
return nil
}
// Store defines an interface for managing the caching operations of SQL statements (Stmt).
// This interface provides methods for creating new statements, retrieving all cache keys,
// getting cached statements, setting cached statements, and deleting cached statements.
type Store interface {
// New creates a new Stmt object and caches it.
// Parameters:
// ctx: The context for the request, which can carry deadlines, cancellation signals, etc.
// key: The key representing the SQL query, used for caching and preparing the statement.
// isTransaction: Indicates whether this operation is part of a transaction, which may affect the caching strategy.
// connPool: A connection pool that provides database connections.
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
// Returns:
// *Stmt: A newly created statement object for executing SQL operations.
// error: An error if the statement preparation fails.
New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error)
// Keys returns a slice of all cache keys in the store.
Keys() []string
// Get retrieves a Stmt object from the store based on the given key.
// Parameters:
// key: The key used to look up the Stmt object.
// Returns:
// *Stmt: The found Stmt object, or nil if not found.
// bool: Indicates whether the corresponding Stmt object was successfully found.
Get(key string) (*Stmt, bool)
// Set stores the given Stmt object in the store and associates it with the specified key.
// Parameters:
// key: The key used to associate the Stmt object.
// value: The Stmt object to be stored.
Set(key string, value *Stmt)
// Delete removes the Stmt object corresponding to the specified key from the store.
// Parameters:
// key: The key associated with the Stmt object to be deleted.
Delete(key string)
}
// defaultMaxSize defines the default maximum capacity of the cache.
// Its value is the maximum value of the int64 type, which means that when the cache size is not specified,
// the cache can theoretically store as many elements as possible.
// (1 << 63) - 1 is the maximum value that an int64 type can represent.
const (
defaultMaxSize = math.MaxInt
// defaultTTL defines the default time-to-live (TTL) for each cache entry.
// When the TTL for cache entries is not specified, each cache entry will expire after 24 hours.
defaultTTL = time.Hour * 24
)
// New creates and returns a new Store instance.
//
// Parameters:
// - size: The maximum capacity of the cache. If the provided size is less than or equal to 0,
// it defaults to defaultMaxSize.
// - ttl: The time-to-live duration for each cache entry. If the provided ttl is less than or equal to 0,
// it defaults to defaultTTL.
//
// This function defines an onEvicted callback that is invoked when a cache entry is evicted.
// The callback ensures that if the evicted value (v) is not nil, its Close method is called asynchronously
// to release associated resources.
//
// Returns:
// - A Store instance implemented by lruStore, which internally uses an LRU cache with the specified size,
// eviction callback, and TTL.
func New(size int, ttl time.Duration) Store {
if size <= 0 {
size = defaultMaxSize
}
if ttl <= 0 {
ttl = defaultTTL
}
onEvicted := func(k string, v *Stmt) {
if v != nil {
go v.Close()
}
}
return &lruStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
}
type lruStore struct {
lru *lru.LRU[string, *Stmt]
}
func (s *lruStore) Keys() []string {
return s.lru.Keys()
}
func (s *lruStore) Get(key string) (*Stmt, bool) {
stmt, ok := s.lru.Get(key)
if ok && stmt != nil {
<-stmt.prepared
}
return stmt, ok
}
func (s *lruStore) Set(key string, value *Stmt) {
s.lru.Add(key, value)
}
func (s *lruStore) Delete(key string) {
s.lru.Remove(key)
}
type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
}
// New creates a new Stmt object for executing SQL queries.
// It caches the Stmt object for future use and handles preparation and error states.
// Parameters:
//
// ctx: Context for the request, used to carry deadlines, cancellation signals, etc.
// key: The key representing the SQL query, used for caching and preparing the statement.
// isTransaction: Indicates whether this operation is part of a transaction, affecting cache strategy.
// conn: A connection pool that provides database connections.
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
//
// Returns:
//
// *Stmt: A newly created statement object for executing SQL operations.
// error: An error if the statement preparation fails.
func (s *lruStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) {
// Create a Stmt object and set its Transaction property.
// The prepared channel is used to synchronize the statement preparation state.
cacheStmt := &Stmt{
Transaction: isTransaction,
prepared: make(chan struct{}),
}
// Cache the Stmt object with the associated key.
s.Set(key, cacheStmt)
// Unlock after completing initialization to prevent deadlocks.
locker.Unlock()
// Ensure the prepared channel is closed after the function execution completes.
defer close(cacheStmt.prepared)
// Prepare the SQL statement using the provided connection.
cacheStmt.Stmt, err = conn.PrepareContext(ctx, key)
if err != nil {
// If statement preparation fails, record the error and remove the invalid Stmt object from the cache.
cacheStmt.prepareErr = err
s.Delete(key)
return &Stmt{}, err
}
// Return the successfully prepared Stmt object.
return cacheStmt, nil
}

View File

@ -4,7 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io"
"log" "log"
"os" "os"
"time" "time"
@ -55,6 +55,7 @@ type Config struct {
SlowThreshold time.Duration SlowThreshold time.Duration
Colorful bool Colorful bool
IgnoreRecordNotFoundError bool IgnoreRecordNotFoundError bool
ParameterizedQueries bool
LogLevel LogLevel LogLevel LogLevel
} }
@ -68,8 +69,8 @@ type Interface interface {
} }
var ( var (
// Discard Discard logger will print any log to ioutil.Discard // Discard logger will print any log to io.Discard
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
// Default Default logger // Default Default logger
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
SlowThreshold: 200 * time.Millisecond, SlowThreshold: 200 * time.Millisecond,
@ -77,8 +78,13 @@ var (
IgnoreRecordNotFoundError: false, IgnoreRecordNotFoundError: false,
Colorful: true, Colorful: true,
}) })
// Recorder Recorder logger records running SQL into a recorder instance // Recorder logger records running SQL into a recorder instance
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
// RecorderParamsFilter defaults to no-op, allows to be run-over by a different implementation
RecorderParamsFilter = func(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
return sql, params
}
) )
// New initialize logger // New initialize logger
@ -128,28 +134,30 @@ func (l *logger) LogMode(level LogLevel) Interface {
} }
// Info print info // Info print info
func (l logger) Info(ctx context.Context, msg string, data ...interface{}) { func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Info { if l.LogLevel >= Info {
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
} }
} }
// Warn print warn messages // Warn print warn messages
func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) { func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Warn { if l.LogLevel >= Warn {
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
} }
} }
// Error print error messages // Error print error messages
func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Error { if l.LogLevel >= Error {
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
} }
} }
// Trace print sql message // Trace print sql message
func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { //
//nolint:cyclop
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel <= Silent { if l.LogLevel <= Silent {
return return
} }
@ -181,6 +189,14 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i
} }
} }
// ParamsFilter filter params
func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
if l.Config.ParameterizedQueries {
return sql, nil
}
return sql, params
}
type traceRecorder struct { type traceRecorder struct {
Interface Interface
BeginAt time.Time BeginAt time.Time
@ -189,8 +205,8 @@ type traceRecorder struct {
Err error Err error
} }
// New new trace recorder // New trace recorder
func (l traceRecorder) New() *traceRecorder { func (l *traceRecorder) New() *traceRecorder {
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
} }
@ -200,3 +216,10 @@ func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (s
l.SQL, l.RowsAffected = fc() l.SQL, l.RowsAffected = fc()
l.Err = err l.Err = err
} }
func (l *traceRecorder) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
if RecorderParamsFilter == nil {
return sql, params
}
return RecorderParamsFilter(ctx, sql, params...)
}

View File

@ -28,8 +28,25 @@ func isPrintable(s string) bool {
return true return true
} }
// A list of Go types that should be converted to SQL primitives
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
// RegEx matches only numeric values
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
func isNumeric(k reflect.Kind) bool {
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
case reflect.Float32, reflect.Float64:
return true
default:
return false
}
}
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var ( var (
@ -75,26 +92,28 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
case reflect.Bool: case reflect.Bool:
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
case reflect.String: case reflect.String:
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
default: default:
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
} else { } else {
vars[idx] = nullStr vars[idx] = nullStr
} }
} }
case []byte: case []byte:
if s := string(v); isPrintable(s) { if s := string(v); isPrintable(s) {
vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
} else { } else {
vars[idx] = escaper + "<binary>" + escaper vars[idx] = escaper + "<binary>" + escaper
} }
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
vars[idx] = utils.ToString(v) vars[idx] = utils.ToString(v)
case float64, float32: case float32:
vars[idx] = fmt.Sprintf("%.6f", v) vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32)
case float64:
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
case string: case string:
vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper
default: default:
rv := reflect.ValueOf(v) rv := reflect.ValueOf(v)
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
@ -104,6 +123,12 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
convertParams(v, idx) convertParams(v, idx)
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() { } else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
convertParams(reflect.Indirect(rv).Interface(), idx) convertParams(reflect.Indirect(rv).Interface(), idx)
} else if isNumeric(rv.Kind()) {
if rv.CanInt() || rv.CanUint() {
vars[idx] = fmt.Sprintf("%d", rv.Interface())
} else {
vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
}
} else { } else {
for _, t := range convertibleTypes { for _, t := range convertibleTypes {
if rv.Type().ConvertibleTo(t) { if rv.Type().ConvertibleTo(t) {
@ -111,7 +136,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
return return
} }
} }
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
} }
} }
} }
@ -138,9 +163,18 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
sql = newSQL.String() sql = newSQL.String()
} else { } else {
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
for idx, v := range vars {
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1) sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string {
} num := v[1 : len(v)-1]
n, _ := strconv.Atoi(num)
// position var start from 1 ($1, $2)
n -= 1
if n >= 0 && n <= len(vars)-1 {
return vars[n]
}
return v
})
} }
return sql return sql

View File

@ -31,20 +31,24 @@ func (s ExampleStruct) Value() (driver.Value, error) {
} }
func format(v []byte, escaper string) string { func format(v []byte, escaper string) string {
return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper return escaper + strings.ReplaceAll(string(v), escaper, escaper+escaper) + escaper
} }
func TestExplainSQL(t *testing.T) { func TestExplainSQL(t *testing.T) {
type role string type role string
type password []byte type password []byte
type intType int
type floatType float64
var ( var (
tt = now.MustParse("2020-02-23 11:10:10") tt = now.MustParse("2020-02-23 11:10:10")
myrole = role("admin") myrole = role("admin")
pwd = password([]byte("pass")) pwd = password("pass")
jsVal = []byte(`{"Name":"test","Val":"test"}`) jsVal = []byte(`{"Name":"test","Val":"test"}`)
js = JSON(jsVal) js = JSON(jsVal)
esVal = []byte(`{"Name":"test","Val":"test"}`) esVal = []byte(`{"Name":"test","Val":"test"}`)
es = ExampleStruct{Name: "test", Val: "test"} es = ExampleStruct{Name: "test", Val: "test"}
intVal intType = 1
floatVal floatType = 1.23
) )
results := []struct { results := []struct {
@ -57,43 +61,67 @@ func TestExplainSQL(t *testing.T) {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil, NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`,
}, },
{ {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil, NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`,
}, },
{ {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
NumericRegexp: regexp.MustCompile(`@p(\d+)`), NumericRegexp: regexp.MustCompile(`@p(\d+)`),
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
}, },
{ {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)",
NumericRegexp: regexp.MustCompile(`\$(\d+)`), NumericRegexp: regexp.MustCompile(`\$(\d+)`),
Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd}, Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
}, },
{ {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)",
NumericRegexp: regexp.MustCompile(`@p(\d+)`), NumericRegexp: regexp.MustCompile(`@p(\d+)`),
Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
}, },
{ {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil, NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es}, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
}, },
{ {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil, NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`,
}, },
} }

View File

@ -13,11 +13,7 @@ func (db *DB) Migrator() Migrator {
// apply scopes to migrator // apply scopes to migrator
for len(tx.Statement.scopes) > 0 { for len(tx.Statement.scopes) > 0 {
scopes := tx.Statement.scopes tx = tx.executeScopes()
tx.Statement.scopes = nil
for _, scope := range scopes {
tx = scope(tx)
}
} }
return tx.Dialector.Migrator(tx.Session(&Session{})) return tx.Dialector.Migrator(tx.Session(&Session{}))
@ -30,9 +26,9 @@ func (db *DB) AutoMigrate(dst ...interface{}) error {
// ViewOption view option // ViewOption view option
type ViewOption struct { type ViewOption struct {
Replace bool Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE`
CheckOption string CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION`
Query *DB Query *DB // required subquery.
} }
// ColumnType column type interface // ColumnType column type interface
@ -51,6 +47,23 @@ type ColumnType interface {
DefaultValue() (value string, ok bool) DefaultValue() (value string, ok bool)
} }
type Index interface {
Table() string
Name() string
Columns() []string
PrimaryKey() (isPrimaryKey bool, ok bool)
Unique() (unique bool, ok bool)
Option() string
}
// TableType table type interface
type TableType interface {
Schema() string
Name() string
Type() string
Comment() (comment string, ok bool)
}
// Migrator migrator interface // Migrator migrator interface
type Migrator interface { type Migrator interface {
// AutoMigrate // AutoMigrate
@ -59,6 +72,7 @@ type Migrator interface {
// Database // Database
CurrentDatabase() string CurrentDatabase() string
FullDataTypeOf(*schema.Field) clause.Expr FullDataTypeOf(*schema.Field) clause.Expr
GetTypeAliases(databaseTypeName string) []string
// Tables // Tables
CreateTable(dst ...interface{}) error CreateTable(dst ...interface{}) error
@ -66,12 +80,15 @@ type Migrator interface {
HasTable(dst interface{}) bool HasTable(dst interface{}) bool
RenameTable(oldName, newName interface{}) error RenameTable(oldName, newName interface{}) error
GetTables() (tableList []string, err error) GetTables() (tableList []string, err error)
TableType(dst interface{}) (TableType, error)
// Columns // Columns
AddColumn(dst interface{}, field string) error AddColumn(dst interface{}, field string) error
DropColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error
AlterColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error
HasColumn(dst interface{}, field string) bool HasColumn(dst interface{}, field string) bool
RenameColumn(dst interface{}, oldName, field string) error RenameColumn(dst interface{}, oldName, field string) error
ColumnTypes(dst interface{}) ([]ColumnType, error) ColumnTypes(dst interface{}) ([]ColumnType, error)
@ -90,4 +107,5 @@ type Migrator interface {
DropIndex(dst interface{}, name string) error DropIndex(dst interface{}, name string) error
HasIndex(dst interface{}, name string) bool HasIndex(dst interface{}, name string) bool
RenameIndex(dst interface{}, oldName, newName string) error RenameIndex(dst interface{}, oldName, newName string) error
GetIndexes(dst interface{}) ([]Index, error)
} }

43
migrator/index.go Normal file
View File

@ -0,0 +1,43 @@
package migrator
import "database/sql"
// Index implements gorm.Index interface
type Index struct {
TableName string
NameValue string
ColumnList []string
PrimaryKeyValue sql.NullBool
UniqueValue sql.NullBool
OptionValue string
}
// Table return the table name of the index.
func (idx Index) Table() string {
return idx.TableName
}
// Name return the name of the index.
func (idx Index) Name() string {
return idx.NameValue
}
// Columns return the columns of the index
func (idx Index) Columns() []string {
return idx.ColumnList
}
// PrimaryKey returns the index is primary key or not.
func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) {
return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid
}
// Unique returns whether the index is unique or not.
func (idx Index) Unique() (unique bool, ok bool) {
return idx.UniqueValue.Bool, idx.UniqueValue.Valid
}
// Option return the optional attribute of the index
func (idx Index) Option() string {
return idx.OptionValue
}

View File

@ -3,20 +3,32 @@ package migrator
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
"strconv"
"strings" "strings"
"time"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
var ( // This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*),
regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`) // with a possible trailing non-digit character (\D?).
regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`)
) // For example, values that can pass this regular expression are:
// - "123"
// - "abc456"
// -"%$#@789"
var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
// TODO:? Create const vars for raw sql queries ?
var _ gorm.Migrator = (*Migrator)(nil)
// Migrator m struct // Migrator m struct
type Migrator struct { type Migrator struct {
@ -30,6 +42,16 @@ type Config struct {
gorm.Dialector gorm.Dialector
} }
type printSQLLogger struct {
logger.Interface
}
func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
fmt.Println(sql + ";")
l.Interface.Trace(ctx, begin, fc, err)
}
// GormDataTypeInterface gorm data type interface // GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface { type GormDataTypeInterface interface {
GormDBDataType(*gorm.DB, *schema.Field) string GormDBDataType(*gorm.DB, *schema.Field) string
@ -72,10 +94,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
expr.SQL += " NOT NULL" expr.SQL += " NOT NULL"
} }
if field.Unique {
expr.SQL += " UNIQUE"
}
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
if field.DefaultValueInterface != nil { if field.DefaultValueInterface != nil {
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
@ -89,20 +107,40 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
return return
} }
func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) {
queryTx = m.DB.Session(&gorm.Session{})
execTx = queryTx
if m.DB.DryRun {
queryTx.DryRun = false
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
}
return queryTx, execTx
}
// AutoMigrate auto migrate values // AutoMigrate auto migrate values
func (m Migrator) AutoMigrate(values ...interface{}) error { func (m Migrator) AutoMigrate(values ...interface{}) error {
for _, value := range m.ReorderModels(values, true) { for _, value := range m.ReorderModels(values, true) {
tx := m.DB.Session(&gorm.Session{}) queryTx, execTx := m.GetQueryAndExecTx()
if !tx.Migrator().HasTable(value) { if !queryTx.Migrator().HasTable(value) {
if err := tx.Migrator().CreateTable(value); err != nil { if err := execTx.Migrator().CreateTable(value); err != nil {
return err return err
} }
} else { } else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
columnTypes, err := queryTx.Migrator().ColumnTypes(value)
if err != nil {
return err
}
var (
parseIndexes = stmt.Schema.ParseIndexes()
parseCheckConstraints = stmt.Schema.ParseCheckConstraints()
)
for _, dbName := range stmt.Schema.DBNames { for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName]
var foundColumn gorm.ColumnType var foundColumn gorm.ColumnType
for _, columnType := range columnTypes { for _, columnType := range columnTypes {
@ -114,37 +152,43 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
if foundColumn == nil { if foundColumn == nil {
// not found, add column // not found, add column
if err := tx.Migrator().AddColumn(value, dbName); err != nil { if err = execTx.Migrator().AddColumn(value, dbName); err != nil {
return err
}
} else {
// found, smartly migrate
field := stmt.Schema.FieldsByDBName[dbName]
if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
return err return err
} }
} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
// found, smart migrate
return err
} }
} }
for _, rel := range stmt.Schema.Relationships.Relations { if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { for _, rel := range stmt.Schema.Relationships.Relations {
if rel.Field.IgnoreMigration {
continue
}
if constraint := rel.ParseConstraint(); constraint != nil && if constraint := rel.ParseConstraint(); constraint != nil &&
constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) { constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) {
if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
return err
}
}
}
for _, chk := range stmt.Schema.ParseCheckConstraints() {
if !tx.Migrator().HasConstraint(value, chk.Name) {
if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil {
return err return err
} }
} }
} }
} }
for _, idx := range stmt.Schema.ParseIndexes() { for _, chk := range parseCheckConstraints {
if !tx.Migrator().HasIndex(value, idx.Name) { if !queryTx.Migrator().HasConstraint(value, chk.Name) {
if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil { if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil {
return err
}
}
}
for _, idx := range parseIndexes {
if !queryTx.Migrator().HasIndex(value, idx.Name) {
if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil {
return err return err
} }
} }
@ -171,7 +215,12 @@ func (m Migrator) GetTables() (tableList []string, err error) {
func (m Migrator) CreateTable(values ...interface{}) error { func (m Migrator) CreateTable(values ...interface{}) error {
for _, value := range m.ReorderModels(values, false) { for _, value := range m.ReorderModels(values, false) {
tx := m.DB.Session(&gorm.Session{}) tx := m.DB.Session(&gorm.Session{})
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
var ( var (
createTableSQL = "CREATE TABLE ? (" createTableSQL = "CREATE TABLE ? ("
values = []interface{}{m.CurrentTable(stmt)} values = []interface{}{m.CurrentTable(stmt)}
@ -182,7 +231,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
field := stmt.Schema.FieldsByDBName[dbName] field := stmt.Schema.FieldsByDBName[dbName]
if !field.IgnoreMigration { if !field.IgnoreMigration {
createTableSQL += "? ?" createTableSQL += "? ?"
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY")
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
createTableSQL += "," createTableSQL += ","
} }
@ -190,7 +239,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
createTableSQL += "PRIMARY KEY ?," createTableSQL += "PRIMARY KEY ?,"
primaryKeys := []interface{}{} primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields))
for _, field := range stmt.Schema.PrimaryFields { for _, field := range stmt.Schema.PrimaryFields {
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
} }
@ -201,8 +250,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
for _, idx := range stmt.Schema.ParseIndexes() { for _, idx := range stmt.Schema.ParseIndexes() {
if m.CreateIndexAfterCreateTable { if m.CreateIndexAfterCreateTable {
defer func(value interface{}, name string) { defer func(value interface{}, name string) {
if errr == nil { if err == nil {
errr = tx.Migrator().CreateIndex(value, name) err = tx.Migrator().CreateIndex(value, name)
} }
}(value, idx.Name) }(value, idx.Name)
} else { } else {
@ -220,15 +269,18 @@ func (m Migrator) CreateTable(values ...interface{}) error {
} }
createTableSQL += "," createTableSQL += ","
values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) values = append(values, clause.Column{Name: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
} }
} }
for _, rel := range stmt.Schema.Relationships.Relations { if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
if !m.DB.DisableForeignKeyConstraintWhenMigrating { for _, rel := range stmt.Schema.Relationships.Relations {
if rel.Field.IgnoreMigration {
continue
}
if constraint := rel.ParseConstraint(); constraint != nil { if constraint := rel.ParseConstraint(); constraint != nil {
if constraint.Schema == stmt.Schema { if constraint.Schema == stmt.Schema {
sql, vars := buildConstraint(constraint) sql, vars := constraint.Build()
createTableSQL += sql + "," createTableSQL += sql + ","
values = append(values, vars...) values = append(values, vars...)
} }
@ -236,6 +288,11 @@ func (m Migrator) CreateTable(values ...interface{}) error {
} }
} }
for _, uni := range stmt.Schema.ParseUniqueConstraints() {
createTableSQL += "CONSTRAINT ? UNIQUE (?),"
values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)})
}
for _, chk := range stmt.Schema.ParseCheckConstraints() { for _, chk := range stmt.Schema.ParseCheckConstraints() {
createTableSQL += "CONSTRAINT ? CHECK (?)," createTableSQL += "CONSTRAINT ? CHECK (?),"
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
@ -249,8 +306,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
createTableSQL += fmt.Sprint(tableOption) createTableSQL += fmt.Sprint(tableOption)
} }
errr = tx.Exec(createTableSQL, values...).Error err = tx.Exec(createTableSQL, values...).Error
return errr return err
}); err != nil { }); err != nil {
return err return err
} }
@ -316,6 +373,9 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
func (m Migrator) AddColumn(value interface{}, name string) error { func (m Migrator) AddColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// avoid using the same name field // avoid using the same name field
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
f := stmt.Schema.LookUpField(name) f := stmt.Schema.LookUpField(name)
if f == nil { if f == nil {
return fmt.Errorf("failed to look up field with name: %s", name) return fmt.Errorf("failed to look up field with name: %s", name)
@ -335,8 +395,10 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
// DropColumn drop value's `name` column // DropColumn drop value's `name` column
func (m Migrator) DropColumn(value interface{}, name string) error { func (m Migrator) DropColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(name); field != nil { if stmt.Schema != nil {
name = field.DBName if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
} }
return m.DB.Exec( return m.DB.Exec(
@ -348,13 +410,15 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
// AlterColumn alter value's `field` column' type based on schema definition // AlterColumn alter value's `field` column' type based on schema definition
func (m Migrator) AlterColumn(value interface{}, field string) error { func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil { if stmt.Schema != nil {
fileType := m.FullDataTypeOf(field) if field := stmt.Schema.LookUpField(field); field != nil {
return m.DB.Exec( fileType := m.FullDataTypeOf(field)
"ALTER TABLE ? ALTER COLUMN ? TYPE ?", return m.DB.Exec(
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, "ALTER TABLE ? ALTER COLUMN ? TYPE ?",
).Error m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
).Error
}
} }
return fmt.Errorf("failed to look up field with name: %s", field) return fmt.Errorf("failed to look up field with name: %s", field)
}) })
@ -366,8 +430,10 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase() currentDatabase := m.DB.Migrator().CurrentDatabase()
name := field name := field
if field := stmt.Schema.LookUpField(field); field != nil { if stmt.Schema != nil {
name = field.DBName if field := stmt.Schema.LookUpField(field); field != nil {
name = field.DBName
}
} }
return m.DB.Raw( return m.DB.Raw(
@ -382,12 +448,14 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
// RenameColumn rename value's field name from oldName to newName // RenameColumn rename value's field name from oldName to newName
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(oldName); field != nil { if stmt.Schema != nil {
oldName = field.DBName if field := stmt.Schema.LookUpField(oldName); field != nil {
} oldName = field.DBName
}
if field := stmt.Schema.LookUpField(newName); field != nil { if field := stmt.Schema.LookUpField(newName); field != nil {
newName = field.DBName newName = field.DBName
}
} }
return m.DB.Exec( return m.DB.Exec(
@ -399,56 +467,102 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
// MigrateColumn migrate column // MigrateColumn migrate column
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
if field.IgnoreMigration {
return nil
}
// found, smart migrate // found, smart migrate
fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
realDataType := strings.ToLower(columnType.DatabaseTypeName()) realDataType := strings.ToLower(columnType.DatabaseTypeName())
var (
alterColumn bool
isSameType = fullDataType == realDataType
)
alterColumn := false if !field.PrimaryKey {
// check type
if !strings.HasPrefix(fullDataType, realDataType) {
// check type aliases
aliases := m.DB.Migrator().GetTypeAliases(realDataType)
for _, alias := range aliases {
if strings.HasPrefix(fullDataType, alias) {
isSameType = true
break
}
}
// check size if !isSameType {
if length, ok := columnType.Length(); length != int64(field.Size) {
if length > 0 && field.Size > 0 {
alterColumn = true
} else {
// has size in data type and not equal
// Since the following code is frequently called in the for loop, reg optimization is needed here
matches := regRealDataType.FindAllStringSubmatch(realDataType, -1)
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) &&
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) {
alterColumn = true alterColumn = true
} }
} }
} }
if !isSameType {
// check size
if length, ok := columnType.Length(); length != int64(field.Size) {
if length > 0 && field.Size > 0 {
alterColumn = true
} else {
// has size in data type and not equal
// Since the following code is frequently called in the for loop, reg optimization is needed here
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
if !field.PrimaryKey &&
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) {
alterColumn = true
}
}
}
}
// check precision // check precision
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { if realDataType == "decimal" || realDataType == "numeric" &&
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { regexp.MustCompile(realDataType+`\(.*\)`).FindString(fullDataType) != "" { // if realDataType has no precision,ignore
alterColumn = true precision, scale, ok := columnType.DecimalSize()
if ok {
if !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d,%d)", realDataType, precision, scale)) &&
!strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d)", realDataType, precision)) {
alterColumn = true
}
}
} else {
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
alterColumn = true
}
} }
} }
// check nullable // check nullable
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
// not primary key & database is nullable // not primary key & current database is non-nullable(to be nullable)
if !field.PrimaryKey && nullable { if !field.PrimaryKey && !nullable {
alterColumn = true
}
}
// check unique
if unique, ok := columnType.Unique(); ok && unique != field.Unique {
// not primary key
if !field.PrimaryKey {
alterColumn = true alterColumn = true
} }
} }
// check default value // check default value
if v, ok := columnType.DefaultValue(); ok && v != field.DefaultValue { if !field.PrimaryKey {
// not primary key currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
if !field.PrimaryKey { dv, dvNotNull := columnType.DefaultValue()
if dvNotNull && !currentDefaultNotNull {
// default value -> null
alterColumn = true alterColumn = true
} else if !dvNotNull && currentDefaultNotNull {
// null -> default value
alterColumn = true
} else if currentDefaultNotNull || dvNotNull {
switch field.GORMDataType {
case schema.Time:
if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) {
alterColumn = true
}
case schema.Bool:
v1, _ := strconv.ParseBool(dv)
v2, _ := strconv.ParseBool(field.DefaultValue)
alterColumn = v1 != v2
default:
alterColumn = dv != field.DefaultValue
}
} }
} }
@ -460,13 +574,39 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
} }
} }
if alterColumn && !field.IgnoreMigration { if alterColumn {
return m.DB.Migrator().AlterColumn(value, field.Name) if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil {
return err
}
}
if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil {
return err
} }
return nil return nil
} }
func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
unique, ok := columnType.Unique()
if !ok || field.PrimaryKey {
return nil // skip primary key
}
// By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex.
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// We're currently only receiving boolean values on `Unique` tag,
// so the UniqueConstraint name is fixed
constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
if unique && !field.Unique {
return m.DB.Migrator().DropConstraint(value, constraint)
}
if !unique && field.Unique {
return m.DB.Migrator().CreateConstraint(value, constraint)
}
return nil
})
}
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error // ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0) columnTypes := make([]gorm.ColumnType, 0)
@ -496,47 +636,76 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
return columnTypes, execErr return columnTypes, execErr
} }
// CreateView create view // CreateView create view from Query in gorm.ViewOption.
// Query in gorm.ViewOption is a [subquery]
//
// // CREATE VIEW `user_view` AS SELECT * FROM `users` WHERE age > 20
// q := DB.Model(&User{}).Where("age > ?", 20)
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q})
//
// // CREATE OR REPLACE VIEW `users_view` AS SELECT * FROM `users` WITH CHECK OPTION
// q := DB.Model(&User{})
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q, Replace: true, CheckOption: "WITH CHECK OPTION"})
//
// [subquery]: https://gorm.io/docs/advanced_query.html#SubQuery
func (m Migrator) CreateView(name string, option gorm.ViewOption) error { func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
return gorm.ErrNotImplemented if option.Query == nil {
return gorm.ErrSubQueryRequired
}
sql := new(strings.Builder)
sql.WriteString("CREATE ")
if option.Replace {
sql.WriteString("OR REPLACE ")
}
sql.WriteString("VIEW ")
m.QuoteTo(sql, name)
sql.WriteString(" AS ")
m.DB.Statement.AddVar(sql, option.Query)
if option.CheckOption != "" {
sql.WriteString(" ")
sql.WriteString(option.CheckOption)
}
return m.DB.Exec(m.Explain(sql.String(), m.DB.Statement.Vars...)).Error
} }
// DropView drop view // DropView drop view
func (m Migrator) DropView(name string) error { func (m Migrator) DropView(name string) error {
return gorm.ErrNotImplemented return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error
}
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
var foreignKeys, references []interface{}
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
} }
// GuessConstraintAndTable guess statement's constraint and it's table based on name // GuessConstraintAndTable guess statement's constraint and it's table based on name
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) { //
// Deprecated: use GuessConstraintInterfaceAndTable instead.
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
switch c := constraint.(type) {
case *schema.Constraint:
return c, nil, table
case *schema.CheckConstraint:
return nil, c, table
default:
return nil, nil, table
}
}
// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name
// nolint:cyclop
func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) {
if stmt.Schema == nil { if stmt.Schema == nil {
return nil, nil, stmt.Table return nil, stmt.Table
} }
checkConstraints := stmt.Schema.ParseCheckConstraints() checkConstraints := stmt.Schema.ParseCheckConstraints()
if chk, ok := checkConstraints[name]; ok { if chk, ok := checkConstraints[name]; ok {
return nil, &chk, stmt.Table return &chk, stmt.Table
}
uniqueConstraints := stmt.Schema.ParseUniqueConstraints()
if uni, ok := uniqueConstraints[name]; ok {
return &uni, stmt.Table
} }
getTable := func(rel *schema.Relationship) string { getTable := func(rel *schema.Relationship) string {
@ -551,7 +720,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_
for _, rel := range stmt.Schema.Relationships.Relations { for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
return constraint, nil, getTable(rel) return constraint, getTable(rel)
} }
} }
@ -559,40 +728,39 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_
for k := range checkConstraints { for k := range checkConstraints {
if checkConstraints[k].Field == field { if checkConstraints[k].Field == field {
v := checkConstraints[k] v := checkConstraints[k]
return nil, &v, stmt.Table return &v, stmt.Table
}
}
for k := range uniqueConstraints {
if uniqueConstraints[k].Field == field {
v := uniqueConstraints[k]
return &v, stmt.Table
} }
} }
for _, rel := range stmt.Schema.Relationships.Relations { for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field { if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field {
return constraint, nil, getTable(rel) return constraint, getTable(rel)
} }
} }
} }
return nil, nil, stmt.Schema.Table return nil, stmt.Schema.Table
} }
// CreateConstraint create constraint // CreateConstraint create constraint
func (m Migrator) CreateConstraint(value interface{}, name string) error { func (m Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name) constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if chk != nil {
return m.DB.Exec(
"ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)",
m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
).Error
}
if constraint != nil { if constraint != nil {
vars := []interface{}{clause.Table{Name: table}} vars := []interface{}{clause.Table{Name: table}}
if stmt.TableExpr != nil { if stmt.TableExpr != nil {
vars[0] = stmt.TableExpr vars[0] = stmt.TableExpr
} }
sql, values := buildConstraint(constraint) sql, values := constraint.Build()
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
} }
return nil return nil
}) })
} }
@ -600,11 +768,9 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
// DropConstraint drop constraint // DropConstraint drop constraint
func (m Migrator) DropConstraint(value interface{}, name string) error { func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name) constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil { if constraint != nil {
name = constraint.Name name = constraint.GetName()
} else if chk != nil {
name = chk.Name
} }
return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
}) })
@ -615,11 +781,9 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64 var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase() currentDatabase := m.DB.Migrator().CurrentDatabase()
constraint, chk, table := m.GuessConstraintAndTable(stmt, name) constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil { if constraint != nil {
name = constraint.Name name = constraint.GetName()
} else if chk != nil {
name = chk.Name
} }
return m.DB.Raw( return m.DB.Raw(
@ -661,6 +825,9 @@ type BuildIndexOptionsInterface interface {
// CreateIndex create index `name` // CreateIndex create index `name`
func (m Migrator) CreateIndex(value interface{}, name string) error { func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
if idx := stmt.Schema.LookIndex(name); idx != nil { if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
@ -693,8 +860,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
// DropIndex drop index `name` // DropIndex drop index `name`
func (m Migrator) DropIndex(value interface{}, name string) error { func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil { if stmt.Schema != nil {
name = idx.Name if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
} }
return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
@ -706,8 +875,10 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int64 var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase() currentDatabase := m.DB.Migrator().CurrentDatabase()
if idx := stmt.Schema.LookIndex(name); idx != nil { if stmt.Schema != nil {
name = idx.Name if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
} }
return m.DB.Raw( return m.DB.Raw(
@ -756,7 +927,8 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
Statement: &gorm.Statement{DB: m.DB, Dest: value}, Statement: &gorm.Statement{DB: m.DB, Dest: value},
} }
beDependedOn := map[*schema.Schema]bool{} beDependedOn := map[*schema.Schema]bool{}
if err := dep.Parse(value); err != nil { // support for special table name
if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil {
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
} }
if _, ok := parsedSchemas[dep.Statement.Schema]; ok { if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
@ -764,26 +936,31 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
} }
parsedSchemas[dep.Statement.Schema] = true parsedSchemas[dep.Statement.Schema] = true
for _, rel := range dep.Schema.Relationships.Relations { if !m.DB.IgnoreRelationshipsWhenMigrating {
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { for _, rel := range dep.Schema.Relationships.Relations {
dep.Depends = append(dep.Depends, c.ReferenceSchema) if rel.Field.IgnoreMigration {
} continue
}
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
dep.Depends = append(dep.Depends, c.ReferenceSchema)
}
if rel.Type == schema.HasOne || rel.Type == schema.HasMany { if rel.Type == schema.HasOne || rel.Type == schema.HasMany {
beDependedOn[rel.FieldSchema] = true beDependedOn[rel.FieldSchema] = true
} }
if rel.JoinTable != nil { if rel.JoinTable != nil {
// append join value // append join value
defer func(rel *schema.Relationship, joinValue interface{}) { defer func(rel *schema.Relationship, joinValue interface{}) {
if !beDependedOn[rel.FieldSchema] { if !beDependedOn[rel.FieldSchema] {
dep.Depends = append(dep.Depends, rel.FieldSchema) dep.Depends = append(dep.Depends, rel.FieldSchema)
} else { } else {
fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface() fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface()
parseDependence(fieldValue, autoAdd) parseDependence(fieldValue, autoAdd)
} }
parseDependence(joinValue, autoAdd) parseDependence(joinValue, autoAdd)
}(rel, reflect.New(rel.JoinTable.ModelType).Interface()) }(rel, reflect.New(rel.JoinTable.ModelType).Interface())
}
} }
} }
@ -840,3 +1017,18 @@ func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
} }
return clause.Table{Name: stmt.Table} return clause.Table{Name: stmt.Table}
} }
// GetIndexes return Indexes []gorm.Index and execErr error
func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) {
return nil, errors.New("not support")
}
// GetTypeAliases return database type aliases
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
return nil
}
// TableType return tableType gorm.TableType and execErr error
func (m Migrator) TableType(dst interface{}) (gorm.TableType, error) {
return nil, errors.New("not support")
}

33
migrator/table_type.go Normal file
View File

@ -0,0 +1,33 @@
package migrator
import (
"database/sql"
)
// TableType table type implements TableType interface
type TableType struct {
SchemaValue string
NameValue string
TypeValue string
CommentValue sql.NullString
}
// Schema returns the schema of the table.
func (ct TableType) Schema() string {
return ct.SchemaValue
}
// Name returns the name of the table.
func (ct TableType) Name() string {
return ct.NameValue
}
// Type returns the type of the table.
func (ct TableType) Type() string {
return ct.TypeValue
}
// Comment returns the comment of current table.
func (ct TableType) Comment() (comment string, ok bool) {
return ct.CommentValue.String, ct.CommentValue.Valid
}

View File

@ -4,9 +4,10 @@ import "time"
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embedded into your model or you may build your own model without it // It may be embedded into your model or you may build your own model without it
// type User struct { //
// gorm.Model // type User struct {
// } // gorm.Model
// }
type Model struct { type Model struct {
ID uint `gorm:"primarykey"` ID uint `gorm:"primarykey"`
CreatedAt time.Time CreatedAt time.Time

View File

@ -3,70 +3,86 @@ package gorm
import ( import (
"context" "context"
"database/sql" "database/sql"
"database/sql/driver"
"errors"
"reflect"
"sync" "sync"
"time"
"gorm.io/gorm/internal/stmt_store"
) )
type Stmt struct {
*sql.Stmt
Transaction bool
}
type PreparedStmtDB struct { type PreparedStmtDB struct {
Stmts map[string]Stmt Stmts stmt_store.Store
PreparedSQL []string Mux *sync.RWMutex
Mux *sync.RWMutex
ConnPool ConnPool
} }
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { // NewPreparedStmtDB creates and initializes a new instance of PreparedStmtDB.
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { //
return dbConnector.GetDBConn() // Parameters:
// - connPool: A connection pool that implements the ConnPool interface, used for managing database connections.
// - maxSize: The maximum number of prepared statements that can be stored in the statement store.
// - ttl: The time-to-live duration for each prepared statement in the store. Statements older than this duration will be automatically removed.
//
// Returns:
// - A pointer to a PreparedStmtDB instance, which manages prepared statements using the provided connection pool and configuration.
func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool, // Assigns the provided connection pool to manage database connections.
Stmts: stmt_store.New(maxSize, ttl), // Initializes a new statement store with the specified maximum size and TTL.
Mux: &sync.RWMutex{}, // Sets up a read-write mutex for synchronizing access to the statement store.
} }
}
// GetDBConn returns the underlying *sql.DB connection
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
if sqldb, ok := db.ConnPool.(*sql.DB); ok { if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil return sqldb, nil
} }
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}
return nil, ErrInvalidDB return nil, ErrInvalidDB
} }
// Close closes all prepared statements in the store
func (db *PreparedStmtDB) Close() { func (db *PreparedStmtDB) Close() {
db.Mux.Lock() db.Mux.Lock()
defer db.Mux.Unlock() defer db.Mux.Unlock()
for _, query := range db.PreparedSQL { for _, key := range db.Stmts.Keys() {
if stmt, ok := db.Stmts[query]; ok { db.Stmts.Delete(key)
delete(db.Stmts, query)
go stmt.Close()
}
} }
} }
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { // Reset Deprecated use Close instead
func (db *PreparedStmtDB) Reset() {
db.Close()
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) {
db.Mux.RLock() db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { if db.Stmts != nil {
db.Mux.RUnlock() if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
return stmt, nil db.Mux.RUnlock()
return stmt, stmt.Error()
}
} }
db.Mux.RUnlock() db.Mux.RUnlock()
// retry
db.Mux.Lock() db.Mux.Lock()
defer db.Mux.Unlock() if db.Stmts != nil {
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
// double check db.Mux.Unlock()
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { return stmt, stmt.Error()
return stmt, nil }
} else if ok {
go stmt.Close()
} }
stmt, err := conn.PrepareContext(ctx, query) return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux)
if err == nil {
db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction}
db.PreparedSQL = append(db.PreparedSQL, query)
}
return db.Stmts[query], err
} }
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
@ -74,6 +90,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
tx, err := beginner.BeginTx(ctx, opt) tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
} }
beginner, ok := db.ConnPool.(ConnPoolBeginner)
if !ok {
return nil, ErrInvalidTransaction
}
connPool, err := beginner.BeginTx(ctx, opt)
if err != nil {
return nil, err
}
if tx, ok := connPool.(Tx); ok {
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
}
return nil, ErrInvalidTransaction return nil, ErrInvalidTransaction
} }
@ -81,11 +110,8 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
stmt, err := db.prepare(ctx, db.ConnPool, false, query) stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil { if err == nil {
result, err = stmt.ExecContext(ctx, args...) result, err = stmt.ExecContext(ctx, args...)
if err != nil { if errors.Is(err, driver.ErrBadConn) {
db.Mux.Lock() db.Stmts.Delete(query)
defer db.Mux.Unlock()
go stmt.Close()
delete(db.Stmts, query)
} }
} }
return result, err return result, err
@ -95,12 +121,8 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
stmt, err := db.prepare(ctx, db.ConnPool, false, query) stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil { if err == nil {
rows, err = stmt.QueryContext(ctx, args...) rows, err = stmt.QueryContext(ctx, args...)
if err != nil { if errors.Is(err, driver.ErrBadConn) {
db.Mux.Lock() db.Stmts.Delete(query)
defer db.Mux.Unlock()
go stmt.Close()
delete(db.Stmts, query)
} }
} }
return rows, err return rows, err
@ -114,20 +136,32 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
return &sql.Row{} return &sql.Row{}
} }
func (db *PreparedStmtDB) Ping() error {
conn, err := db.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}
type PreparedStmtTX struct { type PreparedStmtTX struct {
Tx Tx
PreparedStmtDB *PreparedStmtDB PreparedStmtDB *PreparedStmtDB
} }
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
return db.PreparedStmtDB.GetDBConn()
}
func (tx *PreparedStmtTX) Commit() error { func (tx *PreparedStmtTX) Commit() error {
if tx.Tx != nil { if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Commit() return tx.Tx.Commit()
} }
return ErrInvalidTransaction return ErrInvalidTransaction
} }
func (tx *PreparedStmtTX) Rollback() error { func (tx *PreparedStmtTX) Rollback() error {
if tx.Tx != nil { if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Rollback() return tx.Tx.Rollback()
} }
return ErrInvalidTransaction return ErrInvalidTransaction
@ -137,12 +171,8 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil { if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if err != nil { if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Mux.Lock() tx.PreparedStmtDB.Stmts.Delete(query)
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
} }
} }
return result, err return result, err
@ -152,12 +182,8 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil { if err == nil {
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
if err != nil { if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Mux.Lock() tx.PreparedStmtDB.Stmts.Delete(query)
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
} }
} }
return rows, err return rows, err
@ -170,3 +196,11 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg
} }
return &sql.Row{} return &sql.Row{}
} }
func (tx *PreparedStmtTX) Ping() error {
conn, err := tx.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}

159
scan.go
View File

@ -8,6 +8,7 @@ import (
"time" "time"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
// prepareValues prepare values slice // prepareValues prepare values slice
@ -15,7 +16,7 @@ func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType,
if db.Statement.Schema != nil { if db.Statement.Schema != nil {
for idx, name := range columns { for idx, name := range columns {
if field := db.Statement.Schema.LookUpField(name); field != nil { if field := db.Statement.Schema.LookUpField(name); field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() values[idx] = reflect.New(reflect.PointerTo(field.FieldType)).Interface()
continue continue
} }
values[idx] = new(interface{}) values[idx] = new(interface{})
@ -23,7 +24,7 @@ func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType,
} else if len(columnTypes) > 0 { } else if len(columnTypes) > 0 {
for idx, columnType := range columnTypes { for idx, columnType := range columnTypes {
if columnType.ScanType() != nil { if columnType.ScanType() != nil {
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface() values[idx] = reflect.New(reflect.PointerTo(columnType.ScanType())).Interface()
} else { } else {
values[idx] = new(interface{}) values[idx] = new(interface{})
} }
@ -50,7 +51,7 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
} }
} }
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) { func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) {
for idx, field := range fields { for idx, field := range fields {
if field != nil { if field != nil {
values[idx] = field.NewValuePool.Get() values[idx] = field.NewValuePool.Get()
@ -65,26 +66,49 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
joinedNestedSchemaMap := make(map[string]interface{})
for idx, field := range fields { for idx, field := range fields {
if field != nil { if field == nil {
if len(joinFields) == 0 || joinFields[idx][0] == nil { continue
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) }
} else {
relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
return
}
relValue.Set(reflect.New(relValue.Type().Elem())) if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
} else { // joinFields count is larger than 2 when using join
var isNilPtrValue bool
var relValue reflect.Value
// does not contain raw dbname
nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
// current reflect value
currentReflectValue := reflectValue
fullRels := make([]string, 0, len(nestedJoinSchemas))
for _, joinSchema := range nestedJoinSchemas {
fullRels = append(fullRels, joinSchema.Name)
relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
if relValue.Kind() == reflect.Ptr {
fullRelsName := utils.JoinNestedRelationNames(fullRels)
// same nested structure
if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
isNilPtrValue = true
break
}
relValue.Set(reflect.New(relValue.Type().Elem()))
joinedNestedSchemaMap[fullRelsName] = nil
}
} }
db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) currentReflectValue = relValue
} }
// release data to pool if !isNilPtrValue { // ignore if value is nil
field.NewValuePool.Put(values[idx]) f := joinFields[idx][len(joinFields[idx])-1]
db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
}
} }
// release data to pool
field.NewValuePool.Put(values[idx])
} }
} }
@ -108,6 +132,15 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
onConflictDonothing = mode&ScanOnConflictDoNothing != 0 onConflictDonothing = mode&ScanOnConflictDoNothing != 0
) )
if len(db.Statement.ColumnMapping) > 0 {
for i, column := range columns {
v, ok := db.Statement.ColumnMapping[column]
if ok {
columns[i] = v
}
}
}
db.RowsAffected = 0 db.RowsAffected = 0
switch dest := db.Statement.Dest.(type) { switch dest := db.Statement.Dest.(type) {
@ -156,11 +189,10 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
} }
default: default:
var ( var (
fields = make([]*schema.Field, len(columns)) fields = make([]*schema.Field, len(columns))
selectedColumnsMap = make(map[string]int, len(columns)) joinFields [][]*schema.Field
joinFields [][2]*schema.Field sch = db.Statement.Schema
sch = db.Statement.Schema reflectValue = db.Statement.ReflectValue
reflectValue = db.Statement.ReflectValue
) )
if reflectValue.Kind() == reflect.Interface { if reflectValue.Kind() == reflect.Interface {
@ -193,35 +225,61 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
// Not Pluck // Not Pluck
if sch != nil { if sch != nil {
matchedFieldCount := make(map[string]int, len(columns))
for idx, column := range columns { for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable { if field := sch.LookUpField(column); field != nil && field.Readable {
if curIndex, ok := selectedColumnsMap[column]; ok { fields[idx] = field
for fieldIndex, selectField := range sch.Fields[curIndex:] { if count, ok := matchedFieldCount[column]; ok {
// handle duplicate fields
for _, selectField := range sch.Fields {
if selectField.DBName == column && selectField.Readable { if selectField.DBName == column && selectField.Readable {
selectedColumnsMap[column] = curIndex + fieldIndex + 1 if count == 0 {
fields[idx] = selectField matchedFieldCount[column]++
break fields[idx] = selectField
break
}
count--
} }
} }
} else { } else {
fields[idx] = field matchedFieldCount[column] = 1
selectedColumnsMap[column] = idx
} }
} else if names := strings.Split(column, "__"); len(names) > 1 { } else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1])
for _, join := range db.Statement.Joins {
if join.Alias == aliasName {
names = append(strings.Split(join.Name, "."), names[len(names)-1])
break
}
}
if rel, ok := sch.Relationships.Relations[names[0]]; ok { if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { subNameCount := len(names)
// nested relation fields
relFields := make([]*schema.Field, 0, subNameCount-1)
relFields = append(relFields, rel.Field)
for _, name := range names[1 : subNameCount-1] {
rel = rel.FieldSchema.Relationships.Relations[name]
relFields = append(relFields, rel.Field)
}
// latest name is raw dbname
dbName := names[subNameCount-1]
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
fields[idx] = field fields[idx] = field
if len(joinFields) == 0 { if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns)) joinFields = make([][]*schema.Field, len(columns))
} }
joinFields[idx] = [2]*schema.Field{rel.Field, field} relFields = append(relFields, field)
joinFields[idx] = relFields
continue continue
} }
} }
values[idx] = &sql.RawBytes{} var val interface{}
values[idx] = &val
} else { } else {
values[idx] = &sql.RawBytes{} var val interface{}
values[idx] = &val
} }
} }
} }
@ -229,11 +287,24 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
var elem reflect.Value var (
elem reflect.Value
isArrayKind = reflectValue.Kind() == reflect.Array
)
if !update || reflectValue.Len() == 0 { if !update || reflectValue.Len() == 0 {
update = false update = false
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) if isArrayKind {
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
} else {
// if the slice cap is externally initialized, the externally initialized slice is directly used here
if reflectValue.Cap() == 0 {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
} else {
reflectValue.SetLen(0)
db.Statement.ReflectValue.Set(reflectValue)
}
}
} }
for initialized || rows.Next() { for initialized || rows.Next() {
@ -260,10 +331,15 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
db.scanIntoStruct(rows, elem, values, fields, joinFields) db.scanIntoStruct(rows, elem, values, fields, joinFields)
if !update { if !update {
if isPtr { if !isPtr {
reflectValue = reflect.Append(reflectValue, elem) elem = elem.Elem()
}
if isArrayKind {
if reflectValue.Len() >= int(db.RowsAffected) {
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
}
} else { } else {
reflectValue = reflect.Append(reflectValue, elem.Elem()) reflectValue = reflect.Append(reflectValue, elem)
} }
} }
} }
@ -273,6 +349,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
} }
case reflect.Struct, reflect.Ptr: case reflect.Struct, reflect.Ptr:
if initialized || rows.Next() { if initialized || rows.Next() {
if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct {
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
}
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields) db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
} }
default: default:

View File

@ -1,35 +0,0 @@
package schema
import (
"regexp"
"strings"
)
// reg match english letters and midline
var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$")
type Check struct {
Name string
Constraint string // length(phone) >= 10
*Field
}
// ParseCheckConstraints parse schema check constraints
func (schema *Schema) ParseCheckConstraints() map[string]Check {
checks := map[string]Check{}
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
if names[0] == "" {
chk = strings.Join(names[1:], ",")
}
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = Check{Name: name, Constraint: chk, Field: field}
}
}
}
return checks
}

66
schema/constraint.go Normal file
View File

@ -0,0 +1,66 @@
package schema
import (
"regexp"
"strings"
"gorm.io/gorm/clause"
)
// reg match english letters and midline
var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`)
type CheckConstraint struct {
Name string
Constraint string // length(phone) >= 10
*Field
}
func (chk *CheckConstraint) GetName() string { return chk.Name }
func (chk *CheckConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
}
// ParseCheckConstraints parse schema check constraints
func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint {
checks := map[string]CheckConstraint{}
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
if names[0] == "" {
chk = strings.Join(names[1:], ",")
}
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field}
}
}
}
return checks
}
type UniqueConstraint struct {
Name string
Field *Field
}
func (uni *UniqueConstraint) GetName() string { return uni.Name }
func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}}
}
// ParseUniqueConstraints parse schema unique constraints
func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint {
uniques := make(map[string]UniqueConstraint)
for _, field := range schema.Fields {
if field.Unique {
name := schema.namer.UniqueName(schema.Table, field.DBName)
uniques[name] = UniqueConstraint{Name: name, Field: field}
}
}
return uniques
}

View File

@ -6,6 +6,7 @@ import (
"testing" "testing"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
) )
type UserCheck struct { type UserCheck struct {
@ -20,7 +21,7 @@ func TestParseCheck(t *testing.T) {
t.Fatalf("failed to parse user check, got error %v", err) t.Fatalf("failed to parse user check, got error %v", err)
} }
results := map[string]schema.Check{ results := map[string]schema.CheckConstraint{
"name_checker": { "name_checker": {
Name: "name_checker", Name: "name_checker",
Constraint: "name <> 'jinzhu'", Constraint: "name <> 'jinzhu'",
@ -53,3 +54,31 @@ func TestParseCheck(t *testing.T) {
} }
} }
} }
func TestParseUniqueConstraints(t *testing.T) {
type UserUnique struct {
Name1 string `gorm:"unique"`
Name2 string `gorm:"uniqueIndex"`
}
user, err := schema.Parse(&UserUnique{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user unique, got error %v", err)
}
constraints := user.ParseUniqueConstraints()
results := map[string]schema.UniqueConstraint{
"uni_user_uniques_name1": {
Name: "uni_user_uniques_name1",
Field: &schema.Field{Name: "Name1", Unique: true},
},
}
for k, result := range results {
v, ok := constraints[k]
if !ok {
t.Errorf("Failed to found unique constraint %v from parsed constraints %+v", k, constraints)
}
tests.AssertObjEqual(t, result, v, "Name")
tests.AssertObjEqual(t, result.Field, v.Field, "Name", "Unique", "UniqueIndex")
}
}

View File

@ -49,11 +49,14 @@ const (
Bytes DataType = "bytes" Bytes DataType = "bytes"
) )
const DefaultAutoIncrementIncrement int64 = 1
// Field is the representation of model schema's field // Field is the representation of model schema's field
type Field struct { type Field struct {
Name string Name string
DBName string DBName string
BindNames []string BindNames []string
EmbeddedBindNames []string
DataType DataType DataType DataType
GORMDataType DataType GORMDataType DataType
PrimaryKey bool PrimaryKey bool
@ -87,6 +90,16 @@ type Field struct {
Set func(context.Context, reflect.Value, interface{}) error Set func(context.Context, reflect.Value, interface{}) error
Serializer SerializerInterface Serializer SerializerInterface
NewValuePool FieldNewValuePool NewValuePool FieldNewValuePool
// In some db (e.g. MySQL), Unique and UniqueIndex are indistinguishable.
// When a column has a (not Mul) UniqueIndex, Migrator always reports its gorm.ColumnType is Unique.
// It causes field unnecessarily migration.
// Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique.
UniqueIndex string
}
func (field *Field) BindName() string {
return strings.Join(field.BindNames, ".")
} }
// ParseField parses reflect.StructField to Field // ParseField parses reflect.StructField to Field
@ -100,6 +113,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
Name: fieldStruct.Name, Name: fieldStruct.Name,
DBName: tagSetting["COLUMN"], DBName: tagSetting["COLUMN"],
BindNames: []string{fieldStruct.Name}, BindNames: []string{fieldStruct.Name},
EmbeddedBindNames: []string{fieldStruct.Name},
FieldType: fieldStruct.Type, FieldType: fieldStruct.Type,
IndirectFieldType: fieldStruct.Type, IndirectFieldType: fieldStruct.Type,
StructField: fieldStruct, StructField: fieldStruct,
@ -115,7 +129,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
Unique: utils.CheckTruth(tagSetting["UNIQUE"]), Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
Comment: tagSetting["COMMENT"], Comment: tagSetting["COMMENT"],
AutoIncrementIncrement: 1, AutoIncrementIncrement: DefaultAutoIncrementIncrement,
} }
for field.IndirectFieldType.Kind() == reflect.Ptr { for field.IndirectFieldType.Kind() == reflect.Ptr {
@ -174,7 +188,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.DataType = String field.DataType = String
field.Serializer = v field.Serializer = v
} else { } else {
var serializerName = field.TagSettings["JSON"] serializerName := field.TagSettings["JSON"]
if serializerName == "" { if serializerName == "" {
serializerName = field.TagSettings["SERIALIZER"] serializerName = field.TagSettings["SERIALIZER"]
} }
@ -304,9 +318,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
if val, ok := field.TagSettings["TYPE"]; ok { if val, ok := field.TagSettings["TYPE"]; ok {
switch DataType(strings.ToLower(val)) { lowerVal := DataType(strings.ToLower(val))
switch lowerVal {
case Bool, Int, Uint, Float, String, Time, Bytes: case Bool, Int, Uint, Float, String, Time, Bytes:
field.DataType = DataType(strings.ToLower(val)) field.DataType = lowerVal
default: default:
field.DataType = DataType(val) field.DataType = DataType(val)
} }
@ -391,6 +406,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
ef.Schema = schema ef.Schema = schema
ef.OwnerSchema = field.EmbeddedSchema ef.OwnerSchema = field.EmbeddedSchema
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
if _, ok := field.TagSettings["EMBEDDED"]; ok || !fieldStruct.Anonymous {
ef.EmbeddedBindNames = append([]string{fieldStruct.Name}, ef.EmbeddedBindNames...)
}
// index is negative means is pointer // index is negative means is pointer
if field.FieldType.Kind() == reflect.Struct { if field.FieldType.Kind() == reflect.Struct {
ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...)
@ -403,18 +421,14 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
if ef.PrimaryKey { if ef.PrimaryKey {
if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { if !utils.CheckTruth(ef.TagSettings["PRIMARYKEY"], ef.TagSettings["PRIMARY_KEY"]) {
ef.PrimaryKey = true
} else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) {
ef.PrimaryKey = true
} else {
ef.PrimaryKey = false ef.PrimaryKey = false
if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) {
ef.AutoIncrement = false ef.AutoIncrement = false
} }
if ef.DefaultValue == "" { if !ef.AutoIncrement && ef.DefaultValue == "" {
ef.HasDefaultValue = false ef.HasDefaultValue = false
} }
} }
@ -434,21 +448,30 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
// create valuer, setter when parse struct // create valuer, setter when parse struct
func (field *Field) setupValuerAndSetter() { func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
// Setup NewValuePool // Setup NewValuePool
field.setupNewValuePool() field.setupNewValuePool()
// ValueOf returns field's value and if it is zero // ValueOf returns field's value and if it is zero
fieldIndex := field.StructField.Index[0] fieldIndex := field.StructField.Index[0]
switch { switch {
case len(field.StructField.Index) == 1 && fieldIndex > 0: case len(field.StructField.Index) == 1 && fieldIndex >= 0:
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(fieldIndex) v = reflect.Indirect(v)
if v.Type() != modelType {
fieldValue := v.FieldByName(field.Name)
return fieldValue.Interface(), fieldValue.IsZero()
}
fieldValue := v.Field(fieldIndex)
return fieldValue.Interface(), fieldValue.IsZero() return fieldValue.Interface(), fieldValue.IsZero()
} }
default: default:
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
v = reflect.Indirect(v) v = reflect.Indirect(v)
if v.Type() != modelType {
fieldValue := v.FieldByName(field.Name)
return fieldValue.Interface(), fieldValue.IsZero()
}
for _, fieldIdx := range field.StructField.Index { for _, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 { if fieldIdx >= 0 {
v = v.Field(fieldIdx) v = v.Field(fieldIdx)
@ -472,9 +495,6 @@ func (field *Field) setupValuerAndSetter() {
oldValuerOf := field.ValueOf oldValuerOf := field.ValueOf
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
value, zero := oldValuerOf(ctx, v) value, zero := oldValuerOf(ctx, v)
if zero {
return value, zero
}
s, ok := value.(SerializerValuerInterface) s, ok := value.(SerializerValuerInterface)
if !ok { if !ok {
@ -487,19 +507,26 @@ func (field *Field) setupValuerAndSetter() {
Destination: v, Destination: v,
Context: ctx, Context: ctx,
fieldValue: value, fieldValue: value,
}, false }, zero
} }
} }
// ReflectValueOf returns field's reflect value // ReflectValueOf returns field's reflect value
switch { switch {
case len(field.StructField.Index) == 1 && fieldIndex > 0: case len(field.StructField.Index) == 1 && fieldIndex >= 0:
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(fieldIndex) v = reflect.Indirect(v)
if v.Type() != modelType {
return v.FieldByName(field.Name)
}
return v.Field(fieldIndex)
} }
default: default:
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
v = reflect.Indirect(v) v = reflect.Indirect(v)
if v.Type() != modelType {
return v.FieldByName(field.Name)
}
for idx, fieldIdx := range field.StructField.Index { for idx, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 { if fieldIdx >= 0 {
v = v.Field(fieldIdx) v = v.Field(fieldIdx)
@ -528,6 +555,9 @@ func (field *Field) setupValuerAndSetter() {
reflectValType := reflectV.Type() reflectValType := reflectV.Type()
if reflectValType.AssignableTo(field.FieldType) { if reflectValType.AssignableTo(field.FieldType) {
if reflectV.Kind() == reflect.Ptr && reflectV.Elem().Kind() == reflect.Ptr {
reflectV = reflect.Indirect(reflectV)
}
field.ReflectValueOf(ctx, value).Set(reflectV) field.ReflectValueOf(ctx, value).Set(reflectV)
return return
} else if reflectValType.ConvertibleTo(field.FieldType) { } else if reflectValType.ConvertibleTo(field.FieldType) {
@ -604,6 +634,22 @@ func (field *Field) setupValuerAndSetter() {
if data != nil && *data != nil { if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(**data) field.ReflectValueOf(ctx, value).SetInt(**data)
} }
case **int:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
}
case **int8:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
}
case **int16:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
}
case **int32:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
}
case int64: case int64:
field.ReflectValueOf(ctx, value).SetInt(data) field.ReflectValueOf(ctx, value).SetInt(data)
case int: case int:
@ -640,7 +686,7 @@ func (field *Field) setupValuerAndSetter() {
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
} else { } else {
field.ReflectValueOf(ctx, value).SetInt(data.Unix()) field.ReflectValueOf(ctx, value).SetInt(data.Unix())
} }
@ -649,7 +695,7 @@ func (field *Field) setupValuerAndSetter() {
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
} else { } else {
field.ReflectValueOf(ctx, value).SetInt(data.Unix()) field.ReflectValueOf(ctx, value).SetInt(data.Unix())
} }
@ -668,6 +714,22 @@ func (field *Field) setupValuerAndSetter() {
if data != nil && *data != nil { if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(**data) field.ReflectValueOf(ctx, value).SetUint(**data)
} }
case **uint:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
}
case **uint8:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
}
case **uint16:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
}
case **uint32:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
}
case uint64: case uint64:
field.ReflectValueOf(ctx, value).SetUint(data) field.ReflectValueOf(ctx, value).SetUint(data)
case uint: case uint:
@ -698,7 +760,7 @@ func (field *Field) setupValuerAndSetter() {
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano())) field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6)) field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixMilli()))
} else { } else {
field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix())) field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
} }
@ -720,6 +782,10 @@ func (field *Field) setupValuerAndSetter() {
if data != nil && *data != nil { if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetFloat(**data) field.ReflectValueOf(ctx, value).SetFloat(**data)
} }
case **float32:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetFloat(float64(**data))
}
case float64: case float64:
field.ReflectValueOf(ctx, value).SetFloat(data) field.ReflectValueOf(ctx, value).SetFloat(data)
case float32: case float32:
@ -810,7 +876,7 @@ func (field *Field) setupValuerAndSetter() {
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
switch data := v.(type) { switch data := v.(type) {
case **time.Time: case **time.Time:
if data != nil { if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data)) field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
} }
case time.Time: case time.Time:
@ -846,14 +912,12 @@ func (field *Field) setupValuerAndSetter() {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() { if !reflectV.IsValid() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
return
} else if reflectV.Type().AssignableTo(field.FieldType) { } else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV) field.ReflectValueOf(ctx, value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr { } else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() || !reflectV.IsValid() { return field.Set(ctx, value, reflectV.Elem().Interface())
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else {
return field.Set(ctx, value, reflectV.Elem().Interface())
}
} else { } else {
fieldValue := field.ReflectValueOf(ctx, value) fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() { if fieldValue.IsNil() {
@ -874,14 +938,12 @@ func (field *Field) setupValuerAndSetter() {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() { if !reflectV.IsValid() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
return
} else if reflectV.Type().AssignableTo(field.FieldType) { } else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV) field.ReflectValueOf(ctx, value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr { } else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() || !reflectV.IsValid() { return field.Set(ctx, value, reflectV.Elem().Interface())
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else {
return field.Set(ctx, value, reflectV.Elem().Interface())
}
} else { } else {
if valuer, ok := v.(driver.Valuer); ok { if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value() v, _ = valuer.Value()
@ -910,6 +972,8 @@ func (field *Field) setupValuerAndSetter() {
sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem() sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem()
} }
serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer))
serializerType := serializerValue.Type()
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
if s, ok := v.(*serializer); ok { if s, ok := v.(*serializer); ok {
if s.fieldValue != nil { if s.fieldValue != nil {
@ -917,11 +981,12 @@ func (field *Field) setupValuerAndSetter() {
} else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { } else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil {
if sameElemType { if sameElemType {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem())
s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface)
} else if sameType { } else if sameType {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer)) field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer))
s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface)
} }
si := reflect.New(serializerType)
si.Elem().Set(serializerValue)
s.Serializer = si.Interface().(SerializerInterface)
} }
} else { } else {
err = oldFieldSetter(ctx, value, v) err = oldFieldSetter(ctx, value, v)
@ -932,41 +997,22 @@ func (field *Field) setupValuerAndSetter() {
} }
func (field *Field) setupNewValuePool() { func (field *Field) setupNewValuePool() {
var fieldValue = reflect.New(field.FieldType).Interface()
if field.Serializer != nil { if field.Serializer != nil {
serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer))
serializerType := serializerValue.Type()
field.NewValuePool = &sync.Pool{ field.NewValuePool = &sync.Pool{
New: func() interface{} { New: func() interface{} {
si := reflect.New(serializerType)
si.Elem().Set(serializerValue)
return &serializer{ return &serializer{
Field: field, Field: field,
Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), Serializer: si.Interface().(SerializerInterface),
} }
}, },
} }
} else if _, ok := fieldValue.(sql.Scanner); !ok {
field.setupDefaultNewValuePool()
} }
if field.NewValuePool == nil { if field.NewValuePool == nil {
field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType)) field.NewValuePool = poolInitializer(reflect.PointerTo(field.IndirectFieldType))
}
}
func (field *Field) setupDefaultNewValuePool() {
// set default NewValuePool
switch field.IndirectFieldType.Kind() {
case reflect.String:
field.NewValuePool = stringPool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.NewValuePool = intPool
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.NewValuePool = uintPool
case reflect.Float32, reflect.Float64:
field.NewValuePool = floatPool
case reflect.Bool:
field.NewValuePool = boolPool
default:
if field.IndirectFieldType == TimeReflectType {
field.NewValuePool = timePool
}
} }
} }

View File

@ -1,6 +1,7 @@
package schema package schema
import ( import (
"fmt"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -12,8 +13,8 @@ type Index struct {
Type string // btree, hash, gist, spgist, gin, and brin Type string // btree, hash, gist, spgist, gin, and brin
Where string Where string
Comment string Comment string
Option string // WITH PARSER parser_name Option string // WITH PARSER parser_name
Fields []IndexOption Fields []IndexOption // Note: IndexOption's Field maybe the same
} }
type IndexOption struct { type IndexOption struct {
@ -22,17 +23,28 @@ type IndexOption struct {
Sort string // DESC, ASC Sort string // DESC, ASC
Collate string Collate string
Length int Length int
priority int Priority int
} }
// ParseIndexes parse schema indexes // ParseIndexes parse schema indexes
func (schema *Schema) ParseIndexes() map[string]Index { func (schema *Schema) ParseIndexes() []*Index {
indexes := map[string]Index{} indexesByName := map[string]*Index{}
indexes := []*Index{}
for _, field := range schema.Fields { for _, field := range schema.Fields {
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" { if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
for _, index := range parseFieldIndexes(field) { fieldIndexes, err := parseFieldIndexes(field)
idx := indexes[index.Name] if err != nil {
schema.err = err
break
}
for _, index := range fieldIndexes {
idx := indexesByName[index.Name]
if idx == nil {
idx = &Index{Name: index.Name}
indexesByName[index.Name] = idx
indexes = append(indexes, idx)
}
idx.Name = index.Name idx.Name = index.Name
if idx.Class == "" { if idx.Class == "" {
idx.Class = index.Class idx.Class = index.Class
@ -52,14 +64,16 @@ func (schema *Schema) ParseIndexes() map[string]Index {
idx.Fields = append(idx.Fields, index.Fields...) idx.Fields = append(idx.Fields, index.Fields...)
sort.Slice(idx.Fields, func(i, j int) bool { sort.Slice(idx.Fields, func(i, j int) bool {
return idx.Fields[i].priority < idx.Fields[j].priority return idx.Fields[i].Priority < idx.Fields[j].Priority
}) })
indexes[index.Name] = idx
} }
} }
} }
for _, index := range indexes {
if index.Class == "UNIQUE" && len(index.Fields) == 1 {
index.Fields[0].Field.UniqueIndex = index.Name
}
}
return indexes return indexes
} }
@ -68,12 +82,12 @@ func (schema *Schema) LookIndex(name string) *Index {
indexes := schema.ParseIndexes() indexes := schema.ParseIndexes()
for _, index := range indexes { for _, index := range indexes {
if index.Name == name { if index.Name == name {
return &index return index
} }
for _, field := range index.Fields { for _, field := range index.Fields {
if field.Name == name { if field.Name == name {
return &index return index
} }
} }
} }
@ -82,7 +96,7 @@ func (schema *Schema) LookIndex(name string) *Index {
return nil return nil
} }
func parseFieldIndexes(field *Field) (indexes []Index) { func parseFieldIndexes(field *Field) (indexes []Index, err error) {
for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
if value != "" { if value != "" {
v := strings.Split(value, ":") v := strings.Split(value, ":")
@ -91,7 +105,7 @@ func parseFieldIndexes(field *Field) (indexes []Index) {
var ( var (
name string name string
tag = strings.Join(v[1:], ":") tag = strings.Join(v[1:], ":")
idx = strings.Index(tag, ",") idx = strings.IndexByte(tag, ',')
tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",") tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",")
settings = ParseTagSetting(tagSetting, ",") settings = ParseTagSetting(tagSetting, ",")
length, _ = strconv.Atoi(settings["LENGTH"]) length, _ = strconv.Atoi(settings["LENGTH"])
@ -101,12 +115,22 @@ func parseFieldIndexes(field *Field) (indexes []Index) {
idx = len(tag) idx = len(tag)
} }
if idx != -1 { name = tag[0:idx]
name = tag[0:idx]
}
if name == "" { if name == "" {
name = field.Schema.namer.IndexName(field.Schema.Table, field.Name) subName := field.Name
const key = "COMPOSITE"
if composite, found := settings[key]; found {
if len(composite) == 0 || composite == key {
err = fmt.Errorf(
"the composite tag of %s.%s cannot be empty",
field.Schema.Name,
field.Name)
return
}
subName = composite
}
name = field.Schema.namer.IndexName(
field.Schema.Table, subName)
} }
if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" { if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" {
@ -131,12 +155,13 @@ func parseFieldIndexes(field *Field) (indexes []Index) {
Sort: settings["SORT"], Sort: settings["SORT"],
Collate: settings["COLLATE"], Collate: settings["COLLATE"],
Length: length, Length: length,
priority: priority, Priority: priority,
}}, }},
}) })
} }
} }
} }
err = nil
return return
} }

View File

@ -1,11 +1,11 @@
package schema_test package schema_test
import ( import (
"reflect"
"sync" "sync"
"testing" "testing"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
) )
type UserIndex struct { type UserIndex struct {
@ -19,6 +19,40 @@ type UserIndex struct {
OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` OID int64 `gorm:"index:idx_id;index:idx_oid,unique"`
MemberNumber string `gorm:"index:idx_id,priority:1"` MemberNumber string `gorm:"index:idx_id,priority:1"`
Name7 string `gorm:"index:type"` Name7 string `gorm:"index:type"`
Name8 string `gorm:"index:,length:10;index:,collate:utf8"`
CompName1 string `gorm:"index:,unique,composite:idx_compname_1,option:NULLS NOT DISTINCT;not null"`
CompName2 string `gorm:"index:,composite:idx_compname_1"`
// Composite Index: Flattened structure.
Data0A string `gorm:"index:,composite:comp_id0"`
Data0B string `gorm:"index:,composite:comp_id0"`
// Composite Index: Nested structure.
Data1A string `gorm:"index:,composite:comp_id1"`
CompIdxLevel1C
// Composite Index: Unique and priority.
Data2A string `gorm:"index:,unique,composite:comp_id2,priority:2"`
CompIdxLevel2C
}
type CompIdxLevel1C struct {
CompIdxLevel1B
Data1C string `gorm:"index:,composite:comp_id1"`
}
type CompIdxLevel1B struct {
Data1B string `gorm:"index:,composite:comp_id1"`
}
type CompIdxLevel2C struct {
CompIdxLevel2B
Data2C string `gorm:"index:,unique,composite:comp_id2,priority:1"`
}
type CompIdxLevel2B struct {
Data2B string `gorm:"index:,unique,composite:comp_id2,priority:3"`
} }
func TestParseIndex(t *testing.T) { func TestParseIndex(t *testing.T) {
@ -27,17 +61,17 @@ func TestParseIndex(t *testing.T) {
t.Fatalf("failed to parse user index, got error %v", err) t.Fatalf("failed to parse user index, got error %v", err)
} }
results := map[string]schema.Index{ results := []*schema.Index{
"idx_user_indices_name": { {
Name: "idx_user_indices_name", Name: "idx_user_indices_name",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name"}}},
}, },
"idx_name": { {
Name: "idx_name", Name: "idx_name",
Class: "UNIQUE", Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", UniqueIndex: "idx_name"}}},
}, },
"idx_user_indices_name3": { {
Name: "idx_user_indices_name3", Name: "idx_user_indices_name3",
Type: "btree", Type: "btree",
Where: "name3 != 'jinzhu'", Where: "name3 != 'jinzhu'",
@ -48,19 +82,19 @@ func TestParseIndex(t *testing.T) {
Length: 10, Length: 10,
}}, }},
}, },
"idx_user_indices_name4": { {
Name: "idx_user_indices_name4", Name: "idx_user_indices_name4",
Class: "UNIQUE", Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", UniqueIndex: "idx_user_indices_name4"}}},
}, },
"idx_user_indices_name5": { {
Name: "idx_user_indices_name5", Name: "idx_user_indices_name5",
Class: "FULLTEXT", Class: "FULLTEXT",
Comment: "hello , world", Comment: "hello , world",
Where: "age > 10", Where: "age > 10",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name5"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name5"}}},
}, },
"profile": { {
Name: "profile", Name: "profile",
Comment: "hello , world", Comment: "hello , world",
Where: "age > 10", Where: "age > 10",
@ -70,53 +104,172 @@ func TestParseIndex(t *testing.T) {
Expression: "ABS(age)", Expression: "ABS(age)",
}}, }},
}, },
"idx_id": { {
Name: "idx_id", Name: "idx_id",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
}, },
"idx_oid": { {
Name: "idx_oid", Name: "idx_oid",
Class: "UNIQUE", Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
}, },
"type": { {
Name: "type", Name: "type",
Type: "", Type: "",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}},
}, },
{
Name: "idx_user_indices_name8",
Type: "",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "Name8"}, Length: 10},
// Note: Duplicate Columns
{Field: &schema.Field{Name: "Name8"}, Collate: "utf8"},
},
},
{
Class: "UNIQUE",
Name: "idx_user_indices_idx_compname_1",
Option: "NULLS NOT DISTINCT",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "CompName1", NotNull: true}},
{Field: &schema.Field{Name: "CompName2"}},
},
},
{
Name: "idx_user_indices_comp_id0",
Type: "",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data0A"},
}, {
Field: &schema.Field{Name: "Data0B"},
}},
},
{
Name: "idx_user_indices_comp_id1",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data1A"},
}, {
Field: &schema.Field{Name: "Data1B"},
}, {
Field: &schema.Field{Name: "Data1C"},
}},
},
{
Name: "idx_user_indices_comp_id2",
Class: "UNIQUE",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data2C"},
}, {
Field: &schema.Field{Name: "Data2A"},
}, {
Field: &schema.Field{Name: "Data2B"},
}},
},
} }
indices := user.ParseIndexes() CheckIndices(t, results, user.ParseIndexes())
}
for k, result := range results { func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) {
v, ok := indices[k] type IndexTest struct {
if !ok { FieldA string `gorm:"unique;index"` // unique and index
t.Fatalf("Failed to found index %v from parsed indices %+v", k, indices) FieldB string `gorm:"unique"` // unique
}
for _, name := range []string{"Name", "Class", "Type", "Where", "Comment", "Option"} { FieldC string `gorm:"index:,unique"` // uniqueIndex
if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() { FieldD string `gorm:"uniqueIndex;index"` // uniqueIndex and index
t.Errorf(
"index %v %v should equal, expects %v, got %v", FieldE1 string `gorm:"uniqueIndex:uniq_field_e1_e2"` // mul uniqueIndex
k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(), FieldE2 string `gorm:"uniqueIndex:uniq_field_e1_e2"`
)
FieldF1 string `gorm:"uniqueIndex:uniq_field_f1_f2;index"` // mul uniqueIndex and index
FieldF2 string `gorm:"uniqueIndex:uniq_field_f1_f2;"`
FieldG string `gorm:"unique;uniqueIndex"` // unique and uniqueIndex
FieldH1 string `gorm:"unique;uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
FieldH2 string `gorm:"uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
}
indexSchema, err := schema.Parse(&IndexTest{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user index, got error %v", err)
}
indices := indexSchema.ParseIndexes()
expectedIndices := []*schema.Index{
{
Name: "idx_index_tests_field_a",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldA", Unique: true}}},
},
{
Name: "idx_index_tests_field_c",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldC", UniqueIndex: "idx_index_tests_field_c"}}},
},
{
Name: "idx_index_tests_field_d",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldD"}},
// Note: Duplicate Columns
{Field: &schema.Field{Name: "FieldD"}},
},
},
{
Name: "uniq_field_e1_e2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldE1"}},
{Field: &schema.Field{Name: "FieldE2"}},
},
},
{
Name: "uniq_field_f1_f2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldF1"}},
{Field: &schema.Field{Name: "FieldF2"}},
},
},
{
Name: "idx_index_tests_field_f1",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldF1"}}},
},
{
Name: "idx_index_tests_field_g",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldG", Unique: true, UniqueIndex: "idx_index_tests_field_g"}}},
},
{
Name: "uniq_field_h1_h2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldH1", Unique: true}},
{Field: &schema.Field{Name: "FieldH2"}},
},
},
}
CheckIndices(t, expectedIndices, indices)
}
func CheckIndices(t *testing.T, expected, actual []*schema.Index) {
if len(expected) != len(actual) {
t.Errorf("expected %d indices, but got %d", len(expected), len(actual))
return
}
for i, ei := range expected {
t.Run(ei.Name, func(t *testing.T) {
ai := actual[i]
tests.AssertObjEqual(t, ai, ei, "Name", "Class", "Type", "Where", "Comment", "Option")
if len(ei.Fields) != len(ai.Fields) {
t.Errorf("expected index %q field length is %d but actual %d", ei.Name, len(ei.Fields), len(ai.Fields))
return
} }
} for i, ef := range ei.Fields {
af := ai.Fields[i]
for idx, ef := range result.Fields { tests.AssertObjEqual(t, af, ef, "Name", "Unique", "UniqueIndex", "Expression", "Sort", "Collate", "Length", "NotNull")
rf := v.Fields[idx]
if rf.Field.Name != ef.Field.Name {
t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name)
} }
})
for _, name := range []string{"Expression", "Sort", "Collate", "Length"} {
if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() {
t.Errorf(
"index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name,
reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface(),
)
}
}
}
} }
} }

View File

@ -4,6 +4,12 @@ import (
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
// ConstraintInterface database constraint interface
type ConstraintInterface interface {
GetName() string
Build() (sql string, vars []interface{})
}
// GormDataTypeInterface gorm data type interface // GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface { type GormDataTypeInterface interface {
GormDataType() string GormDataType() string

View File

@ -8,6 +8,8 @@ import (
"unicode/utf8" "unicode/utf8"
"github.com/jinzhu/inflection" "github.com/jinzhu/inflection"
"golang.org/x/text/cases"
"golang.org/x/text/language"
) )
// Namer namer interface // Namer namer interface
@ -19,6 +21,7 @@ type Namer interface {
RelationshipFKName(Relationship) string RelationshipFKName(Relationship) string
CheckerName(table, column string) string CheckerName(table, column string) string
IndexName(table, column string) string IndexName(table, column string) string
UniqueName(table, column string) string
} }
// Replacer replacer interface like strings.Replacer // Replacer replacer interface like strings.Replacer
@ -26,12 +29,15 @@ type Replacer interface {
Replace(name string) string Replace(name string) string
} }
var _ Namer = (*NamingStrategy)(nil)
// NamingStrategy tables, columns naming strategy // NamingStrategy tables, columns naming strategy
type NamingStrategy struct { type NamingStrategy struct {
TablePrefix string TablePrefix string
SingularTable bool SingularTable bool
NameReplacer Replacer NameReplacer Replacer
NoLowerCase bool NoLowerCase bool
IdentifierMaxLength int
} }
// TableName convert string to table name // TableName convert string to table name
@ -84,17 +90,26 @@ func (ns NamingStrategy) IndexName(table, column string) string {
return ns.formatName("idx", table, ns.toDBName(column)) return ns.formatName("idx", table, ns.toDBName(column))
} }
// UniqueName generate unique constraint name
func (ns NamingStrategy) UniqueName(table, column string) string {
return ns.formatName("uni", table, ns.toDBName(column))
}
func (ns NamingStrategy) formatName(prefix, table, name string) string { func (ns NamingStrategy) formatName(prefix, table, name string) string {
formattedName := strings.ReplaceAll(strings.Join([]string{ formattedName := strings.ReplaceAll(strings.Join([]string{
prefix, table, name, prefix, table, name,
}, "_"), ".", "_") }, "_"), ".", "_")
if utf8.RuneCountInString(formattedName) > 64 { if ns.IdentifierMaxLength == 0 {
ns.IdentifierMaxLength = 64
}
if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength {
h := sha1.New() h := sha1.New()
h.Write([]byte(formattedName)) h.Write([]byte(formattedName))
bs := h.Sum(nil) bs := h.Sum(nil)
formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8] formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8]
} }
return formattedName return formattedName
} }
@ -108,7 +123,7 @@ var (
func init() { func init() {
commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms)) commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
for _, initialism := range commonInitialisms { for _, initialism := range commonInitialisms {
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, cases.Title(language.Und).String(initialism))
} }
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
} }
@ -173,9 +188,9 @@ func (ns NamingStrategy) toDBName(name string) string {
} }
func (ns NamingStrategy) toSchemaName(name string) string { func (ns NamingStrategy) toSchemaName(name string) string {
result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "") result := strings.ReplaceAll(cases.Title(language.Und, cases.NoLower).String(strings.ReplaceAll(name, "_", " ")), " ", "")
for _, initialism := range commonInitialisms { for _, initialism := range commonInitialisms {
result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") result = regexp.MustCompile(cases.Title(language.Und, cases.NoLower).String(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
} }
return result return result
} }

View File

@ -189,8 +189,17 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) {
} }
} }
func TestFormatNameWithStringLongerThan63Characters(t *testing.T) {
ns := NamingStrategy{IdentifierMaxLength: 63}
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVer180f2c67" {
t.Errorf("invalid formatted name generated, got %v", formattedName)
}
}
func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { func TestFormatNameWithStringLongerThan64Characters(t *testing.T) {
ns := NamingStrategy{} ns := NamingStrategy{IdentifierMaxLength: 64}
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" { if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" {

View File

@ -3,54 +3,11 @@ package schema
import ( import (
"reflect" "reflect"
"sync" "sync"
"time"
) )
// sync pools // sync pools
var ( var (
normalPool sync.Map normalPool sync.Map
stringPool = &sync.Pool{
New: func() interface{} {
var v string
ptrV := &v
return &ptrV
},
}
intPool = &sync.Pool{
New: func() interface{} {
var v int64
ptrV := &v
return &ptrV
},
}
uintPool = &sync.Pool{
New: func() interface{} {
var v uint64
ptrV := &v
return &ptrV
},
}
floatPool = &sync.Pool{
New: func() interface{} {
var v float64
ptrV := &v
return &ptrV
},
}
boolPool = &sync.Pool{
New: func() interface{} {
var v bool
ptrV := &v
return &ptrV
},
}
timePool = &sync.Pool{
New: func() interface{} {
var v time.Time
ptrV := &v
return &ptrV
},
}
poolInitializer = func(reflectType reflect.Type) FieldNewValuePool { poolInitializer = func(reflectType reflect.Type) FieldNewValuePool {
v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{ v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{
New: func() interface{} { New: func() interface{} {

View File

@ -5,8 +5,12 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"sync"
"github.com/jinzhu/inflection" "github.com/jinzhu/inflection"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
@ -27,6 +31,10 @@ type Relationships struct {
HasMany []*Relationship HasMany []*Relationship
Many2Many []*Relationship Many2Many []*Relationship
Relations map[string]*Relationship Relations map[string]*Relationship
EmbeddedRelations map[string]*Relationships
Mux sync.RWMutex
} }
type Relationship struct { type Relationship struct {
@ -70,12 +78,12 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
cacheStore := schema.cacheStore cacheStore := schema.cacheStore
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
schema.err = err schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err)
return nil return nil
} }
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { if hasPolymorphicRelation(field.TagSettings) {
schema.buildPolymorphicRelation(relation, field, polymorphic) schema.buildPolymorphicRelation(relation, field)
} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
schema.buildMany2ManyRelation(relation, field, many2many) schema.buildMany2ManyRelation(relation, field, many2many)
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" { } else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
@ -87,14 +95,16 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
case reflect.Slice: case reflect.Slice:
schema.guessRelation(relation, field, guessHas) schema.guessRelation(relation, field, guessHas)
default: default:
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name) schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema,
field.Name)
} }
} }
if relation.Type == has { if relation.Type == has {
// don't add relations to embedded schema, which might be shared
if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil { if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil {
relation.FieldSchema.Relationships.Mux.Lock()
relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation
relation.FieldSchema.Relationships.Mux.Unlock()
} }
switch field.IndirectFieldType.Kind() { switch field.IndirectFieldType.Kind() {
@ -106,7 +116,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
} }
if schema.err == nil { if schema.err == nil {
schema.Relationships.Relations[relation.Name] = relation schema.setRelation(relation)
switch relation.Type { switch relation.Type {
case HasOne: case HasOne:
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation) schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
@ -122,34 +132,100 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
return relation return relation
} }
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` // hasPolymorphicRelation check if has polymorphic relation
// type User struct { // 1. `POLYMORPHIC` tag
// Toys []Toy `gorm:"polymorphic:Owner;"` // 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
// } func hasPolymorphicRelation(tagSettings map[string]string) bool {
// type Pet struct { if _, ok := tagSettings["POLYMORPHIC"]; ok {
// Toy Toy `gorm:"polymorphic:Owner;"` return true
// }
// type Toy struct {
// OwnerID int
// OwnerType string
// }
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) {
relation.Polymorphic = &Polymorphic{
Value: schema.Table,
PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"],
} }
_, hasType := tagSettings["POLYMORPHICTYPE"]
_, hasId := tagSettings["POLYMORPHICID"]
return hasType && hasId
}
func (schema *Schema) setRelation(relation *Relationship) {
// set non-embedded relation
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
if len(rel.Field.BindNames) > 1 {
schema.Relationships.Relations[relation.Name] = relation
}
} else {
schema.Relationships.Relations[relation.Name] = relation
}
// set embedded relation
if len(relation.Field.EmbeddedBindNames) <= 1 {
return
}
relationships := &schema.Relationships
for i, name := range relation.Field.EmbeddedBindNames {
if i < len(relation.Field.EmbeddedBindNames)-1 {
if relationships.EmbeddedRelations == nil {
relationships.EmbeddedRelations = map[string]*Relationships{}
}
if r := relationships.EmbeddedRelations[name]; r == nil {
relationships.EmbeddedRelations[name] = &Relationships{}
}
relationships = relationships.EmbeddedRelations[name]
} else {
if relationships.Relations == nil {
relationships.Relations = map[string]*Relationship{}
}
relationships.Relations[relation.Name] = relation
}
}
}
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
//
// type User struct {
// Toys []Toy `gorm:"polymorphic:Owner;"`
// }
// type Pet struct {
// Toy Toy `gorm:"polymorphic:Owner;"`
// }
// type Toy struct {
// OwnerID int
// OwnerType string
// }
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
polymorphic := field.TagSettings["POLYMORPHIC"]
relation.Polymorphic = &Polymorphic{
Value: schema.Table,
}
var (
typeName = polymorphic + "Type"
typeId = polymorphic + "ID"
)
if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
typeName = strings.TrimSpace(value)
}
if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
typeId = strings.TrimSpace(value)
}
relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]
if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok { if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
relation.Polymorphic.Value = strings.TrimSpace(value) relation.Polymorphic.Value = strings.TrimSpace(value)
} }
if relation.Polymorphic.PolymorphicType == nil { if relation.Polymorphic.PolymorphicType == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type") schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"Type")
} }
if relation.Polymorphic.PolymorphicID == nil { if relation.Polymorphic.PolymorphicID == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID") schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"ID")
} }
if schema.err == nil { if schema.err == nil {
@ -161,10 +237,17 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
primaryKeyField := schema.PrioritizedPrimaryField primaryKeyField := schema.PrioritizedPrimaryField
if len(relation.foreignKeys) > 0 { if len(relation.foreignKeys) > 0 {
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name) schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys,
schema, field.Name)
} }
} }
if primaryKeyField == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
relation.FieldSchema, schema, field.Name)
return
}
// use same data type for foreign keys // use same data type for foreign keys
if copyableDataType(primaryKeyField.DataType) { if copyableDataType(primaryKeyField.DataType) {
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
@ -191,7 +274,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
err error err error
joinTableFields []reflect.StructField joinTableFields []reflect.StructField
fieldsMap = map[string]*Field{} fieldsMap = map[string]*Field{}
ownFieldsMap = map[string]bool{} // fix self join many2many ownFieldsMap = map[string]*Field{} // fix self join many2many
referFieldsMap = map[string]*Field{}
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"]) joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"]) joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
) )
@ -224,26 +308,24 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
} }
for idx, ownField := range ownForeignFields { for idx, ownField := range ownForeignFields {
joinFieldName := strings.Title(schema.Name) + ownField.Name joinFieldName := cases.Title(language.Und, cases.NoLower).String(schema.Name) + ownField.Name
if len(joinForeignKeys) > idx { if len(joinForeignKeys) > idx {
joinFieldName = strings.Title(joinForeignKeys[idx]) joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinForeignKeys[idx])
} }
ownFieldsMap[joinFieldName] = true ownFieldsMap[joinFieldName] = ownField
fieldsMap[joinFieldName] = ownField fieldsMap[joinFieldName] = ownField
joinTableFields = append(joinTableFields, reflect.StructField{ joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName, Name: joinFieldName,
PkgPath: ownField.StructField.PkgPath, PkgPath: ownField.StructField.PkgPath,
Type: ownField.StructField.Type, Type: ownField.StructField.Type,
Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
}) })
} }
for idx, relField := range refForeignFields { for idx, relField := range refForeignFields {
joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name joinFieldName := cases.Title(language.Und, cases.NoLower).String(relation.FieldSchema.Name) + relField.Name
if len(joinReferences) > idx {
joinFieldName = strings.Title(joinReferences[idx])
}
if _, ok := ownFieldsMap[joinFieldName]; ok { if _, ok := ownFieldsMap[joinFieldName]; ok {
if field.Name != relation.FieldSchema.Name { if field.Name != relation.FieldSchema.Name {
@ -253,22 +335,32 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
} }
} }
fieldsMap[joinFieldName] = relField if len(joinReferences) > idx {
joinTableFields = append(joinTableFields, reflect.StructField{ joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinReferences[idx])
Name: joinFieldName, }
PkgPath: relField.StructField.PkgPath,
Type: relField.StructField.Type, referFieldsMap[joinFieldName] = relField
Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
}) if _, ok := fieldsMap[joinFieldName]; !ok {
fieldsMap[joinFieldName] = relField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName,
PkgPath: relField.StructField.PkgPath,
Type: relField.StructField.Type,
Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
})
}
} }
joinTableFields = append(joinTableFields, reflect.StructField{ joinTableFields = append(joinTableFields, reflect.StructField{
Name: strings.Title(schema.Name) + field.Name, Name: cases.Title(language.Und, cases.NoLower).String(schema.Name) + field.Name,
Type: schema.ModelType, Type: schema.ModelType,
Tag: `gorm:"-"`, Tag: `gorm:"-"`,
}) })
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
schema.namer); err != nil {
schema.err = err schema.err = err
} }
relation.JoinTable.Name = many2many relation.JoinTable.Name = many2many
@ -315,31 +407,37 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
f.Size = fieldsMap[f.Name].Size f.Size = fieldsMap[f.Name].Size
} }
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name]
if ownPrimaryField { if of, ok := ownFieldsMap[f.Name]; ok {
joinRel := relation.JoinTable.Relationships.Relations[relName] joinRel := relation.JoinTable.Relationships.Relations[relName]
joinRel.Field = relation.Field joinRel.Field = relation.Field
joinRel.References = append(joinRel.References, &Reference{ joinRel.References = append(joinRel.References, &Reference{
PrimaryKey: fieldsMap[f.Name], PrimaryKey: of,
ForeignKey: f, ForeignKey: f,
}) })
} else {
relation.References = append(relation.References, &Reference{
PrimaryKey: of,
ForeignKey: f,
OwnPrimaryKey: true,
})
}
if rf, ok := referFieldsMap[f.Name]; ok {
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
if joinRefRel.Field == nil { if joinRefRel.Field == nil {
joinRefRel.Field = relation.Field joinRefRel.Field = relation.Field
} }
joinRefRel.References = append(joinRefRel.References, &Reference{ joinRefRel.References = append(joinRefRel.References, &Reference{
PrimaryKey: fieldsMap[f.Name], PrimaryKey: rf,
ForeignKey: f,
})
relation.References = append(relation.References, &Reference{
PrimaryKey: rf,
ForeignKey: f, ForeignKey: f,
}) })
} }
relation.References = append(relation.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
ForeignKey: f,
OwnPrimaryKey: ownPrimaryField,
})
} }
} }
} }
@ -381,7 +479,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
schema.guessRelation(relation, field, guessEmbeddedHas) schema.guessRelation(relation, field, guessEmbeddedHas)
// case guessEmbeddedHas: // case guessEmbeddedHas:
default: default:
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name) schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface",
schema, field.Name)
} }
} }
@ -389,34 +488,31 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
case guessBelongs: case guessBelongs:
primarySchema, foreignSchema = relation.FieldSchema, schema primarySchema, foreignSchema = relation.FieldSchema, schema
case guessEmbeddedBelongs: case guessEmbeddedBelongs:
if field.OwnerSchema != nil { if field.OwnerSchema == nil {
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
} else {
reguessOrErr() reguessOrErr()
return return
} }
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
case guessHas: case guessHas:
case guessEmbeddedHas: case guessEmbeddedHas:
if field.OwnerSchema != nil { if field.OwnerSchema == nil {
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
} else {
reguessOrErr() reguessOrErr()
return return
} }
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
} }
if len(relation.foreignKeys) > 0 { if len(relation.foreignKeys) > 0 {
for _, foreignKey := range relation.foreignKeys { for _, foreignKey := range relation.foreignKeys {
if f := foreignSchema.LookUpField(foreignKey); f != nil { f := foreignSchema.LookUpField(foreignKey)
foreignFields = append(foreignFields, f) if f == nil {
} else {
reguessOrErr() reguessOrErr()
return return
} }
foreignFields = append(foreignFields, f)
} }
} else { } else {
var primaryFields []*Field primarySchemaName := primarySchema.Name
var primarySchemaName = primarySchema.Name
if primarySchemaName == "" { if primarySchemaName == "" {
primarySchemaName = relation.FieldSchema.Name primarySchemaName = relation.FieldSchema.Name
} }
@ -431,6 +527,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
primaryFields = primarySchema.PrimaryFields primaryFields = primarySchema.PrimaryFields
} }
primaryFieldLoop:
for _, primaryField := range primaryFields { for _, primaryField := range primaryFields {
lookUpName := primarySchemaName + primaryField.Name lookUpName := primarySchemaName + primaryField.Name
if gl == guessBelongs { if gl == guessBelongs {
@ -439,23 +536,33 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
lookUpNames := []string{lookUpName} lookUpNames := []string{lookUpName}
if len(primaryFields) == 1 { if len(primaryFields) == 1 {
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID",
strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
} }
for _, name := range lookUpNames {
if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil {
foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField)
continue primaryFieldLoop
}
}
for _, name := range lookUpNames { for _, name := range lookUpNames {
if f := foreignSchema.LookUpField(name); f != nil { if f := foreignSchema.LookUpField(name); f != nil {
foreignFields = append(foreignFields, f) foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField) primaryFields = append(primaryFields, primaryField)
break continue primaryFieldLoop
} }
} }
} }
} }
if len(foreignFields) == 0 { switch {
case len(foreignFields) == 0:
reguessOrErr() reguessOrErr()
return return
} else if len(relation.primaryKeys) > 0 { case len(relation.primaryKeys) > 0:
for idx, primaryKey := range relation.primaryKeys { for idx, primaryKey := range relation.primaryKeys {
if f := primarySchema.LookUpField(primaryKey); f != nil { if f := primarySchema.LookUpField(primaryKey); f != nil {
if len(primaryFields) < idx+1 { if len(primaryFields) < idx+1 {
@ -469,7 +576,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
return return
} }
} }
} else if len(primaryFields) == 0 { case len(primaryFields) == 0:
if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil { if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
} else if len(primarySchema.PrimaryFields) == len(foreignFields) { } else if len(primarySchema.PrimaryFields) == len(foreignFields) {
@ -505,6 +612,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
} }
} }
// Constraint is ForeignKey Constraint
type Constraint struct { type Constraint struct {
Name string Name string
Field *Field Field *Field
@ -516,6 +624,31 @@ type Constraint struct {
OnUpdate string OnUpdate string
} }
func (constraint *Constraint) GetName() string { return constraint.Name }
func (constraint *Constraint) Build() (sql string, vars []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys))
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
references := make([]interface{}, 0, len(constraint.References))
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}
func (rel *Relationship) ParseConstraint() *Constraint { func (rel *Relationship) ParseConstraint() *Constraint {
str := rel.Field.TagSettings["CONSTRAINT"] str := rel.Field.TagSettings["CONSTRAINT"]
if str == "-" { if str == "-" {
@ -530,6 +663,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey && if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey &&
rel.References[idx].PrimaryValue == ref.PrimaryValue) { rel.References[idx].PrimaryValue == ref.PrimaryValue) {
matched = false matched = false
break
} }
} }
@ -542,7 +676,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
var ( var (
name string name string
idx = strings.Index(str, ",") idx = strings.IndexByte(str, ',')
settings = ParseTagSetting(str, ",") settings = ParseTagSetting(str, ",")
) )
@ -629,8 +763,9 @@ func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue ref
} }
func copyableDataType(str DataType) bool { func copyableDataType(str DataType) bool {
lowerStr := strings.ToLower(string(str))
for _, s := range []string{"auto_increment", "primary key"} { for _, s := range []string{"auto_increment", "primary key"} {
if strings.Contains(strings.ToLower(string(str)), s) { if strings.Contains(lowerStr, s) {
return false return false
} }
} }

View File

@ -10,7 +10,7 @@ import (
func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) { func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) {
if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil { if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil {
t.Errorf("Failed to parse schema") t.Errorf("Failed to parse schema, got error %v", err)
} else { } else {
for _, rel := range relations { for _, rel := range relations {
checkSchemaRelation(t, s, rel) checkSchemaRelation(t, s, rel)
@ -121,6 +121,29 @@ func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) {
}) })
} }
func TestBelongsToWithMixin(t *testing.T) {
type Profile struct {
gorm.Model
Refer string
Name string
}
type ProfileMixin struct {
Profile Profile `gorm:"References:Refer"`
ProfileRefer int
}
type User struct {
gorm.Model
ProfileMixin
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}},
})
}
func TestHasOneOverrideForeignKey(t *testing.T) { func TestHasOneOverrideForeignKey(t *testing.T) {
type Profile struct { type Profile struct {
gorm.Model gorm.Model
@ -305,6 +328,33 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) {
}) })
} }
func TestMany2ManySharedForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
Name string
Kind string
ProfileRefer uint
}
type User struct {
gorm.Model
Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"`
Kind string
Refer uint
}
checkStructRelation(t, &User{}, Relation{
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"},
References: []Reference{
{"Refer", "User", "UserRefer", "user_profiles", "", true},
{"Kind", "User", "Kind", "user_profiles", "", true},
{"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false},
{"Kind", "Profile", "Kind", "user_profiles", "", false},
},
})
}
func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
type Profile struct { type Profile struct {
gorm.Model gorm.Model
@ -491,6 +541,320 @@ func TestEmbeddedRelation(t *testing.T) {
} }
} }
func TestEmbeddedHas(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type User struct {
ID int
Cat struct {
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
Toys []Toy `gorm:"polymorphic:Owner;"`
} `gorm:"embedded;embeddedPrefix:cat_"`
Dog struct {
ID int
Name string
UserID int
Toy Toy `gorm:"polymorphic:Owner;"`
Toys []Toy `gorm:"polymorphic:Owner;"`
}
Toys []Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
}
func TestPolymorphic(t *testing.T) {
t.Run("has one", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has one with custom polymorphic type and id", func(t *testing.T) {
type Toy struct {
ID int
Name string
RefId int
Type string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type;polymorphicId:RefId"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has one with only polymorphic type", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
Type string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "owner_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has many", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type Cat struct {
ID int
Name string
Toys []Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has many with custom polymorphic type and id", func(t *testing.T) {
type Toy struct {
ID int
Name string
RefId int
Type string
}
type Cat struct {
ID int
Name string
Toys []Toy `gorm:"polymorphicType:Type;polymorphicId:RefId"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
}
func TestEmbeddedBelongsTo(t *testing.T) {
type Country struct {
ID int `gorm:"primaryKey"`
Name string
}
type Address struct {
CountryID int
Country Country
}
type NestedAddress struct {
Address
}
type CountryMixin struct {
CountryID int
Country Country
}
type Org struct {
ID int
PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"`
VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"`
AddressID int
Address struct {
ID int
Address
}
NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
CountryMixin
}
s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Errorf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"PostalAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
"VisitingAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
"NestedAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
})
}
func TestVariableRelation(t *testing.T) { func TestVariableRelation(t *testing.T) {
var result struct { var result struct {
User User
@ -615,7 +979,7 @@ func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) {
s, err := schema.Parse( s, err := schema.Parse(
&Book{}, &Book{},
&sync.Map{}, &sync.Map{},
schema.NamingStrategy{}, schema.NamingStrategy{IdentifierMaxLength: 64},
) )
if err != nil { if err != nil {
t.Fatalf("Failed to parse schema") t.Fatalf("Failed to parse schema")

View File

@ -5,13 +5,29 @@ import (
"errors" "errors"
"fmt" "fmt"
"go/ast" "go/ast"
"path"
"reflect" "reflect"
"strings"
"sync" "sync"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
) )
type callbackType string
const (
callbackTypeBeforeCreate callbackType = "BeforeCreate"
callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
callbackTypeAfterCreate callbackType = "AfterCreate"
callbackTypeAfterUpdate callbackType = "AfterUpdate"
callbackTypeBeforeSave callbackType = "BeforeSave"
callbackTypeAfterSave callbackType = "AfterSave"
callbackTypeBeforeDelete callbackType = "BeforeDelete"
callbackTypeAfterDelete callbackType = "AfterDelete"
callbackTypeAfterFind callbackType = "AfterFind"
)
// ErrUnsupportedDataType unsupported data type // ErrUnsupportedDataType unsupported data type
var ErrUnsupportedDataType = errors.New("unsupported data type") var ErrUnsupportedDataType = errors.New("unsupported data type")
@ -25,6 +41,7 @@ type Schema struct {
PrimaryFieldDBNames []string PrimaryFieldDBNames []string
Fields []*Field Fields []*Field
FieldsByName map[string]*Field FieldsByName map[string]*Field
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
FieldsByDBName map[string]*Field FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
Relationships Relationships Relationships Relationships
@ -51,9 +68,10 @@ func (schema Schema) String() string {
} }
func (schema Schema) MakeSlice() reflect.Value { func (schema Schema) MakeSlice() reflect.Value {
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20) slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20)
results := reflect.New(slice.Type()) results := reflect.New(slice.Type())
results.Elem().Set(slice) results.Elem().Set(slice)
return results return results
} }
@ -67,10 +85,35 @@ func (schema Schema) LookUpField(name string) *Field {
return nil return nil
} }
// LookUpFieldByBindName looks for the closest field in the embedded struct.
//
// type Struct struct {
// Embedded struct {
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
// }
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
// }
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
if len(bindNames) == 0 {
return nil
}
for i := len(bindNames) - 1; i >= 0; i-- {
find := strings.Join(bindNames[:i], ".") + "." + name
if field, ok := schema.FieldsByBindName[find]; ok {
return field
}
}
return nil
}
type Tabler interface { type Tabler interface {
TableName() string TableName() string
} }
type TablerWithNamer interface {
TableName(Namer) string
}
// Parse get data type from dialector // Parse get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, "") return ParseWithSpecialTableName(dest, cacheStore, namer, "")
@ -112,7 +155,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
schemaCacheKey = modelType schemaCacheKey = modelType
} }
// Load exist schmema cache, return if exists // Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok { if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema) s := v.(*Schema)
// Wait for the initialization of other goroutines to complete // Wait for the initialization of other goroutines to complete
@ -125,6 +168,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
if tabler, ok := modelValue.Interface().(Tabler); ok { if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName() tableName = tabler.TableName()
} }
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
tableName = tabler.TableName(namer)
}
if en, ok := namer.(embeddedNamer); ok { if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table tableName = en.Table
} }
@ -133,20 +179,21 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
} }
schema := &Schema{ schema := &Schema{
Name: modelType.Name(), Name: modelType.Name(),
ModelType: modelType, ModelType: modelType,
Table: tableName, Table: tableName,
FieldsByName: map[string]*Field{}, FieldsByName: map[string]*Field{},
FieldsByDBName: map[string]*Field{}, FieldsByBindName: map[string]*Field{},
Relationships: Relationships{Relations: map[string]*Relationship{}}, FieldsByDBName: map[string]*Field{},
cacheStore: cacheStore, Relationships: Relationships{Relations: map[string]*Relationship{}},
namer: namer, cacheStore: cacheStore,
initialized: make(chan struct{}), namer: namer,
initialized: make(chan struct{}),
} }
// When the schema initialization is completed, the channel will be closed // When the schema initialization is completed, the channel will be closed
defer close(schema.initialized) defer close(schema.initialized)
// Load exist schmema cache, return if exists // Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok { if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema) s := v.(*Schema)
// Wait for the initialization of other goroutines to complete // Wait for the initialization of other goroutines to complete
@ -169,6 +216,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
field.DBName = namer.ColumnName(schema.Table, field.Name) field.DBName = namer.ColumnName(schema.Table, field.Name)
} }
bindName := field.BindName()
if field.DBName != "" { if field.DBName != "" {
// nonexistence or shortest path or first appear prioritized if has permission // nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
@ -177,6 +225,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
} }
schema.FieldsByDBName[field.DBName] = field schema.FieldsByDBName[field.DBName] = field
schema.FieldsByName[field.Name] = field schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[bindName] = field
if v != nil && v.PrimaryKey { if v != nil && v.PrimaryKey {
for idx, f := range schema.PrimaryFields { for idx, f := range schema.PrimaryFields {
@ -195,8 +244,11 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByName[field.Name] = field schema.FieldsByName[field.Name] = field
} }
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByBindName[bindName] = field
}
field.setupValuerAndSetter() field.setupValuerAndSetter(modelType)
} }
prioritizedPrimaryField := schema.LookUpField("id") prioritizedPrimaryField := schema.LookUpField("id")
@ -214,8 +266,18 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
} }
} }
if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 { if schema.PrioritizedPrimaryField == nil {
schema.PrioritizedPrimaryField = schema.PrimaryFields[0] if len(schema.PrimaryFields) == 1 {
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
} else if len(schema.PrimaryFields) > 1 {
// If there are multiple primary keys, the AUTOINCREMENT field is prioritized
for _, field := range schema.PrimaryFields {
if field.AutoIncrement {
schema.PrioritizedPrimaryField = field
break
}
}
}
} }
for _, field := range schema.PrimaryFields { for _, field := range schema.PrimaryFields {
@ -223,7 +285,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
} }
for _, field := range schema.Fields { for _, field := range schema.Fields {
if field.HasDefaultValue && field.DefaultValueInterface == nil { if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
} }
} }
@ -242,14 +304,26 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
} }
} }
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} callbackTypes := []callbackType{
for _, name := range callbacks { callbackTypeBeforeCreate, callbackTypeAfterCreate,
if methodValue := modelValue.MethodByName(name); methodValue.IsValid() { callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
callbackTypeBeforeSave, callbackTypeAfterSave,
callbackTypeBeforeDelete, callbackTypeAfterDelete,
callbackTypeAfterFind,
}
for _, cbName := range callbackTypes {
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
switch methodValue.Type().String() { switch methodValue.Type().String() {
case "func(*gorm.DB) error": // TODO hack case "func(*gorm.DB) error":
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath())
if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath {
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
} else {
logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg)
// PASS
}
default: default:
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name) logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
} }
} }
} }
@ -271,11 +345,12 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
for _, field := range schema.Fields { for _, field := range schema.Fields {
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) {
if schema.parseRelation(field); schema.err != nil { if schema.parseRelation(field); schema.err != nil {
return schema, schema.err return schema, schema.err
} else { } else {
schema.FieldsByName[field.Name] = field schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[field.BindName()] = field
} }
} }
@ -302,6 +377,39 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
return schema, schema.err return schema, schema.err
} }
// This unrolling is needed to show to the compiler the exact set of methods
// that can be used on the modelType.
// Prior to go1.22 any use of MethodByName would cause the linker to
// abandon dead code elimination for the entire binary.
// As of go1.22 the compiler supports one special case of a string constant
// being passed to MethodByName. For enterprise customers or those building
// large binaries, this gives a significant reduction in binary size.
// https://github.com/golang/go/issues/62257
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
switch cbType {
case callbackTypeBeforeCreate:
return modelType.MethodByName(string(callbackTypeBeforeCreate))
case callbackTypeAfterCreate:
return modelType.MethodByName(string(callbackTypeAfterCreate))
case callbackTypeBeforeUpdate:
return modelType.MethodByName(string(callbackTypeBeforeUpdate))
case callbackTypeAfterUpdate:
return modelType.MethodByName(string(callbackTypeAfterUpdate))
case callbackTypeBeforeSave:
return modelType.MethodByName(string(callbackTypeBeforeSave))
case callbackTypeAfterSave:
return modelType.MethodByName(string(callbackTypeAfterSave))
case callbackTypeBeforeDelete:
return modelType.MethodByName(string(callbackTypeBeforeDelete))
case callbackTypeAfterDelete:
return modelType.MethodByName(string(callbackTypeAfterDelete))
case callbackTypeAfterFind:
return modelType.MethodByName(string(callbackTypeAfterFind))
default:
return reflect.ValueOf(nil)
}
}
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type() modelType := reflect.ValueOf(dest).Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {

View File

@ -163,8 +163,8 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table) t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table)
} }
for _, f := range relation.JoinTable.Fields { for i := range relation.JoinTable.Fields {
checkSchemaField(t, r.JoinTable, &f, nil) checkSchemaField(t, r.JoinTable, &relation.JoinTable.Fields[i], nil)
} }
} }
@ -201,6 +201,37 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
}) })
} }
type EmbeddedRelations struct {
Relations map[string]Relation
EmbeddedRelations map[string]EmbeddedRelations
}
func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationships, expected map[string]EmbeddedRelations) {
for name, relations := range actual {
rs := expected[name]
t.Run("CheckEmbeddedRelations/"+name, func(t *testing.T) {
if len(relations.Relations) != len(rs.Relations) {
t.Errorf("schema relations count don't match, expects %d, got %d", len(rs.Relations), len(relations.Relations))
}
if len(relations.EmbeddedRelations) != len(rs.EmbeddedRelations) {
t.Errorf("schema embedded relations count don't match, expects %d, got %d", len(rs.EmbeddedRelations), len(relations.EmbeddedRelations))
}
for n, rel := range relations.Relations {
if r, ok := rs.Relations[n]; !ok {
t.Errorf("failed to find relation by name %s", n)
} else {
checkSchemaRelation(t, &schema.Schema{
Relationships: schema.Relationships{
Relations: map[string]*schema.Relationship{n: rel},
},
}, r)
}
}
checkEmbeddedRelations(t, relations.EmbeddedRelations, rs.EmbeddedRelations)
})
}
}
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
for k, v := range values { for k, v := range values {
t.Run("CheckField/"+k, func(t *testing.T) { t.Run("CheckField/"+k, func(t *testing.T) {

View File

@ -19,6 +19,22 @@ func TestParseSchema(t *testing.T) {
checkUserSchema(t, user) checkUserSchema(t, user)
} }
func TestParseSchemaWithMap(t *testing.T) {
type User struct {
tests.User
Attrs map[string]string `gorm:"type:Map(String,String);"`
}
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user with map, got error %v", err)
}
if field := user.FieldsByName["Attrs"]; field.DataType != "Map(String,String)" {
t.Errorf("failed to parse user field Attrs")
}
}
func TestParseSchemaWithPointerFields(t *testing.T) { func TestParseSchemaWithPointerFields(t *testing.T) {
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
@ -46,8 +62,8 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
} }
for _, f := range fields { for i := range fields {
checkSchemaField(t, user, &f, func(f *schema.Field) { checkSchemaField(t, user, &fields[i], func(f *schema.Field) {
f.Creatable = true f.Creatable = true
f.Updatable = true f.Updatable = true
f.Readable = true f.Readable = true
@ -136,8 +152,8 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) {
{Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool}, {Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool},
} }
for _, f := range fields { for i := range fields {
checkSchemaField(t, user, &f, func(f *schema.Field) { checkSchemaField(t, user, &fields[i], func(f *schema.Field) {
f.Creatable = true f.Creatable = true
f.Updatable = true f.Updatable = true
f.Readable = true f.Readable = true
@ -293,3 +309,44 @@ func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) {
}) })
} }
} }
func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) {
type Product struct {
ProductID uint `gorm:"primaryKey;autoIncrement"`
LanguageCode uint `gorm:"primaryKey"`
Code string
Name string
}
type ProductNonAutoIncrement struct {
ProductID uint `gorm:"primaryKey;autoIncrement:false"`
LanguageCode uint `gorm:"primaryKey"`
Code string
Name string
}
product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse product struct with composite primary key, got error %v", err)
}
prioritizedPrimaryField := schema.Field{
Name: "ProductID", DBName: "product_id", BindNames: []string{"ProductID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY", "AUTOINCREMENT": "AUTOINCREMENT"},
}
product.Fields = []*schema.Field{product.PrioritizedPrimaryField}
checkSchemaField(t, product, &prioritizedPrimaryField, func(f *schema.Field) {
f.Creatable = true
f.Updatable = true
f.Readable = true
})
productNonAutoIncrement, err := schema.Parse(&ProductNonAutoIncrement{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse productNonAutoIncrement struct with composite primary key, got error %v", err)
}
if productNonAutoIncrement.PrioritizedPrimaryField != nil {
t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil")
}
}

View File

@ -70,8 +70,7 @@ type SerializerValuerInterface interface {
} }
// JSONSerializer json serializer // JSONSerializer json serializer
type JSONSerializer struct { type JSONSerializer struct{}
}
// Scan implements serializer interface // Scan implements serializer interface
func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
@ -85,10 +84,15 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
case string: case string:
bytes = []byte(v) bytes = []byte(v)
default: default:
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue) bytes, err = json.Marshal(v)
if err != nil {
return err
}
} }
err = json.Unmarshal(bytes, fieldValue.Interface()) if len(bytes) > 0 {
err = json.Unmarshal(bytes, fieldValue.Interface())
}
} }
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
@ -98,18 +102,23 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
// Value implements serializer interface // Value implements serializer interface
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
result, err := json.Marshal(fieldValue) result, err := json.Marshal(fieldValue)
if string(result) == "null" {
if field.TagSettings["NOT NULL"] != "" {
return "", nil
}
return nil, err
}
return string(result), err return string(result), err
} }
// UnixSecondSerializer json serializer // UnixSecondSerializer json serializer
type UnixSecondSerializer struct { type UnixSecondSerializer struct{}
}
// Scan implements serializer interface // Scan implements serializer interface
func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
t := sql.NullTime{} t := sql.NullTime{}
if err = t.Scan(dbValue); err == nil { if err = t.Scan(dbValue); err == nil && t.Valid {
err = field.Set(ctx, dst, t.Time) err = field.Set(ctx, dst, t.Time.Unix())
} }
return return
@ -117,9 +126,15 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.
// Value implements serializer interface // Value implements serializer interface
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) {
rv := reflect.ValueOf(fieldValue)
switch v := fieldValue.(type) { switch v := fieldValue.(type) {
case int64, int, uint, uint64, int32, uint32, int16, uint16: case int64, int, uint, uint64, int32, uint32, int16, uint16:
result = time.Unix(reflect.ValueOf(v).Int(), 0) result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
if rv.IsZero() {
return nil, nil
}
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
default: default:
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
} }
@ -127,8 +142,7 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect
} }
// GobSerializer gob serializer // GobSerializer gob serializer
type GobSerializer struct { type GobSerializer struct{}
}
// Scan implements serializer interface // Scan implements serializer interface
func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
@ -142,8 +156,10 @@ func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
default: default:
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue) return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
} }
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue)) if len(bytesValue) > 0 {
err = decoder.Decode(fieldValue.Interface()) decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
err = decoder.Decode(fieldValue.Interface())
}
} }
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
return return

View File

@ -2,6 +2,7 @@ package schema
import ( import (
"context" "context"
"fmt"
"reflect" "reflect"
"regexp" "regexp"
"strings" "strings"
@ -59,10 +60,18 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct
return tag return tag
} }
func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag {
t := tag.Get("gorm")
if strings.Contains(t, value) {
return tag
}
return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t))
}
// GetRelationsValues get relations's values from a reflect value // GetRelationsValues get relations's values from a reflect value
func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
for _, rel := range rels { for _, rel := range rels {
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1) reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.FieldSchema.ModelType)), 0, 1)
appendToResults := func(value reflect.Value) { appendToResults := func(value reflect.Value) {
if _, isZero := rel.Field.ValueOf(ctx, value); !isZero { if _, isZero := rel.Field.ValueOf(ctx, value); !isZero {
@ -106,6 +115,11 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
notZero, zero bool notZero, zero bool
) )
if reflectValue.Kind() == reflect.Ptr ||
reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem()
}
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Struct: case reflect.Struct:
results = [][]interface{}{make([]interface{}, len(fields))} results = [][]interface{}{make([]interface{}, len(fields))}
@ -124,7 +138,7 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
elem := reflectValue.Index(i) elem := reflectValue.Index(i)
elemKey := elem.Interface() elemKey := elem.Interface()
if elem.Kind() != reflect.Ptr { if elem.Kind() != reflect.Ptr && elem.CanAddr() {
elemKey = elem.Addr().Interface() elemKey = elem.Addr().Interface()
} }

View File

@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"reflect" "reflect"
"github.com/jinzhu/now"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
@ -45,11 +46,21 @@ func (n *DeletedAt) UnmarshalJSON(b []byte) error {
} }
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteQueryClause{Field: f}} return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
}
func parseZeroValueTag(f *schema.Field) sql.NullString {
if v, ok := f.TagSettings["ZEROVALUE"]; ok {
if _, err := now.Parse(v); err == nil {
return sql.NullString{String: v, Valid: true}
}
}
return sql.NullString{Valid: false}
} }
type SoftDeleteQueryClause struct { type SoftDeleteQueryClause struct {
Field *schema.Field ZeroValue sql.NullString
Field *schema.Field
} }
func (sd SoftDeleteQueryClause) Name() string { func (sd SoftDeleteQueryClause) Name() string {
@ -78,18 +89,19 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
} }
stmt.AddClause(clause.Where{Exprs: []clause.Expression{ stmt.AddClause(clause.Where{Exprs: []clause.Expression{
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil}, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue},
}}) }})
stmt.Clauses["soft_delete_enabled"] = clause.Clause{} stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
} }
} }
func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface { func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteUpdateClause{Field: f}} return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
} }
type SoftDeleteUpdateClause struct { type SoftDeleteUpdateClause struct {
Field *schema.Field ZeroValue sql.NullString
Field *schema.Field
} }
func (sd SoftDeleteUpdateClause) Name() string { func (sd SoftDeleteUpdateClause) Name() string {
@ -109,11 +121,12 @@ func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
} }
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteDeleteClause{Field: f}} return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
} }
type SoftDeleteDeleteClause struct { type SoftDeleteDeleteClause struct {
Field *schema.Field ZeroValue sql.NullString
Field *schema.Field
} }
func (sd SoftDeleteDeleteClause) Name() string { func (sd SoftDeleteDeleteClause) Name() string {

View File

@ -30,8 +30,9 @@ type Statement struct {
Clauses map[string]clause.Clause Clauses map[string]clause.Clause
BuildClauses []string BuildClauses []string
Distinct bool Distinct bool
Selects []string // selected columns Selects []string // selected columns
Omits []string // omit columns Omits []string // omit columns
ColumnMapping map[string]string // map columns
Joins []join Joins []join
Preloads map[string][]interface{} Preloads map[string][]interface{}
Settings sync.Map Settings sync.Map
@ -46,12 +47,18 @@ type Statement struct {
attrs []interface{} attrs []interface{}
assigns []interface{} assigns []interface{}
scopes []func(*DB) *DB scopes []func(*DB) *DB
Result *result
} }
type join struct { type join struct {
Name string Name string
Conds []interface{} Alias string
On *clause.Where Conds []interface{}
On *clause.Where
Selects []string
Omits []string
Expression clause.Expression
JoinType clause.JoinType
} }
// StatementModifier statement modifier interface // StatementModifier statement modifier interface
@ -117,6 +124,8 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
} else if len(stmt.Schema.DBNames) > 0 { } else if len(stmt.Schema.DBNames) > 0 {
write(v.Raw, stmt.Schema.DBNames[0]) write(v.Raw, stmt.Schema.DBNames[0])
} else {
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
} }
} else { } else {
write(v.Raw, v.Name) write(v.Raw, v.Name)
@ -179,6 +188,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
} else { } else {
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
} }
case clause.Interface:
c := clause.Clause{Name: v.Name()}
v.MergeClause(&c)
c.Build(stmt)
case clause.Expression: case clause.Expression:
v.Build(stmt) v.Build(stmt)
case driver.Valuer: case driver.Valuer:
@ -195,19 +208,21 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
} else { } else {
writer.WriteString("(NULL)") writer.WriteString("(NULL)")
} }
case *DB: case interface{ getInstance() *DB }:
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() cv := v.getInstance()
if v.Statement.SQL.Len() > 0 {
subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
if cv.Statement.SQL.Len() > 0 {
var ( var (
vars = subdb.Statement.Vars vars = subdb.Statement.Vars
sql = v.Statement.SQL.String() sql = cv.Statement.SQL.String()
) )
subdb.Statement.Vars = make([]interface{}, 0, len(vars)) subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars { for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv) subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{} bindvar := strings.Builder{}
v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv) cv.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1) sql = strings.Replace(sql, bindvar.String(), "?", 1)
} }
@ -304,19 +319,31 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
conds := make([]clause.Expression, 0, 4) conds := make([]clause.Expression, 0, 4)
args = append([]interface{}{query}, args...) args = append([]interface{}{query}, args...)
for idx, arg := range args { for idx, arg := range args {
if arg == nil {
continue
}
if valuer, ok := arg.(driver.Valuer); ok { if valuer, ok := arg.(driver.Valuer); ok {
arg, _ = valuer.Value() arg, _ = valuer.Value()
} }
curTable := stmt.Table
if curTable == "" {
curTable = clause.CurrentTable
}
switch v := arg.(type) { switch v := arg.(type) {
case clause.Expression: case clause.Expression:
conds = append(conds, v) conds = append(conds, v)
case *DB: case *DB:
v.executeScopes()
if cs, ok := v.Statement.Clauses["WHERE"]; ok { if cs, ok := v.Statement.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok { if where, ok := cs.Expression.(clause.Where); ok {
if len(where.Exprs) == 1 { if len(where.Exprs) == 1 {
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
where.Exprs[0] = clause.AndConditions(orConds) if len(orConds.Exprs) == 1 {
where.Exprs[0] = clause.AndConditions(orConds)
}
} }
} }
conds = append(conds, clause.And(where.Exprs...)) conds = append(conds, clause.And(where.Exprs...))
@ -336,7 +363,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
sort.Strings(keys) sort.Strings(keys)
for _, key := range keys { for _, key := range keys {
conds = append(conds, clause.Eq{Column: key, Value: v[key]}) column := clause.Column{Name: key, Table: curTable}
if strings.Contains(key, ".") {
column = clause.Column{Name: key}
}
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
} }
case map[string]interface{}: case map[string]interface{}:
keys := make([]string, 0, len(v)) keys := make([]string, 0, len(v))
@ -347,12 +378,16 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
for _, key := range keys { for _, key := range keys {
reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
column := clause.Column{Name: key, Table: curTable}
if strings.Contains(key, ".") {
column = clause.Column{Name: key}
}
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if _, ok := v[key].(driver.Valuer); ok { if _, ok := v[key].(driver.Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]}) conds = append(conds, clause.Eq{Column: column, Value: v[key]})
} else if _, ok := v[key].(Valuer); ok { } else if _, ok := v[key].(Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]}) conds = append(conds, clause.Eq{Column: column, Value: v[key]})
} else { } else {
// optimize reflect value length // optimize reflect value length
valueLen := reflectValue.Len() valueLen := reflectValue.Len()
@ -361,10 +396,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
values[i] = reflectValue.Index(i).Interface() values[i] = reflectValue.Index(i).Interface()
} }
conds = append(conds, clause.IN{Column: key, Values: values}) conds = append(conds, clause.IN{Column: column, Values: values})
} }
default: default:
conds = append(conds, clause.Eq{Column: key, Value: v[key]}) conds = append(conds, clause.Eq{Column: column, Value: v[key]})
} }
} }
default: default:
@ -391,9 +426,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
if selected || (!restricted && field.Readable) { if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
if field.DBName != "" { if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" { } else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
} }
} }
} }
@ -405,9 +440,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
if selected || (!restricted && field.Readable) { if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
if field.DBName != "" { if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" { } else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
} }
} }
} }
@ -432,18 +467,22 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
} }
if len(values) > 0 { if len(values) > 0 {
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values})
return []clause.Expression{clause.And(conds...)}
} }
return conds return nil
} }
} }
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args})
} }
} }
} }
return conds if len(conds) > 0 {
return []clause.Expression{clause.And(conds...)}
}
return nil
} }
// Build build sql with clauses names // Build build sql with clauses names
@ -495,12 +534,14 @@ func (stmt *Statement) clone() *Statement {
Distinct: stmt.Distinct, Distinct: stmt.Distinct,
Selects: stmt.Selects, Selects: stmt.Selects,
Omits: stmt.Omits, Omits: stmt.Omits,
ColumnMapping: stmt.ColumnMapping,
Preloads: map[string][]interface{}{}, Preloads: map[string][]interface{}{},
ConnPool: stmt.ConnPool, ConnPool: stmt.ConnPool,
Schema: stmt.Schema, Schema: stmt.Schema,
Context: stmt.Context, Context: stmt.Context,
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
SkipHooks: stmt.SkipHooks, SkipHooks: stmt.SkipHooks,
Result: stmt.Result,
} }
if stmt.SQL.Len() > 0 { if stmt.SQL.Len() > 0 {
@ -536,8 +577,9 @@ func (stmt *Statement) clone() *Statement {
} }
// SetColumn set column's value // SetColumn set column's value
// stmt.SetColumn("Name", "jinzhu") // Hooks Method //
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method // stmt.SetColumn("Name", "jinzhu") // Hooks Method
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
if v, ok := stmt.Dest.(map[string]interface{}); ok { if v, ok := stmt.Dest.(map[string]interface{}); ok {
v[name] = value v[name] = value
@ -605,10 +647,10 @@ func (stmt *Statement) Changed(fields ...string) bool {
changed := func(field *schema.Field) bool { changed := func(field *schema.Field) bool {
fieldValue, _ := field.ValueOf(stmt.Context, modelValue) fieldValue, _ := field.ValueOf(stmt.Context, modelValue)
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, ok := stmt.Dest.(map[string]interface{}); ok { if mv, mok := stmt.Dest.(map[string]interface{}); mok {
if fv, ok := v[field.Name]; ok { if fv, ok := mv[field.Name]; ok {
return !utils.AssertEqual(fv, fieldValue) return !utils.AssertEqual(fv, fieldValue)
} else if fv, ok := v[field.DBName]; ok { } else if fv, ok := mv[field.DBName]; ok {
return !utils.AssertEqual(fv, fieldValue) return !utils.AssertEqual(fv, fieldValue)
} }
} else { } else {
@ -618,6 +660,9 @@ func (stmt *Statement) Changed(fields ...string) bool {
} }
changedValue, zero := field.ValueOf(stmt.Context, destValue) changedValue, zero := field.ValueOf(stmt.Context, destValue)
if v {
return !utils.AssertEqual(changedValue, fieldValue)
}
return !zero && !utils.AssertEqual(changedValue, fieldValue) return !zero && !utils.AssertEqual(changedValue, fieldValue)
} }
} }
@ -643,54 +688,62 @@ func (stmt *Statement) Changed(fields ...string) bool {
return false return false
} }
var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_]+?)[\W]?\.[\W]?([a-z_]+?)[\W]?$`) var matchName = func() func(tableColumn string) (table, column string) {
nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
return func(tableColumn string) (table, column string) {
if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
table = matches[1]
star := matches[2]
columnName := matches[3]
if star != "" {
return table, star
}
return table, columnName
}
return "", ""
}
}()
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
results := map[string]bool{} results := map[string]bool{}
notRestricted := false notRestricted := false
// select columns processColumn := func(column string, result bool) {
for _, column := range stmt.Selects {
if stmt.Schema == nil { if stmt.Schema == nil {
results[column] = true results[column] = result
} else if column == "*" { } else if column == "*" {
notRestricted = true notRestricted = result
for _, dbName := range stmt.Schema.DBNames { for _, dbName := range stmt.Schema.DBNames {
results[dbName] = true results[dbName] = result
} }
} else if column == clause.Associations { } else if column == clause.Associations {
for _, rel := range stmt.Schema.Relationships.Relations { for _, rel := range stmt.Schema.Relationships.Relations {
results[rel.Name] = true results[rel.Name] = result
} }
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = true results[field.DBName] = result
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 { } else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") {
results[matches[1]] = true if col == "*" {
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = result
}
} else {
results[col] = result
}
} else { } else {
results[column] = true results[column] = result
} }
} }
// select columns
for _, column := range stmt.Selects {
processColumn(column, true)
}
// omit columns // omit columns
for _, omit := range stmt.Omits { for _, column := range stmt.Omits {
if stmt.Schema == nil { processColumn(column, false)
results[omit] = false
} else if omit == "*" {
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = false
}
} else if omit == clause.Associations {
for _, rel := range stmt.Schema.Relationships.Relations {
results[rel.Name] = false
}
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
results[field.DBName] = false
} else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 {
results[matches[1]] = false
} else {
results[omit] = false
}
} }
if stmt.Schema != nil { if stmt.Schema != nil {

View File

@ -35,15 +35,36 @@ func TestWhereCloneCorruption(t *testing.T) {
} }
} }
func TestNilCondition(t *testing.T) {
s := new(Statement)
if len(s.BuildCondition(nil)) != 0 {
t.Errorf("Nil condition should be empty")
}
}
func TestNameMatcher(t *testing.T) { func TestNameMatcher(t *testing.T) {
for k, v := range map[string]string{ for k, v := range map[string][]string{
"table.name": "name", "table.name": {"table", "name"},
"`table`.`name`": "name", "`table`.`name`": {"table", "name"},
"'table'.'name'": "name", "'table'.'name'": {"table", "name"},
"'table'.name": "name", "'table'.name": {"table", "name"},
"table1.name_23": {"table1", "name_23"},
"`table_1`.`name23`": {"table_1", "name23"},
"'table23'.'name_1'": {"table23", "name_1"},
"'table23'.name1": {"table23", "name1"},
"'name1'": {"", "name1"},
"`name_1`": {"", "name_1"},
"`Name_1`": {"", "Name_1"},
"`Table`.`nAme`": {"Table", "nAme"},
"my_table.*": {"my_table", "*"},
"`my_table`.*": {"my_table", "*"},
"User__Company.*": {"User__Company", "*"},
"`User__Company`.*": {"User__Company", "*"},
`"User__Company".*`: {"User__Company", "*"},
`"table"."*"`: {"", ""},
} { } {
if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 2 || matches[1] != v { if table, column := matchName(k); table != v[0] || column != v[1] {
t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) t.Errorf("failed to match value: %v, got %v, expect: %v", k, []string{table, column}, v)
} }
} }
} }

View File

@ -3,6 +3,7 @@ package tests_test
import ( import (
"testing" "testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -137,6 +138,7 @@ func TestBelongsToAssociation(t *testing.T) {
unexistCompanyID := company.ID + 9999999 unexistCompanyID := company.ID + 9999999
user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID} user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID}
if err := DB.Create(&user).Error; err == nil { if err := DB.Create(&user).Error; err == nil {
tidbSkip(t, "not support the foreign key feature")
t.Errorf("should have gotten foreign key violation error") t.Errorf("should have gotten foreign key violation error")
} }
} }
@ -224,3 +226,81 @@ func TestBelongsToAssociationForSlice(t *testing.T) {
AssertAssociationCount(t, users[0], "Company", 0, "After Delete") AssertAssociationCount(t, users[0], "Company", 0, "After Delete")
AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete") AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete")
} }
func TestBelongsToDefaultValue(t *testing.T) {
type Org struct {
ID string
}
type BelongsToUser struct {
OrgID string
Org Org `gorm:"default:NULL"`
}
tx := DB.Session(&gorm.Session{})
tx.Config.DisableForeignKeyConstraintWhenMigrating = true
AssertEqual(t, DB.Config.DisableForeignKeyConstraintWhenMigrating, false)
tx.Migrator().DropTable(&BelongsToUser{}, &Org{})
tx.AutoMigrate(&BelongsToUser{}, &Org{})
user := &BelongsToUser{
Org: Org{
ID: "BelongsToUser_Org_1",
},
}
err := DB.Create(&user).Error
AssertEqual(t, err, nil)
}
func TestBelongsToAssociationUnscoped(t *testing.T) {
type ItemParent struct {
gorm.Model
Logo string `gorm:"not null;type:varchar(50)"`
}
type ItemChild struct {
gorm.Model
Name string `gorm:"type:varchar(50)"`
ItemParentID uint
ItemParent ItemParent
}
tx := DB.Session(&gorm.Session{})
tx.Migrator().DropTable(&ItemParent{}, &ItemChild{})
tx.AutoMigrate(&ItemParent{}, &ItemChild{})
item := ItemChild{
Name: "name",
ItemParent: ItemParent{
Logo: "logo",
},
}
if err := tx.Create(&item).Error; err != nil {
t.Fatalf("failed to create items, got error: %v", err)
}
// test replace
if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{
Logo: "updated logo",
}); err != nil {
t.Errorf("failed to replace item parent, got error: %v", err)
}
var parents []ItemParent
if err := tx.Find(&parents).Error; err != nil {
t.Errorf("failed to find item parent, got error: %v", err)
}
if len(parents) != 1 {
t.Errorf("expected %d parents, got %d", 1, len(parents))
}
// test delete
if err := tx.Model(&item).Association("ItemParent").Unscoped().Delete(&parents); err != nil {
t.Errorf("failed to delete item parent, got error: %v", err)
}
if err := tx.Find(&parents).Error; err != nil {
t.Errorf("failed to find item parent, got error: %v", err)
}
if len(parents) != 0 {
t.Errorf("expected %d parents, got %d", 0, len(parents))
}
}

View File

@ -3,6 +3,7 @@ package tests_test
import ( import (
"testing" "testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -421,7 +422,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) {
func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
users := []User{ users := []User{
*GetUser("slice-hasmany-1", Config{Toys: 2}), *GetUser("slice-hasmany-1", Config{Toys: 2}),
*GetUser("slice-hasmany-2", Config{Toys: 0}), *GetUser("slice-hasmany-2", Config{Toys: 0, Tools: 2}),
*GetUser("slice-hasmany-3", Config{Toys: 4}), *GetUser("slice-hasmany-3", Config{Toys: 4}),
} }
@ -429,6 +430,7 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
// Count // Count
AssertAssociationCount(t, users, "Toys", 6, "") AssertAssociationCount(t, users, "Toys", 6, "")
AssertAssociationCount(t, users, "Tools", 2, "")
// Find // Find
var toys []Toy var toys []Toy
@ -436,6 +438,14 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
t.Errorf("toys count should be %v, but got %v", 6, len(toys)) t.Errorf("toys count should be %v, but got %v", 6, len(toys))
} }
// Find Tools (polymorphic with custom type and id)
var tools []Tools
DB.Model(&users).Association("Tools").Find(&tools)
if len(tools) != 2 {
t.Errorf("tools count should be %v, but got %v", 2, len(tools))
}
// Append // Append
DB.Model(&users).Association("Toys").Append( DB.Model(&users).Association("Toys").Append(
&Toy{Name: "toy-slice-append-1"}, &Toy{Name: "toy-slice-append-1"},
@ -471,3 +481,88 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
DB.Model(&users).Association("Toys").Clear() DB.Model(&users).Association("Toys").Clear()
AssertAssociationCount(t, users, "Toys", 0, "After Clear") AssertAssociationCount(t, users, "Toys", 0, "After Clear")
} }
func TestHasManyAssociationUnscoped(t *testing.T) {
type ItemContent struct {
gorm.Model
ItemID uint `gorm:"not null"`
Name string `gorm:"not null;type:varchar(50)"`
LanguageCode string `gorm:"not null;type:varchar(2)"`
}
type Item struct {
gorm.Model
Logo string `gorm:"not null;type:varchar(50)"`
Contents []ItemContent `gorm:"foreignKey:ItemID"`
}
tx := DB.Session(&gorm.Session{})
tx.Migrator().DropTable(&ItemContent{}, &Item{})
tx.AutoMigrate(&ItemContent{}, &Item{})
item := Item{
Logo: "logo",
Contents: []ItemContent{
{Name: "name", LanguageCode: "en"},
{Name: "ar name", LanguageCode: "ar"},
},
}
if err := tx.Create(&item).Error; err != nil {
t.Fatalf("failed to create items, got error: %v", err)
}
// test Replace
if err := tx.Model(&item).Association("Contents").Unscoped().Replace([]ItemContent{
{Name: "updated name", LanguageCode: "en"},
{Name: "ar updated name", LanguageCode: "ar"},
{Name: "le nom", LanguageCode: "fr"},
}); err != nil {
t.Errorf("failed to replace item content, got error: %v", err)
}
if count := tx.Model(&item).Association("Contents").Count(); count != 3 {
t.Errorf("expected %d contents, got %d", 3, count)
}
var contents []ItemContent
if err := tx.Find(&contents).Error; err != nil {
t.Errorf("failed to find contents, got error: %v", err)
}
if len(contents) != 3 {
t.Errorf("expected %d contents, got %d", 3, len(contents))
}
// test delete
if err := tx.Model(&item).Association("Contents").Unscoped().Delete(&contents[0]); err != nil {
t.Errorf("failed to delete Contents, got error: %v", err)
}
if count := tx.Model(&item).Association("Contents").Count(); count != 2 {
t.Errorf("expected %d contents, got %d", 2, count)
}
// test clear
if err := tx.Model(&item).Association("Contents").Unscoped().Clear(); err != nil {
t.Errorf("failed to clear contents association, got error: %v", err)
}
if count := tx.Model(&item).Association("Contents").Count(); count != 0 {
t.Errorf("expected %d contents, got %d", 0, count)
}
if err := tx.Find(&contents).Error; err != nil {
t.Errorf("failed to find contents, got error: %v", err)
}
if len(contents) != 0 {
t.Errorf("expected %d contents, got %d", 0, len(contents))
}
}
func TestHasManyAssociationReplaceWithNonValidValue(t *testing.T) {
user := User{Name: "jinzhu", Languages: []Language{{Name: "EN"}}}
if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
if err := DB.Model(&user).Association("Languages").Replace(Language{Name: "DE"}, Language{Name: "FR"}); err == nil {
t.Error("expected association error to be not nil")
}
}

View File

@ -255,3 +255,15 @@ func TestPolymorphicHasOneAssociationForSlice(t *testing.T) {
DB.Model(&pets).Association("Toy").Clear() DB.Model(&pets).Association("Toy").Clear()
AssertAssociationCount(t, pets, "Toy", 0, "After Clear") AssertAssociationCount(t, pets, "Toy", 0, "After Clear")
} }
func TestHasOneAssociationReplaceWithNonValidValue(t *testing.T) {
user := User{Name: "jinzhu", Account: Account{Number: "1"}}
if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
if err := DB.Model(&user).Association("Languages").Replace(Account{Number: "2"}); err == nil {
t.Error("expected association error to be not nil")
}
}

View File

@ -1,8 +1,12 @@
package tests_test package tests_test
import ( import (
"fmt"
"sync"
"testing" "testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -94,6 +98,8 @@ func TestMany2ManyAssociation(t *testing.T) {
} }
func TestMany2ManyOmitAssociations(t *testing.T) { func TestMany2ManyOmitAssociations(t *testing.T) {
tidbSkip(t, "not support the foreign key feature")
user := *GetUser("many2many_omit_associations", Config{Languages: 2}) user := *GetUser("many2many_omit_associations", Config{Languages: 2})
if err := DB.Omit("Languages.*").Create(&user).Error; err == nil { if err := DB.Omit("Languages.*").Create(&user).Error; err == nil {
@ -324,3 +330,96 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) {
DB.Model(&users).Association("Team").Clear() DB.Model(&users).Association("Team").Clear()
AssertAssociationCount(t, users, "Team", 0, "After Clear") AssertAssociationCount(t, users, "Team", 0, "After Clear")
} }
func TestDuplicateMany2ManyAssociation(t *testing.T) {
user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{
{Code: "TestDuplicateMany2ManyAssociation-language-1"},
{Code: "TestDuplicateMany2ManyAssociation-language-2"},
}}
user2 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{
{Code: "TestDuplicateMany2ManyAssociation-language-1"},
{Code: "TestDuplicateMany2ManyAssociation-language-3"},
}}
users := []*User{&user1, &user2}
var err error
err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error
AssertEqual(t, nil, err)
var findUser1 User
err = DB.Preload("Languages").Where("id = ?", user1.ID).First(&findUser1).Error
AssertEqual(t, nil, err)
AssertEqual(t, user1, findUser1)
var findUser2 User
err = DB.Preload("Languages").Where("id = ?", user2.ID).First(&findUser2).Error
AssertEqual(t, nil, err)
AssertEqual(t, user2, findUser2)
}
func TestConcurrentMany2ManyAssociation(t *testing.T) {
db, err := OpenTestConnection(&gorm.Config{})
if err != nil {
t.Fatalf("open test connection failed, err: %+v", err)
}
count := 3
var languages []Language
for i := 0; i < count; i++ {
language := Language{Code: fmt.Sprintf("consurrent %d", i)}
db.Create(&language)
languages = append(languages, language)
}
user := User{}
db.Create(&user)
db.Preload("Languages").FirstOrCreate(&user)
var wg sync.WaitGroup
for i := 0; i < count; i++ {
wg.Add(1)
go func(user User, language Language) {
err := db.Model(&user).Association("Languages").Append(&language)
AssertEqual(t, err, nil)
wg.Done()
}(user, languages[i])
}
wg.Wait()
var find User
err = db.Preload(clause.Associations).Where("id = ?", user.ID).First(&find).Error
AssertEqual(t, err, nil)
AssertAssociationCount(t, find, "Languages", int64(count), "after concurrent append")
}
func TestMany2ManyDuplicateBelongsToAssociation(t *testing.T) {
user1 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-1", Friends: []*User{
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-1", Company: Company{
ID: 1,
Name: "Test-company-1",
}},
}}
user2 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-2", Friends: []*User{
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-2", Company: Company{
ID: 1,
Name: "Test-company-1",
}},
}}
users := []*User{&user1, &user2}
var err error
err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error
AssertEqual(t, nil, err)
var findUser1 User
err = DB.Preload("Friends.Company").Where("id = ?", user1.ID).First(&findUser1).Error
AssertEqual(t, nil, err)
AssertEqual(t, user1, findUser1)
var findUser2 User
err = DB.Preload("Friends.Company").Where("id = ?", user2.ID).First(&findUser2).Error
AssertEqual(t, nil, err)
AssertEqual(t, user2, findUser2)
}

View File

@ -4,6 +4,8 @@ import (
"testing" "testing"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -69,6 +71,8 @@ func TestAssociationNotNullClear(t *testing.T) {
} }
func TestForeignKeyConstraints(t *testing.T) { func TestForeignKeyConstraints(t *testing.T) {
tidbSkip(t, "not support the foreign key feature")
type Profile struct { type Profile struct {
ID uint ID uint
Name string Name string
@ -124,6 +128,8 @@ func TestForeignKeyConstraints(t *testing.T) {
} }
func TestForeignKeyConstraintsBelongsTo(t *testing.T) { func TestForeignKeyConstraintsBelongsTo(t *testing.T) {
tidbSkip(t, "not support the foreign key feature")
type Profile struct { type Profile struct {
ID uint ID uint
Name string Name string
@ -284,3 +290,107 @@ func TestAssociationError(t *testing.T) {
err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages) err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages)
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
} }
type (
myType string
emptyQueryClause struct {
Field *schema.Field
}
)
func (myType) QueryClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{emptyQueryClause{Field: f}}
}
func (sd emptyQueryClause) Name() string {
return "empty"
}
func (sd emptyQueryClause) Build(clause.Builder) {
}
func (sd emptyQueryClause) MergeClause(*clause.Clause) {
}
func (sd emptyQueryClause) ModifyStatement(stmt *gorm.Statement) {
// do nothing
}
func TestAssociationEmptyQueryClause(t *testing.T) {
type Organization struct {
gorm.Model
Name string
}
type Region struct {
gorm.Model
Name string
Organizations []Organization `gorm:"many2many:region_orgs;"`
}
type RegionOrg struct {
RegionId uint
OrganizationId uint
Empty myType
}
if err := DB.SetupJoinTable(&Region{}, "Organizations", &RegionOrg{}); err != nil {
t.Fatalf("Failed to set up join table, got error: %s", err)
}
if err := DB.Migrator().DropTable(&Organization{}, &Region{}); err != nil {
t.Fatalf("Failed to migrate, got error: %s", err)
}
if err := DB.AutoMigrate(&Organization{}, &Region{}); err != nil {
t.Fatalf("Failed to migrate, got error: %v", err)
}
region := &Region{Name: "Region1"}
if err := DB.Create(region).Error; err != nil {
t.Fatalf("fail to create region %v", err)
}
var orgs []Organization
if err := DB.Model(&Region{}).Association("Organizations").Find(&orgs); err != nil {
t.Fatalf("fail to find region organizations %v", err)
} else {
AssertEqual(t, len(orgs), 0)
}
}
type AssociationEmptyUser struct {
ID uint
Name string
Pets []AssociationEmptyPet
}
type AssociationEmptyPet struct {
AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"`
Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"`
}
func TestAssociationEmptyPrimaryKey(t *testing.T) {
if DB.Dialector.Name() != "mysql" {
t.Skip()
}
DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{})
DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{})
id := uint(100)
user := AssociationEmptyUser{
ID: id,
Name: "jinzhu",
Pets: []AssociationEmptyPet{
{AssociationEmptyUserID: &id, Name: "bar"},
{AssociationEmptyUserID: &id, Name: "foo"},
},
}
err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error
if err != nil {
t.Fatalf("Failed to create, got error: %v", err)
}
var result AssociationEmptyUser
err = DB.Preload("Pets").First(&result, &id).Error
if err != nil {
t.Fatalf("Failed to find, got error: %v", err)
}
AssertEqual(t, result, user)
}

View File

@ -1,6 +1,7 @@
package tests_test package tests_test
import ( import (
"fmt"
"testing" "testing"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
@ -24,6 +25,45 @@ func BenchmarkFind(b *testing.B) {
} }
} }
func BenchmarkScan(b *testing.B) {
user := *GetUser("scan", Config{})
DB.Create(&user)
var u User
b.ResetTimer()
for x := 0; x < b.N; x++ {
DB.Raw("select * from users where id = ?", user.ID).Scan(&u)
}
}
func BenchmarkScanSlice(b *testing.B) {
DB.Exec("delete from users")
for i := 0; i < 10_000; i++ {
user := *GetUser(fmt.Sprintf("scan-%d", i), Config{})
DB.Create(&user)
}
var u []User
b.ResetTimer()
for x := 0; x < b.N; x++ {
DB.Raw("select * from users").Scan(&u)
}
}
func BenchmarkScanSlicePointer(b *testing.B) {
DB.Exec("delete from users")
for i := 0; i < 10_000; i++ {
user := *GetUser(fmt.Sprintf("scan-%d", i), Config{})
DB.Create(&user)
}
var u []*User
b.ResetTimer()
for x := 0; x < b.N; x++ {
DB.Raw("select * from users").Scan(&u)
}
}
func BenchmarkUpdate(b *testing.B) { func BenchmarkUpdate(b *testing.B) {
user := *GetUser("find", Config{}) user := *GetUser("find", Config{})
DB.Create(&user) DB.Create(&user)

View File

@ -38,6 +38,7 @@ func c2(*gorm.DB) {}
func c3(*gorm.DB) {} func c3(*gorm.DB) {}
func c4(*gorm.DB) {} func c4(*gorm.DB) {}
func c5(*gorm.DB) {} func c5(*gorm.DB) {}
func c6(*gorm.DB) {}
func TestCallbacks(t *testing.T) { func TestCallbacks(t *testing.T) {
type callback struct { type callback struct {
@ -90,7 +91,7 @@ func TestCallbacks(t *testing.T) {
}, },
{ {
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
results: []string{"c1", "c5", "c3", "c4"}, results: []string{"c1", "c3", "c4", "c5"},
}, },
{ {
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
@ -112,6 +113,9 @@ func TestCallbacks(t *testing.T) {
for idx, data := range datas { for idx, data := range datas {
db, err := gorm.Open(nil, nil) db, err := gorm.Open(nil, nil)
if err != nil {
t.Fatal(err)
}
callbacks := db.Callback() callbacks := db.Callback()
for _, c := range data.callbacks { for _, c := range data.callbacks {
@ -168,3 +172,83 @@ func TestCallbacks(t *testing.T) {
} }
} }
} }
func TestPluginCallbacks(t *testing.T) {
db, _ := gorm.Open(nil, nil)
createCallback := db.Callback().Create()
createCallback.Before("*").Register("plugin_1_fn1", c1)
createCallback.After("*").Register("plugin_1_fn2", c2)
if ok, msg := assertCallbacks(createCallback, []string{"c1", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
// plugin 2
createCallback.Before("*").Register("plugin_2_fn1", c3)
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
createCallback.After("*").Register("plugin_2_fn2", c4)
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2", "c4"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
// plugin 3
createCallback.Before("*").Register("plugin_3_fn1", c5)
if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
createCallback.After("*").Register("plugin_3_fn2", c6)
if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4", "c6"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
}
func TestCallbacksGet(t *testing.T) {
db, _ := gorm.Open(nil, nil)
createCallback := db.Callback().Create()
createCallback.Before("*").Register("c1", c1)
if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) {
t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1)
}
createCallback.Remove("c1")
if cb := createCallback.Get("c2"); cb != nil {
t.Errorf("callbacks test failed. got: %p, want: nil", cb)
}
}
func TestCallbacksRemove(t *testing.T) {
db, _ := gorm.Open(nil, nil)
createCallback := db.Callback().Create()
createCallback.Before("*").Register("c1", c1)
createCallback.After("*").Register("c2", c2)
createCallback.Before("c4").Register("c3", c3)
createCallback.After("c2").Register("c4", c4)
// callbacks: []string{"c1", "c3", "c4", "c2"}
createCallback.Remove("c1")
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
createCallback.Remove("c4")
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
createCallback.Remove("c2")
if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
createCallback.Remove("c3")
if ok, msg := assertCallbacks(createCallback, []string{}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
}

View File

@ -0,0 +1,88 @@
package tests_test
import (
"fmt"
"strings"
"testing"
"gorm.io/gorm"
)
type Man struct {
ID int
Age int
Name string
Detail string
}
// Panic-safe BeforeUpdate hook that checks for Changed("age")
func (m *Man) BeforeUpdate(tx *gorm.DB) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in BeforeUpdate: %v", r)
}
}()
if !tx.Statement.Changed("age") {
return nil
}
return nil
}
func (m *Man) update(data interface{}) error {
return DB.Set("data", data).Model(m).Where("id = ?", m.ID).Updates(data).Error
}
func TestBeforeUpdateStatementChanged(t *testing.T) {
DB.AutoMigrate(&Man{})
type TestCase struct {
BaseObjects Man
change interface{}
expectError bool
}
testCases := []TestCase{
{
BaseObjects: Man{ID: 1, Age: 18, Name: "random-name"},
change: struct {
Age int
}{Age: 20},
expectError: false,
},
{
BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"},
change: struct {
Name string
}{Name: "name-only"},
expectError: true,
},
{
BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"},
change: struct {
Name string
Age int
}{Name: "name-only", Age: 20},
expectError: false,
},
}
for _, test := range testCases {
DB.Create(&test.BaseObjects)
// below comment is stored for future reference
// err := DB.Set("data", test.change).Model(&test.BaseObjects).Where("id = ?", test.BaseObjects.ID).Updates(test.change).Error
err := test.BaseObjects.update(test.change)
if strings.Contains(fmt.Sprint(err), "panic in BeforeUpdate") {
if !test.expectError {
t.Errorf("unexpected panic in BeforeUpdate for input: %+v\nerror: %v", test.change, err)
}
} else {
if test.expectError {
t.Errorf("expected panic did not occur for input: %+v", test.change)
}
if err != nil {
t.Errorf("unexpected GORM error: %v", err)
}
}
}
}

View File

@ -1,10 +1,8 @@
version: '3'
services: services:
mysql: mysql:
image: 'mysql/mysql-server:latest' image: 'mysql:latest'
ports: ports:
- 9910:3306 - "127.0.0.1:9910:3306"
environment: environment:
- MYSQL_DATABASE=gorm - MYSQL_DATABASE=gorm
- MYSQL_USER=gorm - MYSQL_USER=gorm
@ -13,19 +11,22 @@ services:
postgres: postgres:
image: 'postgres:latest' image: 'postgres:latest'
ports: ports:
- 9920:5432 - "127.0.0.1:9920:5432"
environment: environment:
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
- POSTGRES_DB=gorm - POSTGRES_DB=gorm
- POSTGRES_USER=gorm - POSTGRES_USER=gorm
- POSTGRES_PASSWORD=gorm - POSTGRES_PASSWORD=gorm
mssql: mssql:
image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest' image: '${MSSQL_IMAGE}:latest'
ports: ports:
- 9930:1433 - "127.0.0.1:9930:1433"
environment: environment:
- TZ=Asia/Shanghai
- ACCEPT_EULA=Y - ACCEPT_EULA=Y
- SA_PASSWORD=LoremIpsum86 - MSSQL_SA_PASSWORD=LoremIpsum86
- MSSQL_DB=gorm tidb:
- MSSQL_USER=gorm image: 'pingcap/tidb:v6.5.0'
- MSSQL_PASSWORD=LoremIpsum86 ports:
- "127.0.0.1:9940:4000"
command: /tidb-server -store unistore -path "" -lease 0s > tidb.log 2>&1 &

View File

@ -48,9 +48,11 @@ func (c *wrapperConnPool) Ping() error {
} }
// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction. // If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { //
// return c.db.BeginTx(ctx, opts) // func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
// } // return c.db.BeginTx(ctx, opts)
// }
//
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries. // You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) { func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) {
tx, err := c.db.BeginTx(ctx, opts) tx, err := c.db.BeginTx(ctx, opts)
@ -100,13 +102,13 @@ func TestConnPoolWrapper(t *testing.T) {
expect: []string{ expect: []string{
"SELECT VERSION()", "SELECT VERSION()",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
}, },
} }
@ -116,7 +118,8 @@ func TestConnPoolWrapper(t *testing.T) {
} }
}() }()
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn})) db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true}))
db.Logger = DB.Logger
if err != nil { if err != nil {
t.Fatalf("Should open db success, but got %v", err) t.Fatalf("Should open db success, but got %v", err)
} }

View File

@ -11,6 +11,32 @@ import (
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
func TestCountWithGroup(t *testing.T) {
DB.Create([]Company{
{Name: "company_count_group_a"},
{Name: "company_count_group_a"},
{Name: "company_count_group_a"},
{Name: "company_count_group_b"},
{Name: "company_count_group_c"},
})
var count1 int64
if err := DB.Model(&Company{}).Where("name = ?", "company_count_group_a").Group("name").Count(&count1).Error; err != nil {
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
}
if count1 != 1 {
t.Errorf("Count with group should be 1, but got count: %v", count1)
}
var count2 int64
if err := DB.Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil {
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
}
if count2 != 2 {
t.Errorf("Count with group should be 2, but got count: %v", count2)
}
}
func TestCount(t *testing.T) { func TestCount(t *testing.T) {
var ( var (
user1 = *GetUser("count-1", Config{}) user1 = *GetUser("count-1", Config{})
@ -142,7 +168,7 @@ func TestCount(t *testing.T) {
DB.Create(sameUsers) DB.Create(sameUsers)
if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 { if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 {
t.Fatalf("Count should be 3, but got count: %v err %v", count11, err) t.Fatalf("Count should be 1, but got count: %v err %v", count11, err)
} }
var count12 int64 var count12 int64

View File

@ -2,6 +2,7 @@ package tests_test
import ( import (
"errors" "errors"
"fmt"
"regexp" "regexp"
"testing" "testing"
"time" "time"
@ -13,31 +14,48 @@ import (
) )
func TestCreate(t *testing.T) { func TestCreate(t *testing.T) {
user := *GetUser("create", Config{}) u1 := *GetUser("create", Config{})
if results := DB.Create(&user); results.Error != nil { if results := DB.Create(&u1); results.Error != nil {
t.Fatalf("errors happened when create: %v", results.Error) t.Fatalf("errors happened when create: %v", results.Error)
} else if results.RowsAffected != 1 { } else if results.RowsAffected != 1 {
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
} }
if user.ID == 0 { if u1.ID == 0 {
t.Errorf("user's primary key should has value after create, got : %v", user.ID) t.Errorf("user's primary key should has value after create, got : %v", u1.ID)
} }
if user.CreatedAt.IsZero() { if u1.CreatedAt.IsZero() {
t.Errorf("user's created at should be not zero") t.Errorf("user's created at should be not zero")
} }
if user.UpdatedAt.IsZero() { if u1.UpdatedAt.IsZero() {
t.Errorf("user's updated at should be not zero") t.Errorf("user's updated at should be not zero")
} }
var newUser User var newUser User
if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil { if err := DB.Where("id = ?", u1.ID).First(&newUser).Error; err != nil {
t.Fatalf("errors happened when query: %v", err) t.Fatalf("errors happened when query: %v", err)
} else { } else {
CheckUser(t, newUser, user) CheckUser(t, newUser, u1)
}
type user struct {
ID int `gorm:"primaryKey;->:false"`
Name string
Age int
}
var u2 user
if results := DB.Create(&u2); results.Error != nil {
t.Fatalf("errors happened when create: %v", results.Error)
} else if results.RowsAffected != 1 {
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
}
if u2.ID != 0 {
t.Errorf("don't have the permission to read primary key from db, but got %v", u2.ID)
} }
} }
@ -476,6 +494,13 @@ func TestOmitWithCreate(t *testing.T) {
CheckUser(t, result2, user2) CheckUser(t, result2, user2)
} }
func TestFirstOrCreateNotExistsTable(t *testing.T) {
company := Company{Name: "first_or_create_if_not_exists_table"}
if err := DB.Table("not_exists").FirstOrCreate(&company).Error; err == nil {
t.Errorf("not exists table, but err is nil")
}
}
func TestFirstOrCreateWithPrimaryKey(t *testing.T) { func TestFirstOrCreateWithPrimaryKey(t *testing.T) {
company := Company{ID: 100, Name: "company100_with_primarykey"} company := Company{ID: 100, Name: "company100_with_primarykey"}
DB.FirstOrCreate(&company) DB.FirstOrCreate(&company)
@ -526,3 +551,260 @@ func TestCreateNilPointer(t *testing.T) {
t.Fatalf("it is not ErrInvalidValue") t.Fatalf("it is not ErrInvalidValue")
} }
} }
func TestFirstOrCreateRowsAffected(t *testing.T) {
user := User{Name: "TestFirstOrCreateRowsAffected"}
res := DB.FirstOrCreate(&user, "name = ?", user.Name)
if res.Error != nil || res.RowsAffected != 1 {
t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected)
}
res = DB.FirstOrCreate(&user, "name = ?", user.Name)
if res.Error != nil || res.RowsAffected != 0 {
t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected)
}
}
func TestCreateWithAutoIncrementCompositeKey(t *testing.T) {
type CompositeKeyProduct struct {
ProductID int `gorm:"primaryKey;autoIncrement:true;"` // primary key
LanguageCode int `gorm:"primaryKey;"` // primary key
Code string
Name string
}
if err := DB.Migrator().DropTable(&CompositeKeyProduct{}); err != nil {
t.Fatalf("failed to migrate, got error %v", err)
}
if err := DB.AutoMigrate(&CompositeKeyProduct{}); err != nil {
t.Fatalf("failed to migrate, got error %v", err)
}
prod := &CompositeKeyProduct{
LanguageCode: 56,
Code: "Code56",
Name: "ProductName56",
}
if err := DB.Create(&prod).Error; err != nil {
t.Fatalf("failed to create, got error %v", err)
}
newProd := &CompositeKeyProduct{}
if err := DB.First(&newProd).Error; err != nil {
t.Fatalf("errors happened when query: %v", err)
} else {
AssertObjEqual(t, newProd, prod, "ProductID", "LanguageCode", "Code", "Name")
}
}
func TestCreateOnConflictWithDefaultNull(t *testing.T) {
type OnConflictUser struct {
ID string
Name string `gorm:"default:null"`
Email string
Mobile string `gorm:"default:'133xxxx'"`
}
err := DB.Migrator().DropTable(&OnConflictUser{})
AssertEqual(t, err, nil)
err = DB.AutoMigrate(&OnConflictUser{})
AssertEqual(t, err, nil)
u := OnConflictUser{
ID: "on-conflict-user-id",
Name: "on-conflict-user-name",
Email: "on-conflict-user-email",
Mobile: "on-conflict-user-mobile",
}
err = DB.Create(&u).Error
AssertEqual(t, err, nil)
u.Name = "on-conflict-user-name-2"
u.Email = "on-conflict-user-email-2"
u.Mobile = ""
err = DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&u).Error
AssertEqual(t, err, nil)
var u2 OnConflictUser
err = DB.Where("id = ?", u.ID).First(&u2).Error
AssertEqual(t, err, nil)
AssertEqual(t, u2.Name, "on-conflict-user-name-2")
AssertEqual(t, u2.Email, "on-conflict-user-email-2")
AssertEqual(t, u2.Mobile, "133xxxx")
}
func TestCreateFromMapWithoutPK(t *testing.T) {
if !isMysql() {
t.Skipf("This test case skipped, because of only supporting for mysql")
}
// case 1: one record, create from map[string]interface{}
mapValue1 := map[string]interface{}{"name": "create_from_map_with_schema1", "age": 1}
if err := DB.Model(&User{}).Create(mapValue1).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
if _, ok := mapValue1["id"]; !ok {
t.Fatal("failed to create data from map with table, returning map has no primary key")
}
var result1 User
if err := DB.Where("name = ?", "create_from_map_with_schema1").First(&result1).Error; err != nil || result1.Age != 1 {
t.Fatalf("failed to create from map, got error %v", err)
}
var idVal int64
_, ok := mapValue1["id"].(uint)
if ok {
t.Skipf("This test case skipped, because the db supports returning")
}
idVal, ok = mapValue1["id"].(int64)
if !ok {
t.Fatal("ret result missing id")
}
if int64(result1.ID) != idVal {
t.Fatal("failed to create data from map with table, @id != id")
}
// case2: one record, create from *map[string]interface{}
mapValue2 := map[string]interface{}{"name": "create_from_map_with_schema2", "age": 1}
if err := DB.Model(&User{}).Create(&mapValue2).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
if _, ok := mapValue2["id"]; !ok {
t.Fatal("failed to create data from map with table, returning map has no primary key")
}
var result2 User
if err := DB.Where("name = ?", "create_from_map_with_schema2").First(&result2).Error; err != nil || result2.Age != 1 {
t.Fatalf("failed to create from map, got error %v", err)
}
_, ok = mapValue2["id"].(uint)
if ok {
t.Skipf("This test case skipped, because the db supports returning")
}
idVal, ok = mapValue2["id"].(int64)
if !ok {
t.Fatal("ret result missing id")
}
if int64(result2.ID) != idVal {
t.Fatal("failed to create data from map with table, @id != id")
}
// case 3: records
values := []map[string]interface{}{
{"name": "create_from_map_with_schema11", "age": 1}, {"name": "create_from_map_with_schema12", "age": 1},
}
beforeLen := len(values)
if err := DB.Model(&User{}).Create(&values).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
// mariadb with returning, values will be appended with id map
if len(values) == beforeLen*2 {
t.Skipf("This test case skipped, because the db supports returning")
}
for i := range values {
v, ok := values[i]["id"]
if !ok {
t.Fatal("failed to create data from map with table, returning map has no primary key")
}
var result User
if err := DB.Where("name = ?", fmt.Sprintf("create_from_map_with_schema1%d", i+1)).First(&result).Error; err != nil || result.Age != 1 {
t.Fatalf("failed to create from map, got error %v", err)
}
if int64(result.ID) != v.(int64) {
t.Fatal("failed to create data from map with table, @id != id")
}
}
}
func TestCreateFromMapWithTable(t *testing.T) {
tableDB := DB.Table("users")
supportLastInsertID := isMysql() || isSqlite()
// case 1: create from map[string]interface{}
record := map[string]interface{}{"name": "create_from_map_with_table", "age": 18}
if err := tableDB.Create(record).Error; err != nil {
t.Fatalf("failed to create data from map with table, got error: %v", err)
}
if _, ok := record["@id"]; !ok && supportLastInsertID {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}
var res map[string]interface{}
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table").Find(&res).Error; err != nil || res["age"] != int64(18) {
t.Fatalf("failed to create from map, got error %v", err)
}
if _, ok := record["@id"]; ok && fmt.Sprint(res["id"]) != fmt.Sprint(record["@id"]) {
t.Fatalf("failed to create data from map with table, @id != id, got %v, expect %v", res["id"], record["@id"])
}
// case 2: create from *map[string]interface{}
record1 := map[string]interface{}{"name": "create_from_map_with_table_1", "age": 18}
tableDB2 := DB.Table("users")
if err := tableDB2.Create(&record1).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
if _, ok := record1["@id"]; !ok && supportLastInsertID {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}
var res1 map[string]interface{}
if err := tableDB2.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_1").Find(&res1).Error; err != nil || res1["age"] != int64(18) {
t.Fatalf("failed to create from map, got error %v", err)
}
if _, ok := record1["@id"]; ok && fmt.Sprint(res1["id"]) != fmt.Sprint(record1["@id"]) {
t.Fatal("failed to create data from map with table, @id != id")
}
// case 3: create from []map[string]interface{}
records := []map[string]interface{}{
{"name": "create_from_map_with_table_2", "age": 19},
{"name": "create_from_map_with_table_3", "age": 20},
}
tableDB = DB.Table("users")
if err := tableDB.Create(&records).Error; err != nil {
t.Fatalf("failed to create data from slice of map, got error: %v", err)
}
if _, ok := records[0]["@id"]; !ok && supportLastInsertID {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}
if _, ok := records[1]["@id"]; !ok && supportLastInsertID {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}
var res2 map[string]interface{}
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_2").Find(&res2).Error; err != nil || res2["age"] != int64(19) {
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
}
var res3 map[string]interface{}
if err := DB.Table("users").Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_3").Find(&res3).Error; err != nil || res3["age"] != int64(20) {
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
}
if _, ok := records[0]["@id"]; ok && fmt.Sprint(res2["id"]) != fmt.Sprint(records[0]["@id"]) {
t.Errorf("failed to create data from map with table, @id != id, got %v, expect %v", res2["id"], records[0]["@id"])
}
if _, ok := records[1]["id"]; ok && fmt.Sprint(res3["id"]) != fmt.Sprint(records[1]["@id"]) {
t.Errorf("failed to create data from map with table, @id != id")
}
}

View File

@ -38,4 +38,22 @@ func TestDefaultValue(t *testing.T) {
} else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled || result.Created.Format("20060102") != "20000102" { } else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled || result.Created.Format("20060102") != "20000102" {
t.Fatalf("Failed to find created data with default data, got %+v", result) t.Fatalf("Failed to find created data with default data, got %+v", result)
} }
type Harumph2 struct {
ID int `gorm:"default:0"`
Email string `gorm:"not null;index:,unique"`
Name string `gorm:"notNull;default:foo"`
Name2 string `gorm:"size:233;not null;default:'foo'"`
Name3 string `gorm:"size:233;notNull;default:''"`
Age int `gorm:"default:18"`
Created time.Time `gorm:"default:2000-01-02"`
Enabled bool `gorm:"default:true"`
}
harumph2 := Harumph2{ID: 2, Email: "hello2@gorm.io"}
if err := DB.Table("harumphs").Create(&harumph2).Error; err != nil {
t.Fatalf("Failed to create data with default value, got error: %v", err)
} else if harumph2.ID != 2 || harumph2.Name != "foo" || harumph2.Name2 != "foo" || harumph2.Name3 != "" || harumph2.Age != 18 || !harumph2.Enabled || harumph2.Created.Format("20060102") != "20000102" {
t.Fatalf("Failed to create data with default value, got: %+v", harumph2)
}
} }

View File

@ -206,9 +206,9 @@ func TestDeleteSliceWithAssociations(t *testing.T) {
} }
} }
// only sqlite, postgres support returning // only sqlite, postgres, gaussdb, sqlserver support returning
func TestSoftDeleteReturning(t *testing.T) { func TestSoftDeleteReturning(t *testing.T) {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlserver" {
return return
} }
@ -233,7 +233,7 @@ func TestSoftDeleteReturning(t *testing.T) {
} }
func TestDeleteReturning(t *testing.T) { func TestDeleteReturning(t *testing.T) {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlserver" {
return return
} }

View File

@ -4,7 +4,9 @@ import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors" "errors"
"reflect"
"testing" "testing"
"time"
"gorm.io/gorm" "gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
@ -36,7 +38,7 @@ func TestEmbeddedStruct(t *testing.T) {
type EngadgetPost struct { type EngadgetPost struct {
BasePost BasePost `gorm:"Embedded"` BasePost BasePost `gorm:"Embedded"`
Author Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct Author *Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct
ImageUrl string ImageUrl string
} }
@ -74,13 +76,26 @@ func TestEmbeddedStruct(t *testing.T) {
t.Errorf("embedded struct's value should be scanned correctly") t.Errorf("embedded struct's value should be scanned correctly")
} }
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}, Author: &Author{Name: "Edward"}})
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_article"}, Author: &Author{Name: "George"}})
var egNews EngadgetPost var egNews EngadgetPost
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err) t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if egNews.BasePost.Title != "engadget_news" { } else if egNews.BasePost.Title != "engadget_news" {
t.Errorf("embedded struct's value should be scanned correctly") t.Errorf("embedded struct's value should be scanned correctly")
} }
var egPosts []EngadgetPost
if err := DB.Order("author_name asc").Find(&egPosts).Error; err != nil {
t.Fatalf("no error should happen when query with embedded struct, but got %v", err)
}
expectAuthors := []string{"Edward", "George"}
for i, post := range egPosts {
t.Log(i, post.Author)
if want := expectAuthors[i]; post.Author.Name != want {
t.Errorf("expected author %s got %s", want, post.Author.Name)
}
}
} }
func TestEmbeddedPointerTypeStruct(t *testing.T) { func TestEmbeddedPointerTypeStruct(t *testing.T) {
@ -90,9 +105,21 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
URL string URL string
} }
type Author struct {
ID string
Name string
Email string
Age int
Content Content
ContentPtr *Content
Birthday time.Time
BirthdayPtr *time.Time
}
type HNPost struct { type HNPost struct {
*BasePost *BasePost
Upvotes int32 Upvotes int32
*Author `gorm:"EmbeddedPrefix:user_"` // Embedded struct
} }
DB.Migrator().DropTable(&HNPost{}) DB.Migrator().DropTable(&HNPost{})
@ -110,6 +137,52 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
if hnPost.Title != "embedded_pointer_type" { if hnPost.Title != "embedded_pointer_type" {
t.Errorf("Should find correct value for embedded pointer type") t.Errorf("Should find correct value for embedded pointer type")
} }
if hnPost.Author != nil {
t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author)
}
now := time.Now().Round(time.Second)
NewPost := HNPost{
BasePost: &BasePost{Title: "embedded_pointer_type2"},
Author: &Author{
Name: "test",
Content: Content{"test"},
ContentPtr: nil,
Birthday: now,
BirthdayPtr: nil,
},
}
DB.Create(&NewPost)
hnPost = HNPost{}
if err := DB.First(&hnPost, "title = ?", NewPost.Title).Error; err != nil {
t.Errorf("No error should happen when find embedded pointer type, but got %v", err)
}
if hnPost.Title != NewPost.Title {
t.Errorf("Should find correct value for embedded pointer type")
}
if hnPost.Author.Name != NewPost.Author.Name {
t.Errorf("Expected to get Author name %v but got: %v", NewPost.Author.Name, hnPost.Author.Name)
}
if !reflect.DeepEqual(NewPost.Author.Content, hnPost.Author.Content) {
t.Errorf("Expected to get Author content %v but got: %v", NewPost.Author.Content, hnPost.Author.Content)
}
if hnPost.Author.ContentPtr != nil {
t.Errorf("Expected to get nil Author contentPtr but got: %v", hnPost.Author.ContentPtr)
}
if NewPost.Author.Birthday.UnixMilli() != hnPost.Author.Birthday.UnixMilli() {
t.Errorf("Expected to get Author birthday with %+v but got: %+v", NewPost.Author.Birthday, hnPost.Author.Birthday)
}
if hnPost.Author.BirthdayPtr != nil {
t.Errorf("Expected to get nil Author birthdayPtr but got: %+v", hnPost.Author.BirthdayPtr)
}
} }
type Content struct { type Content struct {
@ -117,18 +190,26 @@ type Content struct {
} }
func (c Content) Value() (driver.Value, error) { func (c Content) Value() (driver.Value, error) {
return json.Marshal(c) // mssql driver with issue on handling null bytes https://github.com/denisenkom/go-mssqldb/issues/530,
b, err := json.Marshal(c)
return string(b[:]), err
} }
func (c *Content) Scan(src interface{}) error { func (c *Content) Scan(src interface{}) error {
b, ok := src.([]byte)
if !ok {
return errors.New("Embedded.Scan byte assertion failed")
}
var value Content var value Content
if err := json.Unmarshal(b, &value); err != nil { str, ok := src.(string)
return err if !ok {
byt, ok := src.([]byte)
if !ok {
return errors.New("Embedded.Scan byte assertion failed")
}
if err := json.Unmarshal(byt, &value); err != nil {
return err
}
} else {
if err := json.Unmarshal([]byte(str), &value); err != nil {
return err
}
} }
*c = value *c = value
@ -155,8 +236,15 @@ func TestEmbeddedScanValuer(t *testing.T) {
} }
func TestEmbeddedRelations(t *testing.T) { func TestEmbeddedRelations(t *testing.T) {
type EmbUser struct {
gorm.Model
Name string
Age uint
Languages []Language `gorm:"many2many:EmbUserSpeak;"`
}
type AdvancedUser struct { type AdvancedUser struct {
User `gorm:"embedded"` EmbUser `gorm:"embedded"`
Advanced bool Advanced bool
} }
@ -168,3 +256,29 @@ func TestEmbeddedRelations(t *testing.T) {
} }
} }
} }
func TestEmbeddedTagSetting(t *testing.T) {
type Tag1 struct {
Id int64 `gorm:"autoIncrement"`
}
type Tag2 struct {
Id int64
}
type EmbeddedTag struct {
Tag1 Tag1 `gorm:"Embedded;"`
Tag2 Tag2 `gorm:"Embedded;EmbeddedPrefix:t2_"`
Name string
}
DB.Migrator().DropTable(&EmbeddedTag{})
err := DB.Migrator().AutoMigrate(&EmbeddedTag{})
AssertEqual(t, err, nil)
t1 := EmbeddedTag{Name: "embedded_tag"}
err = DB.Save(&t1).Error
AssertEqual(t, err, nil)
if t1.Tag1.Id == 0 {
t.Errorf("embedded struct's primary field should be rewritten")
}
}

Some files were not shown because too many files have changed in this diff Show More