Compare commits
	
		
			2 Commits
		
	
	
		
			master
			...
			feat_rm_co
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					35d2ba9f08 | ||
| 
						 | 
					a219acca4b | 
							
								
								
									
										20
									
								
								.github/release-drafter.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										20
									
								
								.github/release-drafter.yml
									
									
									
									
										vendored
									
									
								
							@ -1,20 +0,0 @@
 | 
				
			|||||||
name-template: 'v Release $NEXT_PATCH_VERSION 🌈'
 | 
					 | 
				
			||||||
tag-template: 'v$NEXT_PATCH_VERSION'
 | 
					 | 
				
			||||||
categories:
 | 
					 | 
				
			||||||
  - title: '🚀 Features'
 | 
					 | 
				
			||||||
    labels:
 | 
					 | 
				
			||||||
      - 'feature'
 | 
					 | 
				
			||||||
      - 'enhancement'
 | 
					 | 
				
			||||||
  - title: '🐛 Bug Fixes'
 | 
					 | 
				
			||||||
    labels:
 | 
					 | 
				
			||||||
      - 'fix'
 | 
					 | 
				
			||||||
      - 'bugfix'
 | 
					 | 
				
			||||||
      - 'bug'
 | 
					 | 
				
			||||||
  - title: '🧰 Maintenance'
 | 
					 | 
				
			||||||
    label: 'chore'
 | 
					 | 
				
			||||||
change-template: '- $TITLE @$AUTHOR (#$NUMBER)'
 | 
					 | 
				
			||||||
change-title-escapes: '\<*_&'
 | 
					 | 
				
			||||||
template: |
 | 
					 | 
				
			||||||
  ## Changes
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  $CHANGES
 | 
					 | 
				
			||||||
							
								
								
									
										31
									
								
								.github/workflows/create-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										31
									
								
								.github/workflows/create-release.yml
									
									
									
									
										vendored
									
									
								
							@ -1,31 +0,0 @@
 | 
				
			|||||||
name: Create Release
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
on:
 | 
					 | 
				
			||||||
  push:
 | 
					 | 
				
			||||||
    tags:
 | 
					 | 
				
			||||||
      - 'v*.*.*'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
permissions:
 | 
					 | 
				
			||||||
  contents: write
 | 
					 | 
				
			||||||
  pull-requests: read
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
jobs:
 | 
					 | 
				
			||||||
  create_release:
 | 
					 | 
				
			||||||
    name: Create Release
 | 
					 | 
				
			||||||
    runs-on: ubuntu-latest
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    steps:
 | 
					 | 
				
			||||||
      - name: Checkout code
 | 
					 | 
				
			||||||
        uses: actions/checkout@v4
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      - name: Generate Release Notes and Publish
 | 
					 | 
				
			||||||
        id: generate_release_notes
 | 
					 | 
				
			||||||
        uses: release-drafter/release-drafter@v6
 | 
					 | 
				
			||||||
        with:
 | 
					 | 
				
			||||||
          config-name: 'release-drafter.yml'
 | 
					 | 
				
			||||||
          name: "Release ${{ github.ref_name }}"
 | 
					 | 
				
			||||||
          tag: ${{ github.ref_name }}
 | 
					 | 
				
			||||||
          publish: true
 | 
					 | 
				
			||||||
          prerelease: false
 | 
					 | 
				
			||||||
        env:
 | 
					 | 
				
			||||||
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
 | 
					 | 
				
			||||||
							
								
								
									
										26
									
								
								.github/workflows/golangci-lint.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										26
									
								
								.github/workflows/golangci-lint.yml
									
									
									
									
										vendored
									
									
								
							@ -1,26 +0,0 @@
 | 
				
			|||||||
name: golangci-lint
 | 
					 | 
				
			||||||
on:
 | 
					 | 
				
			||||||
  push:
 | 
					 | 
				
			||||||
    branches:
 | 
					 | 
				
			||||||
      - main
 | 
					 | 
				
			||||||
      - master
 | 
					 | 
				
			||||||
  pull_request:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
permissions:
 | 
					 | 
				
			||||||
  contents: read
 | 
					 | 
				
			||||||
  pull-requests: read
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
jobs:
 | 
					 | 
				
			||||||
  golangci:
 | 
					 | 
				
			||||||
    name: lint
 | 
					 | 
				
			||||||
    runs-on: ubuntu-latest
 | 
					 | 
				
			||||||
    steps:
 | 
					 | 
				
			||||||
      - uses: actions/checkout@v4
 | 
					 | 
				
			||||||
      - uses: actions/setup-go@v5
 | 
					 | 
				
			||||||
        with:
 | 
					 | 
				
			||||||
          go-version: stable
 | 
					 | 
				
			||||||
      - name: golangci-lint
 | 
					 | 
				
			||||||
        uses: golangci/golangci-lint-action@v7
 | 
					 | 
				
			||||||
        with:
 | 
					 | 
				
			||||||
          version: v2.0
 | 
					 | 
				
			||||||
          only-new-issues: true
 | 
					 | 
				
			||||||
							
								
								
									
										2
									
								
								.github/workflows/labeler.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/labeler.yml
									
									
									
									
										vendored
									
									
								
							@ -11,7 +11,7 @@ jobs:
 | 
				
			|||||||
    name: Label issues and pull requests
 | 
					    name: Label issues and pull requests
 | 
				
			||||||
    steps:
 | 
					    steps:
 | 
				
			||||||
      - name: check out
 | 
					      - name: check out
 | 
				
			||||||
        uses: actions/checkout@v4
 | 
					        uses: actions/checkout@v3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      - name: labeler
 | 
					      - name: labeler
 | 
				
			||||||
        uses: jinzhu/super-labeler-action@develop
 | 
					        uses: jinzhu/super-labeler-action@develop
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										22
									
								
								.github/workflows/reviewdog.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								.github/workflows/reviewdog.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,22 @@
 | 
				
			|||||||
 | 
					name: reviewdog
 | 
				
			||||||
 | 
					on: [pull_request]
 | 
				
			||||||
 | 
					jobs:
 | 
				
			||||||
 | 
					  golangci-lint:
 | 
				
			||||||
 | 
					    name: runner / golangci-lint
 | 
				
			||||||
 | 
					    runs-on: ubuntu-latest
 | 
				
			||||||
 | 
					    steps:
 | 
				
			||||||
 | 
					      - name: Check out code into the Go module directory
 | 
				
			||||||
 | 
					        uses: actions/checkout@v3
 | 
				
			||||||
 | 
					      - name: golangci-lint
 | 
				
			||||||
 | 
					        uses: reviewdog/action-golangci-lint@v2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      - name: Setup reviewdog
 | 
				
			||||||
 | 
					        uses: reviewdog/action-setup@v1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      - name: gofumpt -s with reviewdog
 | 
				
			||||||
 | 
					        env:
 | 
				
			||||||
 | 
					          REVIEWDOG_GITHUB_API_TOKEN: ${{ secrets.GITHUB_TOKEN }}
 | 
				
			||||||
 | 
					        run: |
 | 
				
			||||||
 | 
					          go install mvdan.cc/gofumpt@v0.2.0
 | 
				
			||||||
 | 
					          gofumpt -e -d . | \
 | 
				
			||||||
 | 
					          reviewdog -name="gofumpt" -f=diff -f.diff.strip=0 -reporter=github-pr-review
 | 
				
			||||||
							
								
								
									
										120
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										120
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							@ -16,7 +16,7 @@ jobs:
 | 
				
			|||||||
  sqlite:
 | 
					  sqlite:
 | 
				
			||||||
    strategy:
 | 
					    strategy:
 | 
				
			||||||
      matrix:
 | 
					      matrix:
 | 
				
			||||||
        go: ['1.23', '1.24']
 | 
					        go: ['1.19', '1.18']
 | 
				
			||||||
        platform: [ubuntu-latest] # can not run in windows OS
 | 
					        platform: [ubuntu-latest] # can not run in windows OS
 | 
				
			||||||
    runs-on: ${{ matrix.platform }}
 | 
					    runs-on: ${{ matrix.platform }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -27,10 +27,10 @@ jobs:
 | 
				
			|||||||
        go-version: ${{ matrix.go }}
 | 
					        go-version: ${{ matrix.go }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    - name: Check out code into the Go module directory
 | 
					    - name: Check out code into the Go module directory
 | 
				
			||||||
      uses: actions/checkout@v4
 | 
					      uses: actions/checkout@v3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    - name: go mod package cache
 | 
					    - name: go mod package cache
 | 
				
			||||||
      uses: actions/cache@v4
 | 
					      uses: actions/cache@v3
 | 
				
			||||||
      with:
 | 
					      with:
 | 
				
			||||||
        path: ~/go/pkg/mod
 | 
					        path: ~/go/pkg/mod
 | 
				
			||||||
        key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
 | 
					        key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
 | 
				
			||||||
@ -41,8 +41,8 @@ jobs:
 | 
				
			|||||||
  mysql:
 | 
					  mysql:
 | 
				
			||||||
    strategy:
 | 
					    strategy:
 | 
				
			||||||
      matrix:
 | 
					      matrix:
 | 
				
			||||||
        dbversion: ['mysql:9', 'mysql:8', 'mysql:5.7']
 | 
					        dbversion: ['mysql:latest', 'mysql:5.7']
 | 
				
			||||||
        go: ['1.23', '1.24']
 | 
					        go: ['1.19', '1.18']
 | 
				
			||||||
        platform: [ubuntu-latest]
 | 
					        platform: [ubuntu-latest]
 | 
				
			||||||
    runs-on: ${{ matrix.platform }}
 | 
					    runs-on: ${{ matrix.platform }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -70,10 +70,10 @@ jobs:
 | 
				
			|||||||
        go-version: ${{ matrix.go }}
 | 
					        go-version: ${{ matrix.go }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    - name: Check out code into the Go module directory
 | 
					    - name: Check out code into the Go module directory
 | 
				
			||||||
      uses: actions/checkout@v4
 | 
					      uses: actions/checkout@v3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    - name: go mod package cache
 | 
					    - name: go mod package cache
 | 
				
			||||||
      uses: actions/cache@v4
 | 
					      uses: actions/cache@v3
 | 
				
			||||||
      with:
 | 
					      with:
 | 
				
			||||||
        path: ~/go/pkg/mod
 | 
					        path: ~/go/pkg/mod
 | 
				
			||||||
        key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
 | 
					        key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
 | 
				
			||||||
@ -85,7 +85,7 @@ jobs:
 | 
				
			|||||||
    strategy:
 | 
					    strategy:
 | 
				
			||||||
      matrix:
 | 
					      matrix:
 | 
				
			||||||
        dbversion: [ 'mariadb:latest' ]
 | 
					        dbversion: [ 'mariadb:latest' ]
 | 
				
			||||||
        go: ['1.23', '1.24']
 | 
					        go: [ '1.19', '1.18' ]
 | 
				
			||||||
        platform: [ ubuntu-latest ]
 | 
					        platform: [ ubuntu-latest ]
 | 
				
			||||||
    runs-on: ${{ matrix.platform }}
 | 
					    runs-on: ${{ matrix.platform }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -113,10 +113,10 @@ jobs:
 | 
				
			|||||||
          go-version: ${{ matrix.go }}
 | 
					          go-version: ${{ matrix.go }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      - name: Check out code into the Go module directory
 | 
					      - name: Check out code into the Go module directory
 | 
				
			||||||
        uses: actions/checkout@v4
 | 
					        uses: actions/checkout@v3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      - name: go mod package cache
 | 
					      - name: go mod package cache
 | 
				
			||||||
        uses: actions/cache@v4
 | 
					        uses: actions/cache@v3
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          path: ~/go/pkg/mod
 | 
					          path: ~/go/pkg/mod
 | 
				
			||||||
          key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
 | 
					          key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
 | 
				
			||||||
@ -127,8 +127,8 @@ jobs:
 | 
				
			|||||||
  postgres:
 | 
					  postgres:
 | 
				
			||||||
    strategy:
 | 
					    strategy:
 | 
				
			||||||
      matrix:
 | 
					      matrix:
 | 
				
			||||||
        dbversion: ['postgres:latest', 'postgres:15', 'postgres:14', 'postgres:13']
 | 
					        dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
 | 
				
			||||||
        go: ['1.23', '1.24']
 | 
					        go: ['1.19', '1.18']
 | 
				
			||||||
        platform: [ubuntu-latest] # can not run in macOS and Windows
 | 
					        platform: [ubuntu-latest] # can not run in macOS and Windows
 | 
				
			||||||
    runs-on: ${{ matrix.platform }}
 | 
					    runs-on: ${{ matrix.platform }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -156,10 +156,10 @@ jobs:
 | 
				
			|||||||
        go-version: ${{ matrix.go }}
 | 
					        go-version: ${{ matrix.go }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    - name: Check out code into the Go module directory
 | 
					    - name: Check out code into the Go module directory
 | 
				
			||||||
      uses: actions/checkout@v4
 | 
					      uses: actions/checkout@v3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    - name: go mod package cache
 | 
					    - name: go mod package cache
 | 
				
			||||||
      uses: actions/cache@v4
 | 
					      uses: actions/cache@v3
 | 
				
			||||||
      with:
 | 
					      with:
 | 
				
			||||||
        path: ~/go/pkg/mod
 | 
					        path: ~/go/pkg/mod
 | 
				
			||||||
        key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
 | 
					        key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
 | 
				
			||||||
@ -170,21 +170,23 @@ jobs:
 | 
				
			|||||||
  sqlserver:
 | 
					  sqlserver:
 | 
				
			||||||
    strategy:
 | 
					    strategy:
 | 
				
			||||||
      matrix:
 | 
					      matrix:
 | 
				
			||||||
        go: ['1.23', '1.24']
 | 
					        go: ['1.19', '1.18']
 | 
				
			||||||
        platform: [ubuntu-latest] # can not run test in macOS and windows
 | 
					        platform: [ubuntu-latest] # can not run test in macOS and windows
 | 
				
			||||||
    runs-on: ${{ matrix.platform }}
 | 
					    runs-on: ${{ matrix.platform }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    services:
 | 
					    services:
 | 
				
			||||||
      mssql:
 | 
					      mssql:
 | 
				
			||||||
        image: mcr.microsoft.com/mssql/server:2022-latest
 | 
					        image: mcmoe/mssqldocker:latest
 | 
				
			||||||
        env:
 | 
					        env:
 | 
				
			||||||
          TZ: Asia/Shanghai
 | 
					 | 
				
			||||||
          ACCEPT_EULA: Y
 | 
					          ACCEPT_EULA: Y
 | 
				
			||||||
          MSSQL_SA_PASSWORD: LoremIpsum86
 | 
					          SA_PASSWORD: LoremIpsum86
 | 
				
			||||||
 | 
					          MSSQL_DB: gorm
 | 
				
			||||||
 | 
					          MSSQL_USER: gorm
 | 
				
			||||||
 | 
					          MSSQL_PASSWORD: LoremIpsum86
 | 
				
			||||||
        ports:
 | 
					        ports:
 | 
				
			||||||
          - 9930:1433
 | 
					          - 9930:1433
 | 
				
			||||||
        options: >-
 | 
					        options: >-
 | 
				
			||||||
          --health-cmd="/opt/mssql-tools18/bin/sqlcmd -S localhost -U sa -P ${MSSQL_SA_PASSWORD} -N -C -l 30 -Q \"SELECT 1\" || exit 1"
 | 
					          --health-cmd="/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P LoremIpsum86 -l 30 -Q \"SELECT 1\" || exit 1"
 | 
				
			||||||
          --health-start-period 10s
 | 
					          --health-start-period 10s
 | 
				
			||||||
          --health-interval 10s
 | 
					          --health-interval 10s
 | 
				
			||||||
          --health-timeout 5s
 | 
					          --health-timeout 5s
 | 
				
			||||||
@ -197,22 +199,22 @@ jobs:
 | 
				
			|||||||
        go-version: ${{ matrix.go }}
 | 
					        go-version: ${{ matrix.go }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    - name: Check out code into the Go module directory
 | 
					    - name: Check out code into the Go module directory
 | 
				
			||||||
      uses: actions/checkout@v4
 | 
					      uses: actions/checkout@v3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    - name: go mod package cache
 | 
					    - name: go mod package cache
 | 
				
			||||||
      uses: actions/cache@v4
 | 
					      uses: actions/cache@v3
 | 
				
			||||||
      with:
 | 
					      with:
 | 
				
			||||||
        path: ~/go/pkg/mod
 | 
					        path: ~/go/pkg/mod
 | 
				
			||||||
        key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
 | 
					        key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    - name: Tests
 | 
					    - name: Tests
 | 
				
			||||||
      run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://sa:LoremIpsum86@localhost:9930?database=master" ./tests/tests_all.sh
 | 
					      run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  tidb:
 | 
					  tidb:
 | 
				
			||||||
    strategy:
 | 
					    strategy:
 | 
				
			||||||
      matrix:
 | 
					      matrix:
 | 
				
			||||||
        dbversion: [ 'v6.5.0' ]
 | 
					        dbversion: [ 'v6.5.0' ]
 | 
				
			||||||
        go: ['1.23', '1.24']
 | 
					        go: [ '1.19', '1.18' ]
 | 
				
			||||||
        platform: [ ubuntu-latest ]
 | 
					        platform: [ ubuntu-latest ]
 | 
				
			||||||
    runs-on: ${{ matrix.platform }}
 | 
					    runs-on: ${{ matrix.platform }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -229,82 +231,14 @@ jobs:
 | 
				
			|||||||
          go-version: ${{ matrix.go }}
 | 
					          go-version: ${{ matrix.go }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      - name: Check out code into the Go module directory
 | 
					      - name: Check out code into the Go module directory
 | 
				
			||||||
        uses: actions/checkout@v4
 | 
					        uses: actions/checkout@v3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      - name: go mod package cache
 | 
					      - name: go mod package cache
 | 
				
			||||||
        uses: actions/cache@v4
 | 
					        uses: actions/cache@v3
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          path: ~/go/pkg/mod
 | 
					          path: ~/go/pkg/mod
 | 
				
			||||||
          key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
 | 
					          key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      - name: Tests
 | 
					      - name: Tests
 | 
				
			||||||
        run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh
 | 
					        run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh
 | 
				
			||||||
 | 
					 | 
				
			||||||
  gaussdb:
 | 
					 | 
				
			||||||
    strategy:
 | 
					 | 
				
			||||||
      matrix:
 | 
					 | 
				
			||||||
        dbversion: ['opengauss/opengauss:7.0.0-RC1.B023']
 | 
					 | 
				
			||||||
        go: ['1.23', '1.24']
 | 
					 | 
				
			||||||
        platform: [ubuntu-latest] # can not run in macOS and Windows
 | 
					 | 
				
			||||||
    runs-on: ${{ matrix.platform }}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    services:
 | 
					 | 
				
			||||||
      gaussdb:
 | 
					 | 
				
			||||||
        image: ${{ matrix.dbversion }}
 | 
					 | 
				
			||||||
        env:
 | 
					 | 
				
			||||||
          # GaussDB has password limitations
 | 
					 | 
				
			||||||
          GS_PASSWORD: Gaussdb@123
 | 
					 | 
				
			||||||
          TZ: Asia/Shanghai
 | 
					 | 
				
			||||||
        ports:
 | 
					 | 
				
			||||||
          - 9950:5432
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    steps:
 | 
					 | 
				
			||||||
      - name: Set up Go 1.x
 | 
					 | 
				
			||||||
        uses: actions/setup-go@v4
 | 
					 | 
				
			||||||
        with:
 | 
					 | 
				
			||||||
          go-version: ${{ matrix.go }}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      - name: Check out code into the Go module directory
 | 
					 | 
				
			||||||
        uses: actions/checkout@v4
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      - name: Waiting for GaussDB to be ready
 | 
					 | 
				
			||||||
        run: |
 | 
					 | 
				
			||||||
          container_name=$(docker ps --filter "ancestor=opengauss/opengauss:7.0.0-RC1.B023" --format "{{.Names}}")
 | 
					 | 
				
			||||||
          if [ -z "$container_name" ]; then
 | 
					 | 
				
			||||||
            echo "Error: failed to find a container created from the 'opengauss/opengauss:7.0.0-RC1.B023' image."
 | 
					 | 
				
			||||||
            exit 1
 | 
					 | 
				
			||||||
          fi
 | 
					 | 
				
			||||||
          max_retries=12
 | 
					 | 
				
			||||||
          retry_count=0
 | 
					 | 
				
			||||||
          if [ -t 0 ]; then
 | 
					 | 
				
			||||||
            TTY_FLAG="-t"
 | 
					 | 
				
			||||||
          else
 | 
					 | 
				
			||||||
            TTY_FLAG=""
 | 
					 | 
				
			||||||
          fi
 | 
					 | 
				
			||||||
          while [ $retry_count -lt $max_retries ]; do
 | 
					 | 
				
			||||||
            if docker exec -i "${container_name}" bash -c "su - omm -c 'gsql -U omm -c \"select 1;\"'" 
 | 
					 | 
				
			||||||
            then
 | 
					 | 
				
			||||||
              echo "Creating database gorm..."
 | 
					 | 
				
			||||||
              sql_file='/tmp/create_database.sql'
 | 
					 | 
				
			||||||
              echo "CREATE DATABASE gorm DBCOMPATIBILITY 'PG';" > ${sql_file}
 | 
					 | 
				
			||||||
              docker cp "${sql_file}" "${container_name}":"${sql_file}"
 | 
					 | 
				
			||||||
              docker exec -i ${TTY_FLAG} "${container_name}" bash -c "su - omm -c 'gsql -U omm -f ${sql_file}'"
 | 
					 | 
				
			||||||
              echo "Database initialization completed."
 | 
					 | 
				
			||||||
              break
 | 
					 | 
				
			||||||
            fi
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            echo "Waiting for database to be ready... (attempt $((retry_count + 1))/$max_retries)"
 | 
					 | 
				
			||||||
            sleep 10
 | 
					 | 
				
			||||||
            ((++retry_count))
 | 
					 | 
				
			||||||
          done
 | 
					 | 
				
			||||||
          exit 0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      - name: go mod package cache
 | 
					 | 
				
			||||||
        uses: actions/cache@v4
 | 
					 | 
				
			||||||
        with:
 | 
					 | 
				
			||||||
          path: ~/go/pkg/mod
 | 
					 | 
				
			||||||
          key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      - name: Tests
 | 
					 | 
				
			||||||
        run: GITHUB_ACTION=true GORM_DIALECT=gaussdb GORM_DSN="user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh
 | 
					 | 
				
			||||||
@ -1,9 +1,7 @@
 | 
				
			|||||||
version: "2"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
linters:
 | 
					linters:
 | 
				
			||||||
  default: standard
 | 
					 | 
				
			||||||
  enable:
 | 
					  enable:
 | 
				
			||||||
    - cyclop
 | 
					    - cyclop
 | 
				
			||||||
 | 
					    - exportloopref
 | 
				
			||||||
    - gocritic
 | 
					    - gocritic
 | 
				
			||||||
    - gosec
 | 
					    - gosec
 | 
				
			||||||
    - ineffassign
 | 
					    - ineffassign
 | 
				
			||||||
@ -11,9 +9,12 @@ linters:
 | 
				
			|||||||
    - prealloc
 | 
					    - prealloc
 | 
				
			||||||
    - unconvert
 | 
					    - unconvert
 | 
				
			||||||
    - unparam
 | 
					    - unparam
 | 
				
			||||||
 | 
					    - goimports
 | 
				
			||||||
    - whitespace
 | 
					    - whitespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
formatters:
 | 
					linters-settings:
 | 
				
			||||||
  enable:
 | 
					  whitespace:
 | 
				
			||||||
    - gofumpt
 | 
					    multi-func: true
 | 
				
			||||||
    - goimports
 | 
					  goimports:
 | 
				
			||||||
 | 
					    local-prefixes: gorm.io/gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -1,128 +0,0 @@
 | 
				
			|||||||
# Contributor Covenant Code of Conduct
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
## Our Pledge
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
We as members, contributors, and leaders pledge to participate in our
 | 
					 | 
				
			||||||
community a harassment-free experience for everyone, regardless of age, body
 | 
					 | 
				
			||||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
 | 
					 | 
				
			||||||
identity and expression, level of experience, education, socio-economic status,
 | 
					 | 
				
			||||||
nationality, personal appearance, race, religion, or sexual identity
 | 
					 | 
				
			||||||
and orientation.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
We pledge to act and interact in ways that contribute to an open, welcoming,
 | 
					 | 
				
			||||||
diverse, inclusive, and healthy community.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
## Our Standards
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Examples of behavior that contributes to a positive environment for our
 | 
					 | 
				
			||||||
community includes:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
* Demonstrating empathy and kindness toward other people
 | 
					 | 
				
			||||||
* Being respectful of differing opinions, viewpoints, and experiences
 | 
					 | 
				
			||||||
* Giving and gracefully accepting constructive feedback
 | 
					 | 
				
			||||||
* Accepting responsibility and apologizing to those affected by our mistakes,
 | 
					 | 
				
			||||||
  and learning from the experience
 | 
					 | 
				
			||||||
* Focusing on what is best not just for us as individuals, but for the
 | 
					 | 
				
			||||||
  overall community
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Examples of unacceptable behavior include:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
* The use of sexualized language or imagery, and sexual attention or
 | 
					 | 
				
			||||||
  advances of any kind
 | 
					 | 
				
			||||||
* Trolling, insulting or derogatory comments, and personal or political attacks
 | 
					 | 
				
			||||||
* Public or private harassment
 | 
					 | 
				
			||||||
* Publishing others' private information, such as a physical or email
 | 
					 | 
				
			||||||
  address, without their explicit permission
 | 
					 | 
				
			||||||
* Other conduct which could reasonably be considered inappropriate in a
 | 
					 | 
				
			||||||
  professional setting
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
## Enforcement Responsibilities
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Community leaders are responsible for clarifying and enforcing our standards of
 | 
					 | 
				
			||||||
acceptable behavior and will take appropriate and fair corrective action in
 | 
					 | 
				
			||||||
response to any behavior that they deem inappropriate, threatening, offensive,
 | 
					 | 
				
			||||||
or harmful.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Community leaders have the right and responsibility to remove, edit, or reject
 | 
					 | 
				
			||||||
comments, commits, code, wiki edits, issues, and other contributions that are
 | 
					 | 
				
			||||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
 | 
					 | 
				
			||||||
decisions when appropriate.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
## Scope
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
This Code of Conduct applies within all community spaces and also applies when
 | 
					 | 
				
			||||||
an individual is officially representing the community in public spaces.
 | 
					 | 
				
			||||||
Examples of representing our community include using an official e-mail address,
 | 
					 | 
				
			||||||
posting via an official social media account, or acting as an appointed
 | 
					 | 
				
			||||||
representative at an online or offline event.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
## Enforcement
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
 | 
					 | 
				
			||||||
reported to the community leaders responsible for enforcement at
 | 
					 | 
				
			||||||
.
 | 
					 | 
				
			||||||
All complaints will be reviewed and investigated promptly and fairly.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
All community leaders are obligated to respect the privacy and security of the
 | 
					 | 
				
			||||||
reporter of any incident.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
## Enforcement Guidelines
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Community leaders will follow these Community Impact Guidelines in determining
 | 
					 | 
				
			||||||
the consequences for any action they deem in violation of this Code of Conduct:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
### 1. Correction
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
**Community Impact**: Use of inappropriate language or other behavior deemed
 | 
					 | 
				
			||||||
unprofessional or unwelcome in the community.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
**Consequence**: A private, written warning from community leaders, providing
 | 
					 | 
				
			||||||
clarity around the nature of the violation and an explanation of why the
 | 
					 | 
				
			||||||
behavior was inappropriate. A public apology may be requested.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
### 2. Warning
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
**Community Impact**: A violation through a single incident or series
 | 
					 | 
				
			||||||
of actions.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
**Consequence**: A warning with consequences for continued behavior. No
 | 
					 | 
				
			||||||
interaction with the people involved, including unsolicited interaction with
 | 
					 | 
				
			||||||
those enforcing the Code of Conduct, for a specified period. This
 | 
					 | 
				
			||||||
includes avoiding interactions in community spaces and external channels
 | 
					 | 
				
			||||||
like social media. Violating these terms may lead to a temporary or
 | 
					 | 
				
			||||||
permanent ban.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
### 3. Temporary Ban
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
**Community Impact**: A serious violation of community standards, including
 | 
					 | 
				
			||||||
sustained inappropriate behavior.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
**Consequence**: A temporary ban from any interaction or public
 | 
					 | 
				
			||||||
communication with the community for a specified period. No public or
 | 
					 | 
				
			||||||
private interaction with the people involved, including unsolicited interaction
 | 
					 | 
				
			||||||
with those enforcing the Code of Conduct, is allowed during this period.
 | 
					 | 
				
			||||||
Violating these terms may lead to a permanent ban.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
### 4. Permanent Ban
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
**Community Impact**: Demonstrating a pattern of violation of community
 | 
					 | 
				
			||||||
standards, including sustained inappropriate behavior,  harassment of an
 | 
					 | 
				
			||||||
individual, or aggression toward or disparagement of classes of individuals.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
**Consequence**: A permanent ban from any sort of public interaction within
 | 
					 | 
				
			||||||
the community.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
## Attribution
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
 | 
					 | 
				
			||||||
version 2.0, available at
 | 
					 | 
				
			||||||
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct
 | 
					 | 
				
			||||||
enforcement ladder](https://github.com/mozilla/diversity).
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
[homepage]: https://www.contributor-covenant.org
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
For answers to common questions about this code of conduct, see the FAQ at
 | 
					 | 
				
			||||||
https://www.contributor-covenant.org/faq. Translations are available at
 | 
					 | 
				
			||||||
https://www.contributor-covenant.org/translations.
 | 
					 | 
				
			||||||
							
								
								
									
										2
									
								
								LICENSE
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								LICENSE
									
									
									
									
									
								
							@ -1,6 +1,6 @@
 | 
				
			|||||||
The MIT License (MIT)
 | 
					The MIT License (MIT)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Copyright (c) 2013-present  Jinzhu <wosmvp@gmail.com>
 | 
					Copyright (c) 2013-NOW  Jinzhu <wosmvp@gmail.com>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
					Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
				
			||||||
of this software and associated documentation files (the "Software"), to deal
 | 
					of this software and associated documentation files (the "Software"), to deal
 | 
				
			||||||
 | 
				
			|||||||
@ -41,4 +41,4 @@ The fantastic ORM library for Golang, aims to be developer friendly.
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
© Jinzhu, 2013~time.Now
 | 
					© Jinzhu, 2013~time.Now
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE)
 | 
					Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License)
 | 
				
			||||||
 | 
				
			|||||||
@ -396,10 +396,6 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
 | 
				
			|||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			case reflect.Struct:
 | 
								case reflect.Struct:
 | 
				
			||||||
				if !rv.CanAddr() {
 | 
					 | 
				
			||||||
					association.Error = ErrInvalidValue
 | 
					 | 
				
			||||||
					return
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
 | 
									association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
 | 
									if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
 | 
				
			||||||
@ -437,10 +433,6 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
 | 
				
			|||||||
					appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
 | 
										appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			case reflect.Struct:
 | 
								case reflect.Struct:
 | 
				
			||||||
				if !rv.CanAddr() {
 | 
					 | 
				
			||||||
					association.Error = ErrInvalidValue
 | 
					 | 
				
			||||||
					return
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				appendToFieldValues(rv.Addr())
 | 
									appendToFieldValues(rv.Addr())
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -518,9 +510,6 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		for i := 0; i < reflectValue.Len(); i++ {
 | 
							for i := 0; i < reflectValue.Len(); i++ {
 | 
				
			||||||
			appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
 | 
								appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
 | 
				
			||||||
			if association.Error != nil {
 | 
					 | 
				
			||||||
				return
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// TODO support save slice data, sql with case?
 | 
								// TODO support save slice data, sql with case?
 | 
				
			||||||
			association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
 | 
								association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
 | 
				
			||||||
@ -542,9 +531,6 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
 | 
				
			|||||||
		for idx, value := range values {
 | 
							for idx, value := range values {
 | 
				
			||||||
			rv := reflect.Indirect(reflect.ValueOf(value))
 | 
								rv := reflect.Indirect(reflect.ValueOf(value))
 | 
				
			||||||
			appendToRelations(reflectValue, rv, clear && idx == 0)
 | 
								appendToRelations(reflectValue, rv, clear && idx == 0)
 | 
				
			||||||
			if association.Error != nil {
 | 
					 | 
				
			||||||
				return
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if len(values) > 0 {
 | 
							if len(values) > 0 {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										19
									
								
								callbacks.go
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								callbacks.go
									
									
									
									
									
								
							@ -187,18 +187,10 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func (p *processor) compile() (err error) {
 | 
					func (p *processor) compile() (err error) {
 | 
				
			||||||
	var callbacks []*callback
 | 
						var callbacks []*callback
 | 
				
			||||||
	removedMap := map[string]bool{}
 | 
					 | 
				
			||||||
	for _, callback := range p.callbacks {
 | 
						for _, callback := range p.callbacks {
 | 
				
			||||||
		if callback.match == nil || callback.match(p.db) {
 | 
							if callback.match == nil || callback.match(p.db) {
 | 
				
			||||||
			callbacks = append(callbacks, callback)
 | 
								callbacks = append(callbacks, callback)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if callback.remove {
 | 
					 | 
				
			||||||
			removedMap[callback.name] = true
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(removedMap) > 0 {
 | 
					 | 
				
			||||||
		callbacks = removeCallbacks(callbacks, removedMap)
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	p.callbacks = callbacks
 | 
						p.callbacks = callbacks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -347,14 +339,3 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
 | 
					 | 
				
			||||||
	callbacks := make([]*callback, 0, len(cs))
 | 
					 | 
				
			||||||
	for _, callback := range cs {
 | 
					 | 
				
			||||||
		if nameMap[callback.name] {
 | 
					 | 
				
			||||||
			continue
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		callbacks = append(callbacks, callback)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return callbacks
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -47,7 +47,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
 | 
				
			|||||||
					)
 | 
										)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
					if !isPtr {
 | 
										if !isPtr {
 | 
				
			||||||
						fieldType = reflect.PointerTo(fieldType)
 | 
											fieldType = reflect.PtrTo(fieldType)
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
					elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | 
										elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | 
				
			||||||
@ -126,7 +126,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
 | 
				
			|||||||
					)
 | 
										)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
					if !isPtr {
 | 
										if !isPtr {
 | 
				
			||||||
						fieldType = reflect.PointerTo(fieldType)
 | 
											fieldType = reflect.PtrTo(fieldType)
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
					elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | 
										elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | 
				
			||||||
@ -195,7 +195,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
 | 
				
			|||||||
				fieldType := rel.Field.IndirectFieldType.Elem()
 | 
									fieldType := rel.Field.IndirectFieldType.Elem()
 | 
				
			||||||
				isPtr := fieldType.Kind() == reflect.Ptr
 | 
									isPtr := fieldType.Kind() == reflect.Ptr
 | 
				
			||||||
				if !isPtr {
 | 
									if !isPtr {
 | 
				
			||||||
					fieldType = reflect.PointerTo(fieldType)
 | 
										fieldType = reflect.PtrTo(fieldType)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | 
									elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | 
				
			||||||
				identityMap := map[string]bool{}
 | 
									identityMap := map[string]bool{}
 | 
				
			||||||
@ -268,11 +268,11 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
 | 
				
			|||||||
				fieldType := rel.Field.IndirectFieldType.Elem()
 | 
									fieldType := rel.Field.IndirectFieldType.Elem()
 | 
				
			||||||
				isPtr := fieldType.Kind() == reflect.Ptr
 | 
									isPtr := fieldType.Kind() == reflect.Ptr
 | 
				
			||||||
				if !isPtr {
 | 
									if !isPtr {
 | 
				
			||||||
					fieldType = reflect.PointerTo(fieldType)
 | 
										fieldType = reflect.PtrTo(fieldType)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | 
									elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | 
				
			||||||
				distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | 
									distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | 
				
			||||||
				joins := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.JoinTable.ModelType)), 0, 10)
 | 
									joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
 | 
				
			||||||
				objs := []reflect.Value{}
 | 
									objs := []reflect.Value{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				appendToJoins := func(obj reflect.Value, elem reflect.Value) {
 | 
									appendToJoins := func(obj reflect.Value, elem reflect.Value) {
 | 
				
			||||||
 | 
				
			|||||||
@ -53,13 +53,9 @@ func Create(config *Config) func(db *gorm.DB) {
 | 
				
			|||||||
				if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
 | 
									if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
 | 
				
			||||||
					fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
 | 
										fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
 | 
				
			||||||
					for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
 | 
										for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
 | 
				
			||||||
						if field.Readable {
 | 
											fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
 | 
				
			||||||
							fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
 | 
					 | 
				
			||||||
						}
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
					if len(fromColumns) > 0 {
 | 
					 | 
				
			||||||
						db.Statement.AddClause(clause.Returning{Columns: fromColumns})
 | 
					 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
 | 
										db.Statement.AddClause(clause.Returning{Columns: fromColumns})
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -93,10 +89,6 @@ func Create(config *Config) func(db *gorm.DB) {
 | 
				
			|||||||
					db.AddError(rows.Close())
 | 
										db.AddError(rows.Close())
 | 
				
			||||||
				}()
 | 
									}()
 | 
				
			||||||
				gorm.Scan(rows, db, mode)
 | 
									gorm.Scan(rows, db, mode)
 | 
				
			||||||
 | 
					 | 
				
			||||||
				if db.Statement.Result != nil {
 | 
					 | 
				
			||||||
					db.Statement.Result.RowsAffected = db.RowsAffected
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
@ -111,70 +103,13 @@ func Create(config *Config) func(db *gorm.DB) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		db.RowsAffected, _ = result.RowsAffected()
 | 
							db.RowsAffected, _ = result.RowsAffected()
 | 
				
			||||||
 | 
							if db.RowsAffected != 0 && db.Statement.Schema != nil &&
 | 
				
			||||||
		if db.Statement.Result != nil {
 | 
								db.Statement.Schema.PrioritizedPrimaryField != nil &&
 | 
				
			||||||
			db.Statement.Result.Result = result
 | 
								db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
 | 
				
			||||||
			db.Statement.Result.RowsAffected = db.RowsAffected
 | 
								insertID, err := result.LastInsertId()
 | 
				
			||||||
		}
 | 
								insertOk := err == nil && insertID > 0
 | 
				
			||||||
 | 
								if !insertOk {
 | 
				
			||||||
		if db.RowsAffected == 0 {
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		var (
 | 
					 | 
				
			||||||
			pkField     *schema.Field
 | 
					 | 
				
			||||||
			pkFieldName = "@id"
 | 
					 | 
				
			||||||
		)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if db.Statement.Schema != nil {
 | 
					 | 
				
			||||||
			if db.Statement.Schema.PrioritizedPrimaryField == nil ||
 | 
					 | 
				
			||||||
				!db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue ||
 | 
					 | 
				
			||||||
				!db.Statement.Schema.PrioritizedPrimaryField.Readable {
 | 
					 | 
				
			||||||
				return
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			pkField = db.Statement.Schema.PrioritizedPrimaryField
 | 
					 | 
				
			||||||
			pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		insertID, err := result.LastInsertId()
 | 
					 | 
				
			||||||
		insertOk := err == nil && insertID > 0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if !insertOk {
 | 
					 | 
				
			||||||
			if !supportReturning {
 | 
					 | 
				
			||||||
				db.AddError(err)
 | 
									db.AddError(err)
 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		// append @id column with value for auto-increment primary key
 | 
					 | 
				
			||||||
		// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
 | 
					 | 
				
			||||||
		switch values := db.Statement.Dest.(type) {
 | 
					 | 
				
			||||||
		case map[string]interface{}:
 | 
					 | 
				
			||||||
			values[pkFieldName] = insertID
 | 
					 | 
				
			||||||
		case *map[string]interface{}:
 | 
					 | 
				
			||||||
			(*values)[pkFieldName] = insertID
 | 
					 | 
				
			||||||
		case []map[string]interface{}, *[]map[string]interface{}:
 | 
					 | 
				
			||||||
			mapValues, ok := values.([]map[string]interface{})
 | 
					 | 
				
			||||||
			if !ok {
 | 
					 | 
				
			||||||
				if v, ok := values.(*[]map[string]interface{}); ok {
 | 
					 | 
				
			||||||
					if *v != nil {
 | 
					 | 
				
			||||||
						mapValues = *v
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if config.LastInsertIDReversed {
 | 
					 | 
				
			||||||
				insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			for _, mapValue := range mapValues {
 | 
					 | 
				
			||||||
				if mapValue != nil {
 | 
					 | 
				
			||||||
					mapValue[pkFieldName] = insertID
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				insertID += schema.DefaultAutoIncrementIncrement
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		default:
 | 
					 | 
				
			||||||
			if pkField == nil {
 | 
					 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -187,10 +122,10 @@ func Create(config *Config) func(db *gorm.DB) {
 | 
				
			|||||||
							break
 | 
												break
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
						_, isZero := pkField.ValueOf(db.Statement.Context, rv)
 | 
											_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
 | 
				
			||||||
						if isZero {
 | 
											if isZero {
 | 
				
			||||||
							db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
 | 
												db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
 | 
				
			||||||
							insertID -= pkField.AutoIncrementIncrement
 | 
												insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
@ -200,16 +135,16 @@ func Create(config *Config) func(db *gorm.DB) {
 | 
				
			|||||||
							break
 | 
												break
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
						if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
 | 
											if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
 | 
				
			||||||
							db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
 | 
												db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
 | 
				
			||||||
							insertID += pkField.AutoIncrementIncrement
 | 
												insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			case reflect.Struct:
 | 
								case reflect.Struct:
 | 
				
			||||||
				_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
 | 
									_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
 | 
				
			||||||
				if isZero {
 | 
									if isZero {
 | 
				
			||||||
					db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
 | 
										db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -318,15 +253,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
 | 
				
			|||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
 | 
								for field, vs := range defaultValueFieldsHavingValue {
 | 
				
			||||||
				if vs, ok := defaultValueFieldsHavingValue[field]; ok {
 | 
									values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
 | 
				
			||||||
					values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
 | 
									for idx := range values.Values {
 | 
				
			||||||
					for idx := range values.Values {
 | 
										if vs[idx] == nil {
 | 
				
			||||||
						if vs[idx] == nil {
 | 
											values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
 | 
				
			||||||
							values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
 | 
										} else {
 | 
				
			||||||
						} else {
 | 
											values.Values[idx] = append(values.Values[idx], vs[idx])
 | 
				
			||||||
							values.Values[idx] = append(values.Values[idx], vs[idx])
 | 
					 | 
				
			||||||
						}
 | 
					 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@ -349,7 +282,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
 | 
								for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
 | 
				
			||||||
				if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
 | 
									if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
 | 
				
			||||||
					if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
 | 
										if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
 | 
				
			||||||
						values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
 | 
											values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
 | 
				
			||||||
						values.Values[0] = append(values.Values[0], rvOfvalue)
 | 
											values.Values[0] = append(values.Values[0], rvOfvalue)
 | 
				
			||||||
@ -378,7 +311,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
 | 
				
			|||||||
									case schema.UnixNanosecond:
 | 
														case schema.UnixNanosecond:
 | 
				
			||||||
										assignment.Value = curTime.UnixNano()
 | 
															assignment.Value = curTime.UnixNano()
 | 
				
			||||||
									case schema.UnixMillisecond:
 | 
														case schema.UnixMillisecond:
 | 
				
			||||||
										assignment.Value = curTime.UnixMilli()
 | 
															assignment.Value = curTime.UnixNano() / 1e6
 | 
				
			||||||
									case schema.UnixSecond:
 | 
														case schema.UnixSecond:
 | 
				
			||||||
										assignment.Value = curTime.Unix()
 | 
															assignment.Value = curTime.Unix()
 | 
				
			||||||
									}
 | 
														}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,71 +0,0 @@
 | 
				
			|||||||
package callbacks
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"reflect"
 | 
					 | 
				
			||||||
	"sync"
 | 
					 | 
				
			||||||
	"testing"
 | 
					 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"gorm.io/gorm"
 | 
					 | 
				
			||||||
	"gorm.io/gorm/clause"
 | 
					 | 
				
			||||||
	"gorm.io/gorm/schema"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
var schemaCache = &sync.Map{}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestConvertToCreateValues_DestType_Slice(t *testing.T) {
 | 
					 | 
				
			||||||
	type user struct {
 | 
					 | 
				
			||||||
		ID    int `gorm:"primaryKey"`
 | 
					 | 
				
			||||||
		Name  string
 | 
					 | 
				
			||||||
		Email string `gorm:"default:(-)"`
 | 
					 | 
				
			||||||
		Age   int    `gorm:"default:(-)"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("parse schema error: %v, is not expected", err)
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	dest := []*user{
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			ID:    1,
 | 
					 | 
				
			||||||
			Name:  "alice",
 | 
					 | 
				
			||||||
			Email: "email",
 | 
					 | 
				
			||||||
			Age:   18,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			ID:    2,
 | 
					 | 
				
			||||||
			Name:  "bob",
 | 
					 | 
				
			||||||
			Email: "email",
 | 
					 | 
				
			||||||
			Age:   19,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	stmt := &gorm.Statement{
 | 
					 | 
				
			||||||
		DB: &gorm.DB{
 | 
					 | 
				
			||||||
			Config: &gorm.Config{
 | 
					 | 
				
			||||||
				NowFunc: func() time.Time { return time.Time{} },
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
			Statement: &gorm.Statement{
 | 
					 | 
				
			||||||
				Settings: sync.Map{},
 | 
					 | 
				
			||||||
				Schema:   s,
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		ReflectValue: reflect.ValueOf(dest),
 | 
					 | 
				
			||||||
		Dest:         dest,
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	stmt.Schema = s
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	values := ConvertToCreateValues(stmt)
 | 
					 | 
				
			||||||
	expected := clause.Values{
 | 
					 | 
				
			||||||
		// column has value + defaultValue column has value (which should have a stable order)
 | 
					 | 
				
			||||||
		Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}},
 | 
					 | 
				
			||||||
		Values: [][]interface{}{
 | 
					 | 
				
			||||||
			{"alice", "email", 18, 1},
 | 
					 | 
				
			||||||
			{"bob", "email", 19, 2},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if !reflect.DeepEqual(expected, values) {
 | 
					 | 
				
			||||||
		t.Errorf("expected: %v got %v", expected, values)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@ -157,14 +157,8 @@ func Delete(config *Config) func(db *gorm.DB) {
 | 
				
			|||||||
			ok, mode := hasReturning(db, supportReturning)
 | 
								ok, mode := hasReturning(db, supportReturning)
 | 
				
			||||||
			if !ok {
 | 
								if !ok {
 | 
				
			||||||
				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 | 
									result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 | 
				
			||||||
 | 
					 | 
				
			||||||
				if db.AddError(err) == nil {
 | 
									if db.AddError(err) == nil {
 | 
				
			||||||
					db.RowsAffected, _ = result.RowsAffected()
 | 
										db.RowsAffected, _ = result.RowsAffected()
 | 
				
			||||||
 | 
					 | 
				
			||||||
					if db.Statement.Result != nil {
 | 
					 | 
				
			||||||
						db.Statement.Result.Result = result
 | 
					 | 
				
			||||||
						db.Statement.Result.RowsAffected = db.RowsAffected
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
@ -172,10 +166,6 @@ func Delete(config *Config) func(db *gorm.DB) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
 | 
								if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
 | 
				
			||||||
				gorm.Scan(rows, db, mode)
 | 
									gorm.Scan(rows, db, mode)
 | 
				
			||||||
 | 
					 | 
				
			||||||
				if db.Statement.Result != nil {
 | 
					 | 
				
			||||||
					db.Statement.Result.RowsAffected = db.RowsAffected
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				db.AddError(rows.Close())
 | 
									db.AddError(rows.Close())
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,157 +0,0 @@
 | 
				
			|||||||
package callbacks
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"reflect"
 | 
					 | 
				
			||||||
	"testing"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"gorm.io/gorm"
 | 
					 | 
				
			||||||
	"gorm.io/gorm/clause"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestLoadOrStoreVisitMap(t *testing.T) {
 | 
					 | 
				
			||||||
	var vm visitMap
 | 
					 | 
				
			||||||
	var loaded bool
 | 
					 | 
				
			||||||
	type testM struct {
 | 
					 | 
				
			||||||
		Name string
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	t1 := testM{Name: "t1"}
 | 
					 | 
				
			||||||
	t2 := testM{Name: "t2"}
 | 
					 | 
				
			||||||
	t3 := testM{Name: "t3"}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	vm = make(visitMap)
 | 
					 | 
				
			||||||
	if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
 | 
					 | 
				
			||||||
		t.Fatalf("loaded should be false")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
 | 
					 | 
				
			||||||
		t.Fatalf("loaded should be true")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// t1 already exist but t2 not
 | 
					 | 
				
			||||||
	if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
 | 
					 | 
				
			||||||
		t.Fatalf("loaded should be false")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
 | 
					 | 
				
			||||||
		t.Fatalf("loaded should be true")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestConvertMapToValuesForCreate(t *testing.T) {
 | 
					 | 
				
			||||||
	testCase := []struct {
 | 
					 | 
				
			||||||
		name   string
 | 
					 | 
				
			||||||
		input  map[string]interface{}
 | 
					 | 
				
			||||||
		expect clause.Values
 | 
					 | 
				
			||||||
	}{
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name: "Test convert string value",
 | 
					 | 
				
			||||||
			input: map[string]interface{}{
 | 
					 | 
				
			||||||
				"name": "my name",
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
			expect: clause.Values{
 | 
					 | 
				
			||||||
				Columns: []clause.Column{{Name: "name"}},
 | 
					 | 
				
			||||||
				Values:  [][]interface{}{{"my name"}},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name: "Test convert int value",
 | 
					 | 
				
			||||||
			input: map[string]interface{}{
 | 
					 | 
				
			||||||
				"age": 18,
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
			expect: clause.Values{
 | 
					 | 
				
			||||||
				Columns: []clause.Column{{Name: "age"}},
 | 
					 | 
				
			||||||
				Values:  [][]interface{}{{18}},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name: "Test convert float value",
 | 
					 | 
				
			||||||
			input: map[string]interface{}{
 | 
					 | 
				
			||||||
				"score": 99.5,
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
			expect: clause.Values{
 | 
					 | 
				
			||||||
				Columns: []clause.Column{{Name: "score"}},
 | 
					 | 
				
			||||||
				Values:  [][]interface{}{{99.5}},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name: "Test convert bool value",
 | 
					 | 
				
			||||||
			input: map[string]interface{}{
 | 
					 | 
				
			||||||
				"active": true,
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
			expect: clause.Values{
 | 
					 | 
				
			||||||
				Columns: []clause.Column{{Name: "active"}},
 | 
					 | 
				
			||||||
				Values:  [][]interface{}{{true}},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, tc := range testCase {
 | 
					 | 
				
			||||||
		t.Run(tc.name, func(t *testing.T) {
 | 
					 | 
				
			||||||
			actual := ConvertMapToValuesForCreate(&gorm.Statement{}, tc.input)
 | 
					 | 
				
			||||||
			if !reflect.DeepEqual(actual, tc.expect) {
 | 
					 | 
				
			||||||
				t.Errorf("expect %v got %v", tc.expect, actual)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestConvertSliceOfMapToValuesForCreate(t *testing.T) {
 | 
					 | 
				
			||||||
	testCase := []struct {
 | 
					 | 
				
			||||||
		name   string
 | 
					 | 
				
			||||||
		input  []map[string]interface{}
 | 
					 | 
				
			||||||
		expect clause.Values
 | 
					 | 
				
			||||||
	}{
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name: "Test convert slice of string value",
 | 
					 | 
				
			||||||
			input: []map[string]interface{}{
 | 
					 | 
				
			||||||
				{"name": "my name"},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
			expect: clause.Values{
 | 
					 | 
				
			||||||
				Columns: []clause.Column{{Name: "name"}},
 | 
					 | 
				
			||||||
				Values:  [][]interface{}{{"my name"}},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name: "Test convert slice of int value",
 | 
					 | 
				
			||||||
			input: []map[string]interface{}{
 | 
					 | 
				
			||||||
				{"age": 18},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
			expect: clause.Values{
 | 
					 | 
				
			||||||
				Columns: []clause.Column{{Name: "age"}},
 | 
					 | 
				
			||||||
				Values:  [][]interface{}{{18}},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name: "Test convert slice of float value",
 | 
					 | 
				
			||||||
			input: []map[string]interface{}{
 | 
					 | 
				
			||||||
				{"score": 99.5},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
			expect: clause.Values{
 | 
					 | 
				
			||||||
				Columns: []clause.Column{{Name: "score"}},
 | 
					 | 
				
			||||||
				Values:  [][]interface{}{{99.5}},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name: "Test convert slice of bool value",
 | 
					 | 
				
			||||||
			input: []map[string]interface{}{
 | 
					 | 
				
			||||||
				{"active": true},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
			expect: clause.Values{
 | 
					 | 
				
			||||||
				Columns: []clause.Column{{Name: "active"}},
 | 
					 | 
				
			||||||
				Values:  [][]interface{}{{true}},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, tc := range testCase {
 | 
					 | 
				
			||||||
		t.Run(tc.name, func(t *testing.T) {
 | 
					 | 
				
			||||||
			actual := ConvertSliceOfMapToValuesForCreate(&gorm.Statement{}, tc.input)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if !reflect.DeepEqual(actual, tc.expect) {
 | 
					 | 
				
			||||||
				t.Errorf("expected %v but got %v", tc.expect, actual)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@ -3,7 +3,6 @@ package callbacks
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"sort"
 | 
					 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
@ -75,7 +74,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
 | 
				
			|||||||
	names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
 | 
						names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
 | 
				
			||||||
	for _, relation := range embeddedRelations.Relations {
 | 
						for _, relation := range embeddedRelations.Relations {
 | 
				
			||||||
		// skip first struct name
 | 
							// skip first struct name
 | 
				
			||||||
		names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
 | 
							names = append(names, strings.Join(relation.Field.BindNames[1:], "."))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	for _, relations := range embeddedRelations.EmbeddedRelations {
 | 
						for _, relations := range embeddedRelations.EmbeddedRelations {
 | 
				
			||||||
		names = append(names, embeddedValues(relations)...)
 | 
							names = append(names, embeddedValues(relations)...)
 | 
				
			||||||
@ -83,105 +82,27 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
 | 
				
			|||||||
	return names
 | 
						return names
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
 | 
					func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error {
 | 
				
			||||||
// If the current relationship is embedded or joined, current query will be ignored.
 | 
						if relationships == nil {
 | 
				
			||||||
//
 | 
							return nil
 | 
				
			||||||
//nolint:cyclop
 | 
					 | 
				
			||||||
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
 | 
					 | 
				
			||||||
	preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// avoid random traversal of the map
 | 
					 | 
				
			||||||
	preloadNames := make([]string, 0, len(preloadMap))
 | 
					 | 
				
			||||||
	for key := range preloadMap {
 | 
					 | 
				
			||||||
		preloadNames = append(preloadNames, key)
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	sort.Strings(preloadNames)
 | 
						preloadMap := parsePreloadMap(s, preloads)
 | 
				
			||||||
 | 
						for name := range preloadMap {
 | 
				
			||||||
	isJoined := func(name string) (joined bool, nestedJoins []string) {
 | 
							if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil {
 | 
				
			||||||
		for _, join := range joins {
 | 
								if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil {
 | 
				
			||||||
			if _, ok := relationships.Relations[join]; ok && name == join {
 | 
					 | 
				
			||||||
				joined = true
 | 
					 | 
				
			||||||
				continue
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			join0, join1, cut := strings.Cut(join, ".")
 | 
					 | 
				
			||||||
			if cut {
 | 
					 | 
				
			||||||
				if _, ok := relationships.Relations[join0]; ok && name == join0 {
 | 
					 | 
				
			||||||
					joined = true
 | 
					 | 
				
			||||||
					nestedJoins = append(nestedJoins, join1)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return joined, nestedJoins
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, name := range preloadNames {
 | 
					 | 
				
			||||||
		if relations := relationships.EmbeddedRelations[name]; relations != nil {
 | 
					 | 
				
			||||||
			if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
 | 
					 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		} else if rel := relationships.Relations[name]; rel != nil {
 | 
							} else if rel := relationships.Relations[name]; rel != nil {
 | 
				
			||||||
			if joined, nestedJoins := isJoined(name); joined {
 | 
								if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil {
 | 
				
			||||||
				switch rv := db.Statement.ReflectValue; rv.Kind() {
 | 
									return err
 | 
				
			||||||
				case reflect.Slice, reflect.Array:
 | 
					 | 
				
			||||||
					if rv.Len() > 0 {
 | 
					 | 
				
			||||||
						reflectValue := rel.FieldSchema.MakeSlice().Elem()
 | 
					 | 
				
			||||||
						for i := 0; i < rv.Len(); i++ {
 | 
					 | 
				
			||||||
							frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
 | 
					 | 
				
			||||||
							if frv.Kind() != reflect.Ptr {
 | 
					 | 
				
			||||||
								reflectValue = reflect.Append(reflectValue, frv.Addr())
 | 
					 | 
				
			||||||
							} else {
 | 
					 | 
				
			||||||
								if frv.IsNil() {
 | 
					 | 
				
			||||||
									continue
 | 
					 | 
				
			||||||
								}
 | 
					 | 
				
			||||||
								reflectValue = reflect.Append(reflectValue, frv)
 | 
					 | 
				
			||||||
							}
 | 
					 | 
				
			||||||
						}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
						tx := preloadDB(db, reflectValue, reflectValue.Interface())
 | 
					 | 
				
			||||||
						if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
 | 
					 | 
				
			||||||
							return err
 | 
					 | 
				
			||||||
						}
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				case reflect.Struct, reflect.Pointer:
 | 
					 | 
				
			||||||
					reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
 | 
					 | 
				
			||||||
					tx := preloadDB(db, reflectValue, reflectValue.Interface())
 | 
					 | 
				
			||||||
					if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
 | 
					 | 
				
			||||||
						return err
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				default:
 | 
					 | 
				
			||||||
					return gorm.ErrInvalidData
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
 | 
					 | 
				
			||||||
				tx.Statement.ReflectValue = db.Statement.ReflectValue
 | 
					 | 
				
			||||||
				tx.Statement.Unscoped = db.Statement.Unscoped
 | 
					 | 
				
			||||||
				if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
 | 
					 | 
				
			||||||
					return err
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
 | 
								return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
 | 
					 | 
				
			||||||
	tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
 | 
					 | 
				
			||||||
	db.Statement.Settings.Range(func(k, v interface{}) bool {
 | 
					 | 
				
			||||||
		tx.Statement.Settings.Store(k, v)
 | 
					 | 
				
			||||||
		return true
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := tx.Statement.Parse(dest); err != nil {
 | 
					 | 
				
			||||||
		tx.AddError(err)
 | 
					 | 
				
			||||||
		return tx
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	tx.Statement.ReflectValue = reflectValue
 | 
					 | 
				
			||||||
	tx.Statement.Unscoped = db.Statement.Unscoped
 | 
					 | 
				
			||||||
	return tx
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
 | 
					func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		reflectValue     = tx.Statement.ReflectValue
 | 
							reflectValue     = tx.Statement.ReflectValue
 | 
				
			||||||
@ -275,8 +196,6 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
 | 
				
			|||||||
	column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
 | 
						column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(values) != 0 {
 | 
						if len(values) != 0 {
 | 
				
			||||||
		tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		for _, cond := range conds {
 | 
							for _, cond := range conds {
 | 
				
			||||||
			if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
 | 
								if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
 | 
				
			||||||
				tx = fc(tx)
 | 
									tx = fc(tx)
 | 
				
			||||||
@ -285,11 +204,7 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if len(inlineConds) > 0 {
 | 
							if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
 | 
				
			||||||
			tx = tx.Where(inlineConds[0], inlineConds[1:]...)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
 | 
					 | 
				
			||||||
			return err
 | 
								return err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -3,6 +3,7 @@ package callbacks
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
 | 
						"sort"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
@ -25,10 +26,6 @@ func Query(db *gorm.DB) {
 | 
				
			|||||||
				db.AddError(rows.Close())
 | 
									db.AddError(rows.Close())
 | 
				
			||||||
			}()
 | 
								}()
 | 
				
			||||||
			gorm.Scan(rows, db, 0)
 | 
								gorm.Scan(rows, db, 0)
 | 
				
			||||||
 | 
					 | 
				
			||||||
			if db.Statement.Result != nil {
 | 
					 | 
				
			||||||
				db.Statement.Result.RowsAffected = db.RowsAffected
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -114,7 +111,7 @@ func BuildQuerySQL(db *gorm.DB) {
 | 
				
			|||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
 | 
								specifiedRelationsName := make(map[string]interface{})
 | 
				
			||||||
			for _, join := range db.Statement.Joins {
 | 
								for _, join := range db.Statement.Joins {
 | 
				
			||||||
				if db.Statement.Schema != nil {
 | 
									if db.Statement.Schema != nil {
 | 
				
			||||||
					var isRelations bool // is relations or raw sql
 | 
										var isRelations bool // is relations or raw sql
 | 
				
			||||||
@ -128,12 +125,12 @@ func BuildQuerySQL(db *gorm.DB) {
 | 
				
			|||||||
						nestedJoinNames := strings.Split(join.Name, ".")
 | 
											nestedJoinNames := strings.Split(join.Name, ".")
 | 
				
			||||||
						if len(nestedJoinNames) > 1 {
 | 
											if len(nestedJoinNames) > 1 {
 | 
				
			||||||
							isNestedJoin := true
 | 
												isNestedJoin := true
 | 
				
			||||||
							guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
 | 
												gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
 | 
				
			||||||
							currentRelations := db.Statement.Schema.Relationships.Relations
 | 
												currentRelations := db.Statement.Schema.Relationships.Relations
 | 
				
			||||||
							for _, relname := range nestedJoinNames {
 | 
												for _, relname := range nestedJoinNames {
 | 
				
			||||||
								// incomplete match, only treated as raw sql
 | 
													// incomplete match, only treated as raw sql
 | 
				
			||||||
								if relation, ok = currentRelations[relname]; ok {
 | 
													if relation, ok = currentRelations[relname]; ok {
 | 
				
			||||||
									guessNestedRelations = append(guessNestedRelations, relation)
 | 
														gussNestedRelations = append(gussNestedRelations, relation)
 | 
				
			||||||
									currentRelations = relation.FieldSchema.Relationships.Relations
 | 
														currentRelations = relation.FieldSchema.Relationships.Relations
 | 
				
			||||||
								} else {
 | 
													} else {
 | 
				
			||||||
									isNestedJoin = false
 | 
														isNestedJoin = false
 | 
				
			||||||
@ -143,13 +140,18 @@ func BuildQuerySQL(db *gorm.DB) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
							if isNestedJoin {
 | 
												if isNestedJoin {
 | 
				
			||||||
								isRelations = true
 | 
													isRelations = true
 | 
				
			||||||
								relations = guessNestedRelations
 | 
													relations = gussNestedRelations
 | 
				
			||||||
							}
 | 
												}
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
					if isRelations {
 | 
										if isRelations {
 | 
				
			||||||
						genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
 | 
											genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
 | 
				
			||||||
 | 
												tableAliasName := relation.Name
 | 
				
			||||||
 | 
												if parentTableName != clause.CurrentTable {
 | 
				
			||||||
 | 
													tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							columnStmt := gorm.Statement{
 | 
												columnStmt := gorm.Statement{
 | 
				
			||||||
								Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
 | 
													Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
 | 
				
			||||||
								Selects: join.Selects, Omits: join.Omits,
 | 
													Selects: join.Selects, Omits: join.Omits,
 | 
				
			||||||
@ -166,13 +168,6 @@ func BuildQuerySQL(db *gorm.DB) {
 | 
				
			|||||||
								}
 | 
													}
 | 
				
			||||||
							}
 | 
												}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							if join.Expression != nil {
 | 
					 | 
				
			||||||
								return clause.Join{
 | 
					 | 
				
			||||||
									Type:       join.JoinType,
 | 
					 | 
				
			||||||
									Expression: join.Expression,
 | 
					 | 
				
			||||||
								}
 | 
					 | 
				
			||||||
							}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
							exprs := make([]clause.Expression, len(relation.References))
 | 
												exprs := make([]clause.Expression, len(relation.References))
 | 
				
			||||||
							for idx, ref := range relation.References {
 | 
												for idx, ref := range relation.References {
 | 
				
			||||||
								if ref.OwnPrimaryKey {
 | 
													if ref.OwnPrimaryKey {
 | 
				
			||||||
@ -232,24 +227,19 @@ func BuildQuerySQL(db *gorm.DB) {
 | 
				
			|||||||
						}
 | 
											}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
						parentTableName := clause.CurrentTable
 | 
											parentTableName := clause.CurrentTable
 | 
				
			||||||
						for idx, rel := range relations {
 | 
											for _, rel := range relations {
 | 
				
			||||||
							// joins table alias like "Manager, Company, Manager__Company"
 | 
												// joins table alias like "Manager, Company, Manager__Company"
 | 
				
			||||||
							curAliasName := rel.Name
 | 
												nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
 | 
				
			||||||
 | 
												if _, ok := specifiedRelationsName[nestedAlias]; !ok {
 | 
				
			||||||
 | 
													fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
 | 
				
			||||||
 | 
													specifiedRelationsName[nestedAlias] = nil
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							if parentTableName != clause.CurrentTable {
 | 
												if parentTableName != clause.CurrentTable {
 | 
				
			||||||
								curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
 | 
													parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
 | 
				
			||||||
 | 
												} else {
 | 
				
			||||||
 | 
													parentTableName = rel.Name
 | 
				
			||||||
							}
 | 
												}
 | 
				
			||||||
 | 
					 | 
				
			||||||
							if _, ok := specifiedRelationsName[curAliasName]; !ok {
 | 
					 | 
				
			||||||
								aliasName := curAliasName
 | 
					 | 
				
			||||||
								if idx == len(relations)-1 && join.Alias != "" {
 | 
					 | 
				
			||||||
									aliasName = join.Alias
 | 
					 | 
				
			||||||
								}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
								fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
 | 
					 | 
				
			||||||
								specifiedRelationsName[curAliasName] = aliasName
 | 
					 | 
				
			||||||
							}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
							parentTableName = curAliasName
 | 
					 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					} else {
 | 
										} else {
 | 
				
			||||||
						fromClause.Joins = append(fromClause.Joins, clause.Join{
 | 
											fromClause.Joins = append(fromClause.Joins, clause.Join{
 | 
				
			||||||
@ -264,6 +254,7 @@ func BuildQuerySQL(db *gorm.DB) {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			db.Statement.AddClause(fromClause)
 | 
								db.Statement.AddClause(fromClause)
 | 
				
			||||||
 | 
								db.Statement.Joins = nil
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			db.Statement.AddClauseIfNotExists(clause.From{})
 | 
								db.Statement.AddClauseIfNotExists(clause.From{})
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -281,27 +272,38 @@ func Preload(db *gorm.DB) {
 | 
				
			|||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		joins := make([]string, 0, len(db.Statement.Joins))
 | 
							preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads)
 | 
				
			||||||
		for _, join := range db.Statement.Joins {
 | 
							preloadNames := make([]string, 0, len(preloadMap))
 | 
				
			||||||
			joins = append(joins, join.Name)
 | 
							for key := range preloadMap {
 | 
				
			||||||
 | 
								preloadNames = append(preloadNames, key)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
							sort.Strings(preloadNames)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
 | 
							preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
 | 
				
			||||||
		if tx.Error != nil {
 | 
							db.Statement.Settings.Range(func(k, v interface{}) bool {
 | 
				
			||||||
 | 
								preloadDB.Statement.Settings.Store(k, v)
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil {
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
							preloadDB.Statement.ReflectValue = db.Statement.ReflectValue
 | 
				
			||||||
 | 
							preloadDB.Statement.Unscoped = db.Statement.Unscoped
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
 | 
							for _, name := range preloadNames {
 | 
				
			||||||
 | 
								if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil {
 | 
				
			||||||
 | 
									db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations]))
 | 
				
			||||||
 | 
								} else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
 | 
				
			||||||
 | 
									db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func AfterQuery(db *gorm.DB) {
 | 
					func AfterQuery(db *gorm.DB) {
 | 
				
			||||||
	// clear the joins after query because preload need it
 | 
					 | 
				
			||||||
	if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
 | 
					 | 
				
			||||||
		fromClause := db.Statement.Clauses["FROM"]
 | 
					 | 
				
			||||||
		fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(db.Statement.Joins))} // keep the original From Joins
 | 
					 | 
				
			||||||
		db.Statement.Clauses["FROM"] = fromClause
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
 | 
						if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
 | 
				
			||||||
		callMethod(db, func(value interface{}, tx *gorm.DB) bool {
 | 
							callMethod(db, func(value interface{}, tx *gorm.DB) bool {
 | 
				
			||||||
			if i, ok := value.(AfterFindInterface); ok {
 | 
								if i, ok := value.(AfterFindInterface); ok {
 | 
				
			||||||
 | 
				
			|||||||
@ -13,10 +13,5 @@ func RawExec(db *gorm.DB) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		db.RowsAffected, _ = result.RowsAffected()
 | 
							db.RowsAffected, _ = result.RowsAffected()
 | 
				
			||||||
 | 
					 | 
				
			||||||
		if db.Statement.Result != nil {
 | 
					 | 
				
			||||||
			db.Statement.Result.Result = result
 | 
					 | 
				
			||||||
			db.Statement.Result.RowsAffected = db.RowsAffected
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -92,10 +92,6 @@ func Update(config *Config) func(db *gorm.DB) {
 | 
				
			|||||||
					gorm.Scan(rows, db, mode)
 | 
										gorm.Scan(rows, db, mode)
 | 
				
			||||||
					db.Statement.Dest = dest
 | 
										db.Statement.Dest = dest
 | 
				
			||||||
					db.AddError(rows.Close())
 | 
										db.AddError(rows.Close())
 | 
				
			||||||
 | 
					 | 
				
			||||||
					if db.Statement.Result != nil {
 | 
					 | 
				
			||||||
						db.Statement.Result.RowsAffected = db.RowsAffected
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 | 
									result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 | 
				
			||||||
@ -103,11 +99,6 @@ func Update(config *Config) func(db *gorm.DB) {
 | 
				
			|||||||
				if db.AddError(err) == nil {
 | 
									if db.AddError(err) == nil {
 | 
				
			||||||
					db.RowsAffected, _ = result.RowsAffected()
 | 
										db.RowsAffected, _ = result.RowsAffected()
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					 | 
				
			||||||
				if db.Statement.Result != nil {
 | 
					 | 
				
			||||||
					db.Statement.Result.Result = result
 | 
					 | 
				
			||||||
					db.Statement.Result.RowsAffected = db.RowsAffected
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -243,7 +234,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
 | 
				
			|||||||
						if field.AutoUpdateTime == schema.UnixNanosecond {
 | 
											if field.AutoUpdateTime == schema.UnixNanosecond {
 | 
				
			||||||
							set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
 | 
												set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
 | 
				
			||||||
						} else if field.AutoUpdateTime == schema.UnixMillisecond {
 | 
											} else if field.AutoUpdateTime == schema.UnixMillisecond {
 | 
				
			||||||
							set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()})
 | 
												set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
 | 
				
			||||||
						} else if field.AutoUpdateTime == schema.UnixSecond {
 | 
											} else if field.AutoUpdateTime == schema.UnixSecond {
 | 
				
			||||||
							set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
 | 
												set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
 | 
				
			||||||
						} else {
 | 
											} else {
 | 
				
			||||||
@ -277,7 +268,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
 | 
				
			|||||||
								if field.AutoUpdateTime == schema.UnixNanosecond {
 | 
													if field.AutoUpdateTime == schema.UnixNanosecond {
 | 
				
			||||||
									value = stmt.DB.NowFunc().UnixNano()
 | 
														value = stmt.DB.NowFunc().UnixNano()
 | 
				
			||||||
								} else if field.AutoUpdateTime == schema.UnixMillisecond {
 | 
													} else if field.AutoUpdateTime == schema.UnixMillisecond {
 | 
				
			||||||
									value = stmt.DB.NowFunc().UnixMilli()
 | 
														value = stmt.DB.NowFunc().UnixNano() / 1e6
 | 
				
			||||||
								} else if field.AutoUpdateTime == schema.UnixSecond {
 | 
													} else if field.AutoUpdateTime == schema.UnixSecond {
 | 
				
			||||||
									value = stmt.DB.NowFunc().Unix()
 | 
														value = stmt.DB.NowFunc().Unix()
 | 
				
			||||||
								} else {
 | 
													} else {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										36
									
								
								callbacks/visit_map_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								callbacks/visit_map_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,36 @@
 | 
				
			|||||||
 | 
					package callbacks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestLoadOrStoreVisitMap(t *testing.T) {
 | 
				
			||||||
 | 
						var vm visitMap
 | 
				
			||||||
 | 
						var loaded bool
 | 
				
			||||||
 | 
						type testM struct {
 | 
				
			||||||
 | 
							Name string
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t1 := testM{Name: "t1"}
 | 
				
			||||||
 | 
						t2 := testM{Name: "t2"}
 | 
				
			||||||
 | 
						t3 := testM{Name: "t3"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						vm = make(visitMap)
 | 
				
			||||||
 | 
						if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
 | 
				
			||||||
 | 
							t.Fatalf("loaded should be false")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
 | 
				
			||||||
 | 
							t.Fatalf("loaded should be true")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// t1 already exist but t2 not
 | 
				
			||||||
 | 
						if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
 | 
				
			||||||
 | 
							t.Fatalf("loaded should be false")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
 | 
				
			||||||
 | 
							t.Fatalf("loaded should be true")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -185,13 +185,6 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// MapColumns modify the column names in the query results to facilitate align to the corresponding structural fields
 | 
					 | 
				
			||||||
func (db *DB) MapColumns(m map[string]string) (tx *DB) {
 | 
					 | 
				
			||||||
	tx = db.getInstance()
 | 
					 | 
				
			||||||
	tx.Statement.ColumnMapping = m
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Where add conditions
 | 
					// Where add conditions
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
 | 
					// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
 | 
				
			||||||
@ -306,16 +299,10 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
 | 
				
			|||||||
//
 | 
					//
 | 
				
			||||||
//	db.Order("name DESC")
 | 
					//	db.Order("name DESC")
 | 
				
			||||||
//	db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
 | 
					//	db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
 | 
				
			||||||
//	db.Order(clause.OrderBy{Columns: []clause.OrderByColumn{
 | 
					 | 
				
			||||||
//		{Column: clause.Column{Name: "name"}, Desc: true},
 | 
					 | 
				
			||||||
//		{Column: clause.Column{Name: "age"}, Desc: true},
 | 
					 | 
				
			||||||
//	}})
 | 
					 | 
				
			||||||
func (db *DB) Order(value interface{}) (tx *DB) {
 | 
					func (db *DB) Order(value interface{}) (tx *DB) {
 | 
				
			||||||
	tx = db.getInstance()
 | 
						tx = db.getInstance()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	switch v := value.(type) {
 | 
						switch v := value.(type) {
 | 
				
			||||||
	case clause.OrderBy:
 | 
					 | 
				
			||||||
		tx.Statement.AddClause(v)
 | 
					 | 
				
			||||||
	case clause.OrderByColumn:
 | 
						case clause.OrderByColumn:
 | 
				
			||||||
		tx.Statement.AddClause(clause.OrderBy{
 | 
							tx.Statement.AddClause(clause.OrderBy{
 | 
				
			||||||
			Columns: []clause.OrderByColumn{v},
 | 
								Columns: []clause.OrderByColumn{v},
 | 
				
			||||||
@ -380,12 +367,33 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (db *DB) executeScopes() (tx *DB) {
 | 
					func (db *DB) executeScopes() (tx *DB) {
 | 
				
			||||||
 | 
						tx = db.getInstance()
 | 
				
			||||||
	scopes := db.Statement.scopes
 | 
						scopes := db.Statement.scopes
 | 
				
			||||||
	db.Statement.scopes = nil
 | 
						if len(scopes) == 0 {
 | 
				
			||||||
	for _, scope := range scopes {
 | 
							return tx
 | 
				
			||||||
		db = scope(db)
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return db
 | 
						tx.Statement.scopes = nil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						conditions := make([]clause.Interface, 0, 4)
 | 
				
			||||||
 | 
						if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
 | 
				
			||||||
 | 
							conditions = append(conditions, cs.Expression.(clause.Interface))
 | 
				
			||||||
 | 
							cs.Expression = nil
 | 
				
			||||||
 | 
							tx.Statement.Clauses["WHERE"] = cs
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, scope := range scopes {
 | 
				
			||||||
 | 
							tx = scope(tx)
 | 
				
			||||||
 | 
							if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
 | 
				
			||||||
 | 
								conditions = append(conditions, cs.Expression.(clause.Interface))
 | 
				
			||||||
 | 
								cs.Expression = nil
 | 
				
			||||||
 | 
								tx.Statement.Clauses["WHERE"] = cs
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, condition := range conditions {
 | 
				
			||||||
 | 
							tx.Statement.AddClause(condition)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return tx
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Preload preload associations with given conditions
 | 
					// Preload preload associations with given conditions
 | 
				
			||||||
@ -442,16 +450,6 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Unscoped disables the global scope of soft deletion in a query.
 | 
					 | 
				
			||||||
// By default, GORM uses soft deletion, marking records as "deleted"
 | 
					 | 
				
			||||||
// by setting a timestamp on a specific field (e.g., `deleted_at`).
 | 
					 | 
				
			||||||
// Unscoped allows queries to include records marked as deleted,
 | 
					 | 
				
			||||||
// overriding the soft deletion behavior.
 | 
					 | 
				
			||||||
// Example:
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
//	var users []User
 | 
					 | 
				
			||||||
//	db.Unscoped().Find(&users)
 | 
					 | 
				
			||||||
//	// Retrieves all users, including deleted ones.
 | 
					 | 
				
			||||||
func (db *DB) Unscoped() (tx *DB) {
 | 
					func (db *DB) Unscoped() (tx *DB) {
 | 
				
			||||||
	tx = db.getInstance()
 | 
						tx = db.getInstance()
 | 
				
			||||||
	tx.Statement.Unscoped = true
 | 
						tx.Statement.Unscoped = true
 | 
				
			||||||
 | 
				
			|||||||
@ -126,7 +126,7 @@ func (expr NamedExpr) Build(builder Builder) {
 | 
				
			|||||||
	for _, v := range []byte(expr.SQL) {
 | 
						for _, v := range []byte(expr.SQL) {
 | 
				
			||||||
		if v == '@' && !inName {
 | 
							if v == '@' && !inName {
 | 
				
			||||||
			inName = true
 | 
								inName = true
 | 
				
			||||||
			name = name[:0]
 | 
								name = []byte{}
 | 
				
			||||||
		} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' {
 | 
							} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' {
 | 
				
			||||||
			if inName {
 | 
								if inName {
 | 
				
			||||||
				if nv, ok := namedMap[string(name)]; ok {
 | 
									if nv, ok := namedMap[string(name)]; ok {
 | 
				
			||||||
 | 
				
			|||||||
@ -1,7 +1,5 @@
 | 
				
			|||||||
package clause
 | 
					package clause
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import "gorm.io/gorm/utils"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type JoinType string
 | 
					type JoinType string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
@ -11,30 +9,6 @@ const (
 | 
				
			|||||||
	RightJoin JoinType = "RIGHT"
 | 
						RightJoin JoinType = "RIGHT"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type JoinTarget struct {
 | 
					 | 
				
			||||||
	Type        JoinType
 | 
					 | 
				
			||||||
	Association string
 | 
					 | 
				
			||||||
	Subquery    Expression
 | 
					 | 
				
			||||||
	Table       string
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func Has(name string) JoinTarget {
 | 
					 | 
				
			||||||
	return JoinTarget{Type: InnerJoin, Association: name}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (jt JoinType) Association(name string) JoinTarget {
 | 
					 | 
				
			||||||
	return JoinTarget{Type: jt, Association: name}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
 | 
					 | 
				
			||||||
	return JoinTarget{Type: jt, Association: name, Subquery: subquery}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (jt JoinTarget) As(name string) JoinTarget {
 | 
					 | 
				
			||||||
	jt.Table = name
 | 
					 | 
				
			||||||
	return jt
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Join clause for from
 | 
					// Join clause for from
 | 
				
			||||||
type Join struct {
 | 
					type Join struct {
 | 
				
			||||||
	Type       JoinType
 | 
						Type       JoinType
 | 
				
			||||||
@ -44,12 +18,6 @@ type Join struct {
 | 
				
			|||||||
	Expression Expression
 | 
						Expression Expression
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func JoinTable(names ...string) Table {
 | 
					 | 
				
			||||||
	return Table{
 | 
					 | 
				
			||||||
		Name: utils.JoinNestedRelationNames(names),
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (join Join) Build(builder Builder) {
 | 
					func (join Join) Build(builder Builder) {
 | 
				
			||||||
	if join.Expression != nil {
 | 
						if join.Expression != nil {
 | 
				
			||||||
		join.Expression.Build(builder)
 | 
							join.Expression.Build(builder)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,5 +1,7 @@
 | 
				
			|||||||
package clause
 | 
					package clause
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "strconv"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Limit limit clause
 | 
					// Limit limit clause
 | 
				
			||||||
type Limit struct {
 | 
					type Limit struct {
 | 
				
			||||||
	Limit  *int
 | 
						Limit  *int
 | 
				
			||||||
@ -15,14 +17,14 @@ func (limit Limit) Name() string {
 | 
				
			|||||||
func (limit Limit) Build(builder Builder) {
 | 
					func (limit Limit) Build(builder Builder) {
 | 
				
			||||||
	if limit.Limit != nil && *limit.Limit >= 0 {
 | 
						if limit.Limit != nil && *limit.Limit >= 0 {
 | 
				
			||||||
		builder.WriteString("LIMIT ")
 | 
							builder.WriteString("LIMIT ")
 | 
				
			||||||
		builder.AddVar(builder, *limit.Limit)
 | 
							builder.WriteString(strconv.Itoa(*limit.Limit))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if limit.Offset > 0 {
 | 
						if limit.Offset > 0 {
 | 
				
			||||||
		if limit.Limit != nil && *limit.Limit >= 0 {
 | 
							if limit.Limit != nil && *limit.Limit >= 0 {
 | 
				
			||||||
			builder.WriteByte(' ')
 | 
								builder.WriteByte(' ')
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		builder.WriteString("OFFSET ")
 | 
							builder.WriteString("OFFSET ")
 | 
				
			||||||
		builder.AddVar(builder, limit.Offset)
 | 
							builder.WriteString(strconv.Itoa(limit.Offset))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -22,53 +22,43 @@ func TestLimit(t *testing.T) {
 | 
				
			|||||||
				Limit:  &limit10,
 | 
									Limit:  &limit10,
 | 
				
			||||||
				Offset: 20,
 | 
									Offset: 20,
 | 
				
			||||||
			}},
 | 
								}},
 | 
				
			||||||
			"SELECT * FROM `users` LIMIT ? OFFSET ?",
 | 
								"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
 | 
				
			||||||
			[]interface{}{limit10, 20},
 | 
					 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
 | 
								[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
 | 
				
			||||||
			"SELECT * FROM `users` LIMIT ?",
 | 
								"SELECT * FROM `users` LIMIT 0", nil,
 | 
				
			||||||
			[]interface{}{limit0},
 | 
					 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}},
 | 
								[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}},
 | 
				
			||||||
			"SELECT * FROM `users` LIMIT ?",
 | 
								"SELECT * FROM `users` LIMIT 0", nil,
 | 
				
			||||||
			[]interface{}{limit0},
 | 
					 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
 | 
								[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
 | 
				
			||||||
			"SELECT * FROM `users` OFFSET ?",
 | 
								"SELECT * FROM `users` OFFSET 20", nil,
 | 
				
			||||||
			[]interface{}{20},
 | 
					 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}},
 | 
								[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}},
 | 
				
			||||||
			"SELECT * FROM `users` OFFSET ?",
 | 
								"SELECT * FROM `users` OFFSET 30", nil,
 | 
				
			||||||
			[]interface{}{30},
 | 
					 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}},
 | 
								[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}},
 | 
				
			||||||
			"SELECT * FROM `users` LIMIT ? OFFSET ?",
 | 
								"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
 | 
				
			||||||
			[]interface{}{limit10, 20},
 | 
					 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}},
 | 
								[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}},
 | 
				
			||||||
			"SELECT * FROM `users` LIMIT ? OFFSET ?",
 | 
								"SELECT * FROM `users` LIMIT 10 OFFSET 30", nil,
 | 
				
			||||||
			[]interface{}{limit10, 30},
 | 
					 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
 | 
								[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
 | 
				
			||||||
			"SELECT * FROM `users` LIMIT ?",
 | 
								"SELECT * FROM `users` LIMIT 10", nil,
 | 
				
			||||||
			[]interface{}{limit10},
 | 
					 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}},
 | 
								[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}},
 | 
				
			||||||
			"SELECT * FROM `users` OFFSET ?",
 | 
								"SELECT * FROM `users` OFFSET 30", nil,
 | 
				
			||||||
			[]interface{}{30},
 | 
					 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}},
 | 
								[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}},
 | 
				
			||||||
			"SELECT * FROM `users` LIMIT ? OFFSET ?",
 | 
								"SELECT * FROM `users` LIMIT 50 OFFSET 30", nil,
 | 
				
			||||||
			[]interface{}{limit50, 30},
 | 
					 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -1,12 +1,5 @@
 | 
				
			|||||||
package clause
 | 
					package clause
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					 | 
				
			||||||
	LockingStrengthUpdate    = "UPDATE"
 | 
					 | 
				
			||||||
	LockingStrengthShare     = "SHARE"
 | 
					 | 
				
			||||||
	LockingOptionsSkipLocked = "SKIP LOCKED"
 | 
					 | 
				
			||||||
	LockingOptionsNoWait     = "NOWAIT"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type Locking struct {
 | 
					type Locking struct {
 | 
				
			||||||
	Strength string
 | 
						Strength string
 | 
				
			||||||
	Table    Table
 | 
						Table    Table
 | 
				
			||||||
 | 
				
			|||||||
@ -14,21 +14,17 @@ func TestLocking(t *testing.T) {
 | 
				
			|||||||
		Vars    []interface{}
 | 
							Vars    []interface{}
 | 
				
			||||||
	}{
 | 
						}{
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate}},
 | 
								[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}},
 | 
				
			||||||
			"SELECT * FROM `users` FOR UPDATE", nil,
 | 
								"SELECT * FROM `users` FOR UPDATE", nil,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthShare, Table: clause.Table{Name: clause.CurrentTable}}},
 | 
								[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}},
 | 
				
			||||||
			"SELECT * FROM `users` FOR SHARE OF `users`", nil,
 | 
								"SELECT * FROM `users` FOR SHARE OF `users`", nil,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsNoWait}},
 | 
								[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}, clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}},
 | 
				
			||||||
			"SELECT * FROM `users` FOR UPDATE NOWAIT", nil,
 | 
								"SELECT * FROM `users` FOR UPDATE NOWAIT", nil,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsSkipLocked}},
 | 
					 | 
				
			||||||
			"SELECT * FROM `users` FOR UPDATE SKIP LOCKED", nil,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for idx, result := range results {
 | 
						for idx, result := range results {
 | 
				
			||||||
 | 
				
			|||||||
@ -26,12 +26,9 @@ func (returning Returning) Build(builder Builder) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// MergeClause merge order by clauses
 | 
					// MergeClause merge order by clauses
 | 
				
			||||||
func (returning Returning) MergeClause(clause *Clause) {
 | 
					func (returning Returning) MergeClause(clause *Clause) {
 | 
				
			||||||
	if v, ok := clause.Expression.(Returning); ok && len(returning.Columns) > 0 {
 | 
						if v, ok := clause.Expression.(Returning); ok {
 | 
				
			||||||
		if v.Columns != nil {
 | 
							returning.Columns = append(v.Columns, returning.Columns...)
 | 
				
			||||||
			returning.Columns = append(v.Columns, returning.Columns...)
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			returning.Columns = nil
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	clause.Expression = returning
 | 
						clause.Expression = returning
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -26,22 +26,6 @@ func TestReturning(t *testing.T) {
 | 
				
			|||||||
			}},
 | 
								}},
 | 
				
			||||||
			"SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil,
 | 
								"SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
 | 
					 | 
				
			||||||
				[]clause.Column{clause.PrimaryColumn},
 | 
					 | 
				
			||||||
			}, clause.Returning{}, clause.Returning{
 | 
					 | 
				
			||||||
				[]clause.Column{{Name: "name"}, {Name: "age"}},
 | 
					 | 
				
			||||||
			}},
 | 
					 | 
				
			||||||
			"SELECT * FROM `users` RETURNING *", nil,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
 | 
					 | 
				
			||||||
				[]clause.Column{clause.PrimaryColumn},
 | 
					 | 
				
			||||||
			}, clause.Returning{
 | 
					 | 
				
			||||||
				[]clause.Column{{Name: "name"}, {Name: "age"}},
 | 
					 | 
				
			||||||
			}, clause.Returning{}},
 | 
					 | 
				
			||||||
			"SELECT * FROM `users` RETURNING *", nil,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for idx, result := range results {
 | 
						for idx, result := range results {
 | 
				
			||||||
 | 
				
			|||||||
@ -21,12 +21,6 @@ func (where Where) Name() string {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// Build build where clause
 | 
					// Build build where clause
 | 
				
			||||||
func (where Where) Build(builder Builder) {
 | 
					func (where Where) Build(builder Builder) {
 | 
				
			||||||
	if len(where.Exprs) == 1 {
 | 
					 | 
				
			||||||
		if andCondition, ok := where.Exprs[0].(AndConditions); ok {
 | 
					 | 
				
			||||||
			where.Exprs = andCondition.Exprs
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Switch position if the first query expression is a single Or condition
 | 
						// Switch position if the first query expression is a single Or condition
 | 
				
			||||||
	for idx, expr := range where.Exprs {
 | 
						for idx, expr := range where.Exprs {
 | 
				
			||||||
		if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
 | 
							if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
 | 
				
			||||||
@ -153,11 +147,6 @@ func Not(exprs ...Expression) Expression {
 | 
				
			|||||||
	if len(exprs) == 0 {
 | 
						if len(exprs) == 0 {
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if len(exprs) == 1 {
 | 
					 | 
				
			||||||
		if andCondition, ok := exprs[0].(AndConditions); ok {
 | 
					 | 
				
			||||||
			exprs = andCondition.Exprs
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return NotConditions{Exprs: exprs}
 | 
						return NotConditions{Exprs: exprs}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -166,63 +155,19 @@ type NotConditions struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (not NotConditions) Build(builder Builder) {
 | 
					func (not NotConditions) Build(builder Builder) {
 | 
				
			||||||
	anyNegationBuilder := false
 | 
						if len(not.Exprs) > 1 {
 | 
				
			||||||
	for _, c := range not.Exprs {
 | 
							builder.WriteByte('(')
 | 
				
			||||||
		if _, ok := c.(NegationExpressionBuilder); ok {
 | 
					 | 
				
			||||||
			anyNegationBuilder = true
 | 
					 | 
				
			||||||
			break
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if anyNegationBuilder {
 | 
						for idx, c := range not.Exprs {
 | 
				
			||||||
		if len(not.Exprs) > 1 {
 | 
							if idx > 0 {
 | 
				
			||||||
			builder.WriteByte('(')
 | 
								builder.WriteString(AndWithSpace)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		for idx, c := range not.Exprs {
 | 
							if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
 | 
				
			||||||
			if idx > 0 {
 | 
								negationBuilder.NegationBuild(builder)
 | 
				
			||||||
				builder.WriteString(AndWithSpace)
 | 
							} else {
 | 
				
			||||||
			}
 | 
								builder.WriteString("NOT ")
 | 
				
			||||||
 | 
					 | 
				
			||||||
			if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
 | 
					 | 
				
			||||||
				negationBuilder.NegationBuild(builder)
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				builder.WriteString("NOT ")
 | 
					 | 
				
			||||||
				e, wrapInParentheses := c.(Expr)
 | 
					 | 
				
			||||||
				if wrapInParentheses {
 | 
					 | 
				
			||||||
					sql := strings.ToUpper(e.SQL)
 | 
					 | 
				
			||||||
					if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
 | 
					 | 
				
			||||||
						builder.WriteByte('(')
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				c.Build(builder)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				if wrapInParentheses {
 | 
					 | 
				
			||||||
					builder.WriteByte(')')
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if len(not.Exprs) > 1 {
 | 
					 | 
				
			||||||
			builder.WriteByte(')')
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		builder.WriteString("NOT ")
 | 
					 | 
				
			||||||
		if len(not.Exprs) > 1 {
 | 
					 | 
				
			||||||
			builder.WriteByte('(')
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		for idx, c := range not.Exprs {
 | 
					 | 
				
			||||||
			if idx > 0 {
 | 
					 | 
				
			||||||
				switch c.(type) {
 | 
					 | 
				
			||||||
				case OrConditions:
 | 
					 | 
				
			||||||
					builder.WriteString(OrWithSpace)
 | 
					 | 
				
			||||||
				default:
 | 
					 | 
				
			||||||
					builder.WriteString(AndWithSpace)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			e, wrapInParentheses := c.(Expr)
 | 
								e, wrapInParentheses := c.(Expr)
 | 
				
			||||||
			if wrapInParentheses {
 | 
								if wrapInParentheses {
 | 
				
			||||||
				sql := strings.ToUpper(e.SQL)
 | 
									sql := strings.ToUpper(e.SQL)
 | 
				
			||||||
@ -237,9 +182,9 @@ func (not NotConditions) Build(builder Builder) {
 | 
				
			|||||||
				builder.WriteByte(')')
 | 
									builder.WriteByte(')')
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if len(not.Exprs) > 1 {
 | 
						if len(not.Exprs) > 1 {
 | 
				
			||||||
			builder.WriteByte(')')
 | 
							builder.WriteByte(')')
 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -63,7 +63,7 @@ func TestWhere(t *testing.T) {
 | 
				
			|||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
 | 
								[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
 | 
				
			||||||
				Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))},
 | 
									Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))},
 | 
				
			||||||
			}},
 | 
								}},
 | 
				
			||||||
			"SELECT * FROM `users` WHERE `age` = ? OR `name` <> ?",
 | 
								"SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)",
 | 
				
			||||||
			[]interface{}{18, "jinzhu"},
 | 
								[]interface{}{18, "jinzhu"},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
@ -94,7 +94,7 @@ func TestWhere(t *testing.T) {
 | 
				
			|||||||
						clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})),
 | 
											clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})),
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
			}},
 | 
								}},
 | 
				
			||||||
			"SELECT * FROM `users` WHERE `users`.`id` <> ? AND `score` <= ?",
 | 
								"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)",
 | 
				
			||||||
			[]interface{}{"1", 100},
 | 
								[]interface{}{"1", 100},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
@ -105,30 +105,6 @@ func TestWhere(t *testing.T) {
 | 
				
			|||||||
			"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)",
 | 
								"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)",
 | 
				
			||||||
			[]interface{}{"1", 100},
 | 
								[]interface{}{"1", 100},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
 | 
					 | 
				
			||||||
				Exprs: []clause.Expression{clause.Not(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}},
 | 
					 | 
				
			||||||
					clause.Expr{SQL: "`age` <= ?", Vars: []interface{}{60}})},
 | 
					 | 
				
			||||||
			}},
 | 
					 | 
				
			||||||
			"SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)",
 | 
					 | 
				
			||||||
			[]interface{}{100, 60},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
 | 
					 | 
				
			||||||
				Exprs: []clause.Expression{
 | 
					 | 
				
			||||||
					clause.Not(clause.AndConditions{
 | 
					 | 
				
			||||||
						Exprs: []clause.Expression{
 | 
					 | 
				
			||||||
							clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
 | 
					 | 
				
			||||||
							clause.Gt{Column: "age", Value: 18},
 | 
					 | 
				
			||||||
						}}, clause.OrConditions{
 | 
					 | 
				
			||||||
						Exprs: []clause.Expression{
 | 
					 | 
				
			||||||
							clause.Lt{Column: "score", Value: 100},
 | 
					 | 
				
			||||||
						},
 | 
					 | 
				
			||||||
					}),
 | 
					 | 
				
			||||||
				}}},
 | 
					 | 
				
			||||||
			"SELECT * FROM `users` WHERE NOT ((`users`.`id` = ? AND `age` > ?) OR `score` < ?)",
 | 
					 | 
				
			||||||
			[]interface{}{"1", 18, 100},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for idx, result := range results {
 | 
						for idx, result := range results {
 | 
				
			||||||
 | 
				
			|||||||
@ -49,6 +49,4 @@ var (
 | 
				
			|||||||
	ErrDuplicatedKey = errors.New("duplicated key not allowed")
 | 
						ErrDuplicatedKey = errors.New("duplicated key not allowed")
 | 
				
			||||||
	// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
 | 
						// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
 | 
				
			||||||
	ErrForeignKeyViolated = errors.New("violates foreign key constraint")
 | 
						ErrForeignKeyViolated = errors.New("violates foreign key constraint")
 | 
				
			||||||
	// ErrCheckConstraintViolated occurs when there is a check constraint violation
 | 
					 | 
				
			||||||
	ErrCheckConstraintViolated = errors.New("violates check constraint")
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,11 +1,9 @@
 | 
				
			|||||||
package gorm
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
					 | 
				
			||||||
	"database/sql"
 | 
						"database/sql"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"hash/maphash"
 | 
					 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -378,12 +376,8 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
				
			|||||||
	} else if len(db.Statement.assigns) > 0 {
 | 
						} else if len(db.Statement.assigns) > 0 {
 | 
				
			||||||
		exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
 | 
							exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
 | 
				
			||||||
		assigns := map[string]interface{}{}
 | 
							assigns := map[string]interface{}{}
 | 
				
			||||||
		for i := 0; i < len(exprs); i++ {
 | 
							for _, expr := range exprs {
 | 
				
			||||||
			expr := exprs[i]
 | 
								if eq, ok := expr.(clause.Eq); ok {
 | 
				
			||||||
 | 
					 | 
				
			||||||
			if eq, ok := expr.(clause.AndConditions); ok {
 | 
					 | 
				
			||||||
				exprs = append(exprs, eq.Exprs...)
 | 
					 | 
				
			||||||
			} else if eq, ok := expr.(clause.Eq); ok {
 | 
					 | 
				
			||||||
				switch column := eq.Column.(type) {
 | 
									switch column := eq.Column.(type) {
 | 
				
			||||||
				case string:
 | 
									case string:
 | 
				
			||||||
					assigns[column] = eq.Value
 | 
										assigns[column] = eq.Value
 | 
				
			||||||
@ -625,15 +619,14 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
 | 
				
			|||||||
	if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
 | 
						if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
 | 
				
			||||||
		// nested transaction
 | 
							// nested transaction
 | 
				
			||||||
		if !db.DisableNestedTransaction {
 | 
							if !db.DisableNestedTransaction {
 | 
				
			||||||
			spID := new(maphash.Hash).Sum64()
 | 
								err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
 | 
				
			||||||
			err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error
 | 
					 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			defer func() {
 | 
								defer func() {
 | 
				
			||||||
				// Make sure to rollback when panic, Block error or Commit error
 | 
									// Make sure to rollback when panic, Block error or Commit error
 | 
				
			||||||
				if panicked || err != nil {
 | 
									if panicked || err != nil {
 | 
				
			||||||
					db.RollbackTo(fmt.Sprintf("sp%d", spID))
 | 
										db.RollbackTo(fmt.Sprintf("sp%p", fc))
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}()
 | 
								}()
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -674,18 +667,11 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
 | 
				
			|||||||
		opt = opts[0]
 | 
							opt = opts[0]
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ctx := tx.Statement.Context
 | 
					 | 
				
			||||||
	if _, ok := ctx.Deadline(); !ok {
 | 
					 | 
				
			||||||
		if db.Config.DefaultTransactionTimeout > 0 {
 | 
					 | 
				
			||||||
			ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	switch beginner := tx.Statement.ConnPool.(type) {
 | 
						switch beginner := tx.Statement.ConnPool.(type) {
 | 
				
			||||||
	case TxBeginner:
 | 
						case TxBeginner:
 | 
				
			||||||
		tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
 | 
							tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
 | 
				
			||||||
	case ConnPoolBeginner:
 | 
						case ConnPoolBeginner:
 | 
				
			||||||
		tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
 | 
							tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		err = ErrInvalidTransaction
 | 
							err = ErrInvalidTransaction
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										605
									
								
								generics.go
									
									
									
									
									
								
							
							
						
						
									
										605
									
								
								generics.go
									
									
									
									
									
								
							@ -1,605 +0,0 @@
 | 
				
			|||||||
package gorm
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"context"
 | 
					 | 
				
			||||||
	"database/sql"
 | 
					 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"sort"
 | 
					 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"gorm.io/gorm/clause"
 | 
					 | 
				
			||||||
	"gorm.io/gorm/logger"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type result struct {
 | 
					 | 
				
			||||||
	Result       sql.Result
 | 
					 | 
				
			||||||
	RowsAffected int64
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (info *result) ModifyStatement(stmt *Statement) {
 | 
					 | 
				
			||||||
	stmt.Result = info
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Build implements clause.Expression interface
 | 
					 | 
				
			||||||
func (result) Build(clause.Builder) {
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func WithResult() *result {
 | 
					 | 
				
			||||||
	return &result{}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type Interface[T any] interface {
 | 
					 | 
				
			||||||
	Raw(sql string, values ...interface{}) ExecInterface[T]
 | 
					 | 
				
			||||||
	Exec(ctx context.Context, sql string, values ...interface{}) error
 | 
					 | 
				
			||||||
	CreateInterface[T]
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type CreateInterface[T any] interface {
 | 
					 | 
				
			||||||
	ChainInterface[T]
 | 
					 | 
				
			||||||
	Table(name string, args ...interface{}) CreateInterface[T]
 | 
					 | 
				
			||||||
	Create(ctx context.Context, r *T) error
 | 
					 | 
				
			||||||
	CreateInBatches(ctx context.Context, r *[]T, batchSize int) error
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ChainInterface[T any] interface {
 | 
					 | 
				
			||||||
	ExecInterface[T]
 | 
					 | 
				
			||||||
	Scopes(scopes ...func(db *Statement)) ChainInterface[T]
 | 
					 | 
				
			||||||
	Where(query interface{}, args ...interface{}) ChainInterface[T]
 | 
					 | 
				
			||||||
	Not(query interface{}, args ...interface{}) ChainInterface[T]
 | 
					 | 
				
			||||||
	Or(query interface{}, args ...interface{}) ChainInterface[T]
 | 
					 | 
				
			||||||
	Limit(offset int) ChainInterface[T]
 | 
					 | 
				
			||||||
	Offset(offset int) ChainInterface[T]
 | 
					 | 
				
			||||||
	Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
 | 
					 | 
				
			||||||
	Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T]
 | 
					 | 
				
			||||||
	Select(query string, args ...interface{}) ChainInterface[T]
 | 
					 | 
				
			||||||
	Omit(columns ...string) ChainInterface[T]
 | 
					 | 
				
			||||||
	MapColumns(m map[string]string) ChainInterface[T]
 | 
					 | 
				
			||||||
	Distinct(args ...interface{}) ChainInterface[T]
 | 
					 | 
				
			||||||
	Group(name string) ChainInterface[T]
 | 
					 | 
				
			||||||
	Having(query interface{}, args ...interface{}) ChainInterface[T]
 | 
					 | 
				
			||||||
	Order(value interface{}) ChainInterface[T]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	Build(builder clause.Builder)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	Delete(ctx context.Context) (rowsAffected int, err error)
 | 
					 | 
				
			||||||
	Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
 | 
					 | 
				
			||||||
	Updates(ctx context.Context, t T) (rowsAffected int, err error)
 | 
					 | 
				
			||||||
	Count(ctx context.Context, column string) (result int64, err error)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ExecInterface[T any] interface {
 | 
					 | 
				
			||||||
	Scan(ctx context.Context, r interface{}) error
 | 
					 | 
				
			||||||
	First(context.Context) (T, error)
 | 
					 | 
				
			||||||
	Last(ctx context.Context) (T, error)
 | 
					 | 
				
			||||||
	Take(context.Context) (T, error)
 | 
					 | 
				
			||||||
	Find(ctx context.Context) ([]T, error)
 | 
					 | 
				
			||||||
	FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error
 | 
					 | 
				
			||||||
	Row(ctx context.Context) *sql.Row
 | 
					 | 
				
			||||||
	Rows(ctx context.Context) (*sql.Rows, error)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type JoinBuilder interface {
 | 
					 | 
				
			||||||
	Select(...string) JoinBuilder
 | 
					 | 
				
			||||||
	Omit(...string) JoinBuilder
 | 
					 | 
				
			||||||
	Where(query interface{}, args ...interface{}) JoinBuilder
 | 
					 | 
				
			||||||
	Not(query interface{}, args ...interface{}) JoinBuilder
 | 
					 | 
				
			||||||
	Or(query interface{}, args ...interface{}) JoinBuilder
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type PreloadBuilder interface {
 | 
					 | 
				
			||||||
	Select(...string) PreloadBuilder
 | 
					 | 
				
			||||||
	Omit(...string) PreloadBuilder
 | 
					 | 
				
			||||||
	Where(query interface{}, args ...interface{}) PreloadBuilder
 | 
					 | 
				
			||||||
	Not(query interface{}, args ...interface{}) PreloadBuilder
 | 
					 | 
				
			||||||
	Or(query interface{}, args ...interface{}) PreloadBuilder
 | 
					 | 
				
			||||||
	Limit(offset int) PreloadBuilder
 | 
					 | 
				
			||||||
	Offset(offset int) PreloadBuilder
 | 
					 | 
				
			||||||
	Order(value interface{}) PreloadBuilder
 | 
					 | 
				
			||||||
	LimitPerRecord(num int) PreloadBuilder
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type op func(*DB) *DB
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
 | 
					 | 
				
			||||||
	v := &g[T]{
 | 
					 | 
				
			||||||
		db:  db,
 | 
					 | 
				
			||||||
		ops: make([]op, 0, 5),
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(opts) > 0 {
 | 
					 | 
				
			||||||
		v.ops = append(v.ops, func(db *DB) *DB {
 | 
					 | 
				
			||||||
			return db.Clauses(opts...)
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	v.createG = &createG[T]{
 | 
					 | 
				
			||||||
		chainG: chainG[T]{
 | 
					 | 
				
			||||||
			execG: execG[T]{g: v},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return v
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type g[T any] struct {
 | 
					 | 
				
			||||||
	*createG[T]
 | 
					 | 
				
			||||||
	db  *DB
 | 
					 | 
				
			||||||
	ops []op
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (g *g[T]) apply(ctx context.Context) *DB {
 | 
					 | 
				
			||||||
	db := g.db
 | 
					 | 
				
			||||||
	if !db.DryRun {
 | 
					 | 
				
			||||||
		db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, op := range g.ops {
 | 
					 | 
				
			||||||
		db = op(db)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return db
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] {
 | 
					 | 
				
			||||||
	return execG[T]{g: &g[T]{
 | 
					 | 
				
			||||||
		db: c.db,
 | 
					 | 
				
			||||||
		ops: append(c.ops, func(db *DB) *DB {
 | 
					 | 
				
			||||||
			return db.Raw(sql, values...)
 | 
					 | 
				
			||||||
		}),
 | 
					 | 
				
			||||||
	}}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error {
 | 
					 | 
				
			||||||
	return c.apply(ctx).Exec(sql, values...).Error
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type createG[T any] struct {
 | 
					 | 
				
			||||||
	chainG[T]
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] {
 | 
					 | 
				
			||||||
	return createG[T]{c.with(func(db *DB) *DB {
 | 
					 | 
				
			||||||
		return db.Table(name, args...)
 | 
					 | 
				
			||||||
	})}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c createG[T]) Create(ctx context.Context, r *T) error {
 | 
					 | 
				
			||||||
	return c.g.apply(ctx).Create(r).Error
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error {
 | 
					 | 
				
			||||||
	return c.g.apply(ctx).CreateInBatches(r, batchSize).Error
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type chainG[T any] struct {
 | 
					 | 
				
			||||||
	execG[T]
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) getInstance() *DB {
 | 
					 | 
				
			||||||
	var r T
 | 
					 | 
				
			||||||
	return c.g.apply(context.Background()).Model(r).getInstance()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) with(v op) chainG[T] {
 | 
					 | 
				
			||||||
	return chainG[T]{
 | 
					 | 
				
			||||||
		execG: execG[T]{g: &g[T]{
 | 
					 | 
				
			||||||
			db:  c.g.db,
 | 
					 | 
				
			||||||
			ops: append(append([]op(nil), c.g.ops...), v),
 | 
					 | 
				
			||||||
		}},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] {
 | 
					 | 
				
			||||||
	return c.with(func(db *DB) *DB {
 | 
					 | 
				
			||||||
		for _, fc := range scopes {
 | 
					 | 
				
			||||||
			fc(db.Statement)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return db
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] {
 | 
					 | 
				
			||||||
	return c.with(func(db *DB) *DB {
 | 
					 | 
				
			||||||
		return db.Table(name, args...)
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] {
 | 
					 | 
				
			||||||
	return c.with(func(db *DB) *DB {
 | 
					 | 
				
			||||||
		return db.Where(query, args...)
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] {
 | 
					 | 
				
			||||||
	return c.with(func(db *DB) *DB {
 | 
					 | 
				
			||||||
		return db.Not(query, args...)
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] {
 | 
					 | 
				
			||||||
	return c.with(func(db *DB) *DB {
 | 
					 | 
				
			||||||
		return db.Or(query, args...)
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Limit(offset int) ChainInterface[T] {
 | 
					 | 
				
			||||||
	return c.with(func(db *DB) *DB {
 | 
					 | 
				
			||||||
		return db.Limit(offset)
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Offset(offset int) ChainInterface[T] {
 | 
					 | 
				
			||||||
	return c.with(func(db *DB) *DB {
 | 
					 | 
				
			||||||
		return db.Offset(offset)
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type joinBuilder struct {
 | 
					 | 
				
			||||||
	db *DB
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
 | 
					 | 
				
			||||||
	q.db.Where(query, args...)
 | 
					 | 
				
			||||||
	return q
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
 | 
					 | 
				
			||||||
	q.db.Where(query, args...)
 | 
					 | 
				
			||||||
	return q
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
 | 
					 | 
				
			||||||
	q.db.Where(query, args...)
 | 
					 | 
				
			||||||
	return q
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (q *joinBuilder) Select(columns ...string) JoinBuilder {
 | 
					 | 
				
			||||||
	q.db.Select(columns)
 | 
					 | 
				
			||||||
	return q
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (q *joinBuilder) Omit(columns ...string) JoinBuilder {
 | 
					 | 
				
			||||||
	q.db.Omit(columns...)
 | 
					 | 
				
			||||||
	return q
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type preloadBuilder struct {
 | 
					 | 
				
			||||||
	limitPerRecord int
 | 
					 | 
				
			||||||
	db             *DB
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
 | 
					 | 
				
			||||||
	q.db.Where(query, args...)
 | 
					 | 
				
			||||||
	return q
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
 | 
					 | 
				
			||||||
	q.db.Where(query, args...)
 | 
					 | 
				
			||||||
	return q
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
 | 
					 | 
				
			||||||
	q.db.Where(query, args...)
 | 
					 | 
				
			||||||
	return q
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (q *preloadBuilder) Select(columns ...string) PreloadBuilder {
 | 
					 | 
				
			||||||
	q.db.Select(columns)
 | 
					 | 
				
			||||||
	return q
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder {
 | 
					 | 
				
			||||||
	q.db.Omit(columns...)
 | 
					 | 
				
			||||||
	return q
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (q *preloadBuilder) Limit(limit int) PreloadBuilder {
 | 
					 | 
				
			||||||
	q.db.Limit(limit)
 | 
					 | 
				
			||||||
	return q
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (q *preloadBuilder) Offset(offset int) PreloadBuilder {
 | 
					 | 
				
			||||||
	q.db.Offset(offset)
 | 
					 | 
				
			||||||
	return q
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (q *preloadBuilder) Order(value interface{}) PreloadBuilder {
 | 
					 | 
				
			||||||
	q.db.Order(value)
 | 
					 | 
				
			||||||
	return q
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder {
 | 
					 | 
				
			||||||
	q.limitPerRecord = num
 | 
					 | 
				
			||||||
	return q
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] {
 | 
					 | 
				
			||||||
	return c.with(func(db *DB) *DB {
 | 
					 | 
				
			||||||
		if jt.Table == "" {
 | 
					 | 
				
			||||||
			jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)}
 | 
					 | 
				
			||||||
		if on != nil {
 | 
					 | 
				
			||||||
			if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
 | 
					 | 
				
			||||||
				db.AddError(err)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		j := join{
 | 
					 | 
				
			||||||
			Name:     jt.Association,
 | 
					 | 
				
			||||||
			Alias:    jt.Table,
 | 
					 | 
				
			||||||
			Selects:  q.db.Statement.Selects,
 | 
					 | 
				
			||||||
			Omits:    q.db.Statement.Omits,
 | 
					 | 
				
			||||||
			JoinType: jt.Type,
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
 | 
					 | 
				
			||||||
			j.On = &where
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if jt.Subquery != nil {
 | 
					 | 
				
			||||||
			joinType := j.JoinType
 | 
					 | 
				
			||||||
			if joinType == "" {
 | 
					 | 
				
			||||||
				joinType = clause.LeftJoin
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if db, ok := jt.Subquery.(interface{ getInstance() *DB }); ok {
 | 
					 | 
				
			||||||
				stmt := db.getInstance().Statement
 | 
					 | 
				
			||||||
				if len(j.Selects) == 0 {
 | 
					 | 
				
			||||||
					j.Selects = stmt.Selects
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				if len(j.Omits) == 0 {
 | 
					 | 
				
			||||||
					j.Omits = stmt.Omits
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if j.On != nil {
 | 
					 | 
				
			||||||
				expr.SQL += " ON ?"
 | 
					 | 
				
			||||||
				expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs})
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			j.Expression = expr
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		db.Statement.Joins = append(db.Statement.Joins, j)
 | 
					 | 
				
			||||||
		sort.Slice(db.Statement.Joins, func(i, j int) bool {
 | 
					 | 
				
			||||||
			return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
		return db
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
 | 
					 | 
				
			||||||
	return c.with(func(db *DB) *DB {
 | 
					 | 
				
			||||||
		return db.Select(query, args...)
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Omit(columns ...string) ChainInterface[T] {
 | 
					 | 
				
			||||||
	return c.with(func(db *DB) *DB {
 | 
					 | 
				
			||||||
		return db.Omit(columns...)
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] {
 | 
					 | 
				
			||||||
	return c.with(func(db *DB) *DB {
 | 
					 | 
				
			||||||
		return db.MapColumns(m)
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] {
 | 
					 | 
				
			||||||
	return c.with(func(db *DB) *DB {
 | 
					 | 
				
			||||||
		return db.Distinct(args...)
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Group(name string) ChainInterface[T] {
 | 
					 | 
				
			||||||
	return c.with(func(db *DB) *DB {
 | 
					 | 
				
			||||||
		return db.Group(name)
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] {
 | 
					 | 
				
			||||||
	return c.with(func(db *DB) *DB {
 | 
					 | 
				
			||||||
		return db.Having(query, args...)
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Order(value interface{}) ChainInterface[T] {
 | 
					 | 
				
			||||||
	return c.with(func(db *DB) *DB {
 | 
					 | 
				
			||||||
		return db.Order(value)
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] {
 | 
					 | 
				
			||||||
	return c.with(func(db *DB) *DB {
 | 
					 | 
				
			||||||
		return db.Preload(association, func(tx *DB) *DB {
 | 
					 | 
				
			||||||
			q := preloadBuilder{db: tx.getInstance()}
 | 
					 | 
				
			||||||
			if query != nil {
 | 
					 | 
				
			||||||
				if err := query(&q); err != nil {
 | 
					 | 
				
			||||||
					db.AddError(err)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			relation, ok := db.Statement.Schema.Relationships.Relations[association]
 | 
					 | 
				
			||||||
			if !ok {
 | 
					 | 
				
			||||||
				if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 {
 | 
					 | 
				
			||||||
					relationships := db.Statement.Schema.Relationships
 | 
					 | 
				
			||||||
					for _, field := range preloadFields {
 | 
					 | 
				
			||||||
						var ok bool
 | 
					 | 
				
			||||||
						relation, ok = relationships.Relations[field]
 | 
					 | 
				
			||||||
						if ok {
 | 
					 | 
				
			||||||
							relationships = relation.FieldSchema.Relationships
 | 
					 | 
				
			||||||
						} else {
 | 
					 | 
				
			||||||
							db.AddError(fmt.Errorf("relation %s not found", association))
 | 
					 | 
				
			||||||
							return nil
 | 
					 | 
				
			||||||
						}
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					db.AddError(fmt.Errorf("relation %s not found", association))
 | 
					 | 
				
			||||||
					return nil
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if q.limitPerRecord > 0 {
 | 
					 | 
				
			||||||
				if relation.JoinTable != nil {
 | 
					 | 
				
			||||||
					tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association))
 | 
					 | 
				
			||||||
					return tx
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				refColumns := []clause.Column{}
 | 
					 | 
				
			||||||
				for _, rel := range relation.References {
 | 
					 | 
				
			||||||
					if rel.OwnPrimaryKey {
 | 
					 | 
				
			||||||
						refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName})
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				if len(refColumns) != 0 {
 | 
					 | 
				
			||||||
					selectExpr := clause.CommaExpression{}
 | 
					 | 
				
			||||||
					for _, column := range q.db.Statement.Selects {
 | 
					 | 
				
			||||||
						selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}})
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
					if len(selectExpr.Exprs) == 0 {
 | 
					 | 
				
			||||||
						selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}}
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
					partitionBy := clause.CommaExpression{}
 | 
					 | 
				
			||||||
					for _, column := range refColumns {
 | 
					 | 
				
			||||||
						partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}})
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
					rnnColumn := clause.Column{Name: "gorm_preload_rnn"}
 | 
					 | 
				
			||||||
					sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)"
 | 
					 | 
				
			||||||
					vars := []interface{}{partitionBy}
 | 
					 | 
				
			||||||
					if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok {
 | 
					 | 
				
			||||||
						vars = append(vars, orderBy)
 | 
					 | 
				
			||||||
					} else {
 | 
					 | 
				
			||||||
						vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{
 | 
					 | 
				
			||||||
							Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
 | 
					 | 
				
			||||||
						}})
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
					vars = append(vars, rnnColumn)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
					selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
					q.db.Clauses(clause.Select{Expression: selectExpr})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
					return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			return q.db
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) {
 | 
					 | 
				
			||||||
	r := new(T)
 | 
					 | 
				
			||||||
	res := c.g.apply(ctx).Delete(r)
 | 
					 | 
				
			||||||
	return int(res.RowsAffected), res.Error
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) {
 | 
					 | 
				
			||||||
	var r T
 | 
					 | 
				
			||||||
	res := c.g.apply(ctx).Model(r).Update(name, value)
 | 
					 | 
				
			||||||
	return int(res.RowsAffected), res.Error
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) {
 | 
					 | 
				
			||||||
	res := c.g.apply(ctx).Updates(t)
 | 
					 | 
				
			||||||
	return int(res.RowsAffected), res.Error
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) {
 | 
					 | 
				
			||||||
	var r T
 | 
					 | 
				
			||||||
	err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c chainG[T]) Build(builder clause.Builder) {
 | 
					 | 
				
			||||||
	subdb := c.getInstance()
 | 
					 | 
				
			||||||
	subdb.Logger = logger.Discard
 | 
					 | 
				
			||||||
	subdb.DryRun = true
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if stmt, ok := builder.(*Statement); ok {
 | 
					 | 
				
			||||||
		if subdb.Statement.SQL.Len() > 0 {
 | 
					 | 
				
			||||||
			var (
 | 
					 | 
				
			||||||
				vars = subdb.Statement.Vars
 | 
					 | 
				
			||||||
				sql  = subdb.Statement.SQL.String()
 | 
					 | 
				
			||||||
			)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			subdb.Statement.Vars = make([]interface{}, 0, len(vars))
 | 
					 | 
				
			||||||
			for _, vv := range vars {
 | 
					 | 
				
			||||||
				subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
 | 
					 | 
				
			||||||
				bindvar := strings.Builder{}
 | 
					 | 
				
			||||||
				subdb.BindVarTo(&bindvar, subdb.Statement, vv)
 | 
					 | 
				
			||||||
				sql = strings.Replace(sql, bindvar.String(), "?", 1)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			subdb.Statement.SQL.Reset()
 | 
					 | 
				
			||||||
			subdb.Statement.Vars = stmt.Vars
 | 
					 | 
				
			||||||
			if strings.Contains(sql, "@") {
 | 
					 | 
				
			||||||
				clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
 | 
					 | 
				
			||||||
			subdb.callbacks.Query().Execute(subdb)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		builder.WriteString(subdb.Statement.SQL.String())
 | 
					 | 
				
			||||||
		stmt.Vars = subdb.Statement.Vars
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type execG[T any] struct {
 | 
					 | 
				
			||||||
	g *g[T]
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (g execG[T]) First(ctx context.Context) (T, error) {
 | 
					 | 
				
			||||||
	var r T
 | 
					 | 
				
			||||||
	err := g.g.apply(ctx).First(&r).Error
 | 
					 | 
				
			||||||
	return r, err
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (g execG[T]) Scan(ctx context.Context, result interface{}) error {
 | 
					 | 
				
			||||||
	var r T
 | 
					 | 
				
			||||||
	err := g.g.apply(ctx).Model(r).Find(result).Error
 | 
					 | 
				
			||||||
	return err
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (g execG[T]) Last(ctx context.Context) (T, error) {
 | 
					 | 
				
			||||||
	var r T
 | 
					 | 
				
			||||||
	err := g.g.apply(ctx).Last(&r).Error
 | 
					 | 
				
			||||||
	return r, err
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (g execG[T]) Take(ctx context.Context) (T, error) {
 | 
					 | 
				
			||||||
	var r T
 | 
					 | 
				
			||||||
	err := g.g.apply(ctx).Take(&r).Error
 | 
					 | 
				
			||||||
	return r, err
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (g execG[T]) Find(ctx context.Context) ([]T, error) {
 | 
					 | 
				
			||||||
	var r []T
 | 
					 | 
				
			||||||
	err := g.g.apply(ctx).Find(&r).Error
 | 
					 | 
				
			||||||
	return r, err
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error {
 | 
					 | 
				
			||||||
	var data []T
 | 
					 | 
				
			||||||
	return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error {
 | 
					 | 
				
			||||||
		return fc(data, batch)
 | 
					 | 
				
			||||||
	}).Error
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (g execG[T]) Row(ctx context.Context) *sql.Row {
 | 
					 | 
				
			||||||
	return g.g.apply(ctx).Row()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) {
 | 
					 | 
				
			||||||
	return g.g.apply(ctx).Rows()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
							
								
								
									
										3
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								go.mod
									
									
									
									
									
								
							@ -1,9 +1,8 @@
 | 
				
			|||||||
module gorm.io/gorm
 | 
					module gorm.io/gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
go 1.18
 | 
					go 1.16
 | 
				
			||||||
 | 
					
 | 
				
			||||||
require (
 | 
					require (
 | 
				
			||||||
	github.com/jinzhu/inflection v1.0.0
 | 
						github.com/jinzhu/inflection v1.0.0
 | 
				
			||||||
	github.com/jinzhu/now v1.1.5
 | 
						github.com/jinzhu/now v1.1.5
 | 
				
			||||||
	golang.org/x/text v0.20.0
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.sum
									
									
									
									
									
								
							@ -2,5 +2,3 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
 | 
				
			|||||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
 | 
					github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
 | 
				
			||||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
 | 
					github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
 | 
				
			||||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
 | 
					github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
 | 
				
			||||||
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
 | 
					 | 
				
			||||||
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
 | 
					 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										50
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										50
									
								
								gorm.go
									
									
									
									
									
								
							@ -21,9 +21,7 @@ const preparedStmtDBKey = "preparedStmt"
 | 
				
			|||||||
type Config struct {
 | 
					type Config struct {
 | 
				
			||||||
	// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
 | 
						// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
 | 
				
			||||||
	// You can disable it by setting `SkipDefaultTransaction` to true
 | 
						// You can disable it by setting `SkipDefaultTransaction` to true
 | 
				
			||||||
	SkipDefaultTransaction    bool
 | 
						SkipDefaultTransaction bool
 | 
				
			||||||
	DefaultTransactionTimeout time.Duration
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// NamingStrategy tables, columns naming strategy
 | 
						// NamingStrategy tables, columns naming strategy
 | 
				
			||||||
	NamingStrategy schema.Namer
 | 
						NamingStrategy schema.Namer
 | 
				
			||||||
	// FullSaveAssociations full save associations
 | 
						// FullSaveAssociations full save associations
 | 
				
			||||||
@ -36,11 +34,6 @@ type Config struct {
 | 
				
			|||||||
	DryRun bool
 | 
						DryRun bool
 | 
				
			||||||
	// PrepareStmt executes the given query in cached statement
 | 
						// PrepareStmt executes the given query in cached statement
 | 
				
			||||||
	PrepareStmt bool
 | 
						PrepareStmt bool
 | 
				
			||||||
	// PrepareStmt cache support LRU expired,
 | 
					 | 
				
			||||||
	// default maxsize=int64 Max value and ttl=1h
 | 
					 | 
				
			||||||
	PrepareStmtMaxSize int
 | 
					 | 
				
			||||||
	PrepareStmtTTL     time.Duration
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// DisableAutomaticPing
 | 
						// DisableAutomaticPing
 | 
				
			||||||
	DisableAutomaticPing bool
 | 
						DisableAutomaticPing bool
 | 
				
			||||||
	// DisableForeignKeyConstraintWhenMigrating
 | 
						// DisableForeignKeyConstraintWhenMigrating
 | 
				
			||||||
@ -57,8 +50,6 @@ type Config struct {
 | 
				
			|||||||
	CreateBatchSize int
 | 
						CreateBatchSize int
 | 
				
			||||||
	// TranslateError enabling error translation
 | 
						// TranslateError enabling error translation
 | 
				
			||||||
	TranslateError bool
 | 
						TranslateError bool
 | 
				
			||||||
	// PropagateUnscoped propagate Unscoped to every other nested statement
 | 
					 | 
				
			||||||
	PropagateUnscoped bool
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// ClauseBuilders clause builder
 | 
						// ClauseBuilders clause builder
 | 
				
			||||||
	ClauseBuilders map[string]clause.ClauseBuilder
 | 
						ClauseBuilders map[string]clause.ClauseBuilder
 | 
				
			||||||
@ -119,7 +110,6 @@ type Session struct {
 | 
				
			|||||||
	DisableNestedTransaction bool
 | 
						DisableNestedTransaction bool
 | 
				
			||||||
	AllowGlobalUpdate        bool
 | 
						AllowGlobalUpdate        bool
 | 
				
			||||||
	FullSaveAssociations     bool
 | 
						FullSaveAssociations     bool
 | 
				
			||||||
	PropagateUnscoped        bool
 | 
					 | 
				
			||||||
	QueryFields              bool
 | 
						QueryFields              bool
 | 
				
			||||||
	Context                  context.Context
 | 
						Context                  context.Context
 | 
				
			||||||
	Logger                   logger.Interface
 | 
						Logger                   logger.Interface
 | 
				
			||||||
@ -137,24 +127,12 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
 | 
				
			|||||||
		return isConfig && !isConfig2
 | 
							return isConfig && !isConfig2
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(opts) > 0 {
 | 
					 | 
				
			||||||
		if c, ok := opts[0].(*Config); ok {
 | 
					 | 
				
			||||||
			config = c
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			opts = append([]Option{config}, opts...)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var skipAfterInitialize bool
 | 
					 | 
				
			||||||
	for _, opt := range opts {
 | 
						for _, opt := range opts {
 | 
				
			||||||
		if opt != nil {
 | 
							if opt != nil {
 | 
				
			||||||
			if applyErr := opt.Apply(config); applyErr != nil {
 | 
								if applyErr := opt.Apply(config); applyErr != nil {
 | 
				
			||||||
				return nil, applyErr
 | 
									return nil, applyErr
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			defer func(opt Option) {
 | 
								defer func(opt Option) {
 | 
				
			||||||
				if skipAfterInitialize {
 | 
					 | 
				
			||||||
					return
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				if errr := opt.AfterInitialize(db); errr != nil {
 | 
									if errr := opt.AfterInitialize(db); errr != nil {
 | 
				
			||||||
					err = errr
 | 
										err = errr
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
@ -202,25 +180,16 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	if config.Dialector != nil {
 | 
						if config.Dialector != nil {
 | 
				
			||||||
		err = config.Dialector.Initialize(db)
 | 
							err = config.Dialector.Initialize(db)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			if db, _ := db.DB(); db != nil {
 | 
								if db, err := db.DB(); err == nil {
 | 
				
			||||||
				_ = db.Close()
 | 
									_ = db.Close()
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					 | 
				
			||||||
			// DB is not initialized, so we skip AfterInitialize
 | 
					 | 
				
			||||||
			skipAfterInitialize = true
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if config.TranslateError {
 | 
					 | 
				
			||||||
			if _, ok := db.Dialector.(ErrorTranslator); !ok {
 | 
					 | 
				
			||||||
				config.Logger.Warn(context.Background(), "The TranslateError option is enabled, but the Dialector %s does not implement ErrorTranslator.", db.Dialector.Name())
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if config.PrepareStmt {
 | 
						if config.PrepareStmt {
 | 
				
			||||||
		preparedStmt := NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL)
 | 
							preparedStmt := NewPreparedStmtDB(db.ConnPool)
 | 
				
			||||||
		db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
 | 
							db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
 | 
				
			||||||
		db.ConnPool = preparedStmt
 | 
							db.ConnPool = preparedStmt
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -272,10 +241,6 @@ func (db *DB) Session(config *Session) *DB {
 | 
				
			|||||||
		txConfig.FullSaveAssociations = true
 | 
							txConfig.FullSaveAssociations = true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if config.PropagateUnscoped {
 | 
					 | 
				
			||||||
		txConfig.PropagateUnscoped = true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if config.Context != nil || config.PrepareStmt || config.SkipHooks {
 | 
						if config.Context != nil || config.PrepareStmt || config.SkipHooks {
 | 
				
			||||||
		tx.Statement = tx.Statement.clone()
 | 
							tx.Statement = tx.Statement.clone()
 | 
				
			||||||
		tx.Statement.DB = tx
 | 
							tx.Statement.DB = tx
 | 
				
			||||||
@ -291,7 +256,7 @@ func (db *DB) Session(config *Session) *DB {
 | 
				
			|||||||
		if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
 | 
							if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
 | 
				
			||||||
			preparedStmt = v.(*PreparedStmtDB)
 | 
								preparedStmt = v.(*PreparedStmtDB)
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			preparedStmt = NewPreparedStmtDB(db.ConnPool, db.PrepareStmtMaxSize, db.PrepareStmtTTL)
 | 
								preparedStmt = NewPreparedStmtDB(db.ConnPool)
 | 
				
			||||||
			db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
 | 
								db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -444,9 +409,6 @@ func (db *DB) getInstance() *DB {
 | 
				
			|||||||
				Vars:      make([]interface{}, 0, 8),
 | 
									Vars:      make([]interface{}, 0, 8),
 | 
				
			||||||
				SkipHooks: db.Statement.SkipHooks,
 | 
									SkipHooks: db.Statement.SkipHooks,
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if db.Config.PropagateUnscoped {
 | 
					 | 
				
			||||||
				tx.Statement.Unscoped = db.Statement.Unscoped
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			// with clone statement
 | 
								// with clone statement
 | 
				
			||||||
			tx.Statement = db.Statement.clone()
 | 
								tx.Statement = db.Statement.clone()
 | 
				
			||||||
@ -537,7 +499,7 @@ func (db *DB) Use(plugin Plugin) error {
 | 
				
			|||||||
//				.First(&User{})
 | 
					//				.First(&User{})
 | 
				
			||||||
//	})
 | 
					//	})
 | 
				
			||||||
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
 | 
					func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
 | 
				
			||||||
	tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance())
 | 
						tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
 | 
				
			||||||
	stmt := tx.Statement
 | 
						stmt := tx.Statement
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
 | 
						return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,493 +0,0 @@
 | 
				
			|||||||
package lru
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// golang -lru
 | 
					 | 
				
			||||||
// https://github.com/hashicorp/golang-lru
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"sync"
 | 
					 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// EvictCallback is used to get a callback when a cache entry is evicted
 | 
					 | 
				
			||||||
type EvictCallback[K comparable, V any] func(key K, value V)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// LRU implements a thread-safe LRU with expirable entries.
 | 
					 | 
				
			||||||
type LRU[K comparable, V any] struct {
 | 
					 | 
				
			||||||
	size      int
 | 
					 | 
				
			||||||
	evictList *LruList[K, V]
 | 
					 | 
				
			||||||
	items     map[K]*Entry[K, V]
 | 
					 | 
				
			||||||
	onEvict   EvictCallback[K, V]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// expirable options
 | 
					 | 
				
			||||||
	mu   sync.Mutex
 | 
					 | 
				
			||||||
	ttl  time.Duration
 | 
					 | 
				
			||||||
	done chan struct{}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// buckets for expiration
 | 
					 | 
				
			||||||
	buckets []bucket[K, V]
 | 
					 | 
				
			||||||
	// uint8 because it's number between 0 and numBuckets
 | 
					 | 
				
			||||||
	nextCleanupBucket uint8
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// bucket is a container for holding entries to be expired
 | 
					 | 
				
			||||||
type bucket[K comparable, V any] struct {
 | 
					 | 
				
			||||||
	entries     map[K]*Entry[K, V]
 | 
					 | 
				
			||||||
	newestEntry time.Time
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// noEvictionTTL - very long ttl to prevent eviction
 | 
					 | 
				
			||||||
const noEvictionTTL = time.Hour * 24 * 365 * 10
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// because of uint8 usage for nextCleanupBucket, should not exceed 256.
 | 
					 | 
				
			||||||
// casting it as uint8 explicitly requires type conversions in multiple places
 | 
					 | 
				
			||||||
const numBuckets = 100
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// NewLRU returns a new thread-safe cache with expirable entries.
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// Size parameter set to 0 makes cache of unlimited size, e.g. turns LRU mechanism off.
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// Providing 0 TTL turns expiring off.
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// Delete expired entries every 1/100th of ttl value. Goroutine which deletes expired entries runs indefinitely.
 | 
					 | 
				
			||||||
func NewLRU[K comparable, V any](size int, onEvict EvictCallback[K, V], ttl time.Duration) *LRU[K, V] {
 | 
					 | 
				
			||||||
	if size < 0 {
 | 
					 | 
				
			||||||
		size = 0
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if ttl <= 0 {
 | 
					 | 
				
			||||||
		ttl = noEvictionTTL
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	res := LRU[K, V]{
 | 
					 | 
				
			||||||
		ttl:       ttl,
 | 
					 | 
				
			||||||
		size:      size,
 | 
					 | 
				
			||||||
		evictList: NewList[K, V](),
 | 
					 | 
				
			||||||
		items:     make(map[K]*Entry[K, V]),
 | 
					 | 
				
			||||||
		onEvict:   onEvict,
 | 
					 | 
				
			||||||
		done:      make(chan struct{}),
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// initialize the buckets
 | 
					 | 
				
			||||||
	res.buckets = make([]bucket[K, V], numBuckets)
 | 
					 | 
				
			||||||
	for i := 0; i < numBuckets; i++ {
 | 
					 | 
				
			||||||
		res.buckets[i] = bucket[K, V]{entries: make(map[K]*Entry[K, V])}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// enable deleteExpired() running in separate goroutine for cache with non-zero TTL
 | 
					 | 
				
			||||||
	//
 | 
					 | 
				
			||||||
	// Important: done channel is never closed, so deleteExpired() goroutine will never exit,
 | 
					 | 
				
			||||||
	// it's decided to add functionality to close it in the version later than v2.
 | 
					 | 
				
			||||||
	if res.ttl != noEvictionTTL {
 | 
					 | 
				
			||||||
		go func(done <-chan struct{}) {
 | 
					 | 
				
			||||||
			ticker := time.NewTicker(res.ttl / numBuckets)
 | 
					 | 
				
			||||||
			defer ticker.Stop()
 | 
					 | 
				
			||||||
			for {
 | 
					 | 
				
			||||||
				select {
 | 
					 | 
				
			||||||
				case <-done:
 | 
					 | 
				
			||||||
					return
 | 
					 | 
				
			||||||
				case <-ticker.C:
 | 
					 | 
				
			||||||
					res.deleteExpired()
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}(res.done)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return &res
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Purge clears the cache completely.
 | 
					 | 
				
			||||||
// onEvict is called for each evicted key.
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) Purge() {
 | 
					 | 
				
			||||||
	c.mu.Lock()
 | 
					 | 
				
			||||||
	defer c.mu.Unlock()
 | 
					 | 
				
			||||||
	for k, v := range c.items {
 | 
					 | 
				
			||||||
		if c.onEvict != nil {
 | 
					 | 
				
			||||||
			c.onEvict(k, v.Value)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		delete(c.items, k)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	for _, b := range c.buckets {
 | 
					 | 
				
			||||||
		for _, ent := range b.entries {
 | 
					 | 
				
			||||||
			delete(b.entries, ent.Key)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	c.evictList.Init()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Add adds a value to the cache. Returns true if an eviction occurred.
 | 
					 | 
				
			||||||
// Returns false if there was no eviction: the item was already in the cache,
 | 
					 | 
				
			||||||
// or the size was not exceeded.
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) Add(key K, value V) (evicted bool) {
 | 
					 | 
				
			||||||
	c.mu.Lock()
 | 
					 | 
				
			||||||
	defer c.mu.Unlock()
 | 
					 | 
				
			||||||
	now := time.Now()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Check for existing item
 | 
					 | 
				
			||||||
	if ent, ok := c.items[key]; ok {
 | 
					 | 
				
			||||||
		c.evictList.MoveToFront(ent)
 | 
					 | 
				
			||||||
		c.removeFromBucket(ent) // remove the entry from its current bucket as expiresAt is renewed
 | 
					 | 
				
			||||||
		ent.Value = value
 | 
					 | 
				
			||||||
		ent.ExpiresAt = now.Add(c.ttl)
 | 
					 | 
				
			||||||
		c.addToBucket(ent)
 | 
					 | 
				
			||||||
		return false
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Add new item
 | 
					 | 
				
			||||||
	ent := c.evictList.PushFrontExpirable(key, value, now.Add(c.ttl))
 | 
					 | 
				
			||||||
	c.items[key] = ent
 | 
					 | 
				
			||||||
	c.addToBucket(ent) // adds the entry to the appropriate bucket and sets entry.expireBucket
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	evict := c.size > 0 && c.evictList.Length() > c.size
 | 
					 | 
				
			||||||
	// Verify size not exceeded
 | 
					 | 
				
			||||||
	if evict {
 | 
					 | 
				
			||||||
		c.removeOldest()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return evict
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Get looks up a key's value from the cache.
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) Get(key K) (value V, ok bool) {
 | 
					 | 
				
			||||||
	c.mu.Lock()
 | 
					 | 
				
			||||||
	defer c.mu.Unlock()
 | 
					 | 
				
			||||||
	var ent *Entry[K, V]
 | 
					 | 
				
			||||||
	if ent, ok = c.items[key]; ok {
 | 
					 | 
				
			||||||
		// Expired item check
 | 
					 | 
				
			||||||
		if time.Now().After(ent.ExpiresAt) {
 | 
					 | 
				
			||||||
			return value, false
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		c.evictList.MoveToFront(ent)
 | 
					 | 
				
			||||||
		return ent.Value, true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Contains checks if a key is in the cache, without updating the recent-ness
 | 
					 | 
				
			||||||
// or deleting it for being stale.
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) Contains(key K) (ok bool) {
 | 
					 | 
				
			||||||
	c.mu.Lock()
 | 
					 | 
				
			||||||
	defer c.mu.Unlock()
 | 
					 | 
				
			||||||
	_, ok = c.items[key]
 | 
					 | 
				
			||||||
	return ok
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Peek returns the key value (or undefined if not found) without updating
 | 
					 | 
				
			||||||
// the "recently used"-ness of the key.
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) Peek(key K) (value V, ok bool) {
 | 
					 | 
				
			||||||
	c.mu.Lock()
 | 
					 | 
				
			||||||
	defer c.mu.Unlock()
 | 
					 | 
				
			||||||
	var ent *Entry[K, V]
 | 
					 | 
				
			||||||
	if ent, ok = c.items[key]; ok {
 | 
					 | 
				
			||||||
		// Expired item check
 | 
					 | 
				
			||||||
		if time.Now().After(ent.ExpiresAt) {
 | 
					 | 
				
			||||||
			return value, false
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return ent.Value, true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Remove removes the provided key from the cache, returning if the
 | 
					 | 
				
			||||||
// key was contained.
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) Remove(key K) bool {
 | 
					 | 
				
			||||||
	c.mu.Lock()
 | 
					 | 
				
			||||||
	defer c.mu.Unlock()
 | 
					 | 
				
			||||||
	if ent, ok := c.items[key]; ok {
 | 
					 | 
				
			||||||
		c.removeElement(ent)
 | 
					 | 
				
			||||||
		return true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return false
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// RemoveOldest removes the oldest item from the cache.
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) RemoveOldest() (key K, value V, ok bool) {
 | 
					 | 
				
			||||||
	c.mu.Lock()
 | 
					 | 
				
			||||||
	defer c.mu.Unlock()
 | 
					 | 
				
			||||||
	if ent := c.evictList.Back(); ent != nil {
 | 
					 | 
				
			||||||
		c.removeElement(ent)
 | 
					 | 
				
			||||||
		return ent.Key, ent.Value, true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// GetOldest returns the oldest entry
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) GetOldest() (key K, value V, ok bool) {
 | 
					 | 
				
			||||||
	c.mu.Lock()
 | 
					 | 
				
			||||||
	defer c.mu.Unlock()
 | 
					 | 
				
			||||||
	if ent := c.evictList.Back(); ent != nil {
 | 
					 | 
				
			||||||
		return ent.Key, ent.Value, true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) KeyValues() map[K]V {
 | 
					 | 
				
			||||||
	c.mu.Lock()
 | 
					 | 
				
			||||||
	defer c.mu.Unlock()
 | 
					 | 
				
			||||||
	maps := make(map[K]V)
 | 
					 | 
				
			||||||
	now := time.Now()
 | 
					 | 
				
			||||||
	for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
 | 
					 | 
				
			||||||
		if now.After(ent.ExpiresAt) {
 | 
					 | 
				
			||||||
			continue
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		maps[ent.Key] = ent.Value
 | 
					 | 
				
			||||||
		// keys = append(keys, ent.Key)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return maps
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Keys returns a slice of the keys in the cache, from oldest to newest.
 | 
					 | 
				
			||||||
// Expired entries are filtered out.
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) Keys() []K {
 | 
					 | 
				
			||||||
	c.mu.Lock()
 | 
					 | 
				
			||||||
	defer c.mu.Unlock()
 | 
					 | 
				
			||||||
	keys := make([]K, 0, len(c.items))
 | 
					 | 
				
			||||||
	now := time.Now()
 | 
					 | 
				
			||||||
	for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
 | 
					 | 
				
			||||||
		if now.After(ent.ExpiresAt) {
 | 
					 | 
				
			||||||
			continue
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		keys = append(keys, ent.Key)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return keys
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Values returns a slice of the values in the cache, from oldest to newest.
 | 
					 | 
				
			||||||
// Expired entries are filtered out.
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) Values() []V {
 | 
					 | 
				
			||||||
	c.mu.Lock()
 | 
					 | 
				
			||||||
	defer c.mu.Unlock()
 | 
					 | 
				
			||||||
	values := make([]V, 0, len(c.items))
 | 
					 | 
				
			||||||
	now := time.Now()
 | 
					 | 
				
			||||||
	for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
 | 
					 | 
				
			||||||
		if now.After(ent.ExpiresAt) {
 | 
					 | 
				
			||||||
			continue
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		values = append(values, ent.Value)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return values
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Len returns the number of items in the cache.
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) Len() int {
 | 
					 | 
				
			||||||
	c.mu.Lock()
 | 
					 | 
				
			||||||
	defer c.mu.Unlock()
 | 
					 | 
				
			||||||
	return c.evictList.Length()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Resize changes the cache size. Size of 0 means unlimited.
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) Resize(size int) (evicted int) {
 | 
					 | 
				
			||||||
	c.mu.Lock()
 | 
					 | 
				
			||||||
	defer c.mu.Unlock()
 | 
					 | 
				
			||||||
	if size <= 0 {
 | 
					 | 
				
			||||||
		c.size = 0
 | 
					 | 
				
			||||||
		return 0
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	diff := c.evictList.Length() - size
 | 
					 | 
				
			||||||
	if diff < 0 {
 | 
					 | 
				
			||||||
		diff = 0
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	for i := 0; i < diff; i++ {
 | 
					 | 
				
			||||||
		c.removeOldest()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	c.size = size
 | 
					 | 
				
			||||||
	return diff
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Close destroys cleanup goroutine. To clean up the cache, run Purge() before Close().
 | 
					 | 
				
			||||||
// func (c *LRU[K, V]) Close() {
 | 
					 | 
				
			||||||
//	c.mu.Lock()
 | 
					 | 
				
			||||||
//	defer c.mu.Unlock()
 | 
					 | 
				
			||||||
//	select {
 | 
					 | 
				
			||||||
//	case <-c.done:
 | 
					 | 
				
			||||||
//		return
 | 
					 | 
				
			||||||
//	default:
 | 
					 | 
				
			||||||
//	}
 | 
					 | 
				
			||||||
//	close(c.done)
 | 
					 | 
				
			||||||
// }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// removeOldest removes the oldest item from the cache. Has to be called with lock!
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) removeOldest() {
 | 
					 | 
				
			||||||
	if ent := c.evictList.Back(); ent != nil {
 | 
					 | 
				
			||||||
		c.removeElement(ent)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// removeElement is used to remove a given list element from the cache. Has to be called with lock!
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) removeElement(e *Entry[K, V]) {
 | 
					 | 
				
			||||||
	c.evictList.Remove(e)
 | 
					 | 
				
			||||||
	delete(c.items, e.Key)
 | 
					 | 
				
			||||||
	c.removeFromBucket(e)
 | 
					 | 
				
			||||||
	if c.onEvict != nil {
 | 
					 | 
				
			||||||
		c.onEvict(e.Key, e.Value)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// deleteExpired deletes expired records from the oldest bucket, waiting for the newest entry
 | 
					 | 
				
			||||||
// in it to expire first.
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) deleteExpired() {
 | 
					 | 
				
			||||||
	c.mu.Lock()
 | 
					 | 
				
			||||||
	bucketIdx := c.nextCleanupBucket
 | 
					 | 
				
			||||||
	timeToExpire := time.Until(c.buckets[bucketIdx].newestEntry)
 | 
					 | 
				
			||||||
	// wait for newest entry to expire before cleanup without holding lock
 | 
					 | 
				
			||||||
	if timeToExpire > 0 {
 | 
					 | 
				
			||||||
		c.mu.Unlock()
 | 
					 | 
				
			||||||
		time.Sleep(timeToExpire)
 | 
					 | 
				
			||||||
		c.mu.Lock()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	for _, ent := range c.buckets[bucketIdx].entries {
 | 
					 | 
				
			||||||
		c.removeElement(ent)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	c.nextCleanupBucket = (c.nextCleanupBucket + 1) % numBuckets
 | 
					 | 
				
			||||||
	c.mu.Unlock()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// addToBucket adds entry to expire bucket so that it will be cleaned up when the time comes. Has to be called with lock!
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) addToBucket(e *Entry[K, V]) {
 | 
					 | 
				
			||||||
	bucketID := (numBuckets + c.nextCleanupBucket - 1) % numBuckets
 | 
					 | 
				
			||||||
	e.ExpireBucket = bucketID
 | 
					 | 
				
			||||||
	c.buckets[bucketID].entries[e.Key] = e
 | 
					 | 
				
			||||||
	if c.buckets[bucketID].newestEntry.Before(e.ExpiresAt) {
 | 
					 | 
				
			||||||
		c.buckets[bucketID].newestEntry = e.ExpiresAt
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// removeFromBucket removes the entry from its corresponding bucket. Has to be called with lock!
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) removeFromBucket(e *Entry[K, V]) {
 | 
					 | 
				
			||||||
	delete(c.buckets[e.ExpireBucket].entries, e.Key)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Cap returns the capacity of the cache
 | 
					 | 
				
			||||||
func (c *LRU[K, V]) Cap() int {
 | 
					 | 
				
			||||||
	return c.size
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Entry is an LRU Entry
 | 
					 | 
				
			||||||
type Entry[K comparable, V any] struct {
 | 
					 | 
				
			||||||
	// Next and previous pointers in the doubly-linked list of elements.
 | 
					 | 
				
			||||||
	// To simplify the implementation, internally a list l is implemented
 | 
					 | 
				
			||||||
	// as a ring, such that &l.root is both the next element of the last
 | 
					 | 
				
			||||||
	// list element (l.Back()) and the previous element of the first list
 | 
					 | 
				
			||||||
	// element (l.Front()).
 | 
					 | 
				
			||||||
	next, prev *Entry[K, V]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// The list to which this element belongs.
 | 
					 | 
				
			||||||
	list *LruList[K, V]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// The LRU Key of this element.
 | 
					 | 
				
			||||||
	Key K
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// The Value stored with this element.
 | 
					 | 
				
			||||||
	Value V
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// The time this element would be cleaned up, optional
 | 
					 | 
				
			||||||
	ExpiresAt time.Time
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// The expiry bucket item was put in, optional
 | 
					 | 
				
			||||||
	ExpireBucket uint8
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// PrevEntry returns the previous list element or nil.
 | 
					 | 
				
			||||||
func (e *Entry[K, V]) PrevEntry() *Entry[K, V] {
 | 
					 | 
				
			||||||
	if p := e.prev; e.list != nil && p != &e.list.root {
 | 
					 | 
				
			||||||
		return p
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// LruList represents a doubly linked list.
 | 
					 | 
				
			||||||
// The zero Value for LruList is an empty list ready to use.
 | 
					 | 
				
			||||||
type LruList[K comparable, V any] struct {
 | 
					 | 
				
			||||||
	root Entry[K, V] // sentinel list element, only &root, root.prev, and root.next are used
 | 
					 | 
				
			||||||
	len  int         // current list Length excluding (this) sentinel element
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Init initializes or clears list l.
 | 
					 | 
				
			||||||
func (l *LruList[K, V]) Init() *LruList[K, V] {
 | 
					 | 
				
			||||||
	l.root.next = &l.root
 | 
					 | 
				
			||||||
	l.root.prev = &l.root
 | 
					 | 
				
			||||||
	l.len = 0
 | 
					 | 
				
			||||||
	return l
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// NewList returns an initialized list.
 | 
					 | 
				
			||||||
func NewList[K comparable, V any]() *LruList[K, V] { return new(LruList[K, V]).Init() }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Length returns the number of elements of list l.
 | 
					 | 
				
			||||||
// The complexity is O(1).
 | 
					 | 
				
			||||||
func (l *LruList[K, V]) Length() int { return l.len }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Back returns the last element of list l or nil if the list is empty.
 | 
					 | 
				
			||||||
func (l *LruList[K, V]) Back() *Entry[K, V] {
 | 
					 | 
				
			||||||
	if l.len == 0 {
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return l.root.prev
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// lazyInit lazily initializes a zero List Value.
 | 
					 | 
				
			||||||
func (l *LruList[K, V]) lazyInit() {
 | 
					 | 
				
			||||||
	if l.root.next == nil {
 | 
					 | 
				
			||||||
		l.Init()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// insert inserts e after at, increments l.len, and returns e.
 | 
					 | 
				
			||||||
func (l *LruList[K, V]) insert(e, at *Entry[K, V]) *Entry[K, V] {
 | 
					 | 
				
			||||||
	e.prev = at
 | 
					 | 
				
			||||||
	e.next = at.next
 | 
					 | 
				
			||||||
	e.prev.next = e
 | 
					 | 
				
			||||||
	e.next.prev = e
 | 
					 | 
				
			||||||
	e.list = l
 | 
					 | 
				
			||||||
	l.len++
 | 
					 | 
				
			||||||
	return e
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// insertValue is a convenience wrapper for insert(&Entry{Value: v, ExpiresAt: ExpiresAt}, at).
 | 
					 | 
				
			||||||
func (l *LruList[K, V]) insertValue(k K, v V, expiresAt time.Time, at *Entry[K, V]) *Entry[K, V] {
 | 
					 | 
				
			||||||
	return l.insert(&Entry[K, V]{Value: v, Key: k, ExpiresAt: expiresAt}, at)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Remove removes e from its list, decrements l.len
 | 
					 | 
				
			||||||
func (l *LruList[K, V]) Remove(e *Entry[K, V]) V {
 | 
					 | 
				
			||||||
	e.prev.next = e.next
 | 
					 | 
				
			||||||
	e.next.prev = e.prev
 | 
					 | 
				
			||||||
	e.next = nil // avoid memory leaks
 | 
					 | 
				
			||||||
	e.prev = nil // avoid memory leaks
 | 
					 | 
				
			||||||
	e.list = nil
 | 
					 | 
				
			||||||
	l.len--
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return e.Value
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// move moves e to next to at.
 | 
					 | 
				
			||||||
func (l *LruList[K, V]) move(e, at *Entry[K, V]) {
 | 
					 | 
				
			||||||
	if e == at {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	e.prev.next = e.next
 | 
					 | 
				
			||||||
	e.next.prev = e.prev
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	e.prev = at
 | 
					 | 
				
			||||||
	e.next = at.next
 | 
					 | 
				
			||||||
	e.prev.next = e
 | 
					 | 
				
			||||||
	e.next.prev = e
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// PushFront inserts a new element e with value v at the front of list l and returns e.
 | 
					 | 
				
			||||||
func (l *LruList[K, V]) PushFront(k K, v V) *Entry[K, V] {
 | 
					 | 
				
			||||||
	l.lazyInit()
 | 
					 | 
				
			||||||
	return l.insertValue(k, v, time.Time{}, &l.root)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// PushFrontExpirable inserts a new expirable element e with Value v at the front of list l and returns e.
 | 
					 | 
				
			||||||
func (l *LruList[K, V]) PushFrontExpirable(k K, v V, expiresAt time.Time) *Entry[K, V] {
 | 
					 | 
				
			||||||
	l.lazyInit()
 | 
					 | 
				
			||||||
	return l.insertValue(k, v, expiresAt, &l.root)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// MoveToFront moves element e to the front of list l.
 | 
					 | 
				
			||||||
// If e is not an element of l, the list is not modified.
 | 
					 | 
				
			||||||
// The element must not be nil.
 | 
					 | 
				
			||||||
func (l *LruList[K, V]) MoveToFront(e *Entry[K, V]) {
 | 
					 | 
				
			||||||
	if e.list != l || l.root.next == e {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	// see comment in List.Remove about initialization of l
 | 
					 | 
				
			||||||
	l.move(e, &l.root)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@ -1,183 +0,0 @@
 | 
				
			|||||||
package stmt_store
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"context"
 | 
					 | 
				
			||||||
	"database/sql"
 | 
					 | 
				
			||||||
	"math"
 | 
					 | 
				
			||||||
	"sync"
 | 
					 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"gorm.io/gorm/internal/lru"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type Stmt struct {
 | 
					 | 
				
			||||||
	*sql.Stmt
 | 
					 | 
				
			||||||
	Transaction bool
 | 
					 | 
				
			||||||
	prepared    chan struct{}
 | 
					 | 
				
			||||||
	prepareErr  error
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (stmt *Stmt) Error() error {
 | 
					 | 
				
			||||||
	return stmt.prepareErr
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (stmt *Stmt) Close() error {
 | 
					 | 
				
			||||||
	<-stmt.prepared
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if stmt.Stmt != nil {
 | 
					 | 
				
			||||||
		return stmt.Stmt.Close()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Store defines an interface for managing the caching operations of SQL statements (Stmt).
 | 
					 | 
				
			||||||
// This interface provides methods for creating new statements, retrieving all cache keys,
 | 
					 | 
				
			||||||
// getting cached statements, setting cached statements, and deleting cached statements.
 | 
					 | 
				
			||||||
type Store interface {
 | 
					 | 
				
			||||||
	// New creates a new Stmt object and caches it.
 | 
					 | 
				
			||||||
	// Parameters:
 | 
					 | 
				
			||||||
	//   ctx: The context for the request, which can carry deadlines, cancellation signals, etc.
 | 
					 | 
				
			||||||
	//   key: The key representing the SQL query, used for caching and preparing the statement.
 | 
					 | 
				
			||||||
	//   isTransaction: Indicates whether this operation is part of a transaction, which may affect the caching strategy.
 | 
					 | 
				
			||||||
	//   connPool: A connection pool that provides database connections.
 | 
					 | 
				
			||||||
	//   locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
 | 
					 | 
				
			||||||
	// Returns:
 | 
					 | 
				
			||||||
	//   *Stmt: A newly created statement object for executing SQL operations.
 | 
					 | 
				
			||||||
	//   error: An error if the statement preparation fails.
 | 
					 | 
				
			||||||
	New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Keys returns a slice of all cache keys in the store.
 | 
					 | 
				
			||||||
	Keys() []string
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Get retrieves a Stmt object from the store based on the given key.
 | 
					 | 
				
			||||||
	// Parameters:
 | 
					 | 
				
			||||||
	//   key: The key used to look up the Stmt object.
 | 
					 | 
				
			||||||
	// Returns:
 | 
					 | 
				
			||||||
	//   *Stmt: The found Stmt object, or nil if not found.
 | 
					 | 
				
			||||||
	//   bool: Indicates whether the corresponding Stmt object was successfully found.
 | 
					 | 
				
			||||||
	Get(key string) (*Stmt, bool)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Set stores the given Stmt object in the store and associates it with the specified key.
 | 
					 | 
				
			||||||
	// Parameters:
 | 
					 | 
				
			||||||
	//   key: The key used to associate the Stmt object.
 | 
					 | 
				
			||||||
	//   value: The Stmt object to be stored.
 | 
					 | 
				
			||||||
	Set(key string, value *Stmt)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Delete removes the Stmt object corresponding to the specified key from the store.
 | 
					 | 
				
			||||||
	// Parameters:
 | 
					 | 
				
			||||||
	//   key: The key associated with the Stmt object to be deleted.
 | 
					 | 
				
			||||||
	Delete(key string)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// defaultMaxSize defines the default maximum capacity of the cache.
 | 
					 | 
				
			||||||
// Its value is the maximum value of the int64 type, which means that when the cache size is not specified,
 | 
					 | 
				
			||||||
// the cache can theoretically store as many elements as possible.
 | 
					 | 
				
			||||||
// (1 << 63) - 1 is the maximum value that an int64 type can represent.
 | 
					 | 
				
			||||||
const (
 | 
					 | 
				
			||||||
	defaultMaxSize = math.MaxInt
 | 
					 | 
				
			||||||
	// defaultTTL defines the default time-to-live (TTL) for each cache entry.
 | 
					 | 
				
			||||||
	// When the TTL for cache entries is not specified, each cache entry will expire after 24 hours.
 | 
					 | 
				
			||||||
	defaultTTL = time.Hour * 24
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// New creates and returns a new Store instance.
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// Parameters:
 | 
					 | 
				
			||||||
//   - size: The maximum capacity of the cache. If the provided size is less than or equal to 0,
 | 
					 | 
				
			||||||
//     it defaults to defaultMaxSize.
 | 
					 | 
				
			||||||
//   - ttl: The time-to-live duration for each cache entry. If the provided ttl is less than or equal to 0,
 | 
					 | 
				
			||||||
//     it defaults to defaultTTL.
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// This function defines an onEvicted callback that is invoked when a cache entry is evicted.
 | 
					 | 
				
			||||||
// The callback ensures that if the evicted value (v) is not nil, its Close method is called asynchronously
 | 
					 | 
				
			||||||
// to release associated resources.
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// Returns:
 | 
					 | 
				
			||||||
//   - A Store instance implemented by lruStore, which internally uses an LRU cache with the specified size,
 | 
					 | 
				
			||||||
//     eviction callback, and TTL.
 | 
					 | 
				
			||||||
func New(size int, ttl time.Duration) Store {
 | 
					 | 
				
			||||||
	if size <= 0 {
 | 
					 | 
				
			||||||
		size = defaultMaxSize
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if ttl <= 0 {
 | 
					 | 
				
			||||||
		ttl = defaultTTL
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	onEvicted := func(k string, v *Stmt) {
 | 
					 | 
				
			||||||
		if v != nil {
 | 
					 | 
				
			||||||
			go v.Close()
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return &lruStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type lruStore struct {
 | 
					 | 
				
			||||||
	lru *lru.LRU[string, *Stmt]
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s *lruStore) Keys() []string {
 | 
					 | 
				
			||||||
	return s.lru.Keys()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s *lruStore) Get(key string) (*Stmt, bool) {
 | 
					 | 
				
			||||||
	stmt, ok := s.lru.Get(key)
 | 
					 | 
				
			||||||
	if ok && stmt != nil {
 | 
					 | 
				
			||||||
		<-stmt.prepared
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return stmt, ok
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s *lruStore) Set(key string, value *Stmt) {
 | 
					 | 
				
			||||||
	s.lru.Add(key, value)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s *lruStore) Delete(key string) {
 | 
					 | 
				
			||||||
	s.lru.Remove(key)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ConnPool interface {
 | 
					 | 
				
			||||||
	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// New creates a new Stmt object for executing SQL queries.
 | 
					 | 
				
			||||||
// It caches the Stmt object for future use and handles preparation and error states.
 | 
					 | 
				
			||||||
// Parameters:
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
//	ctx: Context for the request, used to carry deadlines, cancellation signals, etc.
 | 
					 | 
				
			||||||
//	key: The key representing the SQL query, used for caching and preparing the statement.
 | 
					 | 
				
			||||||
//	isTransaction: Indicates whether this operation is part of a transaction, affecting cache strategy.
 | 
					 | 
				
			||||||
//	conn: A connection pool that provides database connections.
 | 
					 | 
				
			||||||
//	locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// Returns:
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
//	*Stmt: A newly created statement object for executing SQL operations.
 | 
					 | 
				
			||||||
//	error: An error if the statement preparation fails.
 | 
					 | 
				
			||||||
func (s *lruStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) {
 | 
					 | 
				
			||||||
	// Create a Stmt object and set its Transaction property.
 | 
					 | 
				
			||||||
	// The prepared channel is used to synchronize the statement preparation state.
 | 
					 | 
				
			||||||
	cacheStmt := &Stmt{
 | 
					 | 
				
			||||||
		Transaction: isTransaction,
 | 
					 | 
				
			||||||
		prepared:    make(chan struct{}),
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	// Cache the Stmt object with the associated key.
 | 
					 | 
				
			||||||
	s.Set(key, cacheStmt)
 | 
					 | 
				
			||||||
	// Unlock after completing initialization to prevent deadlocks.
 | 
					 | 
				
			||||||
	locker.Unlock()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Ensure the prepared channel is closed after the function execution completes.
 | 
					 | 
				
			||||||
	defer close(cacheStmt.prepared)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Prepare the SQL statement using the provided connection.
 | 
					 | 
				
			||||||
	cacheStmt.Stmt, err = conn.PrepareContext(ctx, key)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		// If statement preparation fails, record the error and remove the invalid Stmt object from the cache.
 | 
					 | 
				
			||||||
		cacheStmt.prepareErr = err
 | 
					 | 
				
			||||||
		s.Delete(key)
 | 
					 | 
				
			||||||
		return &Stmt{}, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Return the successfully prepared Stmt object.
 | 
					 | 
				
			||||||
	return cacheStmt, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@ -69,7 +69,7 @@ type Interface interface {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
	// Discard logger will print any log to io.Discard
 | 
						// Discard Discard logger will print any log to io.Discard
 | 
				
			||||||
	Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
 | 
						Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
 | 
				
			||||||
	// Default Default logger
 | 
						// Default Default logger
 | 
				
			||||||
	Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
 | 
						Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
 | 
				
			||||||
@ -78,13 +78,8 @@ var (
 | 
				
			|||||||
		IgnoreRecordNotFoundError: false,
 | 
							IgnoreRecordNotFoundError: false,
 | 
				
			||||||
		Colorful:                  true,
 | 
							Colorful:                  true,
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	// Recorder logger records running SQL into a recorder instance
 | 
						// Recorder Recorder logger records running SQL into a recorder instance
 | 
				
			||||||
	Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
 | 
						Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
 | 
				
			||||||
 | 
					 | 
				
			||||||
	// RecorderParamsFilter defaults to no-op, allows to be run-over by a different implementation
 | 
					 | 
				
			||||||
	RecorderParamsFilter = func(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
 | 
					 | 
				
			||||||
		return sql, params
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// New initialize logger
 | 
					// New initialize logger
 | 
				
			||||||
@ -134,30 +129,28 @@ func (l *logger) LogMode(level LogLevel) Interface {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Info print info
 | 
					// Info print info
 | 
				
			||||||
func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) {
 | 
					func (l logger) Info(ctx context.Context, msg string, data ...interface{}) {
 | 
				
			||||||
	if l.LogLevel >= Info {
 | 
						if l.LogLevel >= Info {
 | 
				
			||||||
		l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
 | 
							l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Warn print warn messages
 | 
					// Warn print warn messages
 | 
				
			||||||
func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) {
 | 
					func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) {
 | 
				
			||||||
	if l.LogLevel >= Warn {
 | 
						if l.LogLevel >= Warn {
 | 
				
			||||||
		l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
 | 
							l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Error print error messages
 | 
					// Error print error messages
 | 
				
			||||||
func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) {
 | 
					func (l logger) Error(ctx context.Context, msg string, data ...interface{}) {
 | 
				
			||||||
	if l.LogLevel >= Error {
 | 
						if l.LogLevel >= Error {
 | 
				
			||||||
		l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
 | 
							l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Trace print sql message
 | 
					// Trace print sql message
 | 
				
			||||||
//
 | 
					func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
 | 
				
			||||||
//nolint:cyclop
 | 
					 | 
				
			||||||
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
 | 
					 | 
				
			||||||
	if l.LogLevel <= Silent {
 | 
						if l.LogLevel <= Silent {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -189,8 +182,8 @@ func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string,
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ParamsFilter filter params
 | 
					// Trace print sql message
 | 
				
			||||||
func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
 | 
					func (l logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
 | 
				
			||||||
	if l.Config.ParameterizedQueries {
 | 
						if l.Config.ParameterizedQueries {
 | 
				
			||||||
		return sql, nil
 | 
							return sql, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -205,8 +198,8 @@ type traceRecorder struct {
 | 
				
			|||||||
	Err          error
 | 
						Err          error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// New trace recorder
 | 
					// New new trace recorder
 | 
				
			||||||
func (l *traceRecorder) New() *traceRecorder {
 | 
					func (l traceRecorder) New() *traceRecorder {
 | 
				
			||||||
	return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
 | 
						return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -216,10 +209,3 @@ func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (s
 | 
				
			|||||||
	l.SQL, l.RowsAffected = fc()
 | 
						l.SQL, l.RowsAffected = fc()
 | 
				
			||||||
	l.Err = err
 | 
						l.Err = err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func (l *traceRecorder) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
 | 
					 | 
				
			||||||
	if RecorderParamsFilter == nil {
 | 
					 | 
				
			||||||
		return sql, params
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return RecorderParamsFilter(ctx, sql, params...)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -34,19 +34,6 @@ var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO
 | 
				
			|||||||
// RegEx matches only numeric values
 | 
					// RegEx matches only numeric values
 | 
				
			||||||
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
 | 
					var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func isNumeric(k reflect.Kind) bool {
 | 
					 | 
				
			||||||
	switch k {
 | 
					 | 
				
			||||||
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
 | 
					 | 
				
			||||||
		return true
 | 
					 | 
				
			||||||
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
 | 
					 | 
				
			||||||
		return true
 | 
					 | 
				
			||||||
	case reflect.Float32, reflect.Float64:
 | 
					 | 
				
			||||||
		return true
 | 
					 | 
				
			||||||
	default:
 | 
					 | 
				
			||||||
		return false
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
 | 
					// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
 | 
				
			||||||
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
 | 
					func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
@ -92,17 +79,17 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
 | 
				
			|||||||
			case reflect.Bool:
 | 
								case reflect.Bool:
 | 
				
			||||||
				vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
 | 
									vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
 | 
				
			||||||
			case reflect.String:
 | 
								case reflect.String:
 | 
				
			||||||
				vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
 | 
									vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
 | 
				
			||||||
			default:
 | 
								default:
 | 
				
			||||||
				if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
 | 
									if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
 | 
				
			||||||
					vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
 | 
										vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
					vars[idx] = nullStr
 | 
										vars[idx] = nullStr
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		case []byte:
 | 
							case []byte:
 | 
				
			||||||
			if s := string(v); isPrintable(s) {
 | 
								if s := string(v); isPrintable(s) {
 | 
				
			||||||
				vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
 | 
									vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				vars[idx] = escaper + "<binary>" + escaper
 | 
									vars[idx] = escaper + "<binary>" + escaper
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@ -113,7 +100,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
 | 
				
			|||||||
		case float64:
 | 
							case float64:
 | 
				
			||||||
			vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
 | 
								vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
 | 
				
			||||||
		case string:
 | 
							case string:
 | 
				
			||||||
			vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper
 | 
								vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
			rv := reflect.ValueOf(v)
 | 
								rv := reflect.ValueOf(v)
 | 
				
			||||||
			if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
 | 
								if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
 | 
				
			||||||
@ -123,12 +110,6 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
 | 
				
			|||||||
				convertParams(v, idx)
 | 
									convertParams(v, idx)
 | 
				
			||||||
			} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
 | 
								} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
 | 
				
			||||||
				convertParams(reflect.Indirect(rv).Interface(), idx)
 | 
									convertParams(reflect.Indirect(rv).Interface(), idx)
 | 
				
			||||||
			} else if isNumeric(rv.Kind()) {
 | 
					 | 
				
			||||||
				if rv.CanInt() || rv.CanUint() {
 | 
					 | 
				
			||||||
					vars[idx] = fmt.Sprintf("%d", rv.Interface())
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				for _, t := range convertibleTypes {
 | 
									for _, t := range convertibleTypes {
 | 
				
			||||||
					if rv.Type().ConvertibleTo(t) {
 | 
										if rv.Type().ConvertibleTo(t) {
 | 
				
			||||||
@ -136,7 +117,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
 | 
				
			|||||||
						return
 | 
											return
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
 | 
									vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -31,24 +31,20 @@ func (s ExampleStruct) Value() (driver.Value, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func format(v []byte, escaper string) string {
 | 
					func format(v []byte, escaper string) string {
 | 
				
			||||||
	return escaper + strings.ReplaceAll(string(v), escaper, escaper+escaper) + escaper
 | 
						return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestExplainSQL(t *testing.T) {
 | 
					func TestExplainSQL(t *testing.T) {
 | 
				
			||||||
	type role string
 | 
						type role string
 | 
				
			||||||
	type password []byte
 | 
						type password []byte
 | 
				
			||||||
	type intType int
 | 
					 | 
				
			||||||
	type floatType float64
 | 
					 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		tt                 = now.MustParse("2020-02-23 11:10:10")
 | 
							tt     = now.MustParse("2020-02-23 11:10:10")
 | 
				
			||||||
		myrole             = role("admin")
 | 
							myrole = role("admin")
 | 
				
			||||||
		pwd                = password("pass")
 | 
							pwd    = password([]byte("pass"))
 | 
				
			||||||
		jsVal              = []byte(`{"Name":"test","Val":"test"}`)
 | 
							jsVal  = []byte(`{"Name":"test","Val":"test"}`)
 | 
				
			||||||
		js                 = JSON(jsVal)
 | 
							js     = JSON(jsVal)
 | 
				
			||||||
		esVal              = []byte(`{"Name":"test","Val":"test"}`)
 | 
							esVal  = []byte(`{"Name":"test","Val":"test"}`)
 | 
				
			||||||
		es                 = ExampleStruct{Name: "test", Val: "test"}
 | 
							es     = ExampleStruct{Name: "test", Val: "test"}
 | 
				
			||||||
		intVal   intType   = 1
 | 
					 | 
				
			||||||
		floatVal floatType = 1.23
 | 
					 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	results := []struct {
 | 
						results := []struct {
 | 
				
			||||||
@ -61,13 +57,13 @@ func TestExplainSQL(t *testing.T) {
 | 
				
			|||||||
			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
 | 
								SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
 | 
				
			||||||
			NumericRegexp: nil,
 | 
								NumericRegexp: nil,
 | 
				
			||||||
			Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
 | 
								Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
 | 
				
			||||||
			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`,
 | 
								Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
 | 
								SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
 | 
				
			||||||
			NumericRegexp: nil,
 | 
								NumericRegexp: nil,
 | 
				
			||||||
			Vars:          []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
 | 
								Vars:          []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
 | 
				
			||||||
			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`,
 | 
								Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
 | 
								SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
 | 
				
			||||||
@ -91,37 +87,19 @@ func TestExplainSQL(t *testing.T) {
 | 
				
			|||||||
			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
 | 
								SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
 | 
				
			||||||
			NumericRegexp: nil,
 | 
								NumericRegexp: nil,
 | 
				
			||||||
			Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es},
 | 
								Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es},
 | 
				
			||||||
			Result:        fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
 | 
								Result:        fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
 | 
								SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
 | 
				
			||||||
			NumericRegexp: nil,
 | 
								NumericRegexp: nil,
 | 
				
			||||||
			Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
 | 
								Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
 | 
				
			||||||
			Result:        fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
 | 
								Result:        fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
 | 
								SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
 | 
				
			||||||
			NumericRegexp: nil,
 | 
								NumericRegexp: nil,
 | 
				
			||||||
			Vars:          []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
 | 
								Vars:          []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
 | 
				
			||||||
			Result:        fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
 | 
								Result:        fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
 | 
					 | 
				
			||||||
			NumericRegexp: nil,
 | 
					 | 
				
			||||||
			Vars:          []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
 | 
					 | 
				
			||||||
			Result:        fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
 | 
					 | 
				
			||||||
			NumericRegexp: nil,
 | 
					 | 
				
			||||||
			Vars:          []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal},
 | 
					 | 
				
			||||||
			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
 | 
					 | 
				
			||||||
			NumericRegexp: nil,
 | 
					 | 
				
			||||||
			Vars:          []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal},
 | 
					 | 
				
			||||||
			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`,
 | 
					 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -87,8 +87,6 @@ type Migrator interface {
 | 
				
			|||||||
	DropColumn(dst interface{}, field string) error
 | 
						DropColumn(dst interface{}, field string) error
 | 
				
			||||||
	AlterColumn(dst interface{}, field string) error
 | 
						AlterColumn(dst interface{}, field string) error
 | 
				
			||||||
	MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
 | 
						MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
 | 
				
			||||||
	// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
 | 
					 | 
				
			||||||
	MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error
 | 
					 | 
				
			||||||
	HasColumn(dst interface{}, field string) bool
 | 
						HasColumn(dst interface{}, field string) bool
 | 
				
			||||||
	RenameColumn(dst interface{}, oldName, field string) error
 | 
						RenameColumn(dst interface{}, oldName, field string) error
 | 
				
			||||||
	ColumnTypes(dst interface{}) ([]ColumnType, error)
 | 
						ColumnTypes(dst interface{}) ([]ColumnType, error)
 | 
				
			||||||
 | 
				
			|||||||
@ -7,7 +7,6 @@ import (
 | 
				
			|||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
	"strconv"
 | 
					 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -28,8 +27,6 @@ var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// TODO:? Create const vars for raw sql queries ?
 | 
					// TODO:? Create const vars for raw sql queries ?
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var _ gorm.Migrator = (*Migrator)(nil)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Migrator m struct
 | 
					// Migrator m struct
 | 
				
			||||||
type Migrator struct {
 | 
					type Migrator struct {
 | 
				
			||||||
	Config
 | 
						Config
 | 
				
			||||||
@ -94,6 +91,10 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
 | 
				
			|||||||
		expr.SQL += " NOT NULL"
 | 
							expr.SQL += " NOT NULL"
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if field.Unique {
 | 
				
			||||||
 | 
							expr.SQL += " UNIQUE"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
 | 
						if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
 | 
				
			||||||
		if field.DefaultValueInterface != nil {
 | 
							if field.DefaultValueInterface != nil {
 | 
				
			||||||
			defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
 | 
								defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
 | 
				
			||||||
@ -107,31 +108,21 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) {
 | 
					 | 
				
			||||||
	queryTx = m.DB.Session(&gorm.Session{})
 | 
					 | 
				
			||||||
	execTx = queryTx
 | 
					 | 
				
			||||||
	if m.DB.DryRun {
 | 
					 | 
				
			||||||
		queryTx.DryRun = false
 | 
					 | 
				
			||||||
		execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return queryTx, execTx
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// AutoMigrate auto migrate values
 | 
					// AutoMigrate auto migrate values
 | 
				
			||||||
func (m Migrator) AutoMigrate(values ...interface{}) error {
 | 
					func (m Migrator) AutoMigrate(values ...interface{}) error {
 | 
				
			||||||
	for _, value := range m.ReorderModels(values, true) {
 | 
						for _, value := range m.ReorderModels(values, true) {
 | 
				
			||||||
		queryTx, execTx := m.GetQueryAndExecTx()
 | 
							queryTx := m.DB.Session(&gorm.Session{})
 | 
				
			||||||
 | 
							execTx := queryTx
 | 
				
			||||||
 | 
							if m.DB.DryRun {
 | 
				
			||||||
 | 
								queryTx.DryRun = false
 | 
				
			||||||
 | 
								execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
		if !queryTx.Migrator().HasTable(value) {
 | 
							if !queryTx.Migrator().HasTable(value) {
 | 
				
			||||||
			if err := execTx.Migrator().CreateTable(value); err != nil {
 | 
								if err := execTx.Migrator().CreateTable(value); err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
								if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
 | 
					 | 
				
			||||||
				if stmt.Schema == nil {
 | 
					 | 
				
			||||||
					return errors.New("failed to get schema")
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				columnTypes, err := queryTx.Migrator().ColumnTypes(value)
 | 
									columnTypes, err := queryTx.Migrator().ColumnTypes(value)
 | 
				
			||||||
				if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
					return err
 | 
										return err
 | 
				
			||||||
@ -216,11 +207,6 @@ func (m Migrator) CreateTable(values ...interface{}) error {
 | 
				
			|||||||
	for _, value := range m.ReorderModels(values, false) {
 | 
						for _, value := range m.ReorderModels(values, false) {
 | 
				
			||||||
		tx := m.DB.Session(&gorm.Session{})
 | 
							tx := m.DB.Session(&gorm.Session{})
 | 
				
			||||||
		if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
 | 
							if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
 | 
				
			||||||
 | 
					 | 
				
			||||||
			if stmt.Schema == nil {
 | 
					 | 
				
			||||||
				return errors.New("failed to get schema")
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			var (
 | 
								var (
 | 
				
			||||||
				createTableSQL          = "CREATE TABLE ? ("
 | 
									createTableSQL          = "CREATE TABLE ? ("
 | 
				
			||||||
				values                  = []interface{}{m.CurrentTable(stmt)}
 | 
									values                  = []interface{}{m.CurrentTable(stmt)}
 | 
				
			||||||
@ -231,7 +217,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
 | 
				
			|||||||
				field := stmt.Schema.FieldsByDBName[dbName]
 | 
									field := stmt.Schema.FieldsByDBName[dbName]
 | 
				
			||||||
				if !field.IgnoreMigration {
 | 
									if !field.IgnoreMigration {
 | 
				
			||||||
					createTableSQL += "? ?"
 | 
										createTableSQL += "? ?"
 | 
				
			||||||
					hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY")
 | 
										hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
 | 
				
			||||||
					values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
 | 
										values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
 | 
				
			||||||
					createTableSQL += ","
 | 
										createTableSQL += ","
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
@ -280,7 +266,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
 | 
				
			|||||||
					}
 | 
										}
 | 
				
			||||||
					if constraint := rel.ParseConstraint(); constraint != nil {
 | 
										if constraint := rel.ParseConstraint(); constraint != nil {
 | 
				
			||||||
						if constraint.Schema == stmt.Schema {
 | 
											if constraint.Schema == stmt.Schema {
 | 
				
			||||||
							sql, vars := constraint.Build()
 | 
												sql, vars := buildConstraint(constraint)
 | 
				
			||||||
							createTableSQL += sql + ","
 | 
												createTableSQL += sql + ","
 | 
				
			||||||
							values = append(values, vars...)
 | 
												values = append(values, vars...)
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
@ -288,11 +274,6 @@ func (m Migrator) CreateTable(values ...interface{}) error {
 | 
				
			|||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			for _, uni := range stmt.Schema.ParseUniqueConstraints() {
 | 
					 | 
				
			||||||
				createTableSQL += "CONSTRAINT ? UNIQUE (?),"
 | 
					 | 
				
			||||||
				values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)})
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			for _, chk := range stmt.Schema.ParseCheckConstraints() {
 | 
								for _, chk := range stmt.Schema.ParseCheckConstraints() {
 | 
				
			||||||
				createTableSQL += "CONSTRAINT ? CHECK (?),"
 | 
									createTableSQL += "CONSTRAINT ? CHECK (?),"
 | 
				
			||||||
				values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
 | 
									values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
 | 
				
			||||||
@ -373,9 +354,6 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
 | 
				
			|||||||
func (m Migrator) AddColumn(value interface{}, name string) error {
 | 
					func (m Migrator) AddColumn(value interface{}, name string) error {
 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		// avoid using the same name field
 | 
							// avoid using the same name field
 | 
				
			||||||
		if stmt.Schema == nil {
 | 
					 | 
				
			||||||
			return errors.New("failed to get schema")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		f := stmt.Schema.LookUpField(name)
 | 
							f := stmt.Schema.LookUpField(name)
 | 
				
			||||||
		if f == nil {
 | 
							if f == nil {
 | 
				
			||||||
			return fmt.Errorf("failed to look up field with name: %s", name)
 | 
								return fmt.Errorf("failed to look up field with name: %s", name)
 | 
				
			||||||
@ -395,10 +373,8 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
 | 
				
			|||||||
// DropColumn drop value's `name` column
 | 
					// DropColumn drop value's `name` column
 | 
				
			||||||
func (m Migrator) DropColumn(value interface{}, name string) error {
 | 
					func (m Migrator) DropColumn(value interface{}, name string) error {
 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		if stmt.Schema != nil {
 | 
							if field := stmt.Schema.LookUpField(name); field != nil {
 | 
				
			||||||
			if field := stmt.Schema.LookUpField(name); field != nil {
 | 
								name = field.DBName
 | 
				
			||||||
				name = field.DBName
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return m.DB.Exec(
 | 
							return m.DB.Exec(
 | 
				
			||||||
@ -410,15 +386,13 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
 | 
				
			|||||||
// AlterColumn alter value's `field` column' type based on schema definition
 | 
					// AlterColumn alter value's `field` column' type based on schema definition
 | 
				
			||||||
func (m Migrator) AlterColumn(value interface{}, field string) error {
 | 
					func (m Migrator) AlterColumn(value interface{}, field string) error {
 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		if stmt.Schema != nil {
 | 
							if field := stmt.Schema.LookUpField(field); field != nil {
 | 
				
			||||||
			if field := stmt.Schema.LookUpField(field); field != nil {
 | 
								fileType := m.FullDataTypeOf(field)
 | 
				
			||||||
				fileType := m.FullDataTypeOf(field)
 | 
								return m.DB.Exec(
 | 
				
			||||||
				return m.DB.Exec(
 | 
									"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
 | 
				
			||||||
					"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
 | 
									m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
 | 
				
			||||||
					m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
 | 
								).Error
 | 
				
			||||||
				).Error
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return fmt.Errorf("failed to look up field with name: %s", field)
 | 
							return fmt.Errorf("failed to look up field with name: %s", field)
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
@ -430,10 +404,8 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
 | 
				
			|||||||
	m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		currentDatabase := m.DB.Migrator().CurrentDatabase()
 | 
							currentDatabase := m.DB.Migrator().CurrentDatabase()
 | 
				
			||||||
		name := field
 | 
							name := field
 | 
				
			||||||
		if stmt.Schema != nil {
 | 
							if field := stmt.Schema.LookUpField(field); field != nil {
 | 
				
			||||||
			if field := stmt.Schema.LookUpField(field); field != nil {
 | 
								name = field.DBName
 | 
				
			||||||
				name = field.DBName
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return m.DB.Raw(
 | 
							return m.DB.Raw(
 | 
				
			||||||
@ -448,14 +420,12 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
 | 
				
			|||||||
// RenameColumn rename value's field name from oldName to newName
 | 
					// RenameColumn rename value's field name from oldName to newName
 | 
				
			||||||
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
 | 
					func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		if stmt.Schema != nil {
 | 
							if field := stmt.Schema.LookUpField(oldName); field != nil {
 | 
				
			||||||
			if field := stmt.Schema.LookUpField(oldName); field != nil {
 | 
								oldName = field.DBName
 | 
				
			||||||
				oldName = field.DBName
 | 
							}
 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if field := stmt.Schema.LookUpField(newName); field != nil {
 | 
							if field := stmt.Schema.LookUpField(newName); field != nil {
 | 
				
			||||||
				newName = field.DBName
 | 
								newName = field.DBName
 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return m.DB.Exec(
 | 
							return m.DB.Exec(
 | 
				
			||||||
@ -467,13 +437,10 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// MigrateColumn migrate column
 | 
					// MigrateColumn migrate column
 | 
				
			||||||
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
 | 
					func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
 | 
				
			||||||
	if field.IgnoreMigration {
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// found, smart migrate
 | 
						// found, smart migrate
 | 
				
			||||||
	fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
 | 
						fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
 | 
				
			||||||
	realDataType := strings.ToLower(columnType.DatabaseTypeName())
 | 
						realDataType := strings.ToLower(columnType.DatabaseTypeName())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		alterColumn bool
 | 
							alterColumn bool
 | 
				
			||||||
		isSameType  = fullDataType == realDataType
 | 
							isSameType  = fullDataType == realDataType
 | 
				
			||||||
@ -512,19 +479,8 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
 | 
				
			|||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// check precision
 | 
							// check precision
 | 
				
			||||||
	if realDataType == "decimal" || realDataType == "numeric" &&
 | 
					 | 
				
			||||||
		regexp.MustCompile(realDataType+`\(.*\)`).FindString(fullDataType) != "" { // if realDataType has no precision,ignore
 | 
					 | 
				
			||||||
		precision, scale, ok := columnType.DecimalSize()
 | 
					 | 
				
			||||||
		if ok {
 | 
					 | 
				
			||||||
			if !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d,%d)", realDataType, precision, scale)) &&
 | 
					 | 
				
			||||||
				!strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d)", realDataType, precision)) {
 | 
					 | 
				
			||||||
				alterColumn = true
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
 | 
							if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
 | 
				
			||||||
			if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
 | 
								if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
 | 
				
			||||||
				alterColumn = true
 | 
									alterColumn = true
 | 
				
			||||||
@ -534,8 +490,16 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// check nullable
 | 
						// check nullable
 | 
				
			||||||
	if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
 | 
						if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
 | 
				
			||||||
		// not primary key & current database is non-nullable(to be nullable)
 | 
							// not primary key & database is nullable
 | 
				
			||||||
		if !field.PrimaryKey && !nullable {
 | 
							if !field.PrimaryKey && nullable {
 | 
				
			||||||
 | 
								alterColumn = true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// check unique
 | 
				
			||||||
 | 
						if unique, ok := columnType.Unique(); ok && unique != field.Unique {
 | 
				
			||||||
 | 
							// not primary key
 | 
				
			||||||
 | 
							if !field.PrimaryKey {
 | 
				
			||||||
			alterColumn = true
 | 
								alterColumn = true
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -550,18 +514,12 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
 | 
				
			|||||||
		} else if !dvNotNull && currentDefaultNotNull {
 | 
							} else if !dvNotNull && currentDefaultNotNull {
 | 
				
			||||||
			// null -> default value
 | 
								// null -> default value
 | 
				
			||||||
			alterColumn = true
 | 
								alterColumn = true
 | 
				
			||||||
		} else if currentDefaultNotNull || dvNotNull {
 | 
							} else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) ||
 | 
				
			||||||
			switch field.GORMDataType {
 | 
								(field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) {
 | 
				
			||||||
			case schema.Time:
 | 
								// default value not equal
 | 
				
			||||||
				if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) {
 | 
								// not both null
 | 
				
			||||||
					alterColumn = true
 | 
								if currentDefaultNotNull || dvNotNull {
 | 
				
			||||||
				}
 | 
									alterColumn = true
 | 
				
			||||||
			case schema.Bool:
 | 
					 | 
				
			||||||
				v1, _ := strconv.ParseBool(dv)
 | 
					 | 
				
			||||||
				v2, _ := strconv.ParseBool(field.DefaultValue)
 | 
					 | 
				
			||||||
				alterColumn = v1 != v2
 | 
					 | 
				
			||||||
			default:
 | 
					 | 
				
			||||||
				alterColumn = dv != field.DefaultValue
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -574,39 +532,13 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if alterColumn {
 | 
						if alterColumn && !field.IgnoreMigration {
 | 
				
			||||||
		if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil {
 | 
							return m.DB.Migrator().AlterColumn(value, field.DBName)
 | 
				
			||||||
			return err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
 | 
					 | 
				
			||||||
	unique, ok := columnType.Unique()
 | 
					 | 
				
			||||||
	if !ok || field.PrimaryKey {
 | 
					 | 
				
			||||||
		return nil // skip primary key
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	// By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex.
 | 
					 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
					 | 
				
			||||||
		// We're currently only receiving boolean values on `Unique` tag,
 | 
					 | 
				
			||||||
		// so the UniqueConstraint name is fixed
 | 
					 | 
				
			||||||
		constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
 | 
					 | 
				
			||||||
		if unique && !field.Unique {
 | 
					 | 
				
			||||||
			return m.DB.Migrator().DropConstraint(value, constraint)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if !unique && field.Unique {
 | 
					 | 
				
			||||||
			return m.DB.Migrator().CreateConstraint(value, constraint)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
 | 
					// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
 | 
				
			||||||
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
 | 
					func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
 | 
				
			||||||
	columnTypes := make([]gorm.ColumnType, 0)
 | 
						columnTypes := make([]gorm.ColumnType, 0)
 | 
				
			||||||
@ -676,36 +608,37 @@ func (m Migrator) DropView(name string) error {
 | 
				
			|||||||
	return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error
 | 
						return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// GuessConstraintAndTable guess statement's constraint and it's table based on name
 | 
					func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
 | 
				
			||||||
//
 | 
						sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
 | 
				
			||||||
// Deprecated: use GuessConstraintInterfaceAndTable instead.
 | 
						if constraint.OnDelete != "" {
 | 
				
			||||||
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) {
 | 
							sql += " ON DELETE " + constraint.OnDelete
 | 
				
			||||||
	constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
 | 
					 | 
				
			||||||
	switch c := constraint.(type) {
 | 
					 | 
				
			||||||
	case *schema.Constraint:
 | 
					 | 
				
			||||||
		return c, nil, table
 | 
					 | 
				
			||||||
	case *schema.CheckConstraint:
 | 
					 | 
				
			||||||
		return nil, c, table
 | 
					 | 
				
			||||||
	default:
 | 
					 | 
				
			||||||
		return nil, nil, table
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if constraint.OnUpdate != "" {
 | 
				
			||||||
 | 
							sql += " ON UPDATE " + constraint.OnUpdate
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var foreignKeys, references []interface{}
 | 
				
			||||||
 | 
						for _, field := range constraint.ForeignKeys {
 | 
				
			||||||
 | 
							foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, field := range constraint.References {
 | 
				
			||||||
 | 
							references = append(references, clause.Column{Name: field.DBName})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name
 | 
					// GuessConstraintAndTable guess statement's constraint and it's table based on name
 | 
				
			||||||
// nolint:cyclop
 | 
					func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
 | 
				
			||||||
func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) {
 | 
					 | 
				
			||||||
	if stmt.Schema == nil {
 | 
						if stmt.Schema == nil {
 | 
				
			||||||
		return nil, stmt.Table
 | 
							return nil, nil, stmt.Table
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	checkConstraints := stmt.Schema.ParseCheckConstraints()
 | 
						checkConstraints := stmt.Schema.ParseCheckConstraints()
 | 
				
			||||||
	if chk, ok := checkConstraints[name]; ok {
 | 
						if chk, ok := checkConstraints[name]; ok {
 | 
				
			||||||
		return &chk, stmt.Table
 | 
							return nil, &chk, stmt.Table
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	uniqueConstraints := stmt.Schema.ParseUniqueConstraints()
 | 
					 | 
				
			||||||
	if uni, ok := uniqueConstraints[name]; ok {
 | 
					 | 
				
			||||||
		return &uni, stmt.Table
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	getTable := func(rel *schema.Relationship) string {
 | 
						getTable := func(rel *schema.Relationship) string {
 | 
				
			||||||
@ -720,7 +653,7 @@ func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name st
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	for _, rel := range stmt.Schema.Relationships.Relations {
 | 
						for _, rel := range stmt.Schema.Relationships.Relations {
 | 
				
			||||||
		if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
 | 
							if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
 | 
				
			||||||
			return constraint, getTable(rel)
 | 
								return constraint, nil, getTable(rel)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -728,39 +661,40 @@ func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name st
 | 
				
			|||||||
		for k := range checkConstraints {
 | 
							for k := range checkConstraints {
 | 
				
			||||||
			if checkConstraints[k].Field == field {
 | 
								if checkConstraints[k].Field == field {
 | 
				
			||||||
				v := checkConstraints[k]
 | 
									v := checkConstraints[k]
 | 
				
			||||||
				return &v, stmt.Table
 | 
									return nil, &v, stmt.Table
 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		for k := range uniqueConstraints {
 | 
					 | 
				
			||||||
			if uniqueConstraints[k].Field == field {
 | 
					 | 
				
			||||||
				v := uniqueConstraints[k]
 | 
					 | 
				
			||||||
				return &v, stmt.Table
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		for _, rel := range stmt.Schema.Relationships.Relations {
 | 
							for _, rel := range stmt.Schema.Relationships.Relations {
 | 
				
			||||||
			if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field {
 | 
								if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field {
 | 
				
			||||||
				return constraint, getTable(rel)
 | 
									return constraint, nil, getTable(rel)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil, stmt.Schema.Table
 | 
						return nil, nil, stmt.Schema.Table
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CreateConstraint create constraint
 | 
					// CreateConstraint create constraint
 | 
				
			||||||
func (m Migrator) CreateConstraint(value interface{}, name string) error {
 | 
					func (m Migrator) CreateConstraint(value interface{}, name string) error {
 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
 | 
							constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
 | 
				
			||||||
 | 
							if chk != nil {
 | 
				
			||||||
 | 
								return m.DB.Exec(
 | 
				
			||||||
 | 
									"ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)",
 | 
				
			||||||
 | 
									m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
 | 
				
			||||||
 | 
								).Error
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if constraint != nil {
 | 
							if constraint != nil {
 | 
				
			||||||
			vars := []interface{}{clause.Table{Name: table}}
 | 
								vars := []interface{}{clause.Table{Name: table}}
 | 
				
			||||||
			if stmt.TableExpr != nil {
 | 
								if stmt.TableExpr != nil {
 | 
				
			||||||
				vars[0] = stmt.TableExpr
 | 
									vars[0] = stmt.TableExpr
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			sql, values := constraint.Build()
 | 
								sql, values := buildConstraint(constraint)
 | 
				
			||||||
			return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
 | 
								return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -768,9 +702,11 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
 | 
				
			|||||||
// DropConstraint drop constraint
 | 
					// DropConstraint drop constraint
 | 
				
			||||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
 | 
					func (m Migrator) DropConstraint(value interface{}, name string) error {
 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
 | 
							constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
 | 
				
			||||||
		if constraint != nil {
 | 
							if constraint != nil {
 | 
				
			||||||
			name = constraint.GetName()
 | 
								name = constraint.Name
 | 
				
			||||||
 | 
							} else if chk != nil {
 | 
				
			||||||
 | 
								name = chk.Name
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
 | 
							return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
@ -781,9 +717,11 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
 | 
				
			|||||||
	var count int64
 | 
						var count int64
 | 
				
			||||||
	m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		currentDatabase := m.DB.Migrator().CurrentDatabase()
 | 
							currentDatabase := m.DB.Migrator().CurrentDatabase()
 | 
				
			||||||
		constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
 | 
							constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
 | 
				
			||||||
		if constraint != nil {
 | 
							if constraint != nil {
 | 
				
			||||||
			name = constraint.GetName()
 | 
								name = constraint.Name
 | 
				
			||||||
 | 
							} else if chk != nil {
 | 
				
			||||||
 | 
								name = chk.Name
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return m.DB.Raw(
 | 
							return m.DB.Raw(
 | 
				
			||||||
@ -825,9 +763,6 @@ type BuildIndexOptionsInterface interface {
 | 
				
			|||||||
// CreateIndex create index `name`
 | 
					// CreateIndex create index `name`
 | 
				
			||||||
func (m Migrator) CreateIndex(value interface{}, name string) error {
 | 
					func (m Migrator) CreateIndex(value interface{}, name string) error {
 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		if stmt.Schema == nil {
 | 
					 | 
				
			||||||
			return errors.New("failed to get schema")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if idx := stmt.Schema.LookIndex(name); idx != nil {
 | 
							if idx := stmt.Schema.LookIndex(name); idx != nil {
 | 
				
			||||||
			opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
 | 
								opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
 | 
				
			||||||
			values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
 | 
								values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
 | 
				
			||||||
@ -860,10 +795,8 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
 | 
				
			|||||||
// DropIndex drop index `name`
 | 
					// DropIndex drop index `name`
 | 
				
			||||||
func (m Migrator) DropIndex(value interface{}, name string) error {
 | 
					func (m Migrator) DropIndex(value interface{}, name string) error {
 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		if stmt.Schema != nil {
 | 
							if idx := stmt.Schema.LookIndex(name); idx != nil {
 | 
				
			||||||
			if idx := stmt.Schema.LookIndex(name); idx != nil {
 | 
								name = idx.Name
 | 
				
			||||||
				name = idx.Name
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
 | 
							return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
 | 
				
			||||||
@ -875,10 +808,8 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
 | 
				
			|||||||
	var count int64
 | 
						var count int64
 | 
				
			||||||
	m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		currentDatabase := m.DB.Migrator().CurrentDatabase()
 | 
							currentDatabase := m.DB.Migrator().CurrentDatabase()
 | 
				
			||||||
		if stmt.Schema != nil {
 | 
							if idx := stmt.Schema.LookIndex(name); idx != nil {
 | 
				
			||||||
			if idx := stmt.Schema.LookIndex(name); idx != nil {
 | 
								name = idx.Name
 | 
				
			||||||
				name = idx.Name
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return m.DB.Raw(
 | 
							return m.DB.Raw(
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										160
									
								
								prepare_stmt.go
									
									
									
									
									
								
							
							
						
						
									
										160
									
								
								prepare_stmt.go
									
									
									
									
									
								
							@ -3,39 +3,33 @@ package gorm
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"database/sql"
 | 
						"database/sql"
 | 
				
			||||||
	"database/sql/driver"
 | 
					 | 
				
			||||||
	"errors"
 | 
					 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"gorm.io/gorm/internal/stmt_store"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Stmt struct {
 | 
				
			||||||
 | 
						*sql.Stmt
 | 
				
			||||||
 | 
						Transaction bool
 | 
				
			||||||
 | 
						prepared    chan struct{}
 | 
				
			||||||
 | 
						prepareErr  error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type PreparedStmtDB struct {
 | 
					type PreparedStmtDB struct {
 | 
				
			||||||
	Stmts stmt_store.Store
 | 
						Stmts       map[string]*Stmt
 | 
				
			||||||
	Mux   *sync.RWMutex
 | 
						PreparedSQL []string
 | 
				
			||||||
 | 
						Mux         *sync.RWMutex
 | 
				
			||||||
	ConnPool
 | 
						ConnPool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewPreparedStmtDB creates and initializes a new instance of PreparedStmtDB.
 | 
					func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// Parameters:
 | 
					 | 
				
			||||||
// - connPool: A connection pool that implements the ConnPool interface, used for managing database connections.
 | 
					 | 
				
			||||||
// - maxSize: The maximum number of prepared statements that can be stored in the statement store.
 | 
					 | 
				
			||||||
// - ttl: The time-to-live duration for each prepared statement in the store. Statements older than this duration will be automatically removed.
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// Returns:
 | 
					 | 
				
			||||||
// - A pointer to a PreparedStmtDB instance, which manages prepared statements using the provided connection pool and configuration.
 | 
					 | 
				
			||||||
func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB {
 | 
					 | 
				
			||||||
	return &PreparedStmtDB{
 | 
						return &PreparedStmtDB{
 | 
				
			||||||
		ConnPool: connPool,                     // Assigns the provided connection pool to manage database connections.
 | 
							ConnPool:    connPool,
 | 
				
			||||||
		Stmts:    stmt_store.New(maxSize, ttl), // Initializes a new statement store with the specified maximum size and TTL.
 | 
							Stmts:       make(map[string]*Stmt),
 | 
				
			||||||
		Mux:      &sync.RWMutex{},              // Sets up a read-write mutex for synchronizing access to the statement store.
 | 
							Mux:         &sync.RWMutex{},
 | 
				
			||||||
 | 
							PreparedSQL: make([]string, 0, 100),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// GetDBConn returns the underlying *sql.DB connection
 | 
					 | 
				
			||||||
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
 | 
					func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
 | 
				
			||||||
	if sqldb, ok := db.ConnPool.(*sql.DB); ok {
 | 
						if sqldb, ok := db.ConnPool.(*sql.DB); ok {
 | 
				
			||||||
		return sqldb, nil
 | 
							return sqldb, nil
 | 
				
			||||||
@ -48,41 +42,84 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
 | 
				
			|||||||
	return nil, ErrInvalidDB
 | 
						return nil, ErrInvalidDB
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Close closes all prepared statements in the store
 | 
					 | 
				
			||||||
func (db *PreparedStmtDB) Close() {
 | 
					func (db *PreparedStmtDB) Close() {
 | 
				
			||||||
	db.Mux.Lock()
 | 
						db.Mux.Lock()
 | 
				
			||||||
	defer db.Mux.Unlock()
 | 
						defer db.Mux.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, key := range db.Stmts.Keys() {
 | 
						for _, query := range db.PreparedSQL {
 | 
				
			||||||
		db.Stmts.Delete(key)
 | 
							if stmt, ok := db.Stmts[query]; ok {
 | 
				
			||||||
 | 
								delete(db.Stmts, query)
 | 
				
			||||||
 | 
								go stmt.Close()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Reset Deprecated use Close instead
 | 
					func (sdb *PreparedStmtDB) Reset() {
 | 
				
			||||||
func (db *PreparedStmtDB) Reset() {
 | 
						sdb.Mux.Lock()
 | 
				
			||||||
	db.Close()
 | 
						defer sdb.Mux.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, stmt := range sdb.Stmts {
 | 
				
			||||||
 | 
							go stmt.Close()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						sdb.PreparedSQL = make([]string, 0, 100)
 | 
				
			||||||
 | 
						sdb.Stmts = make(map[string]*Stmt)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) {
 | 
					func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
 | 
				
			||||||
	db.Mux.RLock()
 | 
						db.Mux.RLock()
 | 
				
			||||||
	if db.Stmts != nil {
 | 
						if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
 | 
				
			||||||
		if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
 | 
							db.Mux.RUnlock()
 | 
				
			||||||
			db.Mux.RUnlock()
 | 
							// wait for other goroutines prepared
 | 
				
			||||||
			return stmt, stmt.Error()
 | 
							<-stmt.prepared
 | 
				
			||||||
 | 
							if stmt.prepareErr != nil {
 | 
				
			||||||
 | 
								return Stmt{}, stmt.prepareErr
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return *stmt, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	db.Mux.RUnlock()
 | 
						db.Mux.RUnlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// retry
 | 
					 | 
				
			||||||
	db.Mux.Lock()
 | 
						db.Mux.Lock()
 | 
				
			||||||
	if db.Stmts != nil {
 | 
						// double check
 | 
				
			||||||
		if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
 | 
						if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
 | 
				
			||||||
			db.Mux.Unlock()
 | 
							db.Mux.Unlock()
 | 
				
			||||||
			return stmt, stmt.Error()
 | 
							// wait for other goroutines prepared
 | 
				
			||||||
 | 
							<-stmt.prepared
 | 
				
			||||||
 | 
							if stmt.prepareErr != nil {
 | 
				
			||||||
 | 
								return Stmt{}, stmt.prepareErr
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return *stmt, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux)
 | 
						// cache preparing stmt first
 | 
				
			||||||
 | 
						cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
 | 
				
			||||||
 | 
						db.Stmts[query] = &cacheStmt
 | 
				
			||||||
 | 
						db.Mux.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// prepare completed
 | 
				
			||||||
 | 
						defer close(cacheStmt.prepared)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Reason why cannot lock conn.PrepareContext
 | 
				
			||||||
 | 
						// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
 | 
				
			||||||
 | 
						// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
 | 
				
			||||||
 | 
						// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
 | 
				
			||||||
 | 
						// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
 | 
				
			||||||
 | 
						stmt, err := conn.PrepareContext(ctx, query)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							cacheStmt.prepareErr = err
 | 
				
			||||||
 | 
							db.Mux.Lock()
 | 
				
			||||||
 | 
							delete(db.Stmts, query)
 | 
				
			||||||
 | 
							db.Mux.Unlock()
 | 
				
			||||||
 | 
							return Stmt{}, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						db.Mux.Lock()
 | 
				
			||||||
 | 
						cacheStmt.Stmt = stmt
 | 
				
			||||||
 | 
						db.PreparedSQL = append(db.PreparedSQL, query)
 | 
				
			||||||
 | 
						db.Mux.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return cacheStmt, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
 | 
					func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
 | 
				
			||||||
@ -110,8 +147,11 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
 | 
				
			|||||||
	stmt, err := db.prepare(ctx, db.ConnPool, false, query)
 | 
						stmt, err := db.prepare(ctx, db.ConnPool, false, query)
 | 
				
			||||||
	if err == nil {
 | 
						if err == nil {
 | 
				
			||||||
		result, err = stmt.ExecContext(ctx, args...)
 | 
							result, err = stmt.ExecContext(ctx, args...)
 | 
				
			||||||
		if errors.Is(err, driver.ErrBadConn) {
 | 
							if err != nil {
 | 
				
			||||||
			db.Stmts.Delete(query)
 | 
								db.Mux.Lock()
 | 
				
			||||||
 | 
								defer db.Mux.Unlock()
 | 
				
			||||||
 | 
								go stmt.Close()
 | 
				
			||||||
 | 
								delete(db.Stmts, query)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return result, err
 | 
						return result, err
 | 
				
			||||||
@ -121,8 +161,12 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
 | 
				
			|||||||
	stmt, err := db.prepare(ctx, db.ConnPool, false, query)
 | 
						stmt, err := db.prepare(ctx, db.ConnPool, false, query)
 | 
				
			||||||
	if err == nil {
 | 
						if err == nil {
 | 
				
			||||||
		rows, err = stmt.QueryContext(ctx, args...)
 | 
							rows, err = stmt.QueryContext(ctx, args...)
 | 
				
			||||||
		if errors.Is(err, driver.ErrBadConn) {
 | 
							if err != nil {
 | 
				
			||||||
			db.Stmts.Delete(query)
 | 
								db.Mux.Lock()
 | 
				
			||||||
 | 
								defer db.Mux.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								go stmt.Close()
 | 
				
			||||||
 | 
								delete(db.Stmts, query)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return rows, err
 | 
						return rows, err
 | 
				
			||||||
@ -136,14 +180,6 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
 | 
				
			|||||||
	return &sql.Row{}
 | 
						return &sql.Row{}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (db *PreparedStmtDB) Ping() error {
 | 
					 | 
				
			||||||
	conn, err := db.GetDBConn()
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return conn.Ping()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type PreparedStmtTX struct {
 | 
					type PreparedStmtTX struct {
 | 
				
			||||||
	Tx
 | 
						Tx
 | 
				
			||||||
	PreparedStmtDB *PreparedStmtDB
 | 
						PreparedStmtDB *PreparedStmtDB
 | 
				
			||||||
@ -171,8 +207,12 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
 | 
				
			|||||||
	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
 | 
						stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
 | 
				
			||||||
	if err == nil {
 | 
						if err == nil {
 | 
				
			||||||
		result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
 | 
							result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
 | 
				
			||||||
		if errors.Is(err, driver.ErrBadConn) {
 | 
							if err != nil {
 | 
				
			||||||
			tx.PreparedStmtDB.Stmts.Delete(query)
 | 
								tx.PreparedStmtDB.Mux.Lock()
 | 
				
			||||||
 | 
								defer tx.PreparedStmtDB.Mux.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								go stmt.Close()
 | 
				
			||||||
 | 
								delete(tx.PreparedStmtDB.Stmts, query)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return result, err
 | 
						return result, err
 | 
				
			||||||
@ -182,8 +222,12 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
 | 
				
			|||||||
	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
 | 
						stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
 | 
				
			||||||
	if err == nil {
 | 
						if err == nil {
 | 
				
			||||||
		rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
 | 
							rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
 | 
				
			||||||
		if errors.Is(err, driver.ErrBadConn) {
 | 
							if err != nil {
 | 
				
			||||||
			tx.PreparedStmtDB.Stmts.Delete(query)
 | 
								tx.PreparedStmtDB.Mux.Lock()
 | 
				
			||||||
 | 
								defer tx.PreparedStmtDB.Mux.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								go stmt.Close()
 | 
				
			||||||
 | 
								delete(tx.PreparedStmtDB.Stmts, query)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return rows, err
 | 
						return rows, err
 | 
				
			||||||
@ -196,11 +240,3 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	return &sql.Row{}
 | 
						return &sql.Row{}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func (tx *PreparedStmtTX) Ping() error {
 | 
					 | 
				
			||||||
	conn, err := tx.GetDBConn()
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return conn.Ping()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										49
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										49
									
								
								scan.go
									
									
									
									
									
								
							@ -4,7 +4,6 @@ import (
 | 
				
			|||||||
	"database/sql"
 | 
						"database/sql"
 | 
				
			||||||
	"database/sql/driver"
 | 
						"database/sql/driver"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gorm.io/gorm/schema"
 | 
						"gorm.io/gorm/schema"
 | 
				
			||||||
@ -16,7 +15,7 @@ func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType,
 | 
				
			|||||||
	if db.Statement.Schema != nil {
 | 
						if db.Statement.Schema != nil {
 | 
				
			||||||
		for idx, name := range columns {
 | 
							for idx, name := range columns {
 | 
				
			||||||
			if field := db.Statement.Schema.LookUpField(name); field != nil {
 | 
								if field := db.Statement.Schema.LookUpField(name); field != nil {
 | 
				
			||||||
				values[idx] = reflect.New(reflect.PointerTo(field.FieldType)).Interface()
 | 
									values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			values[idx] = new(interface{})
 | 
								values[idx] = new(interface{})
 | 
				
			||||||
@ -24,7 +23,7 @@ func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType,
 | 
				
			|||||||
	} else if len(columnTypes) > 0 {
 | 
						} else if len(columnTypes) > 0 {
 | 
				
			||||||
		for idx, columnType := range columnTypes {
 | 
							for idx, columnType := range columnTypes {
 | 
				
			||||||
			if columnType.ScanType() != nil {
 | 
								if columnType.ScanType() != nil {
 | 
				
			||||||
				values[idx] = reflect.New(reflect.PointerTo(columnType.ScanType())).Interface()
 | 
									values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				values[idx] = new(interface{})
 | 
									values[idx] = new(interface{})
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@ -132,15 +131,6 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
 | 
				
			|||||||
		onConflictDonothing = mode&ScanOnConflictDoNothing != 0
 | 
							onConflictDonothing = mode&ScanOnConflictDoNothing != 0
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(db.Statement.ColumnMapping) > 0 {
 | 
					 | 
				
			||||||
		for i, column := range columns {
 | 
					 | 
				
			||||||
			v, ok := db.Statement.ColumnMapping[column]
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				columns[i] = v
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	db.RowsAffected = 0
 | 
						db.RowsAffected = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	switch dest := db.Statement.Dest.(type) {
 | 
						switch dest := db.Statement.Dest.(type) {
 | 
				
			||||||
@ -245,14 +235,6 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
 | 
				
			|||||||
							matchedFieldCount[column] = 1
 | 
												matchedFieldCount[column] = 1
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
 | 
										} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
 | 
				
			||||||
						aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1])
 | 
					 | 
				
			||||||
						for _, join := range db.Statement.Joins {
 | 
					 | 
				
			||||||
							if join.Alias == aliasName {
 | 
					 | 
				
			||||||
								names = append(strings.Split(join.Name, "."), names[len(names)-1])
 | 
					 | 
				
			||||||
								break
 | 
					 | 
				
			||||||
							}
 | 
					 | 
				
			||||||
						}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
						if rel, ok := sch.Relationships.Relations[names[0]]; ok {
 | 
											if rel, ok := sch.Relationships.Relations[names[0]]; ok {
 | 
				
			||||||
							subNameCount := len(names)
 | 
												subNameCount := len(names)
 | 
				
			||||||
							// nested relation fields
 | 
												// nested relation fields
 | 
				
			||||||
@ -262,7 +244,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
 | 
				
			|||||||
								rel = rel.FieldSchema.Relationships.Relations[name]
 | 
													rel = rel.FieldSchema.Relationships.Relations[name]
 | 
				
			||||||
								relFields = append(relFields, rel.Field)
 | 
													relFields = append(relFields, rel.Field)
 | 
				
			||||||
							}
 | 
												}
 | 
				
			||||||
							// latest name is raw dbname
 | 
												// lastest name is raw dbname
 | 
				
			||||||
							dbName := names[subNameCount-1]
 | 
												dbName := names[subNameCount-1]
 | 
				
			||||||
							if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
 | 
												if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
 | 
				
			||||||
								fields[idx] = field
 | 
													fields[idx] = field
 | 
				
			||||||
@ -275,11 +257,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
 | 
				
			|||||||
								continue
 | 
													continue
 | 
				
			||||||
							}
 | 
												}
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
						var val interface{}
 | 
											values[idx] = &sql.RawBytes{}
 | 
				
			||||||
						values[idx] = &val
 | 
					 | 
				
			||||||
					} else {
 | 
										} else {
 | 
				
			||||||
						var val interface{}
 | 
											values[idx] = &sql.RawBytes{}
 | 
				
			||||||
						values[idx] = &val
 | 
					 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@ -294,16 +274,12 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			if !update || reflectValue.Len() == 0 {
 | 
								if !update || reflectValue.Len() == 0 {
 | 
				
			||||||
				update = false
 | 
									update = false
 | 
				
			||||||
				if isArrayKind {
 | 
									// if the slice cap is externally initialized, the externally initialized slice is directly used here
 | 
				
			||||||
					db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
 | 
									if reflectValue.Cap() == 0 {
 | 
				
			||||||
				} else {
 | 
										db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
 | 
				
			||||||
					// if the slice cap is externally initialized, the externally initialized slice is directly used here
 | 
									} else if !isArrayKind {
 | 
				
			||||||
					if reflectValue.Cap() == 0 {
 | 
										reflectValue.SetLen(0)
 | 
				
			||||||
						db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
 | 
										db.Statement.ReflectValue.Set(reflectValue)
 | 
				
			||||||
					} else {
 | 
					 | 
				
			||||||
						reflectValue.SetLen(0)
 | 
					 | 
				
			||||||
						db.Statement.ReflectValue.Set(reflectValue)
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -349,9 +325,6 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		case reflect.Struct, reflect.Ptr:
 | 
							case reflect.Struct, reflect.Ptr:
 | 
				
			||||||
			if initialized || rows.Next() {
 | 
								if initialized || rows.Next() {
 | 
				
			||||||
				if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct {
 | 
					 | 
				
			||||||
					db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
 | 
									db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										35
									
								
								schema/check.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								schema/check.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,35 @@
 | 
				
			|||||||
 | 
					package schema
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"regexp"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// reg match english letters and midline
 | 
				
			||||||
 | 
					var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Check struct {
 | 
				
			||||||
 | 
						Name       string
 | 
				
			||||||
 | 
						Constraint string // length(phone) >= 10
 | 
				
			||||||
 | 
						*Field
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ParseCheckConstraints parse schema check constraints
 | 
				
			||||||
 | 
					func (schema *Schema) ParseCheckConstraints() map[string]Check {
 | 
				
			||||||
 | 
						checks := map[string]Check{}
 | 
				
			||||||
 | 
						for _, field := range schema.FieldsByDBName {
 | 
				
			||||||
 | 
							if chk := field.TagSettings["CHECK"]; chk != "" {
 | 
				
			||||||
 | 
								names := strings.Split(chk, ",")
 | 
				
			||||||
 | 
								if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
 | 
				
			||||||
 | 
									checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									if names[0] == "" {
 | 
				
			||||||
 | 
										chk = strings.Join(names[1:], ",")
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									name := schema.namer.CheckerName(schema.Table, field.DBName)
 | 
				
			||||||
 | 
									checks[name] = Check{Name: name, Constraint: chk, Field: field}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return checks
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -6,7 +6,6 @@ import (
 | 
				
			|||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gorm.io/gorm/schema"
 | 
						"gorm.io/gorm/schema"
 | 
				
			||||||
	"gorm.io/gorm/utils/tests"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type UserCheck struct {
 | 
					type UserCheck struct {
 | 
				
			||||||
@ -21,7 +20,7 @@ func TestParseCheck(t *testing.T) {
 | 
				
			|||||||
		t.Fatalf("failed to parse user check, got error %v", err)
 | 
							t.Fatalf("failed to parse user check, got error %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	results := map[string]schema.CheckConstraint{
 | 
						results := map[string]schema.Check{
 | 
				
			||||||
		"name_checker": {
 | 
							"name_checker": {
 | 
				
			||||||
			Name:       "name_checker",
 | 
								Name:       "name_checker",
 | 
				
			||||||
			Constraint: "name <> 'jinzhu'",
 | 
								Constraint: "name <> 'jinzhu'",
 | 
				
			||||||
@ -54,31 +53,3 @@ func TestParseCheck(t *testing.T) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestParseUniqueConstraints(t *testing.T) {
 | 
					 | 
				
			||||||
	type UserUnique struct {
 | 
					 | 
				
			||||||
		Name1 string `gorm:"unique"`
 | 
					 | 
				
			||||||
		Name2 string `gorm:"uniqueIndex"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	user, err := schema.Parse(&UserUnique{}, &sync.Map{}, schema.NamingStrategy{})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to parse user unique, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	constraints := user.ParseUniqueConstraints()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	results := map[string]schema.UniqueConstraint{
 | 
					 | 
				
			||||||
		"uni_user_uniques_name1": {
 | 
					 | 
				
			||||||
			Name:  "uni_user_uniques_name1",
 | 
					 | 
				
			||||||
			Field: &schema.Field{Name: "Name1", Unique: true},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	for k, result := range results {
 | 
					 | 
				
			||||||
		v, ok := constraints[k]
 | 
					 | 
				
			||||||
		if !ok {
 | 
					 | 
				
			||||||
			t.Errorf("Failed to found unique constraint %v from parsed constraints %+v", k, constraints)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		tests.AssertObjEqual(t, result, v, "Name")
 | 
					 | 
				
			||||||
		tests.AssertObjEqual(t, result.Field, v.Field, "Name", "Unique", "UniqueIndex")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@ -1,66 +0,0 @@
 | 
				
			|||||||
package schema
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"regexp"
 | 
					 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"gorm.io/gorm/clause"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// reg match english letters and midline
 | 
					 | 
				
			||||||
var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type CheckConstraint struct {
 | 
					 | 
				
			||||||
	Name       string
 | 
					 | 
				
			||||||
	Constraint string // length(phone) >= 10
 | 
					 | 
				
			||||||
	*Field
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (chk *CheckConstraint) GetName() string { return chk.Name }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (chk *CheckConstraint) Build() (sql string, vars []interface{}) {
 | 
					 | 
				
			||||||
	return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// ParseCheckConstraints parse schema check constraints
 | 
					 | 
				
			||||||
func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint {
 | 
					 | 
				
			||||||
	checks := map[string]CheckConstraint{}
 | 
					 | 
				
			||||||
	for _, field := range schema.FieldsByDBName {
 | 
					 | 
				
			||||||
		if chk := field.TagSettings["CHECK"]; chk != "" {
 | 
					 | 
				
			||||||
			names := strings.Split(chk, ",")
 | 
					 | 
				
			||||||
			if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
 | 
					 | 
				
			||||||
				checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				if names[0] == "" {
 | 
					 | 
				
			||||||
					chk = strings.Join(names[1:], ",")
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				name := schema.namer.CheckerName(schema.Table, field.DBName)
 | 
					 | 
				
			||||||
				checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return checks
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type UniqueConstraint struct {
 | 
					 | 
				
			||||||
	Name  string
 | 
					 | 
				
			||||||
	Field *Field
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (uni *UniqueConstraint) GetName() string { return uni.Name }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) {
 | 
					 | 
				
			||||||
	return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// ParseUniqueConstraints parse schema unique constraints
 | 
					 | 
				
			||||||
func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint {
 | 
					 | 
				
			||||||
	uniques := make(map[string]UniqueConstraint)
 | 
					 | 
				
			||||||
	for _, field := range schema.Fields {
 | 
					 | 
				
			||||||
		if field.Unique {
 | 
					 | 
				
			||||||
			name := schema.namer.UniqueName(schema.Table, field.DBName)
 | 
					 | 
				
			||||||
			uniques[name] = UniqueConstraint{Name: name, Field: field}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return uniques
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@ -49,14 +49,11 @@ const (
 | 
				
			|||||||
	Bytes  DataType = "bytes"
 | 
						Bytes  DataType = "bytes"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const DefaultAutoIncrementIncrement int64 = 1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Field is the representation of model schema's field
 | 
					// Field is the representation of model schema's field
 | 
				
			||||||
type Field struct {
 | 
					type Field struct {
 | 
				
			||||||
	Name                   string
 | 
						Name                   string
 | 
				
			||||||
	DBName                 string
 | 
						DBName                 string
 | 
				
			||||||
	BindNames              []string
 | 
						BindNames              []string
 | 
				
			||||||
	EmbeddedBindNames      []string
 | 
					 | 
				
			||||||
	DataType               DataType
 | 
						DataType               DataType
 | 
				
			||||||
	GORMDataType           DataType
 | 
						GORMDataType           DataType
 | 
				
			||||||
	PrimaryKey             bool
 | 
						PrimaryKey             bool
 | 
				
			||||||
@ -90,12 +87,6 @@ type Field struct {
 | 
				
			|||||||
	Set                    func(context.Context, reflect.Value, interface{}) error
 | 
						Set                    func(context.Context, reflect.Value, interface{}) error
 | 
				
			||||||
	Serializer             SerializerInterface
 | 
						Serializer             SerializerInterface
 | 
				
			||||||
	NewValuePool           FieldNewValuePool
 | 
						NewValuePool           FieldNewValuePool
 | 
				
			||||||
 | 
					 | 
				
			||||||
	// In some db (e.g. MySQL), Unique and UniqueIndex are indistinguishable.
 | 
					 | 
				
			||||||
	// When a column has a (not Mul) UniqueIndex, Migrator always reports its gorm.ColumnType is Unique.
 | 
					 | 
				
			||||||
	// It causes field unnecessarily migration.
 | 
					 | 
				
			||||||
	// Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique.
 | 
					 | 
				
			||||||
	UniqueIndex string
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (field *Field) BindName() string {
 | 
					func (field *Field) BindName() string {
 | 
				
			||||||
@ -113,7 +104,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 | 
				
			|||||||
		Name:                   fieldStruct.Name,
 | 
							Name:                   fieldStruct.Name,
 | 
				
			||||||
		DBName:                 tagSetting["COLUMN"],
 | 
							DBName:                 tagSetting["COLUMN"],
 | 
				
			||||||
		BindNames:              []string{fieldStruct.Name},
 | 
							BindNames:              []string{fieldStruct.Name},
 | 
				
			||||||
		EmbeddedBindNames:      []string{fieldStruct.Name},
 | 
					 | 
				
			||||||
		FieldType:              fieldStruct.Type,
 | 
							FieldType:              fieldStruct.Type,
 | 
				
			||||||
		IndirectFieldType:      fieldStruct.Type,
 | 
							IndirectFieldType:      fieldStruct.Type,
 | 
				
			||||||
		StructField:            fieldStruct,
 | 
							StructField:            fieldStruct,
 | 
				
			||||||
@ -129,7 +119,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 | 
				
			|||||||
		NotNull:                utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
 | 
							NotNull:                utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
 | 
				
			||||||
		Unique:                 utils.CheckTruth(tagSetting["UNIQUE"]),
 | 
							Unique:                 utils.CheckTruth(tagSetting["UNIQUE"]),
 | 
				
			||||||
		Comment:                tagSetting["COMMENT"],
 | 
							Comment:                tagSetting["COMMENT"],
 | 
				
			||||||
		AutoIncrementIncrement: DefaultAutoIncrementIncrement,
 | 
							AutoIncrementIncrement: 1,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for field.IndirectFieldType.Kind() == reflect.Ptr {
 | 
						for field.IndirectFieldType.Kind() == reflect.Ptr {
 | 
				
			||||||
@ -318,10 +308,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if val, ok := field.TagSettings["TYPE"]; ok {
 | 
						if val, ok := field.TagSettings["TYPE"]; ok {
 | 
				
			||||||
		lowerVal := DataType(strings.ToLower(val))
 | 
							switch DataType(strings.ToLower(val)) {
 | 
				
			||||||
		switch lowerVal {
 | 
					 | 
				
			||||||
		case Bool, Int, Uint, Float, String, Time, Bytes:
 | 
							case Bool, Int, Uint, Float, String, Time, Bytes:
 | 
				
			||||||
			field.DataType = lowerVal
 | 
								field.DataType = DataType(strings.ToLower(val))
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
			field.DataType = DataType(val)
 | 
								field.DataType = DataType(val)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -406,9 +395,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 | 
				
			|||||||
				ef.Schema = schema
 | 
									ef.Schema = schema
 | 
				
			||||||
				ef.OwnerSchema = field.EmbeddedSchema
 | 
									ef.OwnerSchema = field.EmbeddedSchema
 | 
				
			||||||
				ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
 | 
									ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
 | 
				
			||||||
				if _, ok := field.TagSettings["EMBEDDED"]; ok || !fieldStruct.Anonymous {
 | 
					 | 
				
			||||||
					ef.EmbeddedBindNames = append([]string{fieldStruct.Name}, ef.EmbeddedBindNames...)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				// index is negative means is pointer
 | 
									// index is negative means is pointer
 | 
				
			||||||
				if field.FieldType.Kind() == reflect.Struct {
 | 
									if field.FieldType.Kind() == reflect.Struct {
 | 
				
			||||||
					ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...)
 | 
										ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...)
 | 
				
			||||||
@ -448,30 +434,21 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// create valuer, setter when parse struct
 | 
					// create valuer, setter when parse struct
 | 
				
			||||||
func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
 | 
					func (field *Field) setupValuerAndSetter() {
 | 
				
			||||||
	// Setup NewValuePool
 | 
						// Setup NewValuePool
 | 
				
			||||||
	field.setupNewValuePool()
 | 
						field.setupNewValuePool()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// ValueOf returns field's value and if it is zero
 | 
						// ValueOf returns field's value and if it is zero
 | 
				
			||||||
	fieldIndex := field.StructField.Index[0]
 | 
						fieldIndex := field.StructField.Index[0]
 | 
				
			||||||
	switch {
 | 
						switch {
 | 
				
			||||||
	case len(field.StructField.Index) == 1 && fieldIndex >= 0:
 | 
						case len(field.StructField.Index) == 1 && fieldIndex > 0:
 | 
				
			||||||
		field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
 | 
							field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
 | 
				
			||||||
			v = reflect.Indirect(v)
 | 
								fieldValue := reflect.Indirect(value).Field(fieldIndex)
 | 
				
			||||||
			if v.Type() != modelType {
 | 
					 | 
				
			||||||
				fieldValue := v.FieldByName(field.Name)
 | 
					 | 
				
			||||||
				return fieldValue.Interface(), fieldValue.IsZero()
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			fieldValue := v.Field(fieldIndex)
 | 
					 | 
				
			||||||
			return fieldValue.Interface(), fieldValue.IsZero()
 | 
								return fieldValue.Interface(), fieldValue.IsZero()
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
 | 
							field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
 | 
				
			||||||
			v = reflect.Indirect(v)
 | 
								v = reflect.Indirect(v)
 | 
				
			||||||
			if v.Type() != modelType {
 | 
					 | 
				
			||||||
				fieldValue := v.FieldByName(field.Name)
 | 
					 | 
				
			||||||
				return fieldValue.Interface(), fieldValue.IsZero()
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			for _, fieldIdx := range field.StructField.Index {
 | 
								for _, fieldIdx := range field.StructField.Index {
 | 
				
			||||||
				if fieldIdx >= 0 {
 | 
									if fieldIdx >= 0 {
 | 
				
			||||||
					v = v.Field(fieldIdx)
 | 
										v = v.Field(fieldIdx)
 | 
				
			||||||
@ -513,20 +490,13 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// ReflectValueOf returns field's reflect value
 | 
						// ReflectValueOf returns field's reflect value
 | 
				
			||||||
	switch {
 | 
						switch {
 | 
				
			||||||
	case len(field.StructField.Index) == 1 && fieldIndex >= 0:
 | 
						case len(field.StructField.Index) == 1 && fieldIndex > 0:
 | 
				
			||||||
		field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
 | 
							field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
 | 
				
			||||||
			v = reflect.Indirect(v)
 | 
								return reflect.Indirect(value).Field(fieldIndex)
 | 
				
			||||||
			if v.Type() != modelType {
 | 
					 | 
				
			||||||
				return v.FieldByName(field.Name)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			return v.Field(fieldIndex)
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
 | 
							field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
 | 
				
			||||||
			v = reflect.Indirect(v)
 | 
								v = reflect.Indirect(v)
 | 
				
			||||||
			if v.Type() != modelType {
 | 
					 | 
				
			||||||
				return v.FieldByName(field.Name)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			for idx, fieldIdx := range field.StructField.Index {
 | 
								for idx, fieldIdx := range field.StructField.Index {
 | 
				
			||||||
				if fieldIdx >= 0 {
 | 
									if fieldIdx >= 0 {
 | 
				
			||||||
					v = v.Field(fieldIdx)
 | 
										v = v.Field(fieldIdx)
 | 
				
			||||||
@ -686,7 +656,7 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
 | 
				
			|||||||
				if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
 | 
									if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
 | 
				
			||||||
					field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
 | 
										field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
 | 
				
			||||||
				} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
 | 
									} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
 | 
				
			||||||
					field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
 | 
										field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
					field.ReflectValueOf(ctx, value).SetInt(data.Unix())
 | 
										field.ReflectValueOf(ctx, value).SetInt(data.Unix())
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
@ -695,7 +665,7 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
 | 
				
			|||||||
					if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
 | 
										if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
 | 
				
			||||||
						field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
 | 
											field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
 | 
				
			||||||
					} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
 | 
										} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
 | 
				
			||||||
						field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
 | 
											field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
 | 
				
			||||||
					} else {
 | 
										} else {
 | 
				
			||||||
						field.ReflectValueOf(ctx, value).SetInt(data.Unix())
 | 
											field.ReflectValueOf(ctx, value).SetInt(data.Unix())
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
@ -760,7 +730,7 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
 | 
				
			|||||||
				if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
 | 
									if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
 | 
				
			||||||
					field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
 | 
										field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
 | 
				
			||||||
				} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
 | 
									} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
 | 
				
			||||||
					field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixMilli()))
 | 
										field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6))
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
					field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
 | 
										field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
@ -1013,6 +983,6 @@ func (field *Field) setupNewValuePool() {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if field.NewValuePool == nil {
 | 
						if field.NewValuePool == nil {
 | 
				
			||||||
		field.NewValuePool = poolInitializer(reflect.PointerTo(field.IndirectFieldType))
 | 
							field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -13,8 +13,8 @@ type Index struct {
 | 
				
			|||||||
	Type    string // btree, hash, gist, spgist, gin, and brin
 | 
						Type    string // btree, hash, gist, spgist, gin, and brin
 | 
				
			||||||
	Where   string
 | 
						Where   string
 | 
				
			||||||
	Comment string
 | 
						Comment string
 | 
				
			||||||
	Option  string        // WITH PARSER parser_name
 | 
						Option  string // WITH PARSER parser_name
 | 
				
			||||||
	Fields  []IndexOption // Note: IndexOption's Field maybe the same
 | 
						Fields  []IndexOption
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type IndexOption struct {
 | 
					type IndexOption struct {
 | 
				
			||||||
@ -23,13 +23,12 @@ type IndexOption struct {
 | 
				
			|||||||
	Sort       string // DESC, ASC
 | 
						Sort       string // DESC, ASC
 | 
				
			||||||
	Collate    string
 | 
						Collate    string
 | 
				
			||||||
	Length     int
 | 
						Length     int
 | 
				
			||||||
	Priority   int
 | 
						priority   int
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ParseIndexes parse schema indexes
 | 
					// ParseIndexes parse schema indexes
 | 
				
			||||||
func (schema *Schema) ParseIndexes() []*Index {
 | 
					func (schema *Schema) ParseIndexes() map[string]Index {
 | 
				
			||||||
	indexesByName := map[string]*Index{}
 | 
						indexes := map[string]Index{}
 | 
				
			||||||
	indexes := []*Index{}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, field := range schema.Fields {
 | 
						for _, field := range schema.Fields {
 | 
				
			||||||
		if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
 | 
							if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
 | 
				
			||||||
@ -39,12 +38,7 @@ func (schema *Schema) ParseIndexes() []*Index {
 | 
				
			|||||||
				break
 | 
									break
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			for _, index := range fieldIndexes {
 | 
								for _, index := range fieldIndexes {
 | 
				
			||||||
				idx := indexesByName[index.Name]
 | 
									idx := indexes[index.Name]
 | 
				
			||||||
				if idx == nil {
 | 
					 | 
				
			||||||
					idx = &Index{Name: index.Name}
 | 
					 | 
				
			||||||
					indexesByName[index.Name] = idx
 | 
					 | 
				
			||||||
					indexes = append(indexes, idx)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				idx.Name = index.Name
 | 
									idx.Name = index.Name
 | 
				
			||||||
				if idx.Class == "" {
 | 
									if idx.Class == "" {
 | 
				
			||||||
					idx.Class = index.Class
 | 
										idx.Class = index.Class
 | 
				
			||||||
@ -64,14 +58,16 @@ func (schema *Schema) ParseIndexes() []*Index {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
				idx.Fields = append(idx.Fields, index.Fields...)
 | 
									idx.Fields = append(idx.Fields, index.Fields...)
 | 
				
			||||||
				sort.Slice(idx.Fields, func(i, j int) bool {
 | 
									sort.Slice(idx.Fields, func(i, j int) bool {
 | 
				
			||||||
					return idx.Fields[i].Priority < idx.Fields[j].Priority
 | 
										return idx.Fields[i].priority < idx.Fields[j].priority
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									indexes[index.Name] = idx
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	for _, index := range indexes {
 | 
						for _, index := range indexes {
 | 
				
			||||||
		if index.Class == "UNIQUE" && len(index.Fields) == 1 {
 | 
							if index.Class == "UNIQUE" && len(index.Fields) == 1 {
 | 
				
			||||||
			index.Fields[0].Field.UniqueIndex = index.Name
 | 
								index.Fields[0].Field.Unique = true
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return indexes
 | 
						return indexes
 | 
				
			||||||
@ -82,12 +78,12 @@ func (schema *Schema) LookIndex(name string) *Index {
 | 
				
			|||||||
		indexes := schema.ParseIndexes()
 | 
							indexes := schema.ParseIndexes()
 | 
				
			||||||
		for _, index := range indexes {
 | 
							for _, index := range indexes {
 | 
				
			||||||
			if index.Name == name {
 | 
								if index.Name == name {
 | 
				
			||||||
				return index
 | 
									return &index
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			for _, field := range index.Fields {
 | 
								for _, field := range index.Fields {
 | 
				
			||||||
				if field.Name == name {
 | 
									if field.Name == name {
 | 
				
			||||||
					return index
 | 
										return &index
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -105,7 +101,7 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) {
 | 
				
			|||||||
				var (
 | 
									var (
 | 
				
			||||||
					name       string
 | 
										name       string
 | 
				
			||||||
					tag        = strings.Join(v[1:], ":")
 | 
										tag        = strings.Join(v[1:], ":")
 | 
				
			||||||
					idx        = strings.IndexByte(tag, ',')
 | 
										idx        = strings.Index(tag, ",")
 | 
				
			||||||
					tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",")
 | 
										tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",")
 | 
				
			||||||
					settings   = ParseTagSetting(tagSetting, ",")
 | 
										settings   = ParseTagSetting(tagSetting, ",")
 | 
				
			||||||
					length, _  = strconv.Atoi(settings["LENGTH"])
 | 
										length, _  = strconv.Atoi(settings["LENGTH"])
 | 
				
			||||||
@ -115,14 +111,17 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) {
 | 
				
			|||||||
					idx = len(tag)
 | 
										idx = len(tag)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				name = tag[0:idx]
 | 
									if idx != -1 {
 | 
				
			||||||
 | 
										name = tag[0:idx]
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				if name == "" {
 | 
									if name == "" {
 | 
				
			||||||
					subName := field.Name
 | 
										subName := field.Name
 | 
				
			||||||
					const key = "COMPOSITE"
 | 
										const key = "COMPOSITE"
 | 
				
			||||||
					if composite, found := settings[key]; found {
 | 
										if composite, found := settings[key]; found {
 | 
				
			||||||
						if len(composite) == 0 || composite == key {
 | 
											if len(composite) == 0 || composite == key {
 | 
				
			||||||
							err = fmt.Errorf(
 | 
												err = fmt.Errorf(
 | 
				
			||||||
								"the composite tag of %s.%s cannot be empty",
 | 
													"The composite tag of %s.%s cannot be empty",
 | 
				
			||||||
								field.Schema.Name,
 | 
													field.Schema.Name,
 | 
				
			||||||
								field.Name)
 | 
													field.Name)
 | 
				
			||||||
							return
 | 
												return
 | 
				
			||||||
@ -155,7 +154,7 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) {
 | 
				
			|||||||
						Sort:       settings["SORT"],
 | 
											Sort:       settings["SORT"],
 | 
				
			||||||
						Collate:    settings["COLLATE"],
 | 
											Collate:    settings["COLLATE"],
 | 
				
			||||||
						Length:     length,
 | 
											Length:     length,
 | 
				
			||||||
						Priority:   priority,
 | 
											priority:   priority,
 | 
				
			||||||
					}},
 | 
										}},
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,11 +1,11 @@
 | 
				
			|||||||
package schema_test
 | 
					package schema_test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gorm.io/gorm/schema"
 | 
						"gorm.io/gorm/schema"
 | 
				
			||||||
	"gorm.io/gorm/utils/tests"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type UserIndex struct {
 | 
					type UserIndex struct {
 | 
				
			||||||
@ -19,10 +19,6 @@ type UserIndex struct {
 | 
				
			|||||||
	OID          int64  `gorm:"index:idx_id;index:idx_oid,unique"`
 | 
						OID          int64  `gorm:"index:idx_id;index:idx_oid,unique"`
 | 
				
			||||||
	MemberNumber string `gorm:"index:idx_id,priority:1"`
 | 
						MemberNumber string `gorm:"index:idx_id,priority:1"`
 | 
				
			||||||
	Name7        string `gorm:"index:type"`
 | 
						Name7        string `gorm:"index:type"`
 | 
				
			||||||
	Name8        string `gorm:"index:,length:10;index:,collate:utf8"`
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	CompName1 string `gorm:"index:,unique,composite:idx_compname_1,option:NULLS NOT DISTINCT;not null"`
 | 
					 | 
				
			||||||
	CompName2 string `gorm:"index:,composite:idx_compname_1"`
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Composite Index: Flattened structure.
 | 
						// Composite Index: Flattened structure.
 | 
				
			||||||
	Data0A string `gorm:"index:,composite:comp_id0"`
 | 
						Data0A string `gorm:"index:,composite:comp_id0"`
 | 
				
			||||||
@ -61,17 +57,17 @@ func TestParseIndex(t *testing.T) {
 | 
				
			|||||||
		t.Fatalf("failed to parse user index, got error %v", err)
 | 
							t.Fatalf("failed to parse user index, got error %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	results := []*schema.Index{
 | 
						results := map[string]schema.Index{
 | 
				
			||||||
		{
 | 
							"idx_user_indices_name": {
 | 
				
			||||||
			Name:   "idx_user_indices_name",
 | 
								Name:   "idx_user_indices_name",
 | 
				
			||||||
			Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name"}}},
 | 
								Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name"}}},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							"idx_name": {
 | 
				
			||||||
			Name:   "idx_name",
 | 
								Name:   "idx_name",
 | 
				
			||||||
			Class:  "UNIQUE",
 | 
								Class:  "UNIQUE",
 | 
				
			||||||
			Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", UniqueIndex: "idx_name"}}},
 | 
								Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", Unique: true}}},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							"idx_user_indices_name3": {
 | 
				
			||||||
			Name:  "idx_user_indices_name3",
 | 
								Name:  "idx_user_indices_name3",
 | 
				
			||||||
			Type:  "btree",
 | 
								Type:  "btree",
 | 
				
			||||||
			Where: "name3 != 'jinzhu'",
 | 
								Where: "name3 != 'jinzhu'",
 | 
				
			||||||
@ -82,19 +78,19 @@ func TestParseIndex(t *testing.T) {
 | 
				
			|||||||
				Length:  10,
 | 
									Length:  10,
 | 
				
			||||||
			}},
 | 
								}},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							"idx_user_indices_name4": {
 | 
				
			||||||
			Name:   "idx_user_indices_name4",
 | 
								Name:   "idx_user_indices_name4",
 | 
				
			||||||
			Class:  "UNIQUE",
 | 
								Class:  "UNIQUE",
 | 
				
			||||||
			Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", UniqueIndex: "idx_user_indices_name4"}}},
 | 
								Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", Unique: true}}},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							"idx_user_indices_name5": {
 | 
				
			||||||
			Name:    "idx_user_indices_name5",
 | 
								Name:    "idx_user_indices_name5",
 | 
				
			||||||
			Class:   "FULLTEXT",
 | 
								Class:   "FULLTEXT",
 | 
				
			||||||
			Comment: "hello , world",
 | 
								Comment: "hello , world",
 | 
				
			||||||
			Where:   "age > 10",
 | 
								Where:   "age > 10",
 | 
				
			||||||
			Fields:  []schema.IndexOption{{Field: &schema.Field{Name: "Name5"}}},
 | 
								Fields:  []schema.IndexOption{{Field: &schema.Field{Name: "Name5"}}},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							"profile": {
 | 
				
			||||||
			Name:    "profile",
 | 
								Name:    "profile",
 | 
				
			||||||
			Comment: "hello , world",
 | 
								Comment: "hello , world",
 | 
				
			||||||
			Where:   "age > 10",
 | 
								Where:   "age > 10",
 | 
				
			||||||
@ -104,39 +100,21 @@ func TestParseIndex(t *testing.T) {
 | 
				
			|||||||
				Expression: "ABS(age)",
 | 
									Expression: "ABS(age)",
 | 
				
			||||||
			}},
 | 
								}},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							"idx_id": {
 | 
				
			||||||
			Name:   "idx_id",
 | 
								Name:   "idx_id",
 | 
				
			||||||
			Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
 | 
								Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", Unique: true}}},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							"idx_oid": {
 | 
				
			||||||
			Name:   "idx_oid",
 | 
								Name:   "idx_oid",
 | 
				
			||||||
			Class:  "UNIQUE",
 | 
								Class:  "UNIQUE",
 | 
				
			||||||
			Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
 | 
								Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", Unique: true}}},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							"type": {
 | 
				
			||||||
			Name:   "type",
 | 
								Name:   "type",
 | 
				
			||||||
			Type:   "",
 | 
								Type:   "",
 | 
				
			||||||
			Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}},
 | 
								Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							"idx_user_indices_comp_id0": {
 | 
				
			||||||
			Name: "idx_user_indices_name8",
 | 
					 | 
				
			||||||
			Type: "",
 | 
					 | 
				
			||||||
			Fields: []schema.IndexOption{
 | 
					 | 
				
			||||||
				{Field: &schema.Field{Name: "Name8"}, Length: 10},
 | 
					 | 
				
			||||||
				// Note: Duplicate Columns
 | 
					 | 
				
			||||||
				{Field: &schema.Field{Name: "Name8"}, Collate: "utf8"},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			Class:  "UNIQUE",
 | 
					 | 
				
			||||||
			Name:   "idx_user_indices_idx_compname_1",
 | 
					 | 
				
			||||||
			Option: "NULLS NOT DISTINCT",
 | 
					 | 
				
			||||||
			Fields: []schema.IndexOption{
 | 
					 | 
				
			||||||
				{Field: &schema.Field{Name: "CompName1", NotNull: true}},
 | 
					 | 
				
			||||||
				{Field: &schema.Field{Name: "CompName2"}},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			Name: "idx_user_indices_comp_id0",
 | 
								Name: "idx_user_indices_comp_id0",
 | 
				
			||||||
			Type: "",
 | 
								Type: "",
 | 
				
			||||||
			Fields: []schema.IndexOption{{
 | 
								Fields: []schema.IndexOption{{
 | 
				
			||||||
@ -145,7 +123,7 @@ func TestParseIndex(t *testing.T) {
 | 
				
			|||||||
				Field: &schema.Field{Name: "Data0B"},
 | 
									Field: &schema.Field{Name: "Data0B"},
 | 
				
			||||||
			}},
 | 
								}},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							"idx_user_indices_comp_id1": {
 | 
				
			||||||
			Name: "idx_user_indices_comp_id1",
 | 
								Name: "idx_user_indices_comp_id1",
 | 
				
			||||||
			Fields: []schema.IndexOption{{
 | 
								Fields: []schema.IndexOption{{
 | 
				
			||||||
				Field: &schema.Field{Name: "Data1A"},
 | 
									Field: &schema.Field{Name: "Data1A"},
 | 
				
			||||||
@ -155,7 +133,7 @@ func TestParseIndex(t *testing.T) {
 | 
				
			|||||||
				Field: &schema.Field{Name: "Data1C"},
 | 
									Field: &schema.Field{Name: "Data1C"},
 | 
				
			||||||
			}},
 | 
								}},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							"idx_user_indices_comp_id2": {
 | 
				
			||||||
			Name:  "idx_user_indices_comp_id2",
 | 
								Name:  "idx_user_indices_comp_id2",
 | 
				
			||||||
			Class: "UNIQUE",
 | 
								Class: "UNIQUE",
 | 
				
			||||||
			Fields: []schema.IndexOption{{
 | 
								Fields: []schema.IndexOption{{
 | 
				
			||||||
@ -168,108 +146,40 @@ func TestParseIndex(t *testing.T) {
 | 
				
			|||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	CheckIndices(t, results, user.ParseIndexes())
 | 
						indices := user.ParseIndexes()
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) {
 | 
						for k, result := range results {
 | 
				
			||||||
	type IndexTest struct {
 | 
							v, ok := indices[k]
 | 
				
			||||||
		FieldA string `gorm:"unique;index"` // unique and index
 | 
							if !ok {
 | 
				
			||||||
		FieldB string `gorm:"unique"`       // unique
 | 
								t.Fatalf("Failed to found index %v from parsed indices %+v", k, indices)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		FieldC string `gorm:"index:,unique"`     // uniqueIndex
 | 
							for _, name := range []string{"Name", "Class", "Type", "Where", "Comment", "Option"} {
 | 
				
			||||||
		FieldD string `gorm:"uniqueIndex;index"` // uniqueIndex and index
 | 
								if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() {
 | 
				
			||||||
 | 
									t.Errorf(
 | 
				
			||||||
		FieldE1 string `gorm:"uniqueIndex:uniq_field_e1_e2"` // mul uniqueIndex
 | 
										"index %v %v should equal, expects %v, got %v",
 | 
				
			||||||
		FieldE2 string `gorm:"uniqueIndex:uniq_field_e1_e2"`
 | 
										k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(),
 | 
				
			||||||
 | 
									)
 | 
				
			||||||
		FieldF1 string `gorm:"uniqueIndex:uniq_field_f1_f2;index"` // mul uniqueIndex and index
 | 
					 | 
				
			||||||
		FieldF2 string `gorm:"uniqueIndex:uniq_field_f1_f2;"`
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		FieldG string `gorm:"unique;uniqueIndex"` // unique and uniqueIndex
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		FieldH1 string `gorm:"unique;uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
 | 
					 | 
				
			||||||
		FieldH2 string `gorm:"uniqueIndex:uniq_field_h1_h2"`        // unique and mul uniqueIndex
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	indexSchema, err := schema.Parse(&IndexTest{}, &sync.Map{}, schema.NamingStrategy{})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to parse user index, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	indices := indexSchema.ParseIndexes()
 | 
					 | 
				
			||||||
	expectedIndices := []*schema.Index{
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			Name:   "idx_index_tests_field_a",
 | 
					 | 
				
			||||||
			Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldA", Unique: true}}},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			Name:   "idx_index_tests_field_c",
 | 
					 | 
				
			||||||
			Class:  "UNIQUE",
 | 
					 | 
				
			||||||
			Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldC", UniqueIndex: "idx_index_tests_field_c"}}},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			Name:  "idx_index_tests_field_d",
 | 
					 | 
				
			||||||
			Class: "UNIQUE",
 | 
					 | 
				
			||||||
			Fields: []schema.IndexOption{
 | 
					 | 
				
			||||||
				{Field: &schema.Field{Name: "FieldD"}},
 | 
					 | 
				
			||||||
				// Note: Duplicate Columns
 | 
					 | 
				
			||||||
				{Field: &schema.Field{Name: "FieldD"}},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			Name:  "uniq_field_e1_e2",
 | 
					 | 
				
			||||||
			Class: "UNIQUE",
 | 
					 | 
				
			||||||
			Fields: []schema.IndexOption{
 | 
					 | 
				
			||||||
				{Field: &schema.Field{Name: "FieldE1"}},
 | 
					 | 
				
			||||||
				{Field: &schema.Field{Name: "FieldE2"}},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			Name:  "uniq_field_f1_f2",
 | 
					 | 
				
			||||||
			Class: "UNIQUE",
 | 
					 | 
				
			||||||
			Fields: []schema.IndexOption{
 | 
					 | 
				
			||||||
				{Field: &schema.Field{Name: "FieldF1"}},
 | 
					 | 
				
			||||||
				{Field: &schema.Field{Name: "FieldF2"}},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			Name:   "idx_index_tests_field_f1",
 | 
					 | 
				
			||||||
			Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldF1"}}},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			Name:   "idx_index_tests_field_g",
 | 
					 | 
				
			||||||
			Class:  "UNIQUE",
 | 
					 | 
				
			||||||
			Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldG", Unique: true, UniqueIndex: "idx_index_tests_field_g"}}},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			Name:  "uniq_field_h1_h2",
 | 
					 | 
				
			||||||
			Class: "UNIQUE",
 | 
					 | 
				
			||||||
			Fields: []schema.IndexOption{
 | 
					 | 
				
			||||||
				{Field: &schema.Field{Name: "FieldH1", Unique: true}},
 | 
					 | 
				
			||||||
				{Field: &schema.Field{Name: "FieldH2"}},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	CheckIndices(t, expectedIndices, indices)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func CheckIndices(t *testing.T, expected, actual []*schema.Index) {
 | 
					 | 
				
			||||||
	if len(expected) != len(actual) {
 | 
					 | 
				
			||||||
		t.Errorf("expected %d indices, but got %d", len(expected), len(actual))
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for i, ei := range expected {
 | 
					 | 
				
			||||||
		t.Run(ei.Name, func(t *testing.T) {
 | 
					 | 
				
			||||||
			ai := actual[i]
 | 
					 | 
				
			||||||
			tests.AssertObjEqual(t, ai, ei, "Name", "Class", "Type", "Where", "Comment", "Option")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if len(ei.Fields) != len(ai.Fields) {
 | 
					 | 
				
			||||||
				t.Errorf("expected index %q field length is %d but actual %d", ei.Name, len(ei.Fields), len(ai.Fields))
 | 
					 | 
				
			||||||
				return
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			for i, ef := range ei.Fields {
 | 
							}
 | 
				
			||||||
				af := ai.Fields[i]
 | 
					
 | 
				
			||||||
				tests.AssertObjEqual(t, af, ef, "Name", "Unique", "UniqueIndex", "Expression", "Sort", "Collate", "Length", "NotNull")
 | 
							for idx, ef := range result.Fields {
 | 
				
			||||||
 | 
								rf := v.Fields[idx]
 | 
				
			||||||
 | 
								if rf.Field.Name != ef.Field.Name {
 | 
				
			||||||
 | 
									t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		})
 | 
								if rf.Field.Unique != ef.Field.Unique {
 | 
				
			||||||
 | 
									t.Fatalf("index field '%s' should equal, expects %v, got %v", rf.Field.Name, rf.Field.Unique, ef.Field.Unique)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								for _, name := range []string{"Expression", "Sort", "Collate", "Length"} {
 | 
				
			||||||
 | 
									if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() {
 | 
				
			||||||
 | 
										t.Errorf(
 | 
				
			||||||
 | 
											"index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name,
 | 
				
			||||||
 | 
											reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface(),
 | 
				
			||||||
 | 
										)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -4,12 +4,6 @@ import (
 | 
				
			|||||||
	"gorm.io/gorm/clause"
 | 
						"gorm.io/gorm/clause"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ConstraintInterface database constraint interface
 | 
					 | 
				
			||||||
type ConstraintInterface interface {
 | 
					 | 
				
			||||||
	GetName() string
 | 
					 | 
				
			||||||
	Build() (sql string, vars []interface{})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// GormDataTypeInterface gorm data type interface
 | 
					// GormDataTypeInterface gorm data type interface
 | 
				
			||||||
type GormDataTypeInterface interface {
 | 
					type GormDataTypeInterface interface {
 | 
				
			||||||
	GormDataType() string
 | 
						GormDataType() string
 | 
				
			||||||
 | 
				
			|||||||
@ -8,8 +8,6 @@ import (
 | 
				
			|||||||
	"unicode/utf8"
 | 
						"unicode/utf8"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/jinzhu/inflection"
 | 
						"github.com/jinzhu/inflection"
 | 
				
			||||||
	"golang.org/x/text/cases"
 | 
					 | 
				
			||||||
	"golang.org/x/text/language"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Namer namer interface
 | 
					// Namer namer interface
 | 
				
			||||||
@ -21,7 +19,6 @@ type Namer interface {
 | 
				
			|||||||
	RelationshipFKName(Relationship) string
 | 
						RelationshipFKName(Relationship) string
 | 
				
			||||||
	CheckerName(table, column string) string
 | 
						CheckerName(table, column string) string
 | 
				
			||||||
	IndexName(table, column string) string
 | 
						IndexName(table, column string) string
 | 
				
			||||||
	UniqueName(table, column string) string
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Replacer replacer interface like strings.Replacer
 | 
					// Replacer replacer interface like strings.Replacer
 | 
				
			||||||
@ -29,8 +26,6 @@ type Replacer interface {
 | 
				
			|||||||
	Replace(name string) string
 | 
						Replace(name string) string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var _ Namer = (*NamingStrategy)(nil)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// NamingStrategy tables, columns naming strategy
 | 
					// NamingStrategy tables, columns naming strategy
 | 
				
			||||||
type NamingStrategy struct {
 | 
					type NamingStrategy struct {
 | 
				
			||||||
	TablePrefix         string
 | 
						TablePrefix         string
 | 
				
			||||||
@ -90,11 +85,6 @@ func (ns NamingStrategy) IndexName(table, column string) string {
 | 
				
			|||||||
	return ns.formatName("idx", table, ns.toDBName(column))
 | 
						return ns.formatName("idx", table, ns.toDBName(column))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// UniqueName generate unique constraint name
 | 
					 | 
				
			||||||
func (ns NamingStrategy) UniqueName(table, column string) string {
 | 
					 | 
				
			||||||
	return ns.formatName("uni", table, ns.toDBName(column))
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (ns NamingStrategy) formatName(prefix, table, name string) string {
 | 
					func (ns NamingStrategy) formatName(prefix, table, name string) string {
 | 
				
			||||||
	formattedName := strings.ReplaceAll(strings.Join([]string{
 | 
						formattedName := strings.ReplaceAll(strings.Join([]string{
 | 
				
			||||||
		prefix, table, name,
 | 
							prefix, table, name,
 | 
				
			||||||
@ -123,7 +113,7 @@ var (
 | 
				
			|||||||
func init() {
 | 
					func init() {
 | 
				
			||||||
	commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
 | 
						commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
 | 
				
			||||||
	for _, initialism := range commonInitialisms {
 | 
						for _, initialism := range commonInitialisms {
 | 
				
			||||||
		commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, cases.Title(language.Und).String(initialism))
 | 
							commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
 | 
						commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -188,9 +178,9 @@ func (ns NamingStrategy) toDBName(name string) string {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (ns NamingStrategy) toSchemaName(name string) string {
 | 
					func (ns NamingStrategy) toSchemaName(name string) string {
 | 
				
			||||||
	result := strings.ReplaceAll(cases.Title(language.Und, cases.NoLower).String(strings.ReplaceAll(name, "_", " ")), " ", "")
 | 
						result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "")
 | 
				
			||||||
	for _, initialism := range commonInitialisms {
 | 
						for _, initialism := range commonInitialisms {
 | 
				
			||||||
		result = regexp.MustCompile(cases.Title(language.Und, cases.NoLower).String(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
 | 
							result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return result
 | 
						return result
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -5,12 +5,8 @@ import (
 | 
				
			|||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"sync"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/jinzhu/inflection"
 | 
						"github.com/jinzhu/inflection"
 | 
				
			||||||
	"golang.org/x/text/cases"
 | 
					 | 
				
			||||||
	"golang.org/x/text/language"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"gorm.io/gorm/clause"
 | 
						"gorm.io/gorm/clause"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -33,8 +29,6 @@ type Relationships struct {
 | 
				
			|||||||
	Relations map[string]*Relationship
 | 
						Relations map[string]*Relationship
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	EmbeddedRelations map[string]*Relationships
 | 
						EmbeddedRelations map[string]*Relationships
 | 
				
			||||||
 | 
					 | 
				
			||||||
	Mux sync.RWMutex
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Relationship struct {
 | 
					type Relationship struct {
 | 
				
			||||||
@ -78,12 +72,12 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
 | 
				
			|||||||
	cacheStore := schema.cacheStore
 | 
						cacheStore := schema.cacheStore
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
 | 
						if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
 | 
				
			||||||
		schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err)
 | 
							schema.err = err
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if hasPolymorphicRelation(field.TagSettings) {
 | 
						if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
 | 
				
			||||||
		schema.buildPolymorphicRelation(relation, field)
 | 
							schema.buildPolymorphicRelation(relation, field, polymorphic)
 | 
				
			||||||
	} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
 | 
						} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
 | 
				
			||||||
		schema.buildMany2ManyRelation(relation, field, many2many)
 | 
							schema.buildMany2ManyRelation(relation, field, many2many)
 | 
				
			||||||
	} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
 | 
						} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
 | 
				
			||||||
@ -95,16 +89,14 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
 | 
				
			|||||||
		case reflect.Slice:
 | 
							case reflect.Slice:
 | 
				
			||||||
			schema.guessRelation(relation, field, guessHas)
 | 
								schema.guessRelation(relation, field, guessHas)
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
			schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema,
 | 
								schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name)
 | 
				
			||||||
				field.Name)
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if relation.Type == has {
 | 
						if relation.Type == has {
 | 
				
			||||||
 | 
							// don't add relations to embedded schema, which might be shared
 | 
				
			||||||
		if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil {
 | 
							if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil {
 | 
				
			||||||
			relation.FieldSchema.Relationships.Mux.Lock()
 | 
					 | 
				
			||||||
			relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation
 | 
								relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation
 | 
				
			||||||
			relation.FieldSchema.Relationships.Mux.Unlock()
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		switch field.IndirectFieldType.Kind() {
 | 
							switch field.IndirectFieldType.Kind() {
 | 
				
			||||||
@ -132,20 +124,6 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
 | 
				
			|||||||
	return relation
 | 
						return relation
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// hasPolymorphicRelation check if has polymorphic relation
 | 
					 | 
				
			||||||
// 1. `POLYMORPHIC` tag
 | 
					 | 
				
			||||||
// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
 | 
					 | 
				
			||||||
func hasPolymorphicRelation(tagSettings map[string]string) bool {
 | 
					 | 
				
			||||||
	if _, ok := tagSettings["POLYMORPHIC"]; ok {
 | 
					 | 
				
			||||||
		return true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	_, hasType := tagSettings["POLYMORPHICTYPE"]
 | 
					 | 
				
			||||||
	_, hasId := tagSettings["POLYMORPHICID"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return hasType && hasId
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (schema *Schema) setRelation(relation *Relationship) {
 | 
					func (schema *Schema) setRelation(relation *Relationship) {
 | 
				
			||||||
	// set non-embedded relation
 | 
						// set non-embedded relation
 | 
				
			||||||
	if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
 | 
						if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
 | 
				
			||||||
@ -157,12 +135,12 @@ func (schema *Schema) setRelation(relation *Relationship) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// set embedded relation
 | 
						// set embedded relation
 | 
				
			||||||
	if len(relation.Field.EmbeddedBindNames) <= 1 {
 | 
						if len(relation.Field.BindNames) <= 1 {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	relationships := &schema.Relationships
 | 
						relationships := &schema.Relationships
 | 
				
			||||||
	for i, name := range relation.Field.EmbeddedBindNames {
 | 
						for i, name := range relation.Field.BindNames {
 | 
				
			||||||
		if i < len(relation.Field.EmbeddedBindNames)-1 {
 | 
							if i < len(relation.Field.BindNames)-1 {
 | 
				
			||||||
			if relationships.EmbeddedRelations == nil {
 | 
								if relationships.EmbeddedRelations == nil {
 | 
				
			||||||
				relationships.EmbeddedRelations = map[string]*Relationships{}
 | 
									relationships.EmbeddedRelations = map[string]*Relationships{}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@ -191,41 +169,23 @@ func (schema *Schema) setRelation(relation *Relationship) {
 | 
				
			|||||||
//	  OwnerID   int
 | 
					//	  OwnerID   int
 | 
				
			||||||
//	  OwnerType string
 | 
					//	  OwnerType string
 | 
				
			||||||
//	}
 | 
					//	}
 | 
				
			||||||
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
 | 
					func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) {
 | 
				
			||||||
	polymorphic := field.TagSettings["POLYMORPHIC"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	relation.Polymorphic = &Polymorphic{
 | 
						relation.Polymorphic = &Polymorphic{
 | 
				
			||||||
		Value: schema.Table,
 | 
							Value:           schema.Table,
 | 
				
			||||||
 | 
							PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
 | 
				
			||||||
 | 
							PolymorphicID:   relation.FieldSchema.FieldsByName[polymorphic+"ID"],
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var (
 | 
					 | 
				
			||||||
		typeName = polymorphic + "Type"
 | 
					 | 
				
			||||||
		typeId   = polymorphic + "ID"
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
 | 
					 | 
				
			||||||
		typeName = strings.TrimSpace(value)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
 | 
					 | 
				
			||||||
		typeId = strings.TrimSpace(value)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
 | 
					 | 
				
			||||||
	relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
 | 
						if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
 | 
				
			||||||
		relation.Polymorphic.Value = strings.TrimSpace(value)
 | 
							relation.Polymorphic.Value = strings.TrimSpace(value)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if relation.Polymorphic.PolymorphicType == nil {
 | 
						if relation.Polymorphic.PolymorphicType == nil {
 | 
				
			||||||
		schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
 | 
							schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
 | 
				
			||||||
			relation.FieldSchema, schema, field.Name, polymorphic+"Type")
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if relation.Polymorphic.PolymorphicID == nil {
 | 
						if relation.Polymorphic.PolymorphicID == nil {
 | 
				
			||||||
		schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
 | 
							schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
 | 
				
			||||||
			relation.FieldSchema, schema, field.Name, polymorphic+"ID")
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if schema.err == nil {
 | 
						if schema.err == nil {
 | 
				
			||||||
@ -237,14 +197,12 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
 | 
				
			|||||||
		primaryKeyField := schema.PrioritizedPrimaryField
 | 
							primaryKeyField := schema.PrioritizedPrimaryField
 | 
				
			||||||
		if len(relation.foreignKeys) > 0 {
 | 
							if len(relation.foreignKeys) > 0 {
 | 
				
			||||||
			if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
 | 
								if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
 | 
				
			||||||
				schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys,
 | 
									schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name)
 | 
				
			||||||
					schema, field.Name)
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if primaryKeyField == nil {
 | 
							if primaryKeyField == nil {
 | 
				
			||||||
			schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
 | 
								schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name)
 | 
				
			||||||
				relation.FieldSchema, schema, field.Name)
 | 
					 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -308,9 +266,9 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for idx, ownField := range ownForeignFields {
 | 
						for idx, ownField := range ownForeignFields {
 | 
				
			||||||
		joinFieldName := cases.Title(language.Und, cases.NoLower).String(schema.Name) + ownField.Name
 | 
							joinFieldName := strings.Title(schema.Name) + ownField.Name
 | 
				
			||||||
		if len(joinForeignKeys) > idx {
 | 
							if len(joinForeignKeys) > idx {
 | 
				
			||||||
			joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinForeignKeys[idx])
 | 
								joinFieldName = strings.Title(joinForeignKeys[idx])
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		ownFieldsMap[joinFieldName] = ownField
 | 
							ownFieldsMap[joinFieldName] = ownField
 | 
				
			||||||
@ -325,7 +283,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for idx, relField := range refForeignFields {
 | 
						for idx, relField := range refForeignFields {
 | 
				
			||||||
		joinFieldName := cases.Title(language.Und, cases.NoLower).String(relation.FieldSchema.Name) + relField.Name
 | 
							joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if _, ok := ownFieldsMap[joinFieldName]; ok {
 | 
							if _, ok := ownFieldsMap[joinFieldName]; ok {
 | 
				
			||||||
			if field.Name != relation.FieldSchema.Name {
 | 
								if field.Name != relation.FieldSchema.Name {
 | 
				
			||||||
@ -336,7 +294,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if len(joinReferences) > idx {
 | 
							if len(joinReferences) > idx {
 | 
				
			||||||
			joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinReferences[idx])
 | 
								joinFieldName = strings.Title(joinReferences[idx])
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		referFieldsMap[joinFieldName] = relField
 | 
							referFieldsMap[joinFieldName] = relField
 | 
				
			||||||
@ -354,13 +312,12 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	joinTableFields = append(joinTableFields, reflect.StructField{
 | 
						joinTableFields = append(joinTableFields, reflect.StructField{
 | 
				
			||||||
		Name: cases.Title(language.Und, cases.NoLower).String(schema.Name) + field.Name,
 | 
							Name: strings.Title(schema.Name) + field.Name,
 | 
				
			||||||
		Type: schema.ModelType,
 | 
							Type: schema.ModelType,
 | 
				
			||||||
		Tag:  `gorm:"-"`,
 | 
							Tag:  `gorm:"-"`,
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
 | 
						if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
 | 
				
			||||||
		schema.namer); err != nil {
 | 
					 | 
				
			||||||
		schema.err = err
 | 
							schema.err = err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	relation.JoinTable.Name = many2many
 | 
						relation.JoinTable.Name = many2many
 | 
				
			||||||
@ -479,8 +436,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
 | 
				
			|||||||
			schema.guessRelation(relation, field, guessEmbeddedHas)
 | 
								schema.guessRelation(relation, field, guessEmbeddedHas)
 | 
				
			||||||
		// case guessEmbeddedHas:
 | 
							// case guessEmbeddedHas:
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
			schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface",
 | 
								schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name)
 | 
				
			||||||
				schema, field.Name)
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -536,9 +492,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			lookUpNames := []string{lookUpName}
 | 
								lookUpNames := []string{lookUpName}
 | 
				
			||||||
			if len(primaryFields) == 1 {
 | 
								if len(primaryFields) == 1 {
 | 
				
			||||||
				lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID",
 | 
									lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
 | 
				
			||||||
					strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
 | 
					 | 
				
			||||||
						strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			for _, name := range lookUpNames {
 | 
								for _, name := range lookUpNames {
 | 
				
			||||||
@ -612,7 +566,6 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Constraint is ForeignKey Constraint
 | 
					 | 
				
			||||||
type Constraint struct {
 | 
					type Constraint struct {
 | 
				
			||||||
	Name            string
 | 
						Name            string
 | 
				
			||||||
	Field           *Field
 | 
						Field           *Field
 | 
				
			||||||
@ -624,31 +577,6 @@ type Constraint struct {
 | 
				
			|||||||
	OnUpdate        string
 | 
						OnUpdate        string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (constraint *Constraint) GetName() string { return constraint.Name }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (constraint *Constraint) Build() (sql string, vars []interface{}) {
 | 
					 | 
				
			||||||
	sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
 | 
					 | 
				
			||||||
	if constraint.OnDelete != "" {
 | 
					 | 
				
			||||||
		sql += " ON DELETE " + constraint.OnDelete
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if constraint.OnUpdate != "" {
 | 
					 | 
				
			||||||
		sql += " ON UPDATE " + constraint.OnUpdate
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys))
 | 
					 | 
				
			||||||
	for _, field := range constraint.ForeignKeys {
 | 
					 | 
				
			||||||
		foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	references := make([]interface{}, 0, len(constraint.References))
 | 
					 | 
				
			||||||
	for _, field := range constraint.References {
 | 
					 | 
				
			||||||
		references = append(references, clause.Column{Name: field.DBName})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (rel *Relationship) ParseConstraint() *Constraint {
 | 
					func (rel *Relationship) ParseConstraint() *Constraint {
 | 
				
			||||||
	str := rel.Field.TagSettings["CONSTRAINT"]
 | 
						str := rel.Field.TagSettings["CONSTRAINT"]
 | 
				
			||||||
	if str == "-" {
 | 
						if str == "-" {
 | 
				
			||||||
@ -663,7 +591,6 @@ func (rel *Relationship) ParseConstraint() *Constraint {
 | 
				
			|||||||
					if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey &&
 | 
										if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey &&
 | 
				
			||||||
						rel.References[idx].PrimaryValue == ref.PrimaryValue) {
 | 
											rel.References[idx].PrimaryValue == ref.PrimaryValue) {
 | 
				
			||||||
						matched = false
 | 
											matched = false
 | 
				
			||||||
						break
 | 
					 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -676,7 +603,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		name     string
 | 
							name     string
 | 
				
			||||||
		idx      = strings.IndexByte(str, ',')
 | 
							idx      = strings.Index(str, ",")
 | 
				
			||||||
		settings = ParseTagSetting(str, ",")
 | 
							settings = ParseTagSetting(str, ",")
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -763,9 +690,8 @@ func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue ref
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func copyableDataType(str DataType) bool {
 | 
					func copyableDataType(str DataType) bool {
 | 
				
			||||||
	lowerStr := strings.ToLower(string(str))
 | 
					 | 
				
			||||||
	for _, s := range []string{"auto_increment", "primary key"} {
 | 
						for _, s := range []string{"auto_increment", "primary key"} {
 | 
				
			||||||
		if strings.Contains(lowerStr, s) {
 | 
							if strings.Contains(strings.ToLower(string(str)), s) {
 | 
				
			||||||
			return false
 | 
								return false
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -121,29 +121,6 @@ func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) {
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestBelongsToWithMixin(t *testing.T) {
 | 
					 | 
				
			||||||
	type Profile struct {
 | 
					 | 
				
			||||||
		gorm.Model
 | 
					 | 
				
			||||||
		Refer string
 | 
					 | 
				
			||||||
		Name  string
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type ProfileMixin struct {
 | 
					 | 
				
			||||||
		Profile      Profile `gorm:"References:Refer"`
 | 
					 | 
				
			||||||
		ProfileRefer int
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type User struct {
 | 
					 | 
				
			||||||
		gorm.Model
 | 
					 | 
				
			||||||
		ProfileMixin
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	checkStructRelation(t, &User{}, Relation{
 | 
					 | 
				
			||||||
		Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile",
 | 
					 | 
				
			||||||
		References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}},
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestHasOneOverrideForeignKey(t *testing.T) {
 | 
					func TestHasOneOverrideForeignKey(t *testing.T) {
 | 
				
			||||||
	type Profile struct {
 | 
						type Profile struct {
 | 
				
			||||||
		gorm.Model
 | 
							gorm.Model
 | 
				
			||||||
@ -600,193 +577,6 @@ func TestEmbeddedHas(t *testing.T) {
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestPolymorphic(t *testing.T) {
 | 
					 | 
				
			||||||
	t.Run("has one", func(t *testing.T) {
 | 
					 | 
				
			||||||
		type Toy struct {
 | 
					 | 
				
			||||||
			ID        int
 | 
					 | 
				
			||||||
			Name      string
 | 
					 | 
				
			||||||
			OwnerID   int
 | 
					 | 
				
			||||||
			OwnerType string
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		type Cat struct {
 | 
					 | 
				
			||||||
			ID   int
 | 
					 | 
				
			||||||
			Name string
 | 
					 | 
				
			||||||
			Toy  Toy `gorm:"polymorphic:Owner;"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("Failed to parse schema, got error %v", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
 | 
					 | 
				
			||||||
			"Cat": {
 | 
					 | 
				
			||||||
				Relations: map[string]Relation{
 | 
					 | 
				
			||||||
					"Toy": {
 | 
					 | 
				
			||||||
						Name:        "Toy",
 | 
					 | 
				
			||||||
						Type:        schema.HasOne,
 | 
					 | 
				
			||||||
						Schema:      "User",
 | 
					 | 
				
			||||||
						FieldSchema: "Toy",
 | 
					 | 
				
			||||||
						Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
 | 
					 | 
				
			||||||
						References: []Reference{
 | 
					 | 
				
			||||||
							{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
 | 
					 | 
				
			||||||
						},
 | 
					 | 
				
			||||||
					},
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	t.Run("has one with custom polymorphic type and id", func(t *testing.T) {
 | 
					 | 
				
			||||||
		type Toy struct {
 | 
					 | 
				
			||||||
			ID    int
 | 
					 | 
				
			||||||
			Name  string
 | 
					 | 
				
			||||||
			RefId int
 | 
					 | 
				
			||||||
			Type  string
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		type Cat struct {
 | 
					 | 
				
			||||||
			ID   int
 | 
					 | 
				
			||||||
			Name string
 | 
					 | 
				
			||||||
			Toy  Toy `gorm:"polymorphic:Owner;polymorphicType:Type;polymorphicId:RefId"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("Failed to parse schema, got error %v", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
 | 
					 | 
				
			||||||
			"Cat": {
 | 
					 | 
				
			||||||
				Relations: map[string]Relation{
 | 
					 | 
				
			||||||
					"Toy": {
 | 
					 | 
				
			||||||
						Name:        "Toy",
 | 
					 | 
				
			||||||
						Type:        schema.HasOne,
 | 
					 | 
				
			||||||
						Schema:      "User",
 | 
					 | 
				
			||||||
						FieldSchema: "Toy",
 | 
					 | 
				
			||||||
						Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
 | 
					 | 
				
			||||||
						References: []Reference{
 | 
					 | 
				
			||||||
							{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
 | 
					 | 
				
			||||||
						},
 | 
					 | 
				
			||||||
					},
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	t.Run("has one with only polymorphic type", func(t *testing.T) {
 | 
					 | 
				
			||||||
		type Toy struct {
 | 
					 | 
				
			||||||
			ID      int
 | 
					 | 
				
			||||||
			Name    string
 | 
					 | 
				
			||||||
			OwnerID int
 | 
					 | 
				
			||||||
			Type    string
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		type Cat struct {
 | 
					 | 
				
			||||||
			ID   int
 | 
					 | 
				
			||||||
			Name string
 | 
					 | 
				
			||||||
			Toy  Toy `gorm:"polymorphic:Owner;polymorphicType:Type"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("Failed to parse schema, got error %v", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
 | 
					 | 
				
			||||||
			"Cat": {
 | 
					 | 
				
			||||||
				Relations: map[string]Relation{
 | 
					 | 
				
			||||||
					"Toy": {
 | 
					 | 
				
			||||||
						Name:        "Toy",
 | 
					 | 
				
			||||||
						Type:        schema.HasOne,
 | 
					 | 
				
			||||||
						Schema:      "User",
 | 
					 | 
				
			||||||
						FieldSchema: "Toy",
 | 
					 | 
				
			||||||
						Polymorphic: Polymorphic{ID: "owner_id", Type: "Type", Value: "users"},
 | 
					 | 
				
			||||||
						References: []Reference{
 | 
					 | 
				
			||||||
							{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
 | 
					 | 
				
			||||||
						},
 | 
					 | 
				
			||||||
					},
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	t.Run("has many", func(t *testing.T) {
 | 
					 | 
				
			||||||
		type Toy struct {
 | 
					 | 
				
			||||||
			ID        int
 | 
					 | 
				
			||||||
			Name      string
 | 
					 | 
				
			||||||
			OwnerID   int
 | 
					 | 
				
			||||||
			OwnerType string
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		type Cat struct {
 | 
					 | 
				
			||||||
			ID   int
 | 
					 | 
				
			||||||
			Name string
 | 
					 | 
				
			||||||
			Toys []Toy `gorm:"polymorphic:Owner;"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("Failed to parse schema, got error %v", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
 | 
					 | 
				
			||||||
			"Cat": {
 | 
					 | 
				
			||||||
				Relations: map[string]Relation{
 | 
					 | 
				
			||||||
					"Toys": {
 | 
					 | 
				
			||||||
						Name:        "Toys",
 | 
					 | 
				
			||||||
						Type:        schema.HasMany,
 | 
					 | 
				
			||||||
						Schema:      "User",
 | 
					 | 
				
			||||||
						FieldSchema: "Toy",
 | 
					 | 
				
			||||||
						Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
 | 
					 | 
				
			||||||
						References: []Reference{
 | 
					 | 
				
			||||||
							{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
 | 
					 | 
				
			||||||
						},
 | 
					 | 
				
			||||||
					},
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	t.Run("has many with custom polymorphic type and id", func(t *testing.T) {
 | 
					 | 
				
			||||||
		type Toy struct {
 | 
					 | 
				
			||||||
			ID    int
 | 
					 | 
				
			||||||
			Name  string
 | 
					 | 
				
			||||||
			RefId int
 | 
					 | 
				
			||||||
			Type  string
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		type Cat struct {
 | 
					 | 
				
			||||||
			ID   int
 | 
					 | 
				
			||||||
			Name string
 | 
					 | 
				
			||||||
			Toys []Toy `gorm:"polymorphicType:Type;polymorphicId:RefId"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("Failed to parse schema, got error %v", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
 | 
					 | 
				
			||||||
			"Cat": {
 | 
					 | 
				
			||||||
				Relations: map[string]Relation{
 | 
					 | 
				
			||||||
					"Toys": {
 | 
					 | 
				
			||||||
						Name:        "Toys",
 | 
					 | 
				
			||||||
						Type:        schema.HasMany,
 | 
					 | 
				
			||||||
						Schema:      "User",
 | 
					 | 
				
			||||||
						FieldSchema: "Toy",
 | 
					 | 
				
			||||||
						Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
 | 
					 | 
				
			||||||
						References: []Reference{
 | 
					 | 
				
			||||||
							{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
 | 
					 | 
				
			||||||
						},
 | 
					 | 
				
			||||||
					},
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestEmbeddedBelongsTo(t *testing.T) {
 | 
					func TestEmbeddedBelongsTo(t *testing.T) {
 | 
				
			||||||
	type Country struct {
 | 
						type Country struct {
 | 
				
			||||||
		ID   int `gorm:"primaryKey"`
 | 
							ID   int `gorm:"primaryKey"`
 | 
				
			||||||
@ -799,10 +589,6 @@ func TestEmbeddedBelongsTo(t *testing.T) {
 | 
				
			|||||||
	type NestedAddress struct {
 | 
						type NestedAddress struct {
 | 
				
			||||||
		Address
 | 
							Address
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	type CountryMixin struct {
 | 
					 | 
				
			||||||
		CountryID int
 | 
					 | 
				
			||||||
		Country   Country
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	type Org struct {
 | 
						type Org struct {
 | 
				
			||||||
		ID              int
 | 
							ID              int
 | 
				
			||||||
		PostalAddress   Address `gorm:"embedded;embeddedPrefix:postal_address_"`
 | 
							PostalAddress   Address `gorm:"embedded;embeddedPrefix:postal_address_"`
 | 
				
			||||||
@ -813,7 +599,6 @@ func TestEmbeddedBelongsTo(t *testing.T) {
 | 
				
			|||||||
			Address
 | 
								Address
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
 | 
							NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
 | 
				
			||||||
		CountryMixin
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
 | 
						s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
 | 
				
			||||||
@ -843,11 +628,15 @@ func TestEmbeddedBelongsTo(t *testing.T) {
 | 
				
			|||||||
			},
 | 
								},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		"NestedAddress": {
 | 
							"NestedAddress": {
 | 
				
			||||||
			Relations: map[string]Relation{
 | 
								EmbeddedRelations: map[string]EmbeddedRelations{
 | 
				
			||||||
				"Country": {
 | 
									"Address": {
 | 
				
			||||||
					Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
 | 
										Relations: map[string]Relation{
 | 
				
			||||||
					References: []Reference{
 | 
											"Country": {
 | 
				
			||||||
						{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
 | 
												Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
 | 
				
			||||||
 | 
												References: []Reference{
 | 
				
			||||||
 | 
													{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
 | 
				
			||||||
 | 
												},
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
 | 
				
			|||||||
@ -5,7 +5,6 @@ import (
 | 
				
			|||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"go/ast"
 | 
						"go/ast"
 | 
				
			||||||
	"path"
 | 
					 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
@ -14,20 +13,6 @@ import (
 | 
				
			|||||||
	"gorm.io/gorm/logger"
 | 
						"gorm.io/gorm/logger"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type callbackType string
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
const (
 | 
					 | 
				
			||||||
	callbackTypeBeforeCreate callbackType = "BeforeCreate"
 | 
					 | 
				
			||||||
	callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
 | 
					 | 
				
			||||||
	callbackTypeAfterCreate  callbackType = "AfterCreate"
 | 
					 | 
				
			||||||
	callbackTypeAfterUpdate  callbackType = "AfterUpdate"
 | 
					 | 
				
			||||||
	callbackTypeBeforeSave   callbackType = "BeforeSave"
 | 
					 | 
				
			||||||
	callbackTypeAfterSave    callbackType = "AfterSave"
 | 
					 | 
				
			||||||
	callbackTypeBeforeDelete callbackType = "BeforeDelete"
 | 
					 | 
				
			||||||
	callbackTypeAfterDelete  callbackType = "AfterDelete"
 | 
					 | 
				
			||||||
	callbackTypeAfterFind    callbackType = "AfterFind"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// ErrUnsupportedDataType unsupported data type
 | 
					// ErrUnsupportedDataType unsupported data type
 | 
				
			||||||
var ErrUnsupportedDataType = errors.New("unsupported data type")
 | 
					var ErrUnsupportedDataType = errors.New("unsupported data type")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -68,10 +53,9 @@ func (schema Schema) String() string {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (schema Schema) MakeSlice() reflect.Value {
 | 
					func (schema Schema) MakeSlice() reflect.Value {
 | 
				
			||||||
	slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20)
 | 
						slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20)
 | 
				
			||||||
	results := reflect.New(slice.Type())
 | 
						results := reflect.New(slice.Type())
 | 
				
			||||||
	results.Elem().Set(slice)
 | 
						results.Elem().Set(slice)
 | 
				
			||||||
 | 
					 | 
				
			||||||
	return results
 | 
						return results
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -248,7 +232,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
 | 
				
			|||||||
			schema.FieldsByBindName[bindName] = field
 | 
								schema.FieldsByBindName[bindName] = field
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		field.setupValuerAndSetter(modelType)
 | 
							field.setupValuerAndSetter()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	prioritizedPrimaryField := schema.LookUpField("id")
 | 
						prioritizedPrimaryField := schema.LookUpField("id")
 | 
				
			||||||
@ -304,26 +288,14 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	callbackTypes := []callbackType{
 | 
						callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
 | 
				
			||||||
		callbackTypeBeforeCreate, callbackTypeAfterCreate,
 | 
						for _, name := range callbacks {
 | 
				
			||||||
		callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
 | 
							if methodValue := modelValue.MethodByName(name); methodValue.IsValid() {
 | 
				
			||||||
		callbackTypeBeforeSave, callbackTypeAfterSave,
 | 
					 | 
				
			||||||
		callbackTypeBeforeDelete, callbackTypeAfterDelete,
 | 
					 | 
				
			||||||
		callbackTypeAfterFind,
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	for _, cbName := range callbackTypes {
 | 
					 | 
				
			||||||
		if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
 | 
					 | 
				
			||||||
			switch methodValue.Type().String() {
 | 
								switch methodValue.Type().String() {
 | 
				
			||||||
			case "func(*gorm.DB) error":
 | 
								case "func(*gorm.DB) error": // TODO hack
 | 
				
			||||||
				expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath())
 | 
									reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
 | 
				
			||||||
				if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath {
 | 
					 | 
				
			||||||
					reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg)
 | 
					 | 
				
			||||||
					// PASS
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			default:
 | 
								default:
 | 
				
			||||||
				logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
 | 
									logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -345,7 +317,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
 | 
						if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
 | 
				
			||||||
		for _, field := range schema.Fields {
 | 
							for _, field := range schema.Fields {
 | 
				
			||||||
			if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) {
 | 
								if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
 | 
				
			||||||
				if schema.parseRelation(field); schema.err != nil {
 | 
									if schema.parseRelation(field); schema.err != nil {
 | 
				
			||||||
					return schema, schema.err
 | 
										return schema, schema.err
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
@ -377,39 +349,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
 | 
				
			|||||||
	return schema, schema.err
 | 
						return schema, schema.err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// This unrolling is needed to show to the compiler the exact set of methods
 | 
					 | 
				
			||||||
// that can be used on the modelType.
 | 
					 | 
				
			||||||
// Prior to go1.22 any use of MethodByName would cause the linker to
 | 
					 | 
				
			||||||
// abandon dead code elimination for the entire binary.
 | 
					 | 
				
			||||||
// As of go1.22 the compiler supports one special case of a string constant
 | 
					 | 
				
			||||||
// being passed to MethodByName. For enterprise customers or those building
 | 
					 | 
				
			||||||
// large binaries, this gives a significant reduction in binary size.
 | 
					 | 
				
			||||||
// https://github.com/golang/go/issues/62257
 | 
					 | 
				
			||||||
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
 | 
					 | 
				
			||||||
	switch cbType {
 | 
					 | 
				
			||||||
	case callbackTypeBeforeCreate:
 | 
					 | 
				
			||||||
		return modelType.MethodByName(string(callbackTypeBeforeCreate))
 | 
					 | 
				
			||||||
	case callbackTypeAfterCreate:
 | 
					 | 
				
			||||||
		return modelType.MethodByName(string(callbackTypeAfterCreate))
 | 
					 | 
				
			||||||
	case callbackTypeBeforeUpdate:
 | 
					 | 
				
			||||||
		return modelType.MethodByName(string(callbackTypeBeforeUpdate))
 | 
					 | 
				
			||||||
	case callbackTypeAfterUpdate:
 | 
					 | 
				
			||||||
		return modelType.MethodByName(string(callbackTypeAfterUpdate))
 | 
					 | 
				
			||||||
	case callbackTypeBeforeSave:
 | 
					 | 
				
			||||||
		return modelType.MethodByName(string(callbackTypeBeforeSave))
 | 
					 | 
				
			||||||
	case callbackTypeAfterSave:
 | 
					 | 
				
			||||||
		return modelType.MethodByName(string(callbackTypeAfterSave))
 | 
					 | 
				
			||||||
	case callbackTypeBeforeDelete:
 | 
					 | 
				
			||||||
		return modelType.MethodByName(string(callbackTypeBeforeDelete))
 | 
					 | 
				
			||||||
	case callbackTypeAfterDelete:
 | 
					 | 
				
			||||||
		return modelType.MethodByName(string(callbackTypeAfterDelete))
 | 
					 | 
				
			||||||
	case callbackTypeAfterFind:
 | 
					 | 
				
			||||||
		return modelType.MethodByName(string(callbackTypeAfterFind))
 | 
					 | 
				
			||||||
	default:
 | 
					 | 
				
			||||||
		return reflect.ValueOf(nil)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
 | 
					func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
 | 
				
			||||||
	modelType := reflect.ValueOf(dest).Type()
 | 
						modelType := reflect.ValueOf(dest).Type()
 | 
				
			||||||
	for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
 | 
						for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
 | 
				
			||||||
 | 
				
			|||||||
@ -163,8 +163,8 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
 | 
				
			|||||||
					t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table)
 | 
										t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				for i := range relation.JoinTable.Fields {
 | 
									for _, f := range relation.JoinTable.Fields {
 | 
				
			||||||
					checkSchemaField(t, r.JoinTable, &relation.JoinTable.Fields[i], nil)
 | 
										checkSchemaField(t, r.JoinTable, &f, nil)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -19,22 +19,6 @@ func TestParseSchema(t *testing.T) {
 | 
				
			|||||||
	checkUserSchema(t, user)
 | 
						checkUserSchema(t, user)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestParseSchemaWithMap(t *testing.T) {
 | 
					 | 
				
			||||||
	type User struct {
 | 
					 | 
				
			||||||
		tests.User
 | 
					 | 
				
			||||||
		Attrs map[string]string `gorm:"type:Map(String,String);"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to parse user with map, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if field := user.FieldsByName["Attrs"]; field.DataType != "Map(String,String)" {
 | 
					 | 
				
			||||||
		t.Errorf("failed to parse user field Attrs")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestParseSchemaWithPointerFields(t *testing.T) {
 | 
					func TestParseSchemaWithPointerFields(t *testing.T) {
 | 
				
			||||||
	user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
 | 
						user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@ -62,8 +46,8 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
 | 
				
			|||||||
		{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
 | 
							{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for i := range fields {
 | 
						for _, f := range fields {
 | 
				
			||||||
		checkSchemaField(t, user, &fields[i], func(f *schema.Field) {
 | 
							checkSchemaField(t, user, &f, func(f *schema.Field) {
 | 
				
			||||||
			f.Creatable = true
 | 
								f.Creatable = true
 | 
				
			||||||
			f.Updatable = true
 | 
								f.Updatable = true
 | 
				
			||||||
			f.Readable = true
 | 
								f.Readable = true
 | 
				
			||||||
@ -152,8 +136,8 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) {
 | 
				
			|||||||
		{Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool},
 | 
							{Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for i := range fields {
 | 
						for _, f := range fields {
 | 
				
			||||||
		checkSchemaField(t, user, &fields[i], func(f *schema.Field) {
 | 
							checkSchemaField(t, user, &f, func(f *schema.Field) {
 | 
				
			||||||
			f.Creatable = true
 | 
								f.Creatable = true
 | 
				
			||||||
			f.Updatable = true
 | 
								f.Updatable = true
 | 
				
			||||||
			f.Readable = true
 | 
								f.Readable = true
 | 
				
			||||||
 | 
				
			|||||||
@ -84,10 +84,7 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
 | 
				
			|||||||
		case string:
 | 
							case string:
 | 
				
			||||||
			bytes = []byte(v)
 | 
								bytes = []byte(v)
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
			bytes, err = json.Marshal(v)
 | 
								return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				return err
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if len(bytes) > 0 {
 | 
							if len(bytes) > 0 {
 | 
				
			||||||
@ -129,12 +126,12 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect
 | 
				
			|||||||
	rv := reflect.ValueOf(fieldValue)
 | 
						rv := reflect.ValueOf(fieldValue)
 | 
				
			||||||
	switch v := fieldValue.(type) {
 | 
						switch v := fieldValue.(type) {
 | 
				
			||||||
	case int64, int, uint, uint64, int32, uint32, int16, uint16:
 | 
						case int64, int, uint, uint64, int32, uint32, int16, uint16:
 | 
				
			||||||
		result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
 | 
							result = time.Unix(reflect.Indirect(rv).Int(), 0)
 | 
				
			||||||
	case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
 | 
						case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
 | 
				
			||||||
		if rv.IsZero() {
 | 
							if rv.IsZero() {
 | 
				
			||||||
			return nil, nil
 | 
								return nil, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
 | 
							result = time.Unix(reflect.Indirect(rv).Int(), 0)
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
 | 
							err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -71,7 +71,7 @@ func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag
 | 
				
			|||||||
// GetRelationsValues get relations's values from a reflect value
 | 
					// GetRelationsValues get relations's values from a reflect value
 | 
				
			||||||
func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
 | 
					func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
 | 
				
			||||||
	for _, rel := range rels {
 | 
						for _, rel := range rels {
 | 
				
			||||||
		reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.FieldSchema.ModelType)), 0, 1)
 | 
							reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		appendToResults := func(value reflect.Value) {
 | 
							appendToResults := func(value reflect.Value) {
 | 
				
			||||||
			if _, isZero := rel.Field.ValueOf(ctx, value); !isZero {
 | 
								if _, isZero := rel.Field.ValueOf(ctx, value); !isZero {
 | 
				
			||||||
@ -115,11 +115,6 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
 | 
				
			|||||||
		notZero, zero bool
 | 
							notZero, zero bool
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if reflectValue.Kind() == reflect.Ptr ||
 | 
					 | 
				
			||||||
		reflectValue.Kind() == reflect.Interface {
 | 
					 | 
				
			||||||
		reflectValue = reflectValue.Elem()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	switch reflectValue.Kind() {
 | 
						switch reflectValue.Kind() {
 | 
				
			||||||
	case reflect.Struct:
 | 
						case reflect.Struct:
 | 
				
			||||||
		results = [][]interface{}{make([]interface{}, len(fields))}
 | 
							results = [][]interface{}{make([]interface{}, len(fields))}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										111
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										111
									
								
								statement.go
									
									
									
									
									
								
							@ -30,9 +30,8 @@ type Statement struct {
 | 
				
			|||||||
	Clauses              map[string]clause.Clause
 | 
						Clauses              map[string]clause.Clause
 | 
				
			||||||
	BuildClauses         []string
 | 
						BuildClauses         []string
 | 
				
			||||||
	Distinct             bool
 | 
						Distinct             bool
 | 
				
			||||||
	Selects              []string          // selected columns
 | 
						Selects              []string // selected columns
 | 
				
			||||||
	Omits                []string          // omit columns
 | 
						Omits                []string // omit columns
 | 
				
			||||||
	ColumnMapping        map[string]string // map columns
 | 
					 | 
				
			||||||
	Joins                []join
 | 
						Joins                []join
 | 
				
			||||||
	Preloads             map[string][]interface{}
 | 
						Preloads             map[string][]interface{}
 | 
				
			||||||
	Settings             sync.Map
 | 
						Settings             sync.Map
 | 
				
			||||||
@ -47,18 +46,15 @@ type Statement struct {
 | 
				
			|||||||
	attrs                []interface{}
 | 
						attrs                []interface{}
 | 
				
			||||||
	assigns              []interface{}
 | 
						assigns              []interface{}
 | 
				
			||||||
	scopes               []func(*DB) *DB
 | 
						scopes               []func(*DB) *DB
 | 
				
			||||||
	Result               *result
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type join struct {
 | 
					type join struct {
 | 
				
			||||||
	Name       string
 | 
						Name     string
 | 
				
			||||||
	Alias      string
 | 
						Conds    []interface{}
 | 
				
			||||||
	Conds      []interface{}
 | 
						On       *clause.Where
 | 
				
			||||||
	On         *clause.Where
 | 
						Selects  []string
 | 
				
			||||||
	Selects    []string
 | 
						Omits    []string
 | 
				
			||||||
	Omits      []string
 | 
						JoinType clause.JoinType
 | 
				
			||||||
	Expression clause.Expression
 | 
					 | 
				
			||||||
	JoinType   clause.JoinType
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// StatementModifier statement modifier interface
 | 
					// StatementModifier statement modifier interface
 | 
				
			||||||
@ -208,21 +204,19 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
 | 
				
			|||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				writer.WriteString("(NULL)")
 | 
									writer.WriteString("(NULL)")
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		case interface{ getInstance() *DB }:
 | 
							case *DB:
 | 
				
			||||||
			cv := v.getInstance()
 | 
								subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
 | 
				
			||||||
 | 
								if v.Statement.SQL.Len() > 0 {
 | 
				
			||||||
			subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
 | 
					 | 
				
			||||||
			if cv.Statement.SQL.Len() > 0 {
 | 
					 | 
				
			||||||
				var (
 | 
									var (
 | 
				
			||||||
					vars = subdb.Statement.Vars
 | 
										vars = subdb.Statement.Vars
 | 
				
			||||||
					sql  = cv.Statement.SQL.String()
 | 
										sql  = v.Statement.SQL.String()
 | 
				
			||||||
				)
 | 
									)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				subdb.Statement.Vars = make([]interface{}, 0, len(vars))
 | 
									subdb.Statement.Vars = make([]interface{}, 0, len(vars))
 | 
				
			||||||
				for _, vv := range vars {
 | 
									for _, vv := range vars {
 | 
				
			||||||
					subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
 | 
										subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
 | 
				
			||||||
					bindvar := strings.Builder{}
 | 
										bindvar := strings.Builder{}
 | 
				
			||||||
					cv.BindVarTo(&bindvar, subdb.Statement, vv)
 | 
										v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
 | 
				
			||||||
					sql = strings.Replace(sql, bindvar.String(), "?", 1)
 | 
										sql = strings.Replace(sql, bindvar.String(), "?", 1)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -326,30 +320,27 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
 | 
				
			|||||||
			arg, _ = valuer.Value()
 | 
								arg, _ = valuer.Value()
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		curTable := stmt.Table
 | 
					 | 
				
			||||||
		if curTable == "" {
 | 
					 | 
				
			||||||
			curTable = clause.CurrentTable
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		switch v := arg.(type) {
 | 
							switch v := arg.(type) {
 | 
				
			||||||
		case clause.Expression:
 | 
							case clause.Expression:
 | 
				
			||||||
			conds = append(conds, v)
 | 
								conds = append(conds, v)
 | 
				
			||||||
		case *DB:
 | 
							case *DB:
 | 
				
			||||||
			v.executeScopes()
 | 
								v.executeScopes()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if cs, ok := v.Statement.Clauses["WHERE"]; ok {
 | 
								if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
 | 
				
			||||||
				if where, ok := cs.Expression.(clause.Where); ok {
 | 
									if where, ok := cs.Expression.(clause.Where); ok {
 | 
				
			||||||
					if len(where.Exprs) == 1 {
 | 
										if len(where.Exprs) == 1 {
 | 
				
			||||||
						if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
 | 
											if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
 | 
				
			||||||
							if len(orConds.Exprs) == 1 {
 | 
												where.Exprs[0] = clause.AndConditions(orConds)
 | 
				
			||||||
								where.Exprs[0] = clause.AndConditions(orConds)
 | 
					 | 
				
			||||||
							}
 | 
					 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
					conds = append(conds, clause.And(where.Exprs...))
 | 
										conds = append(conds, clause.And(where.Exprs...))
 | 
				
			||||||
				} else if cs.Expression != nil {
 | 
									} else {
 | 
				
			||||||
					conds = append(conds, cs.Expression)
 | 
										conds = append(conds, cs.Expression)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
									if v.Statement == stmt {
 | 
				
			||||||
 | 
										cs.Expression = nil
 | 
				
			||||||
 | 
										stmt.Statement.Clauses["WHERE"] = cs
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		case map[interface{}]interface{}:
 | 
							case map[interface{}]interface{}:
 | 
				
			||||||
			for i, j := range v {
 | 
								for i, j := range v {
 | 
				
			||||||
@ -363,11 +354,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
 | 
				
			|||||||
			sort.Strings(keys)
 | 
								sort.Strings(keys)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			for _, key := range keys {
 | 
								for _, key := range keys {
 | 
				
			||||||
				column := clause.Column{Name: key, Table: curTable}
 | 
									conds = append(conds, clause.Eq{Column: key, Value: v[key]})
 | 
				
			||||||
				if strings.Contains(key, ".") {
 | 
					 | 
				
			||||||
					column = clause.Column{Name: key}
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				conds = append(conds, clause.Eq{Column: column, Value: v[key]})
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		case map[string]interface{}:
 | 
							case map[string]interface{}:
 | 
				
			||||||
			keys := make([]string, 0, len(v))
 | 
								keys := make([]string, 0, len(v))
 | 
				
			||||||
@ -378,16 +365,12 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			for _, key := range keys {
 | 
								for _, key := range keys {
 | 
				
			||||||
				reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
 | 
									reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
 | 
				
			||||||
				column := clause.Column{Name: key, Table: curTable}
 | 
					 | 
				
			||||||
				if strings.Contains(key, ".") {
 | 
					 | 
				
			||||||
					column = clause.Column{Name: key}
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				switch reflectValue.Kind() {
 | 
									switch reflectValue.Kind() {
 | 
				
			||||||
				case reflect.Slice, reflect.Array:
 | 
									case reflect.Slice, reflect.Array:
 | 
				
			||||||
					if _, ok := v[key].(driver.Valuer); ok {
 | 
										if _, ok := v[key].(driver.Valuer); ok {
 | 
				
			||||||
						conds = append(conds, clause.Eq{Column: column, Value: v[key]})
 | 
											conds = append(conds, clause.Eq{Column: key, Value: v[key]})
 | 
				
			||||||
					} else if _, ok := v[key].(Valuer); ok {
 | 
										} else if _, ok := v[key].(Valuer); ok {
 | 
				
			||||||
						conds = append(conds, clause.Eq{Column: column, Value: v[key]})
 | 
											conds = append(conds, clause.Eq{Column: key, Value: v[key]})
 | 
				
			||||||
					} else {
 | 
										} else {
 | 
				
			||||||
						// optimize reflect value length
 | 
											// optimize reflect value length
 | 
				
			||||||
						valueLen := reflectValue.Len()
 | 
											valueLen := reflectValue.Len()
 | 
				
			||||||
@ -396,10 +379,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
 | 
				
			|||||||
							values[i] = reflectValue.Index(i).Interface()
 | 
												values[i] = reflectValue.Index(i).Interface()
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
						conds = append(conds, clause.IN{Column: column, Values: values})
 | 
											conds = append(conds, clause.IN{Column: key, Values: values})
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				default:
 | 
									default:
 | 
				
			||||||
					conds = append(conds, clause.Eq{Column: column, Value: v[key]})
 | 
										conds = append(conds, clause.Eq{Column: key, Value: v[key]})
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
@ -426,9 +409,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
 | 
				
			|||||||
						if selected || (!restricted && field.Readable) {
 | 
											if selected || (!restricted && field.Readable) {
 | 
				
			||||||
							if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
 | 
												if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
 | 
				
			||||||
								if field.DBName != "" {
 | 
													if field.DBName != "" {
 | 
				
			||||||
									conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
 | 
														conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
 | 
				
			||||||
								} else if field.DataType != "" {
 | 
													} else if field.DataType != "" {
 | 
				
			||||||
									conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
 | 
														conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
 | 
				
			||||||
								}
 | 
													}
 | 
				
			||||||
							}
 | 
												}
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
@ -440,9 +423,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
 | 
				
			|||||||
							if selected || (!restricted && field.Readable) {
 | 
												if selected || (!restricted && field.Readable) {
 | 
				
			||||||
								if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
 | 
													if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
 | 
				
			||||||
									if field.DBName != "" {
 | 
														if field.DBName != "" {
 | 
				
			||||||
										conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
 | 
															conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
 | 
				
			||||||
									} else if field.DataType != "" {
 | 
														} else if field.DataType != "" {
 | 
				
			||||||
										conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
 | 
															conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
 | 
				
			||||||
									}
 | 
														}
 | 
				
			||||||
								}
 | 
													}
 | 
				
			||||||
							}
 | 
												}
 | 
				
			||||||
@ -467,22 +450,18 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
 | 
				
			|||||||
						}
 | 
											}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
						if len(values) > 0 {
 | 
											if len(values) > 0 {
 | 
				
			||||||
							conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values})
 | 
												conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
 | 
				
			||||||
							return []clause.Expression{clause.And(conds...)}
 | 
					 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
						return nil
 | 
											return conds
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args})
 | 
									conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(conds) > 0 {
 | 
						return conds
 | 
				
			||||||
		return []clause.Expression{clause.And(conds...)}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Build build sql with clauses names
 | 
					// Build build sql with clauses names
 | 
				
			||||||
@ -534,14 +513,12 @@ func (stmt *Statement) clone() *Statement {
 | 
				
			|||||||
		Distinct:             stmt.Distinct,
 | 
							Distinct:             stmt.Distinct,
 | 
				
			||||||
		Selects:              stmt.Selects,
 | 
							Selects:              stmt.Selects,
 | 
				
			||||||
		Omits:                stmt.Omits,
 | 
							Omits:                stmt.Omits,
 | 
				
			||||||
		ColumnMapping:        stmt.ColumnMapping,
 | 
					 | 
				
			||||||
		Preloads:             map[string][]interface{}{},
 | 
							Preloads:             map[string][]interface{}{},
 | 
				
			||||||
		ConnPool:             stmt.ConnPool,
 | 
							ConnPool:             stmt.ConnPool,
 | 
				
			||||||
		Schema:               stmt.Schema,
 | 
							Schema:               stmt.Schema,
 | 
				
			||||||
		Context:              stmt.Context,
 | 
							Context:              stmt.Context,
 | 
				
			||||||
		RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
 | 
							RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
 | 
				
			||||||
		SkipHooks:            stmt.SkipHooks,
 | 
							SkipHooks:            stmt.SkipHooks,
 | 
				
			||||||
		Result:               stmt.Result,
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if stmt.SQL.Len() > 0 {
 | 
						if stmt.SQL.Len() > 0 {
 | 
				
			||||||
@ -688,21 +665,7 @@ func (stmt *Statement) Changed(fields ...string) bool {
 | 
				
			|||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var matchName = func() func(tableColumn string) (table, column string) {
 | 
					var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`)
 | 
				
			||||||
	nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
 | 
					 | 
				
			||||||
	return func(tableColumn string) (table, column string) {
 | 
					 | 
				
			||||||
		if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
 | 
					 | 
				
			||||||
			table = matches[1]
 | 
					 | 
				
			||||||
			star := matches[2]
 | 
					 | 
				
			||||||
			columnName := matches[3]
 | 
					 | 
				
			||||||
			if star != "" {
 | 
					 | 
				
			||||||
				return table, star
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			return table, columnName
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "", ""
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
 | 
					// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
 | 
				
			||||||
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
 | 
					func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
 | 
				
			||||||
@ -723,13 +686,13 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
 | 
							} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
 | 
				
			||||||
			results[field.DBName] = result
 | 
								results[field.DBName] = result
 | 
				
			||||||
		} else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") {
 | 
							} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") {
 | 
				
			||||||
			if col == "*" {
 | 
								if matches[2] == "*" {
 | 
				
			||||||
				for _, dbName := range stmt.Schema.DBNames {
 | 
									for _, dbName := range stmt.Schema.DBNames {
 | 
				
			||||||
					results[dbName] = result
 | 
										results[dbName] = result
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				results[col] = result
 | 
									results[matches[2]] = result
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			results[column] = result
 | 
								results[column] = result
 | 
				
			||||||
 | 
				
			|||||||
@ -56,15 +56,9 @@ func TestNameMatcher(t *testing.T) {
 | 
				
			|||||||
		"`name_1`":           {"", "name_1"},
 | 
							"`name_1`":           {"", "name_1"},
 | 
				
			||||||
		"`Name_1`":           {"", "Name_1"},
 | 
							"`Name_1`":           {"", "Name_1"},
 | 
				
			||||||
		"`Table`.`nAme`":     {"Table", "nAme"},
 | 
							"`Table`.`nAme`":     {"Table", "nAme"},
 | 
				
			||||||
		"my_table.*":         {"my_table", "*"},
 | 
					 | 
				
			||||||
		"`my_table`.*":       {"my_table", "*"},
 | 
					 | 
				
			||||||
		"User__Company.*":    {"User__Company", "*"},
 | 
					 | 
				
			||||||
		"`User__Company`.*":  {"User__Company", "*"},
 | 
					 | 
				
			||||||
		`"User__Company".*`:  {"User__Company", "*"},
 | 
					 | 
				
			||||||
		`"table"."*"`:        {"", ""},
 | 
					 | 
				
			||||||
	} {
 | 
						} {
 | 
				
			||||||
		if table, column := matchName(k); table != v[0] || column != v[1] {
 | 
							if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 3 || matches[1] != v[0] || matches[2] != v[1] {
 | 
				
			||||||
			t.Errorf("failed to match value: %v, got %v, expect: %v", k, []string{table, column}, v)
 | 
								t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -278,6 +278,8 @@ func TestBelongsToAssociationUnscoped(t *testing.T) {
 | 
				
			|||||||
		t.Fatalf("failed to create items, got error: %v", err)
 | 
							t.Fatalf("failed to create items, got error: %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tx = tx.Debug()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// test replace
 | 
						// test replace
 | 
				
			||||||
	if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{
 | 
						if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{
 | 
				
			||||||
		Logo: "updated logo",
 | 
							Logo: "updated logo",
 | 
				
			||||||
 | 
				
			|||||||
@ -422,7 +422,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) {
 | 
				
			|||||||
func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
 | 
					func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
 | 
				
			||||||
	users := []User{
 | 
						users := []User{
 | 
				
			||||||
		*GetUser("slice-hasmany-1", Config{Toys: 2}),
 | 
							*GetUser("slice-hasmany-1", Config{Toys: 2}),
 | 
				
			||||||
		*GetUser("slice-hasmany-2", Config{Toys: 0, Tools: 2}),
 | 
							*GetUser("slice-hasmany-2", Config{Toys: 0}),
 | 
				
			||||||
		*GetUser("slice-hasmany-3", Config{Toys: 4}),
 | 
							*GetUser("slice-hasmany-3", Config{Toys: 4}),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -430,7 +430,6 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// Count
 | 
						// Count
 | 
				
			||||||
	AssertAssociationCount(t, users, "Toys", 6, "")
 | 
						AssertAssociationCount(t, users, "Toys", 6, "")
 | 
				
			||||||
	AssertAssociationCount(t, users, "Tools", 2, "")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Find
 | 
						// Find
 | 
				
			||||||
	var toys []Toy
 | 
						var toys []Toy
 | 
				
			||||||
@ -438,14 +437,6 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("toys count should be %v, but got %v", 6, len(toys))
 | 
							t.Errorf("toys count should be %v, but got %v", 6, len(toys))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Find Tools (polymorphic with custom type and id)
 | 
					 | 
				
			||||||
	var tools []Tools
 | 
					 | 
				
			||||||
	DB.Model(&users).Association("Tools").Find(&tools)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(tools) != 2 {
 | 
					 | 
				
			||||||
		t.Errorf("tools count should be %v, but got %v", 2, len(tools))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Append
 | 
						// Append
 | 
				
			||||||
	DB.Model(&users).Association("Toys").Append(
 | 
						DB.Model(&users).Association("Toys").Append(
 | 
				
			||||||
		&Toy{Name: "toy-slice-append-1"},
 | 
							&Toy{Name: "toy-slice-append-1"},
 | 
				
			||||||
@ -554,15 +545,3 @@ func TestHasManyAssociationUnscoped(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("expected %d contents, got %d", 0, len(contents))
 | 
							t.Errorf("expected %d contents, got %d", 0, len(contents))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestHasManyAssociationReplaceWithNonValidValue(t *testing.T) {
 | 
					 | 
				
			||||||
	user := User{Name: "jinzhu", Languages: []Language{{Name: "EN"}}}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.Create(&user).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("errors happened when create: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.Model(&user).Association("Languages").Replace(Language{Name: "DE"}, Language{Name: "FR"}); err == nil {
 | 
					 | 
				
			||||||
		t.Error("expected association error to be not nil")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -255,15 +255,3 @@ func TestPolymorphicHasOneAssociationForSlice(t *testing.T) {
 | 
				
			|||||||
	DB.Model(&pets).Association("Toy").Clear()
 | 
						DB.Model(&pets).Association("Toy").Clear()
 | 
				
			||||||
	AssertAssociationCount(t, pets, "Toy", 0, "After Clear")
 | 
						AssertAssociationCount(t, pets, "Toy", 0, "After Clear")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestHasOneAssociationReplaceWithNonValidValue(t *testing.T) {
 | 
					 | 
				
			||||||
	user := User{Name: "jinzhu", Account: Account{Number: "1"}}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.Create(&user).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("errors happened when create: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.Model(&user).Association("Languages").Replace(Account{Number: "2"}); err == nil {
 | 
					 | 
				
			||||||
		t.Error("expected association error to be not nil")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -91,7 +91,7 @@ func TestCallbacks(t *testing.T) {
 | 
				
			|||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
 | 
								callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
 | 
				
			||||||
			results:   []string{"c1", "c3", "c4", "c5"},
 | 
								results:   []string{"c1", "c5", "c3", "c4"},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
 | 
								callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
 | 
				
			||||||
@ -206,49 +206,3 @@ func TestPluginCallbacks(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("callbacks tests failed, got %v", msg)
 | 
							t.Errorf("callbacks tests failed, got %v", msg)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestCallbacksGet(t *testing.T) {
 | 
					 | 
				
			||||||
	db, _ := gorm.Open(nil, nil)
 | 
					 | 
				
			||||||
	createCallback := db.Callback().Create()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	createCallback.Before("*").Register("c1", c1)
 | 
					 | 
				
			||||||
	if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) {
 | 
					 | 
				
			||||||
		t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	createCallback.Remove("c1")
 | 
					 | 
				
			||||||
	if cb := createCallback.Get("c2"); cb != nil {
 | 
					 | 
				
			||||||
		t.Errorf("callbacks test failed. got: %p, want: nil", cb)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestCallbacksRemove(t *testing.T) {
 | 
					 | 
				
			||||||
	db, _ := gorm.Open(nil, nil)
 | 
					 | 
				
			||||||
	createCallback := db.Callback().Create()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	createCallback.Before("*").Register("c1", c1)
 | 
					 | 
				
			||||||
	createCallback.After("*").Register("c2", c2)
 | 
					 | 
				
			||||||
	createCallback.Before("c4").Register("c3", c3)
 | 
					 | 
				
			||||||
	createCallback.After("c2").Register("c4", c4)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// callbacks: []string{"c1", "c3", "c4", "c2"}
 | 
					 | 
				
			||||||
	createCallback.Remove("c1")
 | 
					 | 
				
			||||||
	if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok {
 | 
					 | 
				
			||||||
		t.Errorf("callbacks tests failed, got %v", msg)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	createCallback.Remove("c4")
 | 
					 | 
				
			||||||
	if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok {
 | 
					 | 
				
			||||||
		t.Errorf("callbacks tests failed, got %v", msg)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	createCallback.Remove("c2")
 | 
					 | 
				
			||||||
	if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok {
 | 
					 | 
				
			||||||
		t.Errorf("callbacks tests failed, got %v", msg)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	createCallback.Remove("c3")
 | 
					 | 
				
			||||||
	if ok, msg := assertCallbacks(createCallback, []string{}); !ok {
 | 
					 | 
				
			||||||
		t.Errorf("callbacks tests failed, got %v", msg)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -1,88 +0,0 @@
 | 
				
			|||||||
package tests_test
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
	"testing"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"gorm.io/gorm"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type Man struct {
 | 
					 | 
				
			||||||
	ID     int
 | 
					 | 
				
			||||||
	Age    int
 | 
					 | 
				
			||||||
	Name   string
 | 
					 | 
				
			||||||
	Detail string
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Panic-safe BeforeUpdate hook that checks for Changed("age")
 | 
					 | 
				
			||||||
func (m *Man) BeforeUpdate(tx *gorm.DB) (err error) {
 | 
					 | 
				
			||||||
	defer func() {
 | 
					 | 
				
			||||||
		if r := recover(); r != nil {
 | 
					 | 
				
			||||||
			err = fmt.Errorf("panic in BeforeUpdate: %v", r)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !tx.Statement.Changed("age") {
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (m *Man) update(data interface{}) error {
 | 
					 | 
				
			||||||
	return DB.Set("data", data).Model(m).Where("id = ?", m.ID).Updates(data).Error
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestBeforeUpdateStatementChanged(t *testing.T) {
 | 
					 | 
				
			||||||
	DB.AutoMigrate(&Man{})
 | 
					 | 
				
			||||||
	type TestCase struct {
 | 
					 | 
				
			||||||
		BaseObjects Man
 | 
					 | 
				
			||||||
		change      interface{}
 | 
					 | 
				
			||||||
		expectError bool
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	testCases := []TestCase{
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			BaseObjects: Man{ID: 1, Age: 18, Name: "random-name"},
 | 
					 | 
				
			||||||
			change: struct {
 | 
					 | 
				
			||||||
				Age int
 | 
					 | 
				
			||||||
			}{Age: 20},
 | 
					 | 
				
			||||||
			expectError: false,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"},
 | 
					 | 
				
			||||||
			change: struct {
 | 
					 | 
				
			||||||
				Name string
 | 
					 | 
				
			||||||
			}{Name: "name-only"},
 | 
					 | 
				
			||||||
			expectError: true,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"},
 | 
					 | 
				
			||||||
			change: struct {
 | 
					 | 
				
			||||||
				Name string
 | 
					 | 
				
			||||||
				Age int
 | 
					 | 
				
			||||||
			}{Name: "name-only", Age: 20},
 | 
					 | 
				
			||||||
			expectError: false,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, test := range testCases {
 | 
					 | 
				
			||||||
		DB.Create(&test.BaseObjects)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		// below comment is stored for future reference
 | 
					 | 
				
			||||||
		// err := DB.Set("data", test.change).Model(&test.BaseObjects).Where("id = ?", test.BaseObjects.ID).Updates(test.change).Error
 | 
					 | 
				
			||||||
		err := test.BaseObjects.update(test.change)
 | 
					 | 
				
			||||||
		if strings.Contains(fmt.Sprint(err), "panic in BeforeUpdate") {
 | 
					 | 
				
			||||||
			if !test.expectError {
 | 
					 | 
				
			||||||
				t.Errorf("unexpected panic in BeforeUpdate for input: %+v\nerror: %v", test.change, err)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			if test.expectError {
 | 
					 | 
				
			||||||
				t.Errorf("expected panic did not occur for input: %+v", test.change)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				t.Errorf("unexpected GORM error: %v", err)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@ -102,13 +102,13 @@ func TestConnPoolWrapper(t *testing.T) {
 | 
				
			|||||||
		expect: []string{
 | 
							expect: []string{
 | 
				
			||||||
			"SELECT VERSION()",
 | 
								"SELECT VERSION()",
 | 
				
			||||||
			"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
								"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
				
			||||||
			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
 | 
								"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
				
			||||||
			"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
								"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
				
			||||||
			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
 | 
								"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
				
			||||||
			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
 | 
								"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
				
			||||||
			"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
								"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
				
			||||||
			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
 | 
								"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
				
			||||||
			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
 | 
								"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -119,7 +119,6 @@ func TestConnPoolWrapper(t *testing.T) {
 | 
				
			|||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true}))
 | 
						db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true}))
 | 
				
			||||||
	db.Logger = DB.Logger
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatalf("Should open db success, but got %v", err)
 | 
							t.Fatalf("Should open db success, but got %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -29,7 +29,7 @@ func TestCountWithGroup(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var count2 int64
 | 
						var count2 int64
 | 
				
			||||||
	if err := DB.Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil {
 | 
						if err := DB.Debug().Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil {
 | 
				
			||||||
		t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
 | 
							t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if count2 != 2 {
 | 
						if count2 != 2 {
 | 
				
			||||||
 | 
				
			|||||||
@ -2,7 +2,6 @@ package tests_test
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
@ -14,48 +13,31 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestCreate(t *testing.T) {
 | 
					func TestCreate(t *testing.T) {
 | 
				
			||||||
	u1 := *GetUser("create", Config{})
 | 
						user := *GetUser("create", Config{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if results := DB.Create(&u1); results.Error != nil {
 | 
						if results := DB.Create(&user); results.Error != nil {
 | 
				
			||||||
		t.Fatalf("errors happened when create: %v", results.Error)
 | 
							t.Fatalf("errors happened when create: %v", results.Error)
 | 
				
			||||||
	} else if results.RowsAffected != 1 {
 | 
						} else if results.RowsAffected != 1 {
 | 
				
			||||||
		t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
 | 
							t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if u1.ID == 0 {
 | 
						if user.ID == 0 {
 | 
				
			||||||
		t.Errorf("user's primary key should has value after create, got : %v", u1.ID)
 | 
							t.Errorf("user's primary key should has value after create, got : %v", user.ID)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if u1.CreatedAt.IsZero() {
 | 
						if user.CreatedAt.IsZero() {
 | 
				
			||||||
		t.Errorf("user's created at should be not zero")
 | 
							t.Errorf("user's created at should be not zero")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if u1.UpdatedAt.IsZero() {
 | 
						if user.UpdatedAt.IsZero() {
 | 
				
			||||||
		t.Errorf("user's updated at should be not zero")
 | 
							t.Errorf("user's updated at should be not zero")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var newUser User
 | 
						var newUser User
 | 
				
			||||||
	if err := DB.Where("id = ?", u1.ID).First(&newUser).Error; err != nil {
 | 
						if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
 | 
				
			||||||
		t.Fatalf("errors happened when query: %v", err)
 | 
							t.Fatalf("errors happened when query: %v", err)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		CheckUser(t, newUser, u1)
 | 
							CheckUser(t, newUser, user)
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type user struct {
 | 
					 | 
				
			||||||
		ID   int `gorm:"primaryKey;->:false"`
 | 
					 | 
				
			||||||
		Name string
 | 
					 | 
				
			||||||
		Age  int
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var u2 user
 | 
					 | 
				
			||||||
	if results := DB.Create(&u2); results.Error != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("errors happened when create: %v", results.Error)
 | 
					 | 
				
			||||||
	} else if results.RowsAffected != 1 {
 | 
					 | 
				
			||||||
		t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if u2.ID != 0 {
 | 
					 | 
				
			||||||
		t.Errorf("don't have the permission to read primary key from db, but got %v", u2.ID)
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -598,213 +580,38 @@ func TestCreateWithAutoIncrementCompositeKey(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestCreateOnConflictWithDefaultNull(t *testing.T) {
 | 
					func TestCreateOnConfilctWithDefalutNull(t *testing.T) {
 | 
				
			||||||
	type OnConflictUser struct {
 | 
						type OnConfilctUser struct {
 | 
				
			||||||
		ID     string
 | 
							ID     string
 | 
				
			||||||
		Name   string `gorm:"default:null"`
 | 
							Name   string `gorm:"default:null"`
 | 
				
			||||||
		Email  string
 | 
							Email  string
 | 
				
			||||||
		Mobile string `gorm:"default:'133xxxx'"`
 | 
							Mobile string `gorm:"default:'133xxxx'"`
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := DB.Migrator().DropTable(&OnConflictUser{})
 | 
						err := DB.Migrator().DropTable(&OnConfilctUser{})
 | 
				
			||||||
	AssertEqual(t, err, nil)
 | 
						AssertEqual(t, err, nil)
 | 
				
			||||||
	err = DB.AutoMigrate(&OnConflictUser{})
 | 
						err = DB.AutoMigrate(&OnConfilctUser{})
 | 
				
			||||||
	AssertEqual(t, err, nil)
 | 
						AssertEqual(t, err, nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	u := OnConflictUser{
 | 
						u := OnConfilctUser{
 | 
				
			||||||
		ID:     "on-conflict-user-id",
 | 
							ID:     "on-confilct-user-id",
 | 
				
			||||||
		Name:   "on-conflict-user-name",
 | 
							Name:   "on-confilct-user-name",
 | 
				
			||||||
		Email:  "on-conflict-user-email",
 | 
							Email:  "on-confilct-user-email",
 | 
				
			||||||
		Mobile: "on-conflict-user-mobile",
 | 
							Mobile: "on-confilct-user-mobile",
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = DB.Create(&u).Error
 | 
						err = DB.Create(&u).Error
 | 
				
			||||||
	AssertEqual(t, err, nil)
 | 
						AssertEqual(t, err, nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	u.Name = "on-conflict-user-name-2"
 | 
						u.Name = "on-confilct-user-name-2"
 | 
				
			||||||
	u.Email = "on-conflict-user-email-2"
 | 
						u.Email = "on-confilct-user-email-2"
 | 
				
			||||||
	u.Mobile = ""
 | 
						u.Mobile = ""
 | 
				
			||||||
	err = DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&u).Error
 | 
						err = DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&u).Error
 | 
				
			||||||
	AssertEqual(t, err, nil)
 | 
						AssertEqual(t, err, nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var u2 OnConflictUser
 | 
						var u2 OnConfilctUser
 | 
				
			||||||
	err = DB.Where("id = ?", u.ID).First(&u2).Error
 | 
						err = DB.Where("id = ?", u.ID).First(&u2).Error
 | 
				
			||||||
	AssertEqual(t, err, nil)
 | 
						AssertEqual(t, err, nil)
 | 
				
			||||||
	AssertEqual(t, u2.Name, "on-conflict-user-name-2")
 | 
						AssertEqual(t, u2.Name, "on-confilct-user-name-2")
 | 
				
			||||||
	AssertEqual(t, u2.Email, "on-conflict-user-email-2")
 | 
						AssertEqual(t, u2.Email, "on-confilct-user-email-2")
 | 
				
			||||||
	AssertEqual(t, u2.Mobile, "133xxxx")
 | 
						AssertEqual(t, u2.Mobile, "133xxxx")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestCreateFromMapWithoutPK(t *testing.T) {
 | 
					 | 
				
			||||||
	if !isMysql() {
 | 
					 | 
				
			||||||
		t.Skipf("This test case skipped, because of only supporting for mysql")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// case 1: one record, create from map[string]interface{}
 | 
					 | 
				
			||||||
	mapValue1 := map[string]interface{}{"name": "create_from_map_with_schema1", "age": 1}
 | 
					 | 
				
			||||||
	if err := DB.Model(&User{}).Create(mapValue1).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to create data from map, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if _, ok := mapValue1["id"]; !ok {
 | 
					 | 
				
			||||||
		t.Fatal("failed to create data from map with table, returning map has no primary key")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var result1 User
 | 
					 | 
				
			||||||
	if err := DB.Where("name = ?", "create_from_map_with_schema1").First(&result1).Error; err != nil || result1.Age != 1 {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to create from map, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var idVal int64
 | 
					 | 
				
			||||||
	_, ok := mapValue1["id"].(uint)
 | 
					 | 
				
			||||||
	if ok {
 | 
					 | 
				
			||||||
		t.Skipf("This test case skipped, because the db supports returning")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	idVal, ok = mapValue1["id"].(int64)
 | 
					 | 
				
			||||||
	if !ok {
 | 
					 | 
				
			||||||
		t.Fatal("ret result missing id")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if int64(result1.ID) != idVal {
 | 
					 | 
				
			||||||
		t.Fatal("failed to create data from map with table, @id != id")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// case2: one record, create from *map[string]interface{}
 | 
					 | 
				
			||||||
	mapValue2 := map[string]interface{}{"name": "create_from_map_with_schema2", "age": 1}
 | 
					 | 
				
			||||||
	if err := DB.Model(&User{}).Create(&mapValue2).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to create data from map, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if _, ok := mapValue2["id"]; !ok {
 | 
					 | 
				
			||||||
		t.Fatal("failed to create data from map with table, returning map has no primary key")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var result2 User
 | 
					 | 
				
			||||||
	if err := DB.Where("name = ?", "create_from_map_with_schema2").First(&result2).Error; err != nil || result2.Age != 1 {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to create from map, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	_, ok = mapValue2["id"].(uint)
 | 
					 | 
				
			||||||
	if ok {
 | 
					 | 
				
			||||||
		t.Skipf("This test case skipped, because the db supports returning")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	idVal, ok = mapValue2["id"].(int64)
 | 
					 | 
				
			||||||
	if !ok {
 | 
					 | 
				
			||||||
		t.Fatal("ret result missing id")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if int64(result2.ID) != idVal {
 | 
					 | 
				
			||||||
		t.Fatal("failed to create data from map with table, @id != id")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// case 3: records
 | 
					 | 
				
			||||||
	values := []map[string]interface{}{
 | 
					 | 
				
			||||||
		{"name": "create_from_map_with_schema11", "age": 1}, {"name": "create_from_map_with_schema12", "age": 1},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	beforeLen := len(values)
 | 
					 | 
				
			||||||
	if err := DB.Model(&User{}).Create(&values).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to create data from map, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// mariadb with returning, values will be appended with id map
 | 
					 | 
				
			||||||
	if len(values) == beforeLen*2 {
 | 
					 | 
				
			||||||
		t.Skipf("This test case skipped, because the db supports returning")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for i := range values {
 | 
					 | 
				
			||||||
		v, ok := values[i]["id"]
 | 
					 | 
				
			||||||
		if !ok {
 | 
					 | 
				
			||||||
			t.Fatal("failed to create data from map with table, returning map has no primary key")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		var result User
 | 
					 | 
				
			||||||
		if err := DB.Where("name = ?", fmt.Sprintf("create_from_map_with_schema1%d", i+1)).First(&result).Error; err != nil || result.Age != 1 {
 | 
					 | 
				
			||||||
			t.Fatalf("failed to create from map, got error %v", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if int64(result.ID) != v.(int64) {
 | 
					 | 
				
			||||||
			t.Fatal("failed to create data from map with table, @id != id")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestCreateFromMapWithTable(t *testing.T) {
 | 
					 | 
				
			||||||
	tableDB := DB.Table("users")
 | 
					 | 
				
			||||||
	supportLastInsertID := isMysql() || isSqlite()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// case 1: create from map[string]interface{}
 | 
					 | 
				
			||||||
	record := map[string]interface{}{"name": "create_from_map_with_table", "age": 18}
 | 
					 | 
				
			||||||
	if err := tableDB.Create(record).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to create data from map with table, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if _, ok := record["@id"]; !ok && supportLastInsertID {
 | 
					 | 
				
			||||||
		t.Fatal("failed to create data from map with table, returning map has no key '@id'")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var res map[string]interface{}
 | 
					 | 
				
			||||||
	if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table").Find(&res).Error; err != nil || res["age"] != int64(18) {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to create from map, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if _, ok := record["@id"]; ok && fmt.Sprint(res["id"]) != fmt.Sprint(record["@id"]) {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to create data from map with table, @id != id, got %v, expect %v", res["id"], record["@id"])
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// case 2: create from *map[string]interface{}
 | 
					 | 
				
			||||||
	record1 := map[string]interface{}{"name": "create_from_map_with_table_1", "age": 18}
 | 
					 | 
				
			||||||
	tableDB2 := DB.Table("users")
 | 
					 | 
				
			||||||
	if err := tableDB2.Create(&record1).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to create data from map, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if _, ok := record1["@id"]; !ok && supportLastInsertID {
 | 
					 | 
				
			||||||
		t.Fatal("failed to create data from map with table, returning map has no key '@id'")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var res1 map[string]interface{}
 | 
					 | 
				
			||||||
	if err := tableDB2.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_1").Find(&res1).Error; err != nil || res1["age"] != int64(18) {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to create from map, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if _, ok := record1["@id"]; ok && fmt.Sprint(res1["id"]) != fmt.Sprint(record1["@id"]) {
 | 
					 | 
				
			||||||
		t.Fatal("failed to create data from map with table, @id != id")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// case 3: create from []map[string]interface{}
 | 
					 | 
				
			||||||
	records := []map[string]interface{}{
 | 
					 | 
				
			||||||
		{"name": "create_from_map_with_table_2", "age": 19},
 | 
					 | 
				
			||||||
		{"name": "create_from_map_with_table_3", "age": 20},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	tableDB = DB.Table("users")
 | 
					 | 
				
			||||||
	if err := tableDB.Create(&records).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to create data from slice of map, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if _, ok := records[0]["@id"]; !ok && supportLastInsertID {
 | 
					 | 
				
			||||||
		t.Fatal("failed to create data from map with table, returning map has no key '@id'")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if _, ok := records[1]["@id"]; !ok && supportLastInsertID {
 | 
					 | 
				
			||||||
		t.Fatal("failed to create data from map with table, returning map has no key '@id'")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var res2 map[string]interface{}
 | 
					 | 
				
			||||||
	if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_2").Find(&res2).Error; err != nil || res2["age"] != int64(19) {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to query data after create from slice of map, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var res3 map[string]interface{}
 | 
					 | 
				
			||||||
	if err := DB.Table("users").Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_3").Find(&res3).Error; err != nil || res3["age"] != int64(20) {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to query data after create from slice of map, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if _, ok := records[0]["@id"]; ok && fmt.Sprint(res2["id"]) != fmt.Sprint(records[0]["@id"]) {
 | 
					 | 
				
			||||||
		t.Errorf("failed to create data from map with table, @id != id, got %v, expect %v", res2["id"], records[0]["@id"])
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if _, ok := records[1]["id"]; ok && fmt.Sprint(res3["id"]) != fmt.Sprint(records[1]["@id"]) {
 | 
					 | 
				
			||||||
		t.Errorf("failed to create data from map with table, @id != id")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -38,22 +38,4 @@ func TestDefaultValue(t *testing.T) {
 | 
				
			|||||||
	} else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled || result.Created.Format("20060102") != "20000102" {
 | 
						} else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled || result.Created.Format("20060102") != "20000102" {
 | 
				
			||||||
		t.Fatalf("Failed to find created data with default data, got %+v", result)
 | 
							t.Fatalf("Failed to find created data with default data, got %+v", result)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					 | 
				
			||||||
	type Harumph2 struct {
 | 
					 | 
				
			||||||
		ID      int       `gorm:"default:0"`
 | 
					 | 
				
			||||||
		Email   string    `gorm:"not null;index:,unique"`
 | 
					 | 
				
			||||||
		Name    string    `gorm:"notNull;default:foo"`
 | 
					 | 
				
			||||||
		Name2   string    `gorm:"size:233;not null;default:'foo'"`
 | 
					 | 
				
			||||||
		Name3   string    `gorm:"size:233;notNull;default:''"`
 | 
					 | 
				
			||||||
		Age     int       `gorm:"default:18"`
 | 
					 | 
				
			||||||
		Created time.Time `gorm:"default:2000-01-02"`
 | 
					 | 
				
			||||||
		Enabled bool      `gorm:"default:true"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	harumph2 := Harumph2{ID: 2, Email: "hello2@gorm.io"}
 | 
					 | 
				
			||||||
	if err := DB.Table("harumphs").Create(&harumph2).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Failed to create data with default value, got error: %v", err)
 | 
					 | 
				
			||||||
	} else if harumph2.ID != 2 || harumph2.Name != "foo" || harumph2.Name2 != "foo" || harumph2.Name3 != "" || harumph2.Age != 18 || !harumph2.Enabled || harumph2.Created.Format("20060102") != "20000102" {
 | 
					 | 
				
			||||||
		t.Fatalf("Failed to create data with default value, got: %+v", harumph2)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -206,9 +206,9 @@ func TestDeleteSliceWithAssociations(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// only sqlite, postgres, gaussdb, sqlserver support returning
 | 
					// only sqlite, postgres support returning
 | 
				
			||||||
func TestSoftDeleteReturning(t *testing.T) {
 | 
					func TestSoftDeleteReturning(t *testing.T) {
 | 
				
			||||||
	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlserver" {
 | 
						if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -233,7 +233,7 @@ func TestSoftDeleteReturning(t *testing.T) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestDeleteReturning(t *testing.T) {
 | 
					func TestDeleteReturning(t *testing.T) {
 | 
				
			||||||
	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlserver" {
 | 
						if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -1,8 +1,10 @@
 | 
				
			|||||||
 | 
					version: '3'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
services:
 | 
					services:
 | 
				
			||||||
  mysql:
 | 
					  mysql:
 | 
				
			||||||
    image: 'mysql:latest'
 | 
					    image: 'mysql/mysql-server:latest'
 | 
				
			||||||
    ports:
 | 
					    ports:
 | 
				
			||||||
      - "127.0.0.1:9910:3306"
 | 
					      - "9910:3306"
 | 
				
			||||||
    environment:
 | 
					    environment:
 | 
				
			||||||
      - MYSQL_DATABASE=gorm
 | 
					      - MYSQL_DATABASE=gorm
 | 
				
			||||||
      - MYSQL_USER=gorm
 | 
					      - MYSQL_USER=gorm
 | 
				
			||||||
@ -11,22 +13,24 @@ services:
 | 
				
			|||||||
  postgres:
 | 
					  postgres:
 | 
				
			||||||
    image: 'postgres:latest'
 | 
					    image: 'postgres:latest'
 | 
				
			||||||
    ports:
 | 
					    ports:
 | 
				
			||||||
      - "127.0.0.1:9920:5432"
 | 
					      - "9920:5432"
 | 
				
			||||||
    environment:
 | 
					    environment:
 | 
				
			||||||
      - TZ=Asia/Shanghai
 | 
					      - TZ=Asia/Shanghai
 | 
				
			||||||
      - POSTGRES_DB=gorm
 | 
					      - POSTGRES_DB=gorm
 | 
				
			||||||
      - POSTGRES_USER=gorm
 | 
					      - POSTGRES_USER=gorm
 | 
				
			||||||
      - POSTGRES_PASSWORD=gorm
 | 
					      - POSTGRES_PASSWORD=gorm
 | 
				
			||||||
  mssql:
 | 
					  mssql:
 | 
				
			||||||
    image: '${MSSQL_IMAGE}:latest'
 | 
					    image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest'
 | 
				
			||||||
    ports:
 | 
					    ports:
 | 
				
			||||||
      - "127.0.0.1:9930:1433"
 | 
					      - "9930:1433"
 | 
				
			||||||
    environment:
 | 
					    environment:
 | 
				
			||||||
      - TZ=Asia/Shanghai
 | 
					 | 
				
			||||||
      - ACCEPT_EULA=Y
 | 
					      - ACCEPT_EULA=Y
 | 
				
			||||||
      - MSSQL_SA_PASSWORD=LoremIpsum86
 | 
					      - SA_PASSWORD=LoremIpsum86
 | 
				
			||||||
 | 
					      - MSSQL_DB=gorm
 | 
				
			||||||
 | 
					      - MSSQL_USER=gorm
 | 
				
			||||||
 | 
					      - MSSQL_PASSWORD=LoremIpsum86
 | 
				
			||||||
  tidb:
 | 
					  tidb:
 | 
				
			||||||
    image: 'pingcap/tidb:v6.5.0'
 | 
					    image: 'pingcap/tidb:v6.5.0'
 | 
				
			||||||
    ports:
 | 
					    ports:
 | 
				
			||||||
      - "127.0.0.1:9940:4000"
 | 
					      - "9940:4000"
 | 
				
			||||||
    command: /tidb-server -store unistore -path "" -lease 0s > tidb.log 2>&1 &
 | 
					    command: /tidb-server -store unistore -path "" -lease 0s > tidb.log 2>&1 &
 | 
				
			||||||
@ -236,15 +236,8 @@ func TestEmbeddedScanValuer(t *testing.T) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestEmbeddedRelations(t *testing.T) {
 | 
					func TestEmbeddedRelations(t *testing.T) {
 | 
				
			||||||
	type EmbUser struct {
 | 
					 | 
				
			||||||
		gorm.Model
 | 
					 | 
				
			||||||
		Name      string
 | 
					 | 
				
			||||||
		Age       uint
 | 
					 | 
				
			||||||
		Languages []Language `gorm:"many2many:EmbUserSpeak;"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type AdvancedUser struct {
 | 
						type AdvancedUser struct {
 | 
				
			||||||
		EmbUser  `gorm:"embedded"`
 | 
							User     `gorm:"embedded"`
 | 
				
			||||||
		Advanced bool
 | 
							Advanced bool
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -279,6 +272,6 @@ func TestEmbeddedTagSetting(t *testing.T) {
 | 
				
			|||||||
	err = DB.Save(&t1).Error
 | 
						err = DB.Save(&t1).Error
 | 
				
			||||||
	AssertEqual(t, err, nil)
 | 
						AssertEqual(t, err, nil)
 | 
				
			||||||
	if t1.Tag1.Id == 0 {
 | 
						if t1.Tag1.Id == 0 {
 | 
				
			||||||
		t.Errorf("embedded struct's primary field should be rewritten")
 | 
							t.Errorf("embedded struct's primary field should be rewrited")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -39,7 +39,7 @@ func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) {
 | 
				
			|||||||
		t.Fatalf("failed to connect database, got error %v", err)
 | 
							t.Fatalf("failed to connect database, got error %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	dialectors := map[string]bool{"sqlite": true, "postgres": true, "gaussdb": true, "mysql": true, "sqlserver": true}
 | 
						dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
 | 
				
			||||||
	if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
 | 
						if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -81,7 +81,7 @@ func TestSupportedDialectorWithErrForeignKeyViolated(t *testing.T) {
 | 
				
			|||||||
		t.Fatalf("failed to connect database, got error %v", err)
 | 
							t.Fatalf("failed to connect database, got error %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	dialectors := map[string]bool{"sqlite": true, "postgres": true, "gaussdb": true, "mysql": true, "sqlserver": true}
 | 
						dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
 | 
				
			||||||
	if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
 | 
						if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,248 +0,0 @@
 | 
				
			|||||||
package tests_test
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"testing"
 | 
					 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"github.com/google/uuid"
 | 
					 | 
				
			||||||
	"github.com/lib/pq"
 | 
					 | 
				
			||||||
	"gorm.io/gorm"
 | 
					 | 
				
			||||||
	"gorm.io/gorm/clause"
 | 
					 | 
				
			||||||
	. "gorm.io/gorm/utils/tests"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGaussDBReturningIDWhichHasStringType(t *testing.T) {
 | 
					 | 
				
			||||||
	t.Skipf("This test case skipped, because of gaussdb not support pgcrypto extension and gen_random_uuid() function")
 | 
					 | 
				
			||||||
	if DB.Dialector.Name() != "gaussdb" {
 | 
					 | 
				
			||||||
		t.Skip()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type Yasuo struct {
 | 
					 | 
				
			||||||
		// TODO: function gen_random_uuid() does not exist
 | 
					 | 
				
			||||||
		ID        string `gorm:"default:gen_random_uuid()"`
 | 
					 | 
				
			||||||
		Name      string
 | 
					 | 
				
			||||||
		CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
 | 
					 | 
				
			||||||
		UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("Failed to create extension pgcrypto, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.Migrator().DropTable(&Yasuo{})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.AutoMigrate(&Yasuo{}); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Failed to migrate for uuid default value, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	yasuo := Yasuo{Name: "jinzhu"}
 | 
					 | 
				
			||||||
	if err := DB.Create(&yasuo).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("should be able to create data, but got %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if yasuo.ID == "" {
 | 
					 | 
				
			||||||
		t.Fatal("should be able to has ID, but got zero value")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var result Yasuo
 | 
					 | 
				
			||||||
	if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" {
 | 
					 | 
				
			||||||
		t.Errorf("No error should happen, but got %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" {
 | 
					 | 
				
			||||||
		t.Errorf("No error should happen, but got %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	yasuo.Name = "jinzhu1"
 | 
					 | 
				
			||||||
	if err := DB.Save(&yasuo).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("Failed to update date, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" {
 | 
					 | 
				
			||||||
		t.Errorf("No error should happen, but got %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGaussDB(t *testing.T) {
 | 
					 | 
				
			||||||
	t.Skipf("This test case skipped, because of gaussdb not support pgcrypto extension and gen_random_uuid() function")
 | 
					 | 
				
			||||||
	if DB.Dialector.Name() != "gaussdb" {
 | 
					 | 
				
			||||||
		t.Skip()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type Harumph struct {
 | 
					 | 
				
			||||||
		gorm.Model
 | 
					 | 
				
			||||||
		Name string `gorm:"check:name_checker,name <> ''"`
 | 
					 | 
				
			||||||
		// TODO: function gen_random_uuid() does not exist
 | 
					 | 
				
			||||||
		Test      uuid.UUID      `gorm:"type:uuid;not null;default:gen_random_uuid()"`
 | 
					 | 
				
			||||||
		CreatedAt time.Time      `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
 | 
					 | 
				
			||||||
		UpdatedAt time.Time      `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
 | 
					 | 
				
			||||||
		Things    pq.StringArray `gorm:"type:text[]"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("Failed to create extension pgcrypto, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.Migrator().DropTable(&Harumph{})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.AutoMigrate(&Harumph{}); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Failed to migrate for uuid default value, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	harumph := Harumph{}
 | 
					 | 
				
			||||||
	if err := DB.Create(&harumph).Error; err == nil {
 | 
					 | 
				
			||||||
		t.Fatalf("should failed to create data, name can't be blank")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	harumph = Harumph{Name: "jinzhu"}
 | 
					 | 
				
			||||||
	if err := DB.Create(&harumph).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("should be able to create data, but got %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var result Harumph
 | 
					 | 
				
			||||||
	if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu" {
 | 
					 | 
				
			||||||
		t.Errorf("No error should happen, but got %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" {
 | 
					 | 
				
			||||||
		t.Errorf("No error should happen, but got %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	harumph.Name = "jinzhu1"
 | 
					 | 
				
			||||||
	if err := DB.Save(&harumph).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("Failed to update date, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
 | 
					 | 
				
			||||||
		t.Errorf("No error should happen, but got %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.Migrator().DropTable("log_usage")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.Exec(`
 | 
					 | 
				
			||||||
CREATE TABLE public.log_usage (
 | 
					 | 
				
			||||||
    log_id bigint NOT NULL
 | 
					 | 
				
			||||||
);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY (
 | 
					 | 
				
			||||||
    SEQUENCE NAME public.log_usage_log_id_seq
 | 
					 | 
				
			||||||
    START WITH 1
 | 
					 | 
				
			||||||
    INCREMENT BY 1
 | 
					 | 
				
			||||||
    NO MINVALUE
 | 
					 | 
				
			||||||
    NO MAXVALUE
 | 
					 | 
				
			||||||
    CACHE 1
 | 
					 | 
				
			||||||
);
 | 
					 | 
				
			||||||
	`).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to create table, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	columns, err := DB.Migrator().ColumnTypes("log_usage")
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to get columns, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	hasLogID := false
 | 
					 | 
				
			||||||
	for _, column := range columns {
 | 
					 | 
				
			||||||
		if column.Name() == "log_id" {
 | 
					 | 
				
			||||||
			hasLogID = true
 | 
					 | 
				
			||||||
			autoIncrement, ok := column.AutoIncrement()
 | 
					 | 
				
			||||||
			if !ok || !autoIncrement {
 | 
					 | 
				
			||||||
				t.Fatalf("column log_id should be auto incrementment")
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !hasLogID {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to found column log_id")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGaussDBMany2ManyWithDefaultValueUUID(t *testing.T) {
 | 
					 | 
				
			||||||
	t.Skipf("This test case skipped, because of gaussdb does not have 'uuid-ossp' extension")
 | 
					 | 
				
			||||||
	if DB.Dialector.Name() != "gaussdb" {
 | 
					 | 
				
			||||||
		t.Skip()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.Exec(`create extension if not exists "uuid-ossp"`).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Failed to create 'uuid-ossp' extension, but got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.Migrator().DropTable(&Post{}, &Category{}, "post_categories")
 | 
					 | 
				
			||||||
	DB.AutoMigrate(&Post{}, &Category{})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	post := Post{
 | 
					 | 
				
			||||||
		Title: "Hello World",
 | 
					 | 
				
			||||||
		Categories: []*Category{
 | 
					 | 
				
			||||||
			{Title: "Coding"},
 | 
					 | 
				
			||||||
			{Title: "Golang"},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.Create(&post).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("Failed, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGaussDBOnConstraint(t *testing.T) {
 | 
					 | 
				
			||||||
	t.Skipf("This test case skipped, because of gaussdb not support 'ON CONSTRAINT' statement")
 | 
					 | 
				
			||||||
	if DB.Dialector.Name() != "gaussdb" {
 | 
					 | 
				
			||||||
		t.Skip()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type Thing struct {
 | 
					 | 
				
			||||||
		gorm.Model
 | 
					 | 
				
			||||||
		SomeID  string
 | 
					 | 
				
			||||||
		OtherID string
 | 
					 | 
				
			||||||
		Data    string
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.Migrator().DropTable(&Thing{})
 | 
					 | 
				
			||||||
	DB.Migrator().CreateTable(&Thing{})
 | 
					 | 
				
			||||||
	if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil {
 | 
					 | 
				
			||||||
		t.Error(err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	thing := Thing{
 | 
					 | 
				
			||||||
		SomeID:  "1234",
 | 
					 | 
				
			||||||
		OtherID: "1234",
 | 
					 | 
				
			||||||
		Data:    "something",
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.Create(&thing)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	thing2 := Thing{
 | 
					 | 
				
			||||||
		SomeID:  "1234",
 | 
					 | 
				
			||||||
		OtherID: "1234",
 | 
					 | 
				
			||||||
		Data:    "something else",
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	result := DB.Clauses(clause.OnConflict{
 | 
					 | 
				
			||||||
		OnConstraint: "some_id_other_id_unique",
 | 
					 | 
				
			||||||
		UpdateAll:    true,
 | 
					 | 
				
			||||||
	}).Create(&thing2)
 | 
					 | 
				
			||||||
	if result.Error != nil {
 | 
					 | 
				
			||||||
		t.Errorf("creating second thing: %v", result.Error)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var things []Thing
 | 
					 | 
				
			||||||
	if err := DB.Find(&things).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("Failed, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(things) > 1 {
 | 
					 | 
				
			||||||
		t.Errorf("expected 1 thing got more")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGaussDBAlterColumnDataType(t *testing.T) {
 | 
					 | 
				
			||||||
	if DB.Dialector.Name() != "gaussdb" {
 | 
					 | 
				
			||||||
		t.Skip()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	DB.Migrator().DropTable(&Company{})
 | 
					 | 
				
			||||||
	DB.AutoMigrate(Company{})
 | 
					 | 
				
			||||||
	if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to alter column from string to int, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.AutoMigrate(Company{})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@ -1,875 +0,0 @@
 | 
				
			|||||||
package tests_test
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"context"
 | 
					 | 
				
			||||||
	"errors"
 | 
					 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"reflect"
 | 
					 | 
				
			||||||
	"regexp"
 | 
					 | 
				
			||||||
	"sort"
 | 
					 | 
				
			||||||
	"strconv"
 | 
					 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
	"sync"
 | 
					 | 
				
			||||||
	"testing"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"github.com/google/uuid"
 | 
					 | 
				
			||||||
	"gorm.io/driver/mysql"
 | 
					 | 
				
			||||||
	"gorm.io/gorm"
 | 
					 | 
				
			||||||
	"gorm.io/gorm/clause"
 | 
					 | 
				
			||||||
	. "gorm.io/gorm/utils/tests"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsCreate(t *testing.T) {
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	user := User{Name: "TestGenericsCreate", Age: 18}
 | 
					 | 
				
			||||||
	err := gorm.G[User](DB).Create(ctx, &user)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Create failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if user.ID == 0 {
 | 
					 | 
				
			||||||
		t.Fatalf("no primary key found for %v", user)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if u, err := gorm.G[User](DB).Where("name = ?", user.Name).First(ctx); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to find user, got error: %v", err)
 | 
					 | 
				
			||||||
	} else if u.Name != user.Name || u.ID != user.ID {
 | 
					 | 
				
			||||||
		t.Errorf("found invalid user, got %v, expect %v", u, user)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if u, err := gorm.G[User](DB).Where("name = ?", user.Name).Take(ctx); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to find user, got error: %v", err)
 | 
					 | 
				
			||||||
	} else if u.Name != user.Name || u.ID != user.ID {
 | 
					 | 
				
			||||||
		t.Errorf("found invalid user, got %v, expect %v", u, user)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if u, err := gorm.G[User](DB).Select("name").Where("name = ?", user.Name).First(ctx); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to find user, got error: %v", err)
 | 
					 | 
				
			||||||
	} else if u.Name != user.Name || u.Age != 0 {
 | 
					 | 
				
			||||||
		t.Errorf("found invalid user, got %v, expect %v", u, user)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to find user, got error: %v", err)
 | 
					 | 
				
			||||||
	} else if u.Name != "" || u.Age != user.Age {
 | 
					 | 
				
			||||||
		t.Errorf("found invalid user, got %v, expect %v", u, user)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	result := struct {
 | 
					 | 
				
			||||||
		ID   int
 | 
					 | 
				
			||||||
		Name string
 | 
					 | 
				
			||||||
	}{}
 | 
					 | 
				
			||||||
	if err := gorm.G[User](DB).Where("name = ?", user.Name).Scan(ctx, &result); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to scan user, got error: %v", err)
 | 
					 | 
				
			||||||
	} else if result.Name != user.Name || uint(result.ID) != user.ID {
 | 
					 | 
				
			||||||
		t.Errorf("found invalid user, got %v, expect %v", result, user)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	mapResult, err := gorm.G[map[string]interface{}](DB).Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "user_name"}).Take(ctx)
 | 
					 | 
				
			||||||
	if v := mapResult["user_name"]; fmt.Sprint(v) != user.Name {
 | 
					 | 
				
			||||||
		t.Errorf("failed to find map results, got %v, err %v", mapResult, err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsCreateInBatches(t *testing.T) {
 | 
					 | 
				
			||||||
	batch := []User{
 | 
					 | 
				
			||||||
		{Name: "GenericsCreateInBatches1"},
 | 
					 | 
				
			||||||
		{Name: "GenericsCreateInBatches2"},
 | 
					 | 
				
			||||||
		{Name: "GenericsCreateInBatches3"},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, 2); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("CreateInBatches failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, u := range batch {
 | 
					 | 
				
			||||||
		if u.ID == 0 {
 | 
					 | 
				
			||||||
			t.Fatalf("no primary key found for %v", u)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	count, err := gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Count(ctx, "*")
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Count failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if count != 3 {
 | 
					 | 
				
			||||||
		t.Errorf("expected 3 records, got %d", count)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx)
 | 
					 | 
				
			||||||
	if len(found) != len(batch) {
 | 
					 | 
				
			||||||
		t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Limit(2).Find(ctx)
 | 
					 | 
				
			||||||
	if len(found) != 2 {
 | 
					 | 
				
			||||||
		t.Errorf("expected %d from Raw Find, got %d", 2, len(found))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Offset(2).Limit(2).Find(ctx)
 | 
					 | 
				
			||||||
	if len(found) != 1 {
 | 
					 | 
				
			||||||
		t.Errorf("expected %d from Raw Find, got %d", 1, len(found))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsExecAndUpdate(t *testing.T) {
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	name := "GenericsExec"
 | 
					 | 
				
			||||||
	if err := gorm.G[User](DB).Exec(ctx, "INSERT INTO users(name) VALUES(?)", name); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Exec insert failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	u, err := gorm.G[User](DB).Table("users as u").Where("u.name = ?", name).First(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to find user, got error: %v", err)
 | 
					 | 
				
			||||||
	} else if u.Name != name || u.ID == 0 {
 | 
					 | 
				
			||||||
		t.Errorf("found invalid user, got %v", u)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	name += "Update"
 | 
					 | 
				
			||||||
	rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Update(ctx, "name", name)
 | 
					 | 
				
			||||||
	if rows != 1 {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	nu, err := gorm.G[User](DB).Where("name = ?", name).First(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to find user, got error: %v", err)
 | 
					 | 
				
			||||||
	} else if nu.Name != name || u.ID != nu.ID {
 | 
					 | 
				
			||||||
		t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	rows, err = gorm.G[User](DB).Where("id = ?", u.ID).Updates(ctx, User{Name: "GenericsExecUpdates", Age: 18})
 | 
					 | 
				
			||||||
	if rows != 1 {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	nu, err = gorm.G[User](DB).Where("id = ?", u.ID).Last(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to find user, got error: %v", err)
 | 
					 | 
				
			||||||
	} else if nu.Name != "GenericsExecUpdates" || nu.Age != 18 || u.ID != nu.ID {
 | 
					 | 
				
			||||||
		t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsRow(t *testing.T) {
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	user := User{Name: "GenericsRow"}
 | 
					 | 
				
			||||||
	if err := gorm.G[User](DB).Create(ctx, &user); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Create failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	row := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id = ?", user.ID).Row(ctx)
 | 
					 | 
				
			||||||
	var name string
 | 
					 | 
				
			||||||
	if err := row.Scan(&name); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Row scan failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if name != user.Name {
 | 
					 | 
				
			||||||
		t.Errorf("expected %s, got %s", user.Name, name)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	user2 := User{Name: "GenericsRow2"}
 | 
					 | 
				
			||||||
	if err := gorm.G[User](DB).Create(ctx, &user2); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Create failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	rows, err := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id IN ?", []uint{user.ID, user2.ID}).Rows(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Rows failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	count := 0
 | 
					 | 
				
			||||||
	for rows.Next() {
 | 
					 | 
				
			||||||
		var name string
 | 
					 | 
				
			||||||
		if err := rows.Scan(&name); err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("rows.Scan failed: %v", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		count++
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if count != 2 {
 | 
					 | 
				
			||||||
		t.Errorf("expected 2 rows, got %d", count)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsDelete(t *testing.T) {
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	u := User{Name: "GenericsDelete"}
 | 
					 | 
				
			||||||
	if err := gorm.G[User](DB).Create(ctx, &u); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Create failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Delete(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Delete failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if rows != 1 {
 | 
					 | 
				
			||||||
		t.Errorf("expected 1 row deleted, got %d", rows)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	_, err = gorm.G[User](DB).Where("id = ?", u.ID).First(ctx)
 | 
					 | 
				
			||||||
	if err != gorm.ErrRecordNotFound {
 | 
					 | 
				
			||||||
		t.Fatalf("User after delete failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsFindInBatches(t *testing.T) {
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	users := []User{
 | 
					 | 
				
			||||||
		{Name: "GenericsFindBatchA"},
 | 
					 | 
				
			||||||
		{Name: "GenericsFindBatchB"},
 | 
					 | 
				
			||||||
		{Name: "GenericsFindBatchC"},
 | 
					 | 
				
			||||||
		{Name: "GenericsFindBatchD"},
 | 
					 | 
				
			||||||
		{Name: "GenericsFindBatchE"},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("CreateInBatches failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	total := 0
 | 
					 | 
				
			||||||
	err := gorm.G[User](DB).Where("name like ?", "GenericsFindBatch%").FindInBatches(ctx, 2, func(chunk []User, batch int) error {
 | 
					 | 
				
			||||||
		if len(chunk) > 2 {
 | 
					 | 
				
			||||||
			t.Errorf("batch size exceed 2: got %d", len(chunk))
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		total += len(chunk)
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("FindInBatches failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if total != len(users) {
 | 
					 | 
				
			||||||
		t.Errorf("expected total %d, got %d", len(users), total)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsScopes(t *testing.T) {
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	users := []User{{Name: "GenericsScopes1"}, {Name: "GenericsScopes2"}, {Name: "GenericsScopes3"}}
 | 
					 | 
				
			||||||
	err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users))
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("CreateInBatches failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	filterName1 := func(stmt *gorm.Statement) {
 | 
					 | 
				
			||||||
		stmt.Where("name = ?", "GenericsScopes1")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	results, err := gorm.G[User](DB).Scopes(filterName1).Find(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Scopes failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if len(results) != 1 || results[0].Name != "GenericsScopes1" {
 | 
					 | 
				
			||||||
		t.Fatalf("Scopes expected 1, got %d", len(results))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	notResult, err := gorm.G[User](DB).Where("name like ?", "GenericsScopes%").Not("name = ?", "GenericsScopes1").Order("name").Find(ctx)
 | 
					 | 
				
			||||||
	if len(notResult) != 2 {
 | 
					 | 
				
			||||||
		t.Fatalf("expected 2 results, got %d", len(notResult))
 | 
					 | 
				
			||||||
	} else if notResult[0].Name != "GenericsScopes2" || notResult[1].Name != "GenericsScopes3" {
 | 
					 | 
				
			||||||
		t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", notResult[0].Name, notResult[1].Name)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	orResult, err := gorm.G[User](DB).Or("name = ?", "GenericsScopes1").Or("name = ?", "GenericsScopes2").Order("name").Find(ctx)
 | 
					 | 
				
			||||||
	if len(orResult) != 2 {
 | 
					 | 
				
			||||||
		t.Fatalf("expected 2 results, got %d", len(notResult))
 | 
					 | 
				
			||||||
	} else if orResult[0].Name != "GenericsScopes1" || orResult[1].Name != "GenericsScopes2" {
 | 
					 | 
				
			||||||
		t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", orResult[0].Name, orResult[1].Name)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsJoins(t *testing.T) {
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
	db := gorm.G[User](DB)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}}
 | 
					 | 
				
			||||||
	u2 := User{Name: "GenericsJoins_2", Company: Company{Name: "GenericsCompany_2"}}
 | 
					 | 
				
			||||||
	u3 := User{Name: "GenericsJoins_3", Company: Company{Name: "GenericsCompany_3"}}
 | 
					 | 
				
			||||||
	db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Inner JOIN + WHERE
 | 
					 | 
				
			||||||
	result, err := db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
 | 
					 | 
				
			||||||
		db.Where("?.name = ?", joinTable, u.Company.Name)
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}).First(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Joins failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if result.Name != u.Name || result.Company.Name != u.Company.Name {
 | 
					 | 
				
			||||||
		t.Fatalf("Joins expected %s, got %+v", u.Name, result)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Inner JOIN + WHERE with map
 | 
					 | 
				
			||||||
	result, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
 | 
					 | 
				
			||||||
		db.Where(map[string]any{"name": u.Company.Name})
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}).First(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Joins failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if result.Name != u.Name || result.Company.Name != u.Company.Name {
 | 
					 | 
				
			||||||
		t.Fatalf("Joins expected %s, got %+v", u.Name, result)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Left JOIN w/o WHERE
 | 
					 | 
				
			||||||
	result, err = db.Joins(clause.LeftJoin.Association("Company"), nil).Where(map[string]any{"name": u.Name}).First(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Joins failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if result.Name != u.Name || result.Company.Name != u.Company.Name {
 | 
					 | 
				
			||||||
		t.Fatalf("Joins expected %s, got %+v", u.Name, result)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Left JOIN + Alias WHERE
 | 
					 | 
				
			||||||
	result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
 | 
					 | 
				
			||||||
		if joinTable.Name != "t" {
 | 
					 | 
				
			||||||
			t.Fatalf("Join table should be t, but got %v", joinTable.Name)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		db.Where("?.name = ?", joinTable, u.Company.Name)
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}).Where(map[string]any{"name": u.Name}).First(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Joins failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if result.Name != u.Name || result.Company.Name != u.Company.Name {
 | 
					 | 
				
			||||||
		t.Fatalf("Joins expected %s, got %+v", u.Name, result)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Raw Subquery JOIN + WHERE
 | 
					 | 
				
			||||||
	result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"),
 | 
					 | 
				
			||||||
		func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
 | 
					 | 
				
			||||||
			if joinTable.Name != "t" {
 | 
					 | 
				
			||||||
				t.Fatalf("Join table should be t, but got %v", joinTable.Name)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			db.Where("?.name = ?", joinTable, u.Company.Name)
 | 
					 | 
				
			||||||
			return nil
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	).Where(map[string]any{"name": u2.Name}).First(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Raw subquery join failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID == 0 {
 | 
					 | 
				
			||||||
		t.Fatalf("Joins expected %s, got %+v", u.Name, result)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Raw Subquery JOIN + WHERE + Select
 | 
					 | 
				
			||||||
	result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"),
 | 
					 | 
				
			||||||
		func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
 | 
					 | 
				
			||||||
			if joinTable.Name != "t" {
 | 
					 | 
				
			||||||
				t.Fatalf("Join table should be t, but got %v", joinTable.Name)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			db.Where("?.name = ?", joinTable, u.Company.Name)
 | 
					 | 
				
			||||||
			return nil
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	).Where(map[string]any{"name": u2.Name}).First(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Raw subquery join failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID != 0 {
 | 
					 | 
				
			||||||
		t.Fatalf("Joins expected %s, got %+v", u.Name, result)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	_, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
 | 
					 | 
				
			||||||
		return errors.New("join error")
 | 
					 | 
				
			||||||
	}).First(ctx)
 | 
					 | 
				
			||||||
	if err == nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Joins should got error, but got nil")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsNestedJoins(t *testing.T) {
 | 
					 | 
				
			||||||
	users := []User{
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			Name: "generics-nested-joins-1",
 | 
					 | 
				
			||||||
			Manager: &User{
 | 
					 | 
				
			||||||
				Name: "generics-nested-joins-manager-1",
 | 
					 | 
				
			||||||
				Company: Company{
 | 
					 | 
				
			||||||
					Name: "generics-nested-joins-manager-company-1",
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
				NamedPet: &Pet{
 | 
					 | 
				
			||||||
					Name: "generics-nested-joins-manager-namepet-1",
 | 
					 | 
				
			||||||
					Toy: Toy{
 | 
					 | 
				
			||||||
						Name: "generics-nested-joins-manager-namepet-toy-1",
 | 
					 | 
				
			||||||
					},
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
			NamedPet: &Pet{Name: "generics-nested-joins-namepet-1", Toy: Toy{Name: "generics-nested-joins-namepet-toy-1"}},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			Name:     "generics-nested-joins-2",
 | 
					 | 
				
			||||||
			Manager:  GetUser("generics-nested-joins-manager-2", Config{Company: true, NamedPet: true}),
 | 
					 | 
				
			||||||
			NamedPet: &Pet{Name: "generics-nested-joins-namepet-2", Toy: Toy{Name: "generics-nested-joins-namepet-toy-2"}},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
	db := gorm.G[User](DB)
 | 
					 | 
				
			||||||
	db.CreateInBatches(ctx, &users, 100)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var userIDs []uint
 | 
					 | 
				
			||||||
	for _, user := range users {
 | 
					 | 
				
			||||||
		userIDs = append(userIDs, user.ID)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	users2, err := db.Joins(clause.LeftJoin.Association("Manager"), nil).
 | 
					 | 
				
			||||||
		Joins(clause.LeftJoin.Association("Manager.Company"), nil).
 | 
					 | 
				
			||||||
		Joins(clause.LeftJoin.Association("Manager.NamedPet.Toy"), nil).
 | 
					 | 
				
			||||||
		Joins(clause.LeftJoin.Association("NamedPet.Toy"), nil).
 | 
					 | 
				
			||||||
		Joins(clause.LeftJoin.Association("NamedPet").As("t"), nil).
 | 
					 | 
				
			||||||
		Where(map[string]any{"id": userIDs}).Find(ctx)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Failed to load with joins, got error: %v", err)
 | 
					 | 
				
			||||||
	} else if len(users2) != len(users) {
 | 
					 | 
				
			||||||
		t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	sort.Slice(users2, func(i, j int) bool {
 | 
					 | 
				
			||||||
		return users2[i].ID > users2[j].ID
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	sort.Slice(users, func(i, j int) bool {
 | 
					 | 
				
			||||||
		return users[i].ID > users[j].ID
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for idx, user := range users {
 | 
					 | 
				
			||||||
		// user
 | 
					 | 
				
			||||||
		CheckUser(t, user, users2[idx])
 | 
					 | 
				
			||||||
		if users2[idx].Manager == nil {
 | 
					 | 
				
			||||||
			t.Fatalf("Failed to load Manager")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		// manager
 | 
					 | 
				
			||||||
		CheckUser(t, *user.Manager, *users2[idx].Manager)
 | 
					 | 
				
			||||||
		// user pet
 | 
					 | 
				
			||||||
		if users2[idx].NamedPet == nil {
 | 
					 | 
				
			||||||
			t.Fatalf("Failed to load NamedPet")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		CheckPet(t, *user.NamedPet, *users2[idx].NamedPet)
 | 
					 | 
				
			||||||
		// manager pet
 | 
					 | 
				
			||||||
		if users2[idx].Manager.NamedPet == nil {
 | 
					 | 
				
			||||||
			t.Fatalf("Failed to load NamedPet")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsPreloads(t *testing.T) {
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
	db := gorm.G[User](DB)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	u := *GetUser("GenericsPreloads_1", Config{Company: true, Pets: 3, Friends: 7})
 | 
					 | 
				
			||||||
	u2 := *GetUser("GenericsPreloads_2", Config{Company: true, Pets: 5, Friends: 5})
 | 
					 | 
				
			||||||
	u3 := *GetUser("GenericsPreloads_3", Config{Company: true, Pets: 7, Friends: 3})
 | 
					 | 
				
			||||||
	names := []string{u.Name, u2.Name, u3.Name}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	result, err := db.Preload("Company", nil).Preload("Pets", nil).Where("name = ?", u.Name).First(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Preload failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if result.Name != u.Name || result.Company.Name != u.Company.Name || len(result.Pets) != len(u.Pets) {
 | 
					 | 
				
			||||||
		t.Fatalf("Preload expected %s, got %+v", u.Name, result)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error {
 | 
					 | 
				
			||||||
		db.Where("name = ?", u.Company.Name)
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}).Where("name in ?", names).Find(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Preload failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	for _, result := range results {
 | 
					 | 
				
			||||||
		if result.Name == u.Name {
 | 
					 | 
				
			||||||
			if result.Company.Name != u.Company.Name {
 | 
					 | 
				
			||||||
				t.Fatalf("Preload user %v company should be %v, but got %+v", u.Name, u.Company.Name, result.Company.Name)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else if result.Company.Name != "" {
 | 
					 | 
				
			||||||
			t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	_, err = db.Preload("Company", func(db gorm.PreloadBuilder) error {
 | 
					 | 
				
			||||||
		return errors.New("preload error")
 | 
					 | 
				
			||||||
	}).Where("name in ?", names).Find(ctx)
 | 
					 | 
				
			||||||
	if err == nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Preload should failed, but got nil")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if DB.Dialector.Name() == "mysql" {
 | 
					 | 
				
			||||||
		// mysql 5.7 doesn't support row_number()
 | 
					 | 
				
			||||||
		if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") {
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
 | 
					 | 
				
			||||||
		db.LimitPerRecord(5)
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}).Where("name in ?", names).Find(ctx)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, result := range results {
 | 
					 | 
				
			||||||
		if result.Name == u.Name {
 | 
					 | 
				
			||||||
			if len(result.Pets) != len(u.Pets) {
 | 
					 | 
				
			||||||
				t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else if len(result.Pets) != 5 {
 | 
					 | 
				
			||||||
			t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if DB.Dialector.Name() == "sqlserver" {
 | 
					 | 
				
			||||||
		// sqlserver doesn't support order by in subquery
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
 | 
					 | 
				
			||||||
		db.Order("name desc").LimitPerRecord(5)
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}).Where("name in ?", names).Find(ctx)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, result := range results {
 | 
					 | 
				
			||||||
		if result.Name == u.Name {
 | 
					 | 
				
			||||||
			if len(result.Pets) != len(u.Pets) {
 | 
					 | 
				
			||||||
				t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else if len(result.Pets) != 5 {
 | 
					 | 
				
			||||||
			t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		for i := 1; i < len(result.Pets); i++ {
 | 
					 | 
				
			||||||
			if result.Pets[i-1].Name < result.Pets[i].Name {
 | 
					 | 
				
			||||||
				t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
 | 
					 | 
				
			||||||
		db.Order("name").LimitPerRecord(5)
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}).Preload("Friends", func(db gorm.PreloadBuilder) error {
 | 
					 | 
				
			||||||
		db.Order("name")
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}).Where("name in ?", names).Find(ctx)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, result := range results {
 | 
					 | 
				
			||||||
		if result.Name == u.Name {
 | 
					 | 
				
			||||||
			if len(result.Pets) != len(u.Pets) {
 | 
					 | 
				
			||||||
				t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			if len(result.Friends) != len(u.Friends) {
 | 
					 | 
				
			||||||
				t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else if len(result.Pets) != 5 || len(result.Friends) == 0 {
 | 
					 | 
				
			||||||
			t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		for i := 1; i < len(result.Pets); i++ {
 | 
					 | 
				
			||||||
			if result.Pets[i-1].Name > result.Pets[i].Name {
 | 
					 | 
				
			||||||
				t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		for i := 1; i < len(result.Pets); i++ {
 | 
					 | 
				
			||||||
			if result.Pets[i-1].Name > result.Pets[i].Name {
 | 
					 | 
				
			||||||
				t.Fatalf("Preload user %v friends not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsNestedPreloads(t *testing.T) {
 | 
					 | 
				
			||||||
	user := *GetUser("generics_nested_preload", Config{Pets: 2})
 | 
					 | 
				
			||||||
	user.Friends = []*User{GetUser("generics_nested_preload", Config{Pets: 5})}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
	db := gorm.G[User](DB)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for idx, pet := range user.Pets {
 | 
					 | 
				
			||||||
		pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := db.Create(ctx, &user); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("errors happened when create: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}).Where(user.ID).Take(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("failed to nested preload user")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	CheckUser(t, user2, user)
 | 
					 | 
				
			||||||
	if len(user.Pets) == 0 || len(user.Friends) == 0 || len(user.Friends[0].Pets) == 0 {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to nested preload")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if DB.Dialector.Name() == "mysql" {
 | 
					 | 
				
			||||||
		// mysql 5.7 doesn't support row_number()
 | 
					 | 
				
			||||||
		if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") {
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if DB.Dialector.Name() == "sqlserver" {
 | 
					 | 
				
			||||||
		// sqlserver doesn't support order by in subquery
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	user3, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
 | 
					 | 
				
			||||||
		db.LimitPerRecord(3)
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}).Where(user.ID).Take(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("failed to nested preload user")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	CheckUser(t, user3, user)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(user3.Friends) != 1 || len(user3.Friends[0].Pets) != 3 {
 | 
					 | 
				
			||||||
		t.Errorf("failed to nested preload with limit per record")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsDistinct(t *testing.T) {
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	batch := []User{
 | 
					 | 
				
			||||||
		{Name: "GenericsDistinctDup"},
 | 
					 | 
				
			||||||
		{Name: "GenericsDistinctDup"},
 | 
					 | 
				
			||||||
		{Name: "GenericsDistinctUnique"},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("CreateInBatches failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	results, err := gorm.G[User](DB).Where("name like ?", "GenericsDistinct%").Distinct("name").Find(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Distinct Find failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(results) != 2 {
 | 
					 | 
				
			||||||
		t.Errorf("expected 2 distinct names, got %d", len(results))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var names []string
 | 
					 | 
				
			||||||
	for _, u := range results {
 | 
					 | 
				
			||||||
		names = append(names, u.Name)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	sort.Strings(names)
 | 
					 | 
				
			||||||
	expected := []string{"GenericsDistinctDup", "GenericsDistinctUnique"}
 | 
					 | 
				
			||||||
	if !reflect.DeepEqual(names, expected) {
 | 
					 | 
				
			||||||
		t.Errorf("expected names %v, got %v", expected, names)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsGroupHaving(t *testing.T) {
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	batch := []User{
 | 
					 | 
				
			||||||
		{Name: "GenericsGroupHavingMulti"},
 | 
					 | 
				
			||||||
		{Name: "GenericsGroupHavingMulti"},
 | 
					 | 
				
			||||||
		{Name: "GenericsGroupHavingSingle"},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("CreateInBatches failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	grouped, err := gorm.G[User](DB).Select("name").Where("name like ?", "GenericsGroupHaving%").Group("name").Having("COUNT(id) > ?", 1).Find(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Group+Having Find failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(grouped) != 1 {
 | 
					 | 
				
			||||||
		t.Errorf("expected 1 group with count>1, got %d", len(grouped))
 | 
					 | 
				
			||||||
	} else if grouped[0].Name != "GenericsGroupHavingMulti" {
 | 
					 | 
				
			||||||
		t.Errorf("expected group name 'GenericsGroupHavingMulti', got '%s'", grouped[0].Name)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsSubQuery(t *testing.T) {
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
	users := []User{
 | 
					 | 
				
			||||||
		{Name: "GenericsSubquery_1", Age: 10},
 | 
					 | 
				
			||||||
		{Name: "GenericsSubquery_2", Age: 20},
 | 
					 | 
				
			||||||
		{Name: "GenericsSubquery_3", Age: 30},
 | 
					 | 
				
			||||||
		{Name: "GenericsSubquery_4", Age: 40},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("CreateInBatches failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	results, err := gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name LIKE ?", "GenericsSubquery%")).Find(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(results) != 4 {
 | 
					 | 
				
			||||||
		t.Errorf("Four users should be found, instead found %d", len(results))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	results, err = gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name IN ?", []string{"GenericsSubquery_1", "GenericsSubquery_2"}).Or("name = ?", "GenericsSubquery_3")).Find(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(results) != 3 {
 | 
					 | 
				
			||||||
		t.Errorf("Three users should be found, instead found %d", len(results))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsUpsert(t *testing.T) {
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
	lang := Language{Code: "upsert", Name: "Upsert"}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to upsert, got %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	lang2 := Language{Code: "upsert", Name: "Upsert"}
 | 
					 | 
				
			||||||
	if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang2); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to upsert, got %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("no error should happen when find languages with code, but got %v", err)
 | 
					 | 
				
			||||||
	} else if len(langs) != 1 {
 | 
					 | 
				
			||||||
		t.Errorf("should only find only 1 languages, but got %+v", langs)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	lang3 := Language{Code: "upsert", Name: "Upsert"}
 | 
					 | 
				
			||||||
	if err := gorm.G[Language](DB, clause.OnConflict{
 | 
					 | 
				
			||||||
		Columns:   []clause.Column{{Name: "code"}},
 | 
					 | 
				
			||||||
		DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}),
 | 
					 | 
				
			||||||
	}).Create(ctx, &lang3); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to upsert, got %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx); err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("no error should happen when find languages with code, but got %v", err)
 | 
					 | 
				
			||||||
	} else if len(langs) != 1 {
 | 
					 | 
				
			||||||
		t.Errorf("should only find only 1 languages, but got %+v", langs)
 | 
					 | 
				
			||||||
	} else if langs[0].Name != "upsert-new" {
 | 
					 | 
				
			||||||
		t.Errorf("should update name on conflict, but got name %+v", langs[0].Name)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsWithResult(t *testing.T) {
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
	users := []User{{Name: "TestGenericsWithResult", Age: 18}, {Name: "TestGenericsWithResult2", Age: 18}}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	result := gorm.WithResult()
 | 
					 | 
				
			||||||
	err := gorm.G[User](DB, result).CreateInBatches(ctx, &users, 2)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("failed to create users WithResult")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if result.RowsAffected != 2 {
 | 
					 | 
				
			||||||
		t.Errorf("failed to get affected rows, got %d, should be %d", result.RowsAffected, 2)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsReuse(t *testing.T) {
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
	users := []User{{Name: "TestGenericsReuse1", Age: 18}, {Name: "TestGenericsReuse2", Age: 18}}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("failed to create users")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	reusedb := gorm.G[User](DB).Where("name like ?", "TestGenericsReuse%")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	sg := sync.WaitGroup{}
 | 
					 | 
				
			||||||
	for i := 0; i < 5; i++ {
 | 
					 | 
				
			||||||
		sg.Add(1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		go func() {
 | 
					 | 
				
			||||||
			if u1, err := reusedb.Where("id = ?", users[0].ID).First(ctx); err != nil {
 | 
					 | 
				
			||||||
				t.Errorf("failed to find user, got error: %v", err)
 | 
					 | 
				
			||||||
			} else if u1.Name != users[0].Name || u1.ID != users[0].ID {
 | 
					 | 
				
			||||||
				t.Errorf("found invalid user, got %v, expect %v", u1, users[0])
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if u2, err := reusedb.Where("id = ?", users[1].ID).First(ctx); err != nil {
 | 
					 | 
				
			||||||
				t.Errorf("failed to find user, got error: %v", err)
 | 
					 | 
				
			||||||
			} else if u2.Name != users[1].Name || u2.ID != users[1].ID {
 | 
					 | 
				
			||||||
				t.Errorf("found invalid user, got %v, expect %v", u2, users[1])
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if users, err := reusedb.Where("id IN ?", []uint{users[0].ID, users[1].ID}).Find(ctx); err != nil {
 | 
					 | 
				
			||||||
				t.Errorf("failed to find user, got error: %v", err)
 | 
					 | 
				
			||||||
			} else if len(users) != 2 {
 | 
					 | 
				
			||||||
				t.Errorf("should find 2 users, but got %d", len(users))
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			sg.Done()
 | 
					 | 
				
			||||||
		}()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	sg.Wait()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsWithTransaction(t *testing.T) {
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
	tx := DB.Begin()
 | 
					 | 
				
			||||||
	if tx.Error != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to begin transaction: %v", tx.Error)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	users := []User{{Name: "TestGenericsTransaction", Age: 18}, {Name: "TestGenericsTransaction2", Age: 18}}
 | 
					 | 
				
			||||||
	err := gorm.G[User](tx).CreateInBatches(ctx, &users, 2)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	count, err := gorm.G[User](tx).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Count failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if count != 2 {
 | 
					 | 
				
			||||||
		t.Errorf("expected 2 records, got %d", count)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := tx.Rollback().Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to rollback transaction: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	count2, err := gorm.G[User](DB).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Count failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if count2 != 0 {
 | 
					 | 
				
			||||||
		t.Errorf("expected 0 records after rollback, got %d", count2)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsToSQL(t *testing.T) {
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
	sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
 | 
					 | 
				
			||||||
		gorm.G[User](tx).Limit(10).Find(ctx)
 | 
					 | 
				
			||||||
		return tx
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !regexp.MustCompile("SELECT \\* FROM .users..* 10").MatchString(sql) {
 | 
					 | 
				
			||||||
		t.Errorf("ToSQL: got wrong sql with Generics API %v", sql)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGenericsScanUUID(t *testing.T) {
 | 
					 | 
				
			||||||
	ctx := context.Background()
 | 
					 | 
				
			||||||
	users := []User{
 | 
					 | 
				
			||||||
		{Name: uuid.NewString(), Age: 21},
 | 
					 | 
				
			||||||
		{Name: uuid.NewString(), Age: 22},
 | 
					 | 
				
			||||||
		{Name: uuid.NewString(), Age: 23},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("CreateInBatches failed: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	userIds := []uuid.UUID{}
 | 
					 | 
				
			||||||
	if err := gorm.G[User](DB).Select("name").Where("id in ?", []uint{users[0].ID, users[1].ID, users[2].ID}).Order("age").Scan(ctx, &userIds); err != nil || len(users) != 3 {
 | 
					 | 
				
			||||||
		t.Fatalf("Scan failed: %v, userids %v", err, userIds)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if userIds[0].String() != users[0].Name || userIds[1].String() != users[1].Name || userIds[2].String() != users[2].Name {
 | 
					 | 
				
			||||||
		t.Fatalf("wrong uuid scanned")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
							
								
								
									
										38
									
								
								tests/go.mod
									
									
									
									
									
								
							
							
						
						
									
										38
									
								
								tests/go.mod
									
									
									
									
									
								
							@ -1,40 +1,30 @@
 | 
				
			|||||||
module gorm.io/gorm/tests
 | 
					module gorm.io/gorm/tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
go 1.23.0
 | 
					go 1.18
 | 
				
			||||||
 | 
					
 | 
				
			||||||
require (
 | 
					require (
 | 
				
			||||||
	github.com/google/uuid v1.6.0
 | 
						github.com/google/uuid v1.3.0
 | 
				
			||||||
	github.com/jinzhu/now v1.1.5
 | 
						github.com/jinzhu/now v1.1.5
 | 
				
			||||||
	github.com/lib/pq v1.10.9
 | 
						github.com/lib/pq v1.10.9
 | 
				
			||||||
	github.com/stretchr/testify v1.10.0
 | 
						gorm.io/driver/mysql v1.5.2-0.20230612053416-48b6526a21f0
 | 
				
			||||||
	gorm.io/driver/gaussdb v0.1.0
 | 
						gorm.io/driver/postgres v1.5.3-0.20230607070428-18bc84b75196
 | 
				
			||||||
	gorm.io/driver/mysql v1.6.0
 | 
						gorm.io/driver/sqlite v1.5.2
 | 
				
			||||||
	gorm.io/driver/postgres v1.6.0
 | 
						gorm.io/driver/sqlserver v1.5.2-0.20230613072041-6e2cde390b0a
 | 
				
			||||||
	gorm.io/driver/sqlite v1.6.0
 | 
						gorm.io/gorm v1.25.2
 | 
				
			||||||
	gorm.io/driver/sqlserver v1.6.1
 | 
					 | 
				
			||||||
	gorm.io/gorm v1.30.0
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
require (
 | 
					require (
 | 
				
			||||||
	filippo.io/edwards25519 v1.1.0 // indirect
 | 
						github.com/go-sql-driver/mysql v1.7.1 // indirect
 | 
				
			||||||
	github.com/HuaweiCloudDeveloper/gaussdb-go v1.0.0-rc1 // indirect
 | 
					 | 
				
			||||||
	github.com/davecgh/go-spew v1.1.1 // indirect
 | 
					 | 
				
			||||||
	github.com/go-sql-driver/mysql v1.9.3 // indirect
 | 
					 | 
				
			||||||
	github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
 | 
						github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
 | 
				
			||||||
	github.com/golang-sql/sqlexp v0.1.0 // indirect
 | 
						github.com/golang-sql/sqlexp v0.1.0 // indirect
 | 
				
			||||||
	github.com/jackc/pgpassfile v1.0.0 // indirect
 | 
						github.com/jackc/pgpassfile v1.0.0 // indirect
 | 
				
			||||||
	github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
 | 
						github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
 | 
				
			||||||
	github.com/jackc/pgx/v5 v5.7.5 // indirect
 | 
						github.com/jackc/pgx/v5 v5.4.2 // indirect
 | 
				
			||||||
	github.com/jackc/puddle/v2 v2.2.2 // indirect
 | 
					 | 
				
			||||||
	github.com/jinzhu/inflection v1.0.0 // indirect
 | 
						github.com/jinzhu/inflection v1.0.0 // indirect
 | 
				
			||||||
	github.com/mattn/go-sqlite3 v1.14.28 // indirect
 | 
						github.com/mattn/go-sqlite3 v1.14.17 // indirect
 | 
				
			||||||
	github.com/microsoft/go-mssqldb v1.9.2 // indirect
 | 
						github.com/microsoft/go-mssqldb v1.5.0 // indirect
 | 
				
			||||||
	github.com/pmezard/go-difflib v1.0.0 // indirect
 | 
						golang.org/x/crypto v0.12.0 // indirect
 | 
				
			||||||
	github.com/tjfoc/gmsm v1.4.1 // indirect
 | 
						golang.org/x/text v0.12.0 // indirect
 | 
				
			||||||
	golang.org/x/crypto v0.40.0 // indirect
 | 
					 | 
				
			||||||
	golang.org/x/sync v0.16.0 // indirect
 | 
					 | 
				
			||||||
	golang.org/x/text v0.27.0 // indirect
 | 
					 | 
				
			||||||
	gopkg.in/yaml.v3 v3.0.1 // indirect
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
replace gorm.io/gorm => ../
 | 
					replace gorm.io/gorm => ../
 | 
				
			||||||
 | 
				
			|||||||
@ -23,7 +23,6 @@ type Config struct {
 | 
				
			|||||||
	Languages int
 | 
						Languages int
 | 
				
			||||||
	Friends   int
 | 
						Friends   int
 | 
				
			||||||
	NamedPet  bool
 | 
						NamedPet  bool
 | 
				
			||||||
	Tools     int
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetUser(name string, config Config) *User {
 | 
					func GetUser(name string, config Config) *User {
 | 
				
			||||||
@ -48,10 +47,6 @@ func GetUser(name string, config Config) *User {
 | 
				
			|||||||
		user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)})
 | 
							user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for i := 0; i < config.Tools; i++ {
 | 
					 | 
				
			||||||
		user.Tools = append(user.Tools, Tools{Name: name + "_tool_" + strconv.Itoa(i+1)})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if config.Company {
 | 
						if config.Company {
 | 
				
			||||||
		user.Company = Company{Name: "company-" + name}
 | 
							user.Company = Company{Name: "company-" + name}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -123,13 +118,11 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) {
 | 
				
			|||||||
		if err := db(unscoped).Where("id = ?", user.ID).First(&newUser).Error; err != nil {
 | 
							if err := db(unscoped).Where("id = ?", user.ID).First(&newUser).Error; err != nil {
 | 
				
			||||||
			t.Fatalf("errors happened when query: %v", err)
 | 
								t.Fatalf("errors happened when query: %v", err)
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday",
 | 
								AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
 | 
				
			||||||
				"CompanyID", "ManagerID", "Active")
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID",
 | 
						AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
 | 
				
			||||||
		"ManagerID", "Active")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	t.Run("Account", func(t *testing.T) {
 | 
						t.Run("Account", func(t *testing.T) {
 | 
				
			||||||
		AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number")
 | 
							AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number")
 | 
				
			||||||
@ -140,8 +133,7 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) {
 | 
				
			|||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				var account Account
 | 
									var account Account
 | 
				
			||||||
				db(unscoped).First(&account, "user_id = ?", user.ID)
 | 
									db(unscoped).First(&account, "user_id = ?", user.ID)
 | 
				
			||||||
				AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID",
 | 
									AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number")
 | 
				
			||||||
					"Number")
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
@ -201,10 +193,8 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) {
 | 
				
			|||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				var manager User
 | 
									var manager User
 | 
				
			||||||
				db(unscoped).First(&manager, "id = ?", *user.ManagerID)
 | 
									db(unscoped).First(&manager, "id = ?", *user.ManagerID)
 | 
				
			||||||
				AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
 | 
									AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
 | 
				
			||||||
					"Birthday", "CompanyID", "ManagerID", "Active")
 | 
									AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
 | 
				
			||||||
				AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
 | 
					 | 
				
			||||||
					"Birthday", "CompanyID", "ManagerID", "Active")
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		} else if user.ManagerID != nil {
 | 
							} else if user.ManagerID != nil {
 | 
				
			||||||
			t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID)
 | 
								t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID)
 | 
				
			||||||
@ -225,8 +215,7 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) {
 | 
				
			|||||||
		})
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		for idx, team := range user.Team {
 | 
							for idx, team := range user.Team {
 | 
				
			||||||
			AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
 | 
								AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
 | 
				
			||||||
				"Birthday", "CompanyID", "ManagerID", "Active")
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -261,8 +250,7 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) {
 | 
				
			|||||||
		})
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		for idx, friend := range user.Friends {
 | 
							for idx, friend := range user.Friends {
 | 
				
			||||||
			AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
 | 
								AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
 | 
				
			||||||
				"Birthday", "CompanyID", "ManagerID", "Active")
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -281,10 +269,6 @@ func isMysql() bool {
 | 
				
			|||||||
	return os.Getenv("GORM_DIALECT") == "mysql"
 | 
						return os.Getenv("GORM_DIALECT") == "mysql"
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func isSqlite() bool {
 | 
					 | 
				
			||||||
	return os.Getenv("GORM_DIALECT") == "sqlite"
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func db(unscoped bool) *gorm.DB {
 | 
					func db(unscoped bool) *gorm.DB {
 | 
				
			||||||
	if unscoped {
 | 
						if unscoped {
 | 
				
			||||||
		return DB.Unscoped()
 | 
							return DB.Unscoped()
 | 
				
			||||||
 | 
				
			|||||||
@ -2,8 +2,6 @@ package tests_test
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"log"
 | 
					 | 
				
			||||||
	"os"
 | 
					 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
@ -568,44 +566,3 @@ func TestUpdateCallbacks(t *testing.T) {
 | 
				
			|||||||
		t.Fatalf("before update should not be called")
 | 
							t.Fatalf("before update should not be called")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
type Product6 struct {
 | 
					 | 
				
			||||||
	gorm.Model
 | 
					 | 
				
			||||||
	Name string
 | 
					 | 
				
			||||||
	Item *ProductItem2
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ProductItem2 struct {
 | 
					 | 
				
			||||||
	gorm.Model
 | 
					 | 
				
			||||||
	Product6ID uint
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (p *Product6) BeforeDelete(tx *gorm.DB) error {
 | 
					 | 
				
			||||||
	if err := tx.Delete(&p.Item).Error; err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestPropagateUnscoped(t *testing.T) {
 | 
					 | 
				
			||||||
	_DB, err := OpenTestConnection(&gorm.Config{
 | 
					 | 
				
			||||||
		PropagateUnscoped: true,
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		log.Printf("failed to connect database, got error %v", err)
 | 
					 | 
				
			||||||
		os.Exit(1)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	_DB.Migrator().DropTable(&Product6{}, &ProductItem2{})
 | 
					 | 
				
			||||||
	_DB.AutoMigrate(&Product6{}, &ProductItem2{})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	p := Product6{
 | 
					 | 
				
			||||||
		Name: "unique_code",
 | 
					 | 
				
			||||||
		Item: &ProductItem2{},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	_DB.Model(&Product6{}).Create(&p)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := _DB.Unscoped().Delete(&p).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("unscoped did not propagate")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -1,12 +1,10 @@
 | 
				
			|||||||
package tests_test
 | 
					package tests_test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
	"sort"
 | 
						"sort"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/stretchr/testify/assert"
 | 
					 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	. "gorm.io/gorm/utils/tests"
 | 
						. "gorm.io/gorm/utils/tests"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@ -186,12 +184,14 @@ func TestJoinCount(t *testing.T) {
 | 
				
			|||||||
	DB.Create(&user)
 | 
						DB.Create(&user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	query := DB.Model(&User{}).Joins("Company")
 | 
						query := DB.Model(&User{}).Joins("Company")
 | 
				
			||||||
 | 
						// Bug happens when .Count is called on a query.
 | 
				
			||||||
 | 
						// Removing the below two lines or downgrading to gorm v1.20.12 will make this test pass.
 | 
				
			||||||
	var total int64
 | 
						var total int64
 | 
				
			||||||
	query.Count(&total)
 | 
						query.Count(&total)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var result User
 | 
						var result User
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Incorrectly generates a 'SELECT *' query which causes companies.id to overwrite users.id
 | 
				
			||||||
	if err := query.First(&result, user.ID).Error; err != nil {
 | 
						if err := query.First(&result, user.ID).Error; err != nil {
 | 
				
			||||||
		t.Fatalf("Failed, got error: %v", err)
 | 
							t.Fatalf("Failed, got error: %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -199,10 +199,6 @@ func TestJoinCount(t *testing.T) {
 | 
				
			|||||||
	if result.ID != user.ID {
 | 
						if result.ID != user.ID {
 | 
				
			||||||
		t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID)
 | 
							t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	// should find company
 | 
					 | 
				
			||||||
	if result.Company.ID != *user.CompanyID {
 | 
					 | 
				
			||||||
		t.Fatalf("result's id, %d, doesn't match user's company id, %d", result.Company.ID, *user.CompanyID)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestJoinWithSoftDeleted(t *testing.T) {
 | 
					func TestJoinWithSoftDeleted(t *testing.T) {
 | 
				
			||||||
@ -404,75 +400,3 @@ func TestNestedJoins(t *testing.T) {
 | 
				
			|||||||
		CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
 | 
							CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestJoinsPreload_Issue7013(t *testing.T) {
 | 
					 | 
				
			||||||
	manager := &User{Name: "Manager"}
 | 
					 | 
				
			||||||
	DB.Create(manager)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var userIDs []uint
 | 
					 | 
				
			||||||
	for i := 0; i < 21; i++ {
 | 
					 | 
				
			||||||
		user := &User{Name: fmt.Sprintf("User%d", i), ManagerID: &manager.ID}
 | 
					 | 
				
			||||||
		DB.Create(user)
 | 
					 | 
				
			||||||
		userIDs = append(userIDs, user.ID)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var entries []User
 | 
					 | 
				
			||||||
	assert.NotPanics(t, func() {
 | 
					 | 
				
			||||||
		assert.NoError(t,
 | 
					 | 
				
			||||||
			DB.Preload("Manager.Team").
 | 
					 | 
				
			||||||
				Joins("Manager.Company").
 | 
					 | 
				
			||||||
				Find(&entries).Error)
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestJoinsPreload_Issue7013_RelationEmpty(t *testing.T) {
 | 
					 | 
				
			||||||
	type (
 | 
					 | 
				
			||||||
		Furniture struct {
 | 
					 | 
				
			||||||
			gorm.Model
 | 
					 | 
				
			||||||
			OwnerID *uint
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		Owner struct {
 | 
					 | 
				
			||||||
			gorm.Model
 | 
					 | 
				
			||||||
			Furnitures []Furniture
 | 
					 | 
				
			||||||
			CompanyID  *uint
 | 
					 | 
				
			||||||
			Company    Company
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		Building struct {
 | 
					 | 
				
			||||||
			gorm.Model
 | 
					 | 
				
			||||||
			Name    string
 | 
					 | 
				
			||||||
			OwnerID *uint
 | 
					 | 
				
			||||||
			Owner   Owner
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.Migrator().DropTable(&Building{}, &Owner{}, &Furniture{})
 | 
					 | 
				
			||||||
	DB.Migrator().AutoMigrate(&Building{}, &Owner{}, &Furniture{})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	home := &Building{Name: "relation_empty"}
 | 
					 | 
				
			||||||
	DB.Create(home)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var entries []Building
 | 
					 | 
				
			||||||
	assert.NotPanics(t, func() {
 | 
					 | 
				
			||||||
		assert.NoError(t,
 | 
					 | 
				
			||||||
			DB.Preload("Owner.Furnitures").
 | 
					 | 
				
			||||||
				Joins("Owner.Company").
 | 
					 | 
				
			||||||
				Find(&entries).Error)
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	AssertEqual(t, entries, []Building{{Model: home.Model, Name: "relation_empty", Owner: Owner{Company: Company{}}}})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestJoinsPreload_Issue7013_NoEntries(t *testing.T) {
 | 
					 | 
				
			||||||
	var entries []User
 | 
					 | 
				
			||||||
	assert.NotPanics(t, func() {
 | 
					 | 
				
			||||||
		assert.NoError(t,
 | 
					 | 
				
			||||||
			DB.Preload("Manager.Team").
 | 
					 | 
				
			||||||
				Joins("Manager.Company").
 | 
					 | 
				
			||||||
				Where("1 <> 1").
 | 
					 | 
				
			||||||
				Find(&entries).Error)
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	AssertEqual(t, len(entries), 0)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -1,529 +0,0 @@
 | 
				
			|||||||
package tests_test
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"crypto/rand"
 | 
					 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"gorm.io/gorm/internal/lru"
 | 
					 | 
				
			||||||
	"math"
 | 
					 | 
				
			||||||
	"math/big"
 | 
					 | 
				
			||||||
	"reflect"
 | 
					 | 
				
			||||||
	"sync"
 | 
					 | 
				
			||||||
	"testing"
 | 
					 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestLRU_Add_ExistingKey_UpdatesValueAndExpiresAt(t *testing.T) {
 | 
					 | 
				
			||||||
	lru := lru.NewLRU[string, int](10, nil, time.Hour)
 | 
					 | 
				
			||||||
	lru.Add("key1", 1)
 | 
					 | 
				
			||||||
	lru.Add("key1", 2)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if value, ok := lru.Get("key1"); !ok || value != 2 {
 | 
					 | 
				
			||||||
		t.Errorf("Expected value to be updated to 2, got %v", value)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestLRU_Add_NewKey_AddsEntry(t *testing.T) {
 | 
					 | 
				
			||||||
	lru := lru.NewLRU[string, int](10, nil, time.Hour)
 | 
					 | 
				
			||||||
	lru.Add("key1", 1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if value, ok := lru.Get("key1"); !ok || value != 1 {
 | 
					 | 
				
			||||||
		t.Errorf("Expected key1 to be added with value 1, got %v", value)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestLRU_Add_ExceedsSize_RemovesOldest(t *testing.T) {
 | 
					 | 
				
			||||||
	lru := lru.NewLRU[string, int](2, nil, time.Hour)
 | 
					 | 
				
			||||||
	lru.Add("key1", 1)
 | 
					 | 
				
			||||||
	lru.Add("key2", 2)
 | 
					 | 
				
			||||||
	lru.Add("key3", 3)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if _, ok := lru.Get("key1"); ok {
 | 
					 | 
				
			||||||
		t.Errorf("Expected key1 to be removed, but it still exists")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestLRU_Add_UnlimitedSize_NoEviction(t *testing.T) {
 | 
					 | 
				
			||||||
	lru := lru.NewLRU[string, int](0, nil, time.Hour)
 | 
					 | 
				
			||||||
	lru.Add("key1", 1)
 | 
					 | 
				
			||||||
	lru.Add("key2", 2)
 | 
					 | 
				
			||||||
	lru.Add("key3", 3)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if _, ok := lru.Get("key1"); !ok {
 | 
					 | 
				
			||||||
		t.Errorf("Expected key1 to exist, but it was evicted")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestLRU_Add_Eviction(t *testing.T) {
 | 
					 | 
				
			||||||
	lru := lru.NewLRU[string, int](0, nil, time.Second*2)
 | 
					 | 
				
			||||||
	lru.Add("key1", 1)
 | 
					 | 
				
			||||||
	lru.Add("key2", 2)
 | 
					 | 
				
			||||||
	lru.Add("key3", 3)
 | 
					 | 
				
			||||||
	time.Sleep(time.Second * 3)
 | 
					 | 
				
			||||||
	if lru.Cap() != 0 {
 | 
					 | 
				
			||||||
		t.Errorf("Expected lru to be empty, but it was not")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func BenchmarkLRU_Rand_NoExpire(b *testing.B) {
 | 
					 | 
				
			||||||
	l := lru.NewLRU[int64, int64](8192, nil, 0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	trace := make([]int64, b.N*2)
 | 
					 | 
				
			||||||
	for i := 0; i < b.N*2; i++ {
 | 
					 | 
				
			||||||
		trace[i] = getRand(b) % 32768
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	b.ResetTimer()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var hit, miss int
 | 
					 | 
				
			||||||
	for i := 0; i < 2*b.N; i++ {
 | 
					 | 
				
			||||||
		if i%2 == 0 {
 | 
					 | 
				
			||||||
			l.Add(trace[i], trace[i])
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			if _, ok := l.Get(trace[i]); ok {
 | 
					 | 
				
			||||||
				hit++
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				miss++
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func BenchmarkLRU_Freq_NoExpire(b *testing.B) {
 | 
					 | 
				
			||||||
	l := lru.NewLRU[int64, int64](8192, nil, 0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	trace := make([]int64, b.N*2)
 | 
					 | 
				
			||||||
	for i := 0; i < b.N*2; i++ {
 | 
					 | 
				
			||||||
		if i%2 == 0 {
 | 
					 | 
				
			||||||
			trace[i] = getRand(b) % 16384
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			trace[i] = getRand(b) % 32768
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	b.ResetTimer()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for i := 0; i < b.N; i++ {
 | 
					 | 
				
			||||||
		l.Add(trace[i], trace[i])
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	var hit, miss int
 | 
					 | 
				
			||||||
	for i := 0; i < b.N; i++ {
 | 
					 | 
				
			||||||
		if _, ok := l.Get(trace[i]); ok {
 | 
					 | 
				
			||||||
			hit++
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			miss++
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func BenchmarkLRU_Rand_WithExpire(b *testing.B) {
 | 
					 | 
				
			||||||
	l := lru.NewLRU[int64, int64](8192, nil, time.Millisecond*10)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	trace := make([]int64, b.N*2)
 | 
					 | 
				
			||||||
	for i := 0; i < b.N*2; i++ {
 | 
					 | 
				
			||||||
		trace[i] = getRand(b) % 32768
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	b.ResetTimer()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var hit, miss int
 | 
					 | 
				
			||||||
	for i := 0; i < 2*b.N; i++ {
 | 
					 | 
				
			||||||
		if i%2 == 0 {
 | 
					 | 
				
			||||||
			l.Add(trace[i], trace[i])
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			if _, ok := l.Get(trace[i]); ok {
 | 
					 | 
				
			||||||
				hit++
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				miss++
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func BenchmarkLRU_Freq_WithExpire(b *testing.B) {
 | 
					 | 
				
			||||||
	l := lru.NewLRU[int64, int64](8192, nil, time.Millisecond*10)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	trace := make([]int64, b.N*2)
 | 
					 | 
				
			||||||
	for i := 0; i < b.N*2; i++ {
 | 
					 | 
				
			||||||
		if i%2 == 0 {
 | 
					 | 
				
			||||||
			trace[i] = getRand(b) % 16384
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			trace[i] = getRand(b) % 32768
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	b.ResetTimer()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for i := 0; i < b.N; i++ {
 | 
					 | 
				
			||||||
		l.Add(trace[i], trace[i])
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	var hit, miss int
 | 
					 | 
				
			||||||
	for i := 0; i < b.N; i++ {
 | 
					 | 
				
			||||||
		if _, ok := l.Get(trace[i]); ok {
 | 
					 | 
				
			||||||
			hit++
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			miss++
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestLRUNoPurge(t *testing.T) {
 | 
					 | 
				
			||||||
	lc := lru.NewLRU[string, string](10, nil, 0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	lc.Add("key1", "val1")
 | 
					 | 
				
			||||||
	if lc.Len() != 1 {
 | 
					 | 
				
			||||||
		t.Fatalf("length differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	v, ok := lc.Peek("key1")
 | 
					 | 
				
			||||||
	if v != "val1" {
 | 
					 | 
				
			||||||
		t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if !ok {
 | 
					 | 
				
			||||||
		t.Fatalf("should be true")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !lc.Contains("key1") {
 | 
					 | 
				
			||||||
		t.Fatalf("should contain key1")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if lc.Contains("key2") {
 | 
					 | 
				
			||||||
		t.Fatalf("should not contain key2")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	v, ok = lc.Peek("key2")
 | 
					 | 
				
			||||||
	if v != "" {
 | 
					 | 
				
			||||||
		t.Fatalf("should be empty")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if ok {
 | 
					 | 
				
			||||||
		t.Fatalf("should be false")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !reflect.DeepEqual(lc.Keys(), []string{"key1"}) {
 | 
					 | 
				
			||||||
		t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if lc.Resize(0) != 0 {
 | 
					 | 
				
			||||||
		t.Fatalf("evicted count differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if lc.Resize(2) != 0 {
 | 
					 | 
				
			||||||
		t.Fatalf("evicted count differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	lc.Add("key2", "val2")
 | 
					 | 
				
			||||||
	if lc.Resize(1) != 1 {
 | 
					 | 
				
			||||||
		t.Fatalf("evicted count differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestLRUEdgeCases(t *testing.T) {
 | 
					 | 
				
			||||||
	lc := lru.NewLRU[string, *string](2, nil, 0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Adding a nil value
 | 
					 | 
				
			||||||
	lc.Add("key1", nil)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	value, exists := lc.Get("key1")
 | 
					 | 
				
			||||||
	if value != nil || !exists {
 | 
					 | 
				
			||||||
		t.Fatalf("unexpected value or existence flag for key1: value=%v, exists=%v", value, exists)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Adding an entry with the same key but different value
 | 
					 | 
				
			||||||
	newVal := "val1"
 | 
					 | 
				
			||||||
	lc.Add("key1", &newVal)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	value, exists = lc.Get("key1")
 | 
					 | 
				
			||||||
	if value != &newVal || !exists {
 | 
					 | 
				
			||||||
		t.Fatalf("unexpected value or existence flag for key1: value=%v, exists=%v", value, exists)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestLRU_Values(t *testing.T) {
 | 
					 | 
				
			||||||
	lc := lru.NewLRU[string, string](3, nil, 0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	lc.Add("key1", "val1")
 | 
					 | 
				
			||||||
	lc.Add("key2", "val2")
 | 
					 | 
				
			||||||
	lc.Add("key3", "val3")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	values := lc.Values()
 | 
					 | 
				
			||||||
	if !reflect.DeepEqual(values, []string{"val1", "val2", "val3"}) {
 | 
					 | 
				
			||||||
		t.Fatalf("values differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// func TestExpirableMultipleClose(_ *testing.T) {
 | 
					 | 
				
			||||||
//	lc :=lru.NewLRU[string, string](10, nil, 0)
 | 
					 | 
				
			||||||
//	lc.Close()
 | 
					 | 
				
			||||||
//	// should not panic
 | 
					 | 
				
			||||||
//	lc.Close()
 | 
					 | 
				
			||||||
// }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestLRUWithPurge(t *testing.T) {
 | 
					 | 
				
			||||||
	var evicted []string
 | 
					 | 
				
			||||||
	lc := lru.NewLRU(10, func(key string, value string) { evicted = append(evicted, key, value) }, 150*time.Millisecond)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	k, v, ok := lc.GetOldest()
 | 
					 | 
				
			||||||
	if k != "" {
 | 
					 | 
				
			||||||
		t.Fatalf("should be empty")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if v != "" {
 | 
					 | 
				
			||||||
		t.Fatalf("should be empty")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if ok {
 | 
					 | 
				
			||||||
		t.Fatalf("should be false")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	lc.Add("key1", "val1")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	time.Sleep(100 * time.Millisecond) // not enough to expire
 | 
					 | 
				
			||||||
	if lc.Len() != 1 {
 | 
					 | 
				
			||||||
		t.Fatalf("length differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	v, ok = lc.Get("key1")
 | 
					 | 
				
			||||||
	if v != "val1" {
 | 
					 | 
				
			||||||
		t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if !ok {
 | 
					 | 
				
			||||||
		t.Fatalf("should be true")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	time.Sleep(200 * time.Millisecond) // expire
 | 
					 | 
				
			||||||
	v, ok = lc.Get("key1")
 | 
					 | 
				
			||||||
	if ok {
 | 
					 | 
				
			||||||
		t.Fatalf("should be false")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if v != "" {
 | 
					 | 
				
			||||||
		t.Fatalf("should be nil")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if lc.Len() != 0 {
 | 
					 | 
				
			||||||
		t.Fatalf("length differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if !reflect.DeepEqual(evicted, []string{"key1", "val1"}) {
 | 
					 | 
				
			||||||
		t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// add new entry
 | 
					 | 
				
			||||||
	lc.Add("key2", "val2")
 | 
					 | 
				
			||||||
	if lc.Len() != 1 {
 | 
					 | 
				
			||||||
		t.Fatalf("length differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	k, v, ok = lc.GetOldest()
 | 
					 | 
				
			||||||
	if k != "key2" {
 | 
					 | 
				
			||||||
		t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if v != "val2" {
 | 
					 | 
				
			||||||
		t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if !ok {
 | 
					 | 
				
			||||||
		t.Fatalf("should be true")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestLRUWithPurgeEnforcedBySize(t *testing.T) {
 | 
					 | 
				
			||||||
	lc := lru.NewLRU[string, string](10, nil, time.Hour)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for i := 0; i < 100; i++ {
 | 
					 | 
				
			||||||
		i := i
 | 
					 | 
				
			||||||
		lc.Add(fmt.Sprintf("key%d", i), fmt.Sprintf("val%d", i))
 | 
					 | 
				
			||||||
		v, ok := lc.Get(fmt.Sprintf("key%d", i))
 | 
					 | 
				
			||||||
		if v != fmt.Sprintf("val%d", i) {
 | 
					 | 
				
			||||||
			t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if !ok {
 | 
					 | 
				
			||||||
			t.Fatalf("should be true")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if lc.Len() > 20 {
 | 
					 | 
				
			||||||
			t.Fatalf("length should be less than 20")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if lc.Len() != 10 {
 | 
					 | 
				
			||||||
		t.Fatalf("length differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestLRUConcurrency(t *testing.T) {
 | 
					 | 
				
			||||||
	lc := lru.NewLRU[string, string](0, nil, 0)
 | 
					 | 
				
			||||||
	wg := sync.WaitGroup{}
 | 
					 | 
				
			||||||
	wg.Add(1000)
 | 
					 | 
				
			||||||
	for i := 0; i < 1000; i++ {
 | 
					 | 
				
			||||||
		go func(i int) {
 | 
					 | 
				
			||||||
			lc.Add(fmt.Sprintf("key-%d", i/10), fmt.Sprintf("val-%d", i/10))
 | 
					 | 
				
			||||||
			wg.Done()
 | 
					 | 
				
			||||||
		}(i)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	wg.Wait()
 | 
					 | 
				
			||||||
	if lc.Len() != 100 {
 | 
					 | 
				
			||||||
		t.Fatalf("length differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestLRUInvalidateAndEvict(t *testing.T) {
 | 
					 | 
				
			||||||
	var evicted int
 | 
					 | 
				
			||||||
	lc := lru.NewLRU(-1, func(_, _ string) { evicted++ }, 0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	lc.Add("key1", "val1")
 | 
					 | 
				
			||||||
	lc.Add("key2", "val2")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	val, ok := lc.Get("key1")
 | 
					 | 
				
			||||||
	if !ok {
 | 
					 | 
				
			||||||
		t.Fatalf("should be true")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if val != "val1" {
 | 
					 | 
				
			||||||
		t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if evicted != 0 {
 | 
					 | 
				
			||||||
		t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	lc.Remove("key1")
 | 
					 | 
				
			||||||
	if evicted != 1 {
 | 
					 | 
				
			||||||
		t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	val, ok = lc.Get("key1")
 | 
					 | 
				
			||||||
	if val != "" {
 | 
					 | 
				
			||||||
		t.Fatalf("should be empty")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if ok {
 | 
					 | 
				
			||||||
		t.Fatalf("should be false")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestLoadingExpired(t *testing.T) {
 | 
					 | 
				
			||||||
	lc := lru.NewLRU[string, string](0, nil, time.Millisecond*5)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	lc.Add("key1", "val1")
 | 
					 | 
				
			||||||
	if lc.Len() != 1 {
 | 
					 | 
				
			||||||
		t.Fatalf("length differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	v, ok := lc.Peek("key1")
 | 
					 | 
				
			||||||
	if v != "val1" {
 | 
					 | 
				
			||||||
		t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if !ok {
 | 
					 | 
				
			||||||
		t.Fatalf("should be true")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	v, ok = lc.Get("key1")
 | 
					 | 
				
			||||||
	if v != "val1" {
 | 
					 | 
				
			||||||
		t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if !ok {
 | 
					 | 
				
			||||||
		t.Fatalf("should be true")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for {
 | 
					 | 
				
			||||||
		result, ok := lc.Get("key1")
 | 
					 | 
				
			||||||
		if ok && result == "" {
 | 
					 | 
				
			||||||
			t.Fatalf("ok should return a result")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if !ok {
 | 
					 | 
				
			||||||
			break
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	time.Sleep(time.Millisecond * 100) // wait for expiration reaper
 | 
					 | 
				
			||||||
	if lc.Len() != 0 {
 | 
					 | 
				
			||||||
		t.Fatalf("length differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	v, ok = lc.Peek("key1")
 | 
					 | 
				
			||||||
	if v != "" {
 | 
					 | 
				
			||||||
		t.Fatalf("should be empty")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if ok {
 | 
					 | 
				
			||||||
		t.Fatalf("should be false")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	v, ok = lc.Get("key1")
 | 
					 | 
				
			||||||
	if v != "" {
 | 
					 | 
				
			||||||
		t.Fatalf("should be empty")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if ok {
 | 
					 | 
				
			||||||
		t.Fatalf("should be false")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestLRURemoveOldest(t *testing.T) {
 | 
					 | 
				
			||||||
	lc := lru.NewLRU[string, string](2, nil, 0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if lc.Cap() != 2 {
 | 
					 | 
				
			||||||
		t.Fatalf("expect cap is 2")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	k, v, ok := lc.RemoveOldest()
 | 
					 | 
				
			||||||
	if k != "" {
 | 
					 | 
				
			||||||
		t.Fatalf("should be empty")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if v != "" {
 | 
					 | 
				
			||||||
		t.Fatalf("should be empty")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if ok {
 | 
					 | 
				
			||||||
		t.Fatalf("should be false")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	ok = lc.Remove("non_existent")
 | 
					 | 
				
			||||||
	if ok {
 | 
					 | 
				
			||||||
		t.Fatalf("should be false")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	lc.Add("key1", "val1")
 | 
					 | 
				
			||||||
	if lc.Len() != 1 {
 | 
					 | 
				
			||||||
		t.Fatalf("length differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	v, ok = lc.Get("key1")
 | 
					 | 
				
			||||||
	if !ok {
 | 
					 | 
				
			||||||
		t.Fatalf("should be true")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if v != "val1" {
 | 
					 | 
				
			||||||
		t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !reflect.DeepEqual(lc.Keys(), []string{"key1"}) {
 | 
					 | 
				
			||||||
		t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if lc.Len() != 1 {
 | 
					 | 
				
			||||||
		t.Fatalf("length differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	lc.Add("key2", "val2")
 | 
					 | 
				
			||||||
	if !reflect.DeepEqual(lc.Keys(), []string{"key1", "key2"}) {
 | 
					 | 
				
			||||||
		t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if lc.Len() != 2 {
 | 
					 | 
				
			||||||
		t.Fatalf("length differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	k, v, ok = lc.RemoveOldest()
 | 
					 | 
				
			||||||
	if k != "key1" {
 | 
					 | 
				
			||||||
		t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if v != "val1" {
 | 
					 | 
				
			||||||
		t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if !ok {
 | 
					 | 
				
			||||||
		t.Fatalf("should be true")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !reflect.DeepEqual(lc.Keys(), []string{"key2"}) {
 | 
					 | 
				
			||||||
		t.Fatalf("value differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if lc.Len() != 1 {
 | 
					 | 
				
			||||||
		t.Fatalf("length differs from expected")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func getRand(tb testing.TB) int64 {
 | 
					 | 
				
			||||||
	out, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64))
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		tb.Fatal(err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return out.Int64()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@ -2,29 +2,23 @@ package tests_test
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"database/sql"
 | 
					 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"math/rand"
 | 
						"math/rand"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"strconv"
 | 
					 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/stretchr/testify/assert"
 | 
					 | 
				
			||||||
	"gorm.io/driver/gaussdb"
 | 
					 | 
				
			||||||
	"gorm.io/driver/postgres"
 | 
						"gorm.io/driver/postgres"
 | 
				
			||||||
 | 
					 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	"gorm.io/gorm/clause"
 | 
						"gorm.io/gorm/logger"
 | 
				
			||||||
	"gorm.io/gorm/migrator"
 | 
					 | 
				
			||||||
	"gorm.io/gorm/schema"
 | 
						"gorm.io/gorm/schema"
 | 
				
			||||||
	"gorm.io/gorm/utils"
 | 
					 | 
				
			||||||
	. "gorm.io/gorm/utils/tests"
 | 
						. "gorm.io/gorm/utils/tests"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestMigrate(t *testing.T) {
 | 
					func TestMigrate(t *testing.T) {
 | 
				
			||||||
	allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Tools{}}
 | 
						allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}}
 | 
				
			||||||
	rand.Seed(time.Now().UnixNano())
 | 
						rand.Seed(time.Now().UnixNano())
 | 
				
			||||||
	rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
 | 
						rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
 | 
				
			||||||
	DB.Migrator().DropTable("user_speaks", "user_friends", "ccc")
 | 
						DB.Migrator().DropTable("user_speaks", "user_friends", "ccc")
 | 
				
			||||||
@ -40,7 +34,7 @@ func TestMigrate(t *testing.T) {
 | 
				
			|||||||
	if tables, err := DB.Migrator().GetTables(); err != nil {
 | 
						if tables, err := DB.Migrator().GetTables(); err != nil {
 | 
				
			||||||
		t.Fatalf("Failed to get database all tables, but got error %v", err)
 | 
							t.Fatalf("Failed to get database all tables, but got error %v", err)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages", "tools"} {
 | 
							for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages"} {
 | 
				
			||||||
			hasTable := false
 | 
								hasTable := false
 | 
				
			||||||
			for _, t2 := range tables {
 | 
								for _, t2 := range tables {
 | 
				
			||||||
				if t2 == t1 {
 | 
									if t2 == t1 {
 | 
				
			||||||
@ -83,8 +77,8 @@ func TestMigrate(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestAutoMigrateInt8PGAndGaussDB(t *testing.T) {
 | 
					func TestAutoMigrateInt8PG(t *testing.T) {
 | 
				
			||||||
	if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" {
 | 
						if DB.Dialector.Name() != "postgres" {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -99,8 +93,7 @@ func TestAutoMigrateInt8PGAndGaussDB(t *testing.T) {
 | 
				
			|||||||
		Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
 | 
							Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
 | 
				
			||||||
			sql, _ := fc()
 | 
								sql, _ := fc()
 | 
				
			||||||
			if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") {
 | 
								if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") {
 | 
				
			||||||
				t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s",
 | 
									t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql)
 | 
				
			||||||
					sql)
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -140,137 +133,8 @@ func TestAutoMigrateSelfReferential(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestAutoMigrateNullable(t *testing.T) {
 | 
					 | 
				
			||||||
	type MigrateNullableColumn struct {
 | 
					 | 
				
			||||||
		ID    uint
 | 
					 | 
				
			||||||
		Bonus float64 `gorm:"not null"`
 | 
					 | 
				
			||||||
		Stock float64
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.Migrator().DropTable(&MigrateNullableColumn{})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.AutoMigrate(&MigrateNullableColumn{})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type MigrateNullableColumn2 struct {
 | 
					 | 
				
			||||||
		ID    uint
 | 
					 | 
				
			||||||
		Bonus float64
 | 
					 | 
				
			||||||
		Stock float64 `gorm:"not null"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.Table("migrate_nullable_columns").AutoMigrate(&MigrateNullableColumn2{}); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to auto migrate, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	columnTypes, err := DB.Table("migrate_nullable_columns").Migrator().ColumnTypes(&MigrateNullableColumn{})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to get column types, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, columnType := range columnTypes {
 | 
					 | 
				
			||||||
		switch columnType.Name() {
 | 
					 | 
				
			||||||
		case "bonus":
 | 
					 | 
				
			||||||
			// allow to change non-nullable to nullable
 | 
					 | 
				
			||||||
			if nullable, _ := columnType.Nullable(); !nullable {
 | 
					 | 
				
			||||||
				t.Fatalf("bonus's nullable should be true, bug got %t", nullable)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		case "stock":
 | 
					 | 
				
			||||||
			// do not allow to change nullable to non-nullable
 | 
					 | 
				
			||||||
			if nullable, _ := columnType.Nullable(); !nullable {
 | 
					 | 
				
			||||||
				t.Fatalf("stock's nullable should be true, bug got %t", nullable)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestSmartMigrateColumn(t *testing.T) {
 | 
					func TestSmartMigrateColumn(t *testing.T) {
 | 
				
			||||||
	fullSupported := map[string]bool{"mysql": true, "postgres": true, "gaussdb": true}[DB.Dialector.Name()]
 | 
						fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()]
 | 
				
			||||||
 | 
					 | 
				
			||||||
	type UserMigrateColumn struct {
 | 
					 | 
				
			||||||
		ID       uint
 | 
					 | 
				
			||||||
		Name     string
 | 
					 | 
				
			||||||
		Salary   float64
 | 
					 | 
				
			||||||
		Birthday time.Time `gorm:"precision:4"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.Migrator().DropTable(&UserMigrateColumn{})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.AutoMigrate(&UserMigrateColumn{})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type UserMigrateColumn2 struct {
 | 
					 | 
				
			||||||
		ID                  uint
 | 
					 | 
				
			||||||
		Name                string    `gorm:"size:128"`
 | 
					 | 
				
			||||||
		Salary              float64   `gorm:"precision:2"`
 | 
					 | 
				
			||||||
		Birthday            time.Time `gorm:"precision:2"`
 | 
					 | 
				
			||||||
		NameIgnoreMigration string    `gorm:"size:100"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to auto migrate, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to get column types, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, columnType := range columnTypes {
 | 
					 | 
				
			||||||
		switch columnType.Name() {
 | 
					 | 
				
			||||||
		case "name":
 | 
					 | 
				
			||||||
			if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 128 {
 | 
					 | 
				
			||||||
				t.Fatalf("name's length should be 128, but got %v", length)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		case "salary":
 | 
					 | 
				
			||||||
			if precision, o, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 {
 | 
					 | 
				
			||||||
				t.Fatalf("salary's precision should be 2, but got %v %v", precision, o)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		case "birthday":
 | 
					 | 
				
			||||||
			if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 {
 | 
					 | 
				
			||||||
				t.Fatalf("birthday's precision should be 2, but got %v", precision)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type UserMigrateColumn3 struct {
 | 
					 | 
				
			||||||
		ID                  uint
 | 
					 | 
				
			||||||
		Name                string    `gorm:"size:256"`
 | 
					 | 
				
			||||||
		Salary              float64   `gorm:"precision:3"`
 | 
					 | 
				
			||||||
		Birthday            time.Time `gorm:"precision:3"`
 | 
					 | 
				
			||||||
		NameIgnoreMigration string    `gorm:"size:128;-:migration"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to auto migrate, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to get column types, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, columnType := range columnTypes {
 | 
					 | 
				
			||||||
		switch columnType.Name() {
 | 
					 | 
				
			||||||
		case "name":
 | 
					 | 
				
			||||||
			if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 256 {
 | 
					 | 
				
			||||||
				t.Fatalf("name's length should be 128, but got %v", length)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		case "salary":
 | 
					 | 
				
			||||||
			if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 {
 | 
					 | 
				
			||||||
				t.Fatalf("salary's precision should be 2, but got %v", precision)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		case "birthday":
 | 
					 | 
				
			||||||
			if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 {
 | 
					 | 
				
			||||||
				t.Fatalf("birthday's precision should be 2, but got %v", precision)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		case "name_ignore_migration":
 | 
					 | 
				
			||||||
			if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 100 {
 | 
					 | 
				
			||||||
				t.Fatalf("name_ignore_migration's length should still be 100 but got %v", length)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestSmartMigrateColumnGaussDB(t *testing.T) {
 | 
					 | 
				
			||||||
	fullSupported := map[string]bool{"mysql": true, "gaussdb": true}[DB.Dialector.Name()]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	type UserMigrateColumn struct {
 | 
						type UserMigrateColumn struct {
 | 
				
			||||||
		ID       uint
 | 
							ID       uint
 | 
				
			||||||
@ -568,50 +432,40 @@ func TestTiDBMigrateColumns(t *testing.T) {
 | 
				
			|||||||
			switch columnType.Name() {
 | 
								switch columnType.Name() {
 | 
				
			||||||
			case "id":
 | 
								case "id":
 | 
				
			||||||
				if v, ok := columnType.PrimaryKey(); !ok || !v {
 | 
									if v, ok := columnType.PrimaryKey(); !ok || !v {
 | 
				
			||||||
					t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(),
 | 
										t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
 | 
				
			||||||
						columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			case "name":
 | 
								case "name":
 | 
				
			||||||
				dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
 | 
									dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
 | 
				
			||||||
				if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
 | 
									if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
 | 
				
			||||||
					t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v",
 | 
										t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
 | 
				
			||||||
						columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				if length, ok := columnType.Length(); !ok || length != 100 {
 | 
									if length, ok := columnType.Length(); !ok || length != 100 {
 | 
				
			||||||
					t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v",
 | 
										t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType)
 | 
				
			||||||
						columnType.Name(), length, 100, columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			case "age":
 | 
								case "age":
 | 
				
			||||||
				if v, ok := columnType.DefaultValue(); !ok || v != "18" {
 | 
									if v, ok := columnType.DefaultValue(); !ok || v != "18" {
 | 
				
			||||||
					t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(),
 | 
										t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
 | 
				
			||||||
						columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				if v, ok := columnType.Comment(); !ok || v != "my age" {
 | 
									if v, ok := columnType.Comment(); !ok || v != "my age" {
 | 
				
			||||||
					t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(),
 | 
										t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
 | 
				
			||||||
						columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			case "code":
 | 
								case "code":
 | 
				
			||||||
				if v, ok := columnType.Unique(); !ok || !v {
 | 
									if v, ok := columnType.Unique(); !ok || !v {
 | 
				
			||||||
					t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(),
 | 
										t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
 | 
				
			||||||
						columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				if v, ok := columnType.DefaultValue(); !ok || v != "hello" {
 | 
									if v, ok := columnType.DefaultValue(); !ok || v != "hello" {
 | 
				
			||||||
					t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v",
 | 
										t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
 | 
				
			||||||
						columnType.Name(), columnType, v)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				if v, ok := columnType.Comment(); !ok || v != "my code2" {
 | 
									if v, ok := columnType.Comment(); !ok || v != "my code2" {
 | 
				
			||||||
					t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(),
 | 
										t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
 | 
				
			||||||
						columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			case "code2":
 | 
								case "code2":
 | 
				
			||||||
				// Code2 string `gorm:"comment:my code2;default:hello"`
 | 
									// Code2 string `gorm:"comment:my code2;default:hello"`
 | 
				
			||||||
				if v, ok := columnType.DefaultValue(); !ok || v != "hello" {
 | 
									if v, ok := columnType.DefaultValue(); !ok || v != "hello" {
 | 
				
			||||||
					t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v",
 | 
										t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
 | 
				
			||||||
						columnType.Name(), columnType, v)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				if v, ok := columnType.Comment(); !ok || v != "my code2" {
 | 
									if v, ok := columnType.Comment(); !ok || v != "my code2" {
 | 
				
			||||||
					t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(),
 | 
										t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
 | 
				
			||||||
						columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -643,8 +497,7 @@ func TestTiDBMigrateColumns(t *testing.T) {
 | 
				
			|||||||
		t.Fatalf("Failed to add column, got %v", err)
 | 
							t.Fatalf("Failed to add column, got %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName",
 | 
						if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil {
 | 
				
			||||||
		"new_new_name"); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Failed to add column, got %v", err)
 | 
							t.Fatalf("Failed to add column, got %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -708,45 +561,36 @@ func TestMigrateColumns(t *testing.T) {
 | 
				
			|||||||
			switch columnType.Name() {
 | 
								switch columnType.Name() {
 | 
				
			||||||
			case "id":
 | 
								case "id":
 | 
				
			||||||
				if v, ok := columnType.PrimaryKey(); !ok || !v {
 | 
									if v, ok := columnType.PrimaryKey(); !ok || !v {
 | 
				
			||||||
					t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(),
 | 
										t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
 | 
				
			||||||
						columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			case "name":
 | 
								case "name":
 | 
				
			||||||
				dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
 | 
									dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
 | 
				
			||||||
				if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
 | 
									if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
 | 
				
			||||||
					t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v",
 | 
										t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
 | 
				
			||||||
						columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) {
 | 
									if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) {
 | 
				
			||||||
					t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v",
 | 
										t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType)
 | 
				
			||||||
						columnType.Name(), length, 100, columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			case "age":
 | 
								case "age":
 | 
				
			||||||
				if v, ok := columnType.DefaultValue(); !ok || v != "18" {
 | 
									if v, ok := columnType.DefaultValue(); !ok || v != "18" {
 | 
				
			||||||
					t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(),
 | 
										t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
 | 
				
			||||||
						columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") {
 | 
									if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") {
 | 
				
			||||||
					t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(),
 | 
										t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
 | 
				
			||||||
						columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			case "code":
 | 
								case "code":
 | 
				
			||||||
				if v, ok := columnType.Unique(); !ok || !v {
 | 
									if v, ok := columnType.Unique(); !ok || !v {
 | 
				
			||||||
					t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(),
 | 
										t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
 | 
				
			||||||
						columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") {
 | 
									if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") {
 | 
				
			||||||
					t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v",
 | 
										t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
 | 
				
			||||||
						columnType.Name(), columnType, v)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") {
 | 
									if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") {
 | 
				
			||||||
					t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(),
 | 
										t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
 | 
				
			||||||
						columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			case "code2":
 | 
								case "code2":
 | 
				
			||||||
				if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) {
 | 
									if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) {
 | 
				
			||||||
					t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(),
 | 
										t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
 | 
				
			||||||
						columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			case "code3":
 | 
								case "code3":
 | 
				
			||||||
				// TODO
 | 
									// TODO
 | 
				
			||||||
@ -783,8 +627,7 @@ func TestMigrateColumns(t *testing.T) {
 | 
				
			|||||||
		t.Fatalf("Failed to add column, got %v", err)
 | 
							t.Fatalf("Failed to add column, got %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName",
 | 
						if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil {
 | 
				
			||||||
		"new_new_name"); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Failed to add column, got %v", err)
 | 
							t.Fatalf("Failed to add column, got %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -938,7 +781,7 @@ func TestMigrateColumnOrder(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// https://github.com/go-gorm/gorm/issues/5047
 | 
					// https://github.com/go-gorm/gorm/issues/5047
 | 
				
			||||||
func TestMigrateSerialColumn(t *testing.T) {
 | 
					func TestMigrateSerialColumn(t *testing.T) {
 | 
				
			||||||
	if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" {
 | 
						if DB.Dialector.Name() != "postgres" {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1019,48 +862,6 @@ func TestMigrateWithSpecialName(t *testing.T) {
 | 
				
			|||||||
	AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2"))
 | 
						AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2"))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// https://github.com/go-gorm/gorm/issues/4760
 | 
					 | 
				
			||||||
func TestMigrateAutoIncrement(t *testing.T) {
 | 
					 | 
				
			||||||
	type AutoIncrementStruct struct {
 | 
					 | 
				
			||||||
		ID     int64   `gorm:"primarykey;autoIncrement"`
 | 
					 | 
				
			||||||
		Field1 uint32  `gorm:"column:field1"`
 | 
					 | 
				
			||||||
		Field2 float32 `gorm:"column:field2"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.AutoMigrate(&AutoIncrementStruct{}); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("AutoMigrate err: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	const ROWS = 10
 | 
					 | 
				
			||||||
	for idx := 0; idx < ROWS; idx++ {
 | 
					 | 
				
			||||||
		if err := DB.Create(&AutoIncrementStruct{}).Error; err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("create auto_increment_struct fail, err: %v", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	rows := make([]*AutoIncrementStruct, 0, ROWS)
 | 
					 | 
				
			||||||
	if err := DB.Order("id ASC").Find(&rows).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("find auto_increment_struct fail, err: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	ids := make([]int64, 0, len(rows))
 | 
					 | 
				
			||||||
	for _, row := range rows {
 | 
					 | 
				
			||||||
		ids = append(ids, row.ID)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	lastID := ids[len(ids)-1]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.Where("id IN (?)", ids).Delete(&AutoIncrementStruct{}).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("delete auto_increment_struct fail, err: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	newRow := &AutoIncrementStruct{}
 | 
					 | 
				
			||||||
	if err := DB.Create(newRow).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("create auto_increment_struct fail, err: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	AssertEqual(t, newRow.ID, lastID+1)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// https://github.com/go-gorm/gorm/issues/5320
 | 
					// https://github.com/go-gorm/gorm/issues/5320
 | 
				
			||||||
func TestPrimarykeyID(t *testing.T) {
 | 
					func TestPrimarykeyID(t *testing.T) {
 | 
				
			||||||
	if DB.Dialector.Name() != "postgres" {
 | 
						if DB.Dialector.Name() != "postgres" {
 | 
				
			||||||
@ -1097,42 +898,6 @@ func TestPrimarykeyID(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestPrimarykeyIDGaussDB(t *testing.T) {
 | 
					 | 
				
			||||||
	t.Skipf("This test case skipped, because of gaussdb not support uuid-ossp plugin (SQLSTATE 58P01)")
 | 
					 | 
				
			||||||
	if DB.Dialector.Name() != "gaussdb" {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type MissPKLanguage struct {
 | 
					 | 
				
			||||||
		ID   string `gorm:"type:uuid;default:uuid_generate_v4()"`
 | 
					 | 
				
			||||||
		Name string
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type MissPKUser struct {
 | 
					 | 
				
			||||||
		ID              string           `gorm:"type:uuid;default:uuid_generate_v4()"`
 | 
					 | 
				
			||||||
		MissPKLanguages []MissPKLanguage `gorm:"many2many:miss_pk_user_languages;"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var err error
 | 
					 | 
				
			||||||
	err = DB.Migrator().DropTable(&MissPKUser{}, &MissPKLanguage{})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("DropTable err:%v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	// TODO: ERROR: could not open extension control file: No such file or directory (SQLSTATE 58P01)
 | 
					 | 
				
			||||||
	DB.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("AutoMigrate err:%v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// patch
 | 
					 | 
				
			||||||
	err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("AutoMigrate err:%v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestCurrentTimestamp(t *testing.T) {
 | 
					func TestCurrentTimestamp(t *testing.T) {
 | 
				
			||||||
	if DB.Dialector.Name() != "mysql" {
 | 
						if DB.Dialector.Name() != "mysql" {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
@ -1155,8 +920,7 @@ func TestCurrentTimestamp(t *testing.T) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatalf("AutoMigrate err:%v", err)
 | 
							t.Fatalf("AutoMigrate err:%v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	AssertEqual(t, true, DB.Migrator().HasConstraint(&CurrentTimestampTest{}, "uni_current_timestamp_tests_time_at"))
 | 
						AssertEqual(t, true, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at"))
 | 
				
			||||||
	AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at"))
 | 
					 | 
				
			||||||
	AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at_2"))
 | 
						AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at_2"))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1218,8 +982,7 @@ func TestUniqueColumn(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// not trigger alert column
 | 
						// not trigger alert column
 | 
				
			||||||
	AssertEqual(t, true, DB.Migrator().HasConstraint(&UniqueTest{}, "uni_unique_tests_name"))
 | 
						AssertEqual(t, true, DB.Migrator().HasIndex(&UniqueTest{}, "name"))
 | 
				
			||||||
	AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name"))
 | 
					 | 
				
			||||||
	AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_1"))
 | 
						AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_1"))
 | 
				
			||||||
	AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_2"))
 | 
						AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_2"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1333,24 +1096,35 @@ func TestInvalidCachedPlanSimpleProtocol(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TODO: ERROR: must have at least one column (SQLSTATE 0A000)
 | 
					func TestInvalidCachedPlanPrepareStmt(t *testing.T) {
 | 
				
			||||||
func TestInvalidCachedPlanSimpleProtocolGaussDB(t *testing.T) {
 | 
						if DB.Dialector.Name() != "postgres" {
 | 
				
			||||||
	t.Skipf("This test case skipped, because of gaussdb not support creaing empty table(SQLSTATE 0A000)")
 | 
					 | 
				
			||||||
	if DB.Dialector.Name() != "gaussdb" {
 | 
					 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	db, err := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{})
 | 
						db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{PrepareStmt: true})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Errorf("Open err:%v", err)
 | 
							t.Errorf("Open err:%v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						if debug := os.Getenv("DEBUG"); debug == "true" {
 | 
				
			||||||
 | 
							db.Logger = db.Logger.LogMode(logger.Info)
 | 
				
			||||||
 | 
						} else if debug == "false" {
 | 
				
			||||||
 | 
							db.Logger = db.Logger.LogMode(logger.Silent)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	type Object1 struct{}
 | 
						type Object1 struct {
 | 
				
			||||||
 | 
							ID uint
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	type Object2 struct {
 | 
						type Object2 struct {
 | 
				
			||||||
		Field1 string
 | 
							ID     uint
 | 
				
			||||||
 | 
							Field1 int `gorm:"type:int8"`
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	type Object3 struct {
 | 
						type Object3 struct {
 | 
				
			||||||
		Field2 string
 | 
							ID     uint
 | 
				
			||||||
 | 
							Field1 int `gorm:"type:int4"`
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						type Object4 struct {
 | 
				
			||||||
 | 
							ID     uint
 | 
				
			||||||
 | 
							Field2 int
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	db.Migrator().DropTable("objects")
 | 
						db.Migrator().DropTable("objects")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1358,16 +1132,63 @@ func TestInvalidCachedPlanSimpleProtocolGaussDB(t *testing.T) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Errorf("AutoMigrate err:%v", err)
 | 
							t.Errorf("AutoMigrate err:%v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						err = db.Table("objects").Create(&Object1{}).Error
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf("create err:%v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// AddColumn
 | 
				
			||||||
	err = db.Table("objects").AutoMigrate(&Object2{})
 | 
						err = db.Table("objects").AutoMigrate(&Object2{})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Errorf("AutoMigrate err:%v", err)
 | 
							t.Errorf("AutoMigrate err:%v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = db.Table("objects").Take(&Object2{}).Error
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf("take err:%v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// AlterColumn
 | 
				
			||||||
	err = db.Table("objects").AutoMigrate(&Object3{})
 | 
						err = db.Table("objects").AutoMigrate(&Object3{})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Errorf("AutoMigrate err:%v", err)
 | 
							t.Errorf("AutoMigrate err:%v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = db.Table("objects").Take(&Object3{}).Error
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf("take err:%v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// AddColumn
 | 
				
			||||||
 | 
						err = db.Table("objects").AutoMigrate(&Object4{})
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf("AutoMigrate err:%v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = db.Table("objects").Take(&Object4{}).Error
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf("take err:%v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						db.Table("objects").Migrator().RenameColumn(&Object4{}, "field2", "field3")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf("RenameColumn err:%v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = db.Table("objects").Take(&Object4{}).Error
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf("take err:%v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						db.Table("objects").Migrator().DropColumn(&Object4{}, "field3")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf("RenameColumn err:%v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = db.Table("objects").Take(&Object4{}).Error
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf("take err:%v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestDifferentTypeWithoutDeclaredLength(t *testing.T) {
 | 
					func TestDifferentTypeWithoutDeclaredLength(t *testing.T) {
 | 
				
			||||||
@ -1410,7 +1231,7 @@ func TestDifferentTypeWithoutDeclaredLength(t *testing.T) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestMigrateArrayTypeModel(t *testing.T) {
 | 
					func TestMigrateArrayTypeModel(t *testing.T) {
 | 
				
			||||||
	if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" {
 | 
						if DB.Dialector.Name() != "postgres" {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1520,14 +1341,14 @@ func TestMigrateSameEmbeddedFieldName(t *testing.T) {
 | 
				
			|||||||
	err = DB.Table("game_users").AutoMigrate(&GameUser1{})
 | 
						err = DB.Table("game_users").AutoMigrate(&GameUser1{})
 | 
				
			||||||
	AssertEqual(t, nil, err)
 | 
						AssertEqual(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_, err = findColumnType(&GameUser{}, "stat_ab_ground_destroy_count")
 | 
						_, err = findColumnType(&GameUser{}, "stat_ab_ground_destory_count")
 | 
				
			||||||
	AssertEqual(t, nil, err)
 | 
						AssertEqual(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destroy_count")
 | 
						_, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destory_count")
 | 
				
			||||||
	AssertEqual(t, nil, err)
 | 
						AssertEqual(t, nil, err)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestMigrateWithDefaultValue(t *testing.T) {
 | 
					func TestMigrateDefaultNullString(t *testing.T) {
 | 
				
			||||||
	if DB.Dialector.Name() == "sqlserver" {
 | 
						if DB.Dialector.Name() == "sqlserver" {
 | 
				
			||||||
		// sqlserver driver treats NULL and 'NULL' the same
 | 
							// sqlserver driver treats NULL and 'NULL' the same
 | 
				
			||||||
		t.Skip("skip sqlserver")
 | 
							t.Skip("skip sqlserver")
 | 
				
			||||||
@ -1541,7 +1362,6 @@ func TestMigrateWithDefaultValue(t *testing.T) {
 | 
				
			|||||||
	type NullStringModel struct {
 | 
						type NullStringModel struct {
 | 
				
			||||||
		ID      uint
 | 
							ID      uint
 | 
				
			||||||
		Content string `gorm:"default:'null'"`
 | 
							Content string `gorm:"default:'null'"`
 | 
				
			||||||
		Active  bool   `gorm:"default:false"`
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tableName := "null_string_model"
 | 
						tableName := "null_string_model"
 | 
				
			||||||
@ -1562,14 +1382,6 @@ func TestMigrateWithDefaultValue(t *testing.T) {
 | 
				
			|||||||
	AssertEqual(t, defVal, "null")
 | 
						AssertEqual(t, defVal, "null")
 | 
				
			||||||
	AssertEqual(t, ok, true)
 | 
						AssertEqual(t, ok, true)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	columnType2, err := findColumnType(tableName, "active")
 | 
					 | 
				
			||||||
	AssertEqual(t, err, nil)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	defVal, ok = columnType2.DefaultValue()
 | 
					 | 
				
			||||||
	bv, _ := strconv.ParseBool(defVal)
 | 
					 | 
				
			||||||
	AssertEqual(t, bv, false)
 | 
					 | 
				
			||||||
	AssertEqual(t, ok, true)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// default 'null' -> 'null'
 | 
						// default 'null' -> 'null'
 | 
				
			||||||
	session := DB.Session(&gorm.Session{Logger: Tracer{
 | 
						session := DB.Session(&gorm.Session{Logger: Tracer{
 | 
				
			||||||
		Logger: DB.Config.Logger,
 | 
							Logger: DB.Config.Logger,
 | 
				
			||||||
@ -1701,8 +1513,7 @@ func TestMigrateIgnoreRelations(t *testing.T) {
 | 
				
			|||||||
func TestMigrateView(t *testing.T) {
 | 
					func TestMigrateView(t *testing.T) {
 | 
				
			||||||
	DB.Save(GetUser("joins-args-db", Config{Pets: 2}))
 | 
						DB.Save(GetUser("joins-args-db", Config{Pets: 2}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := DB.Migrator().CreateView("invalid_users_pets",
 | 
						if err := DB.Migrator().CreateView("invalid_users_pets", gorm.ViewOption{Query: nil}); err != gorm.ErrSubQueryRequired {
 | 
				
			||||||
		gorm.ViewOption{Query: nil}); err != gorm.ErrSubQueryRequired {
 | 
					 | 
				
			||||||
		t.Fatalf("no view should be created, got %v", err)
 | 
							t.Fatalf("no view should be created, got %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1732,8 +1543,8 @@ func TestMigrateView(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestMigrateExistingBoolColumnPGAndGaussDB(t *testing.T) {
 | 
					func TestMigrateExistingBoolColumnPG(t *testing.T) {
 | 
				
			||||||
	if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" {
 | 
						if DB.Dialector.Name() != "postgres" {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1771,20 +1582,17 @@ func TestMigrateExistingBoolColumnPGAndGaussDB(t *testing.T) {
 | 
				
			|||||||
			switch columnType.Name() {
 | 
								switch columnType.Name() {
 | 
				
			||||||
			case "id":
 | 
								case "id":
 | 
				
			||||||
				if v, ok := columnType.PrimaryKey(); !ok || !v {
 | 
									if v, ok := columnType.PrimaryKey(); !ok || !v {
 | 
				
			||||||
					t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(),
 | 
										t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
 | 
				
			||||||
						columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			case "string_bool":
 | 
								case "string_bool":
 | 
				
			||||||
				dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
 | 
									dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
 | 
				
			||||||
				if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
 | 
									if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
 | 
				
			||||||
					t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v",
 | 
										t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
 | 
				
			||||||
						columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			case "smallint_bool":
 | 
								case "smallint_bool":
 | 
				
			||||||
				dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
 | 
									dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
 | 
				
			||||||
				if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
 | 
									if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
 | 
				
			||||||
					t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v",
 | 
										t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
 | 
				
			||||||
						columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -1809,8 +1617,7 @@ func TestTableType(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	DB.Migrator().DropTable(&City{})
 | 
						DB.Migrator().DropTable(&City{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := DB.Set("gorm:table_options",
 | 
						if err := DB.Set("gorm:table_options", fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil {
 | 
				
			||||||
		fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to migrate cities tables, got error: %v", err)
 | 
							t.Fatalf("failed to migrate cities tables, got error: %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1836,329 +1643,3 @@ func TestTableType(t *testing.T) {
 | 
				
			|||||||
		t.Fatalf("expected comment %s got %s", tblComment, comment)
 | 
							t.Fatalf("expected comment %s got %s", tblComment, comment)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestMigrateWithUniqueIndexAndUnique(t *testing.T) {
 | 
					 | 
				
			||||||
	const table = "unique_struct"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	checkField := func(model interface{}, fieldName string, unique bool, uniqueIndex string) {
 | 
					 | 
				
			||||||
		stmt := &gorm.Statement{DB: DB}
 | 
					 | 
				
			||||||
		err := stmt.Parse(model)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("%v: failed to parse schema, got error: %v", utils.FileWithLineNum(), err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		_ = stmt.Schema.ParseIndexes()
 | 
					 | 
				
			||||||
		field := stmt.Schema.LookUpField(fieldName)
 | 
					 | 
				
			||||||
		if field == nil {
 | 
					 | 
				
			||||||
			t.Fatalf("%v: failed to find column %q", utils.FileWithLineNum(), fieldName)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if field.Unique != unique {
 | 
					 | 
				
			||||||
			t.Fatalf("%v: %q column %q unique should be %v but got %v", utils.FileWithLineNum(), stmt.Schema.Table, fieldName, unique, field.Unique)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if field.UniqueIndex != uniqueIndex {
 | 
					 | 
				
			||||||
			t.Fatalf("%v: %q column %q uniqueIndex should be %v but got %v", utils.FileWithLineNum(), stmt.Schema, fieldName, uniqueIndex, field.UniqueIndex)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type ( // not unique
 | 
					 | 
				
			||||||
		UniqueStruct1 struct {
 | 
					 | 
				
			||||||
			Name string `gorm:"size:10"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		UniqueStruct2 struct {
 | 
					 | 
				
			||||||
			Name string `gorm:"size:20"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	checkField(&UniqueStruct1{}, "name", false, "")
 | 
					 | 
				
			||||||
	checkField(&UniqueStruct2{}, "name", false, "")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type ( // unique
 | 
					 | 
				
			||||||
		UniqueStruct3 struct {
 | 
					 | 
				
			||||||
			Name string `gorm:"size:30;unique"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		UniqueStruct4 struct {
 | 
					 | 
				
			||||||
			Name string `gorm:"size:40;unique"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	checkField(&UniqueStruct3{}, "name", true, "")
 | 
					 | 
				
			||||||
	checkField(&UniqueStruct4{}, "name", true, "")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type ( // uniqueIndex
 | 
					 | 
				
			||||||
		UniqueStruct5 struct {
 | 
					 | 
				
			||||||
			Name string `gorm:"size:50;uniqueIndex"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		UniqueStruct6 struct {
 | 
					 | 
				
			||||||
			Name string `gorm:"size:60;uniqueIndex"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		UniqueStruct7 struct {
 | 
					 | 
				
			||||||
			Name     string `gorm:"size:70;uniqueIndex:idx_us6_all_names"`
 | 
					 | 
				
			||||||
			NickName string `gorm:"size:70;uniqueIndex:idx_us6_all_names"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	checkField(&UniqueStruct5{}, "name", false, "idx_unique_struct5_name")
 | 
					 | 
				
			||||||
	checkField(&UniqueStruct6{}, "name", false, "idx_unique_struct6_name")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	checkField(&UniqueStruct7{}, "name", false, "")
 | 
					 | 
				
			||||||
	checkField(&UniqueStruct7{}, "nick_name", false, "")
 | 
					 | 
				
			||||||
	checkField(&UniqueStruct7{}, "nick_name", false, "")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type UniqueStruct8 struct { // unique and uniqueIndex
 | 
					 | 
				
			||||||
		Name string `gorm:"size:60;unique;index:my_us8_index,unique;"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	checkField(&UniqueStruct8{}, "name", true, "my_us8_index")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type TestCase struct {
 | 
					 | 
				
			||||||
		name      string
 | 
					 | 
				
			||||||
		from, to  interface{}
 | 
					 | 
				
			||||||
		checkFunc func(t *testing.T)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	checkColumnType := func(t *testing.T, fieldName string, unique bool) {
 | 
					 | 
				
			||||||
		columnTypes, err := DB.Migrator().ColumnTypes(table)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("%v: failed to get column types, got error: %v", utils.FileWithLineNum(), err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		var found gorm.ColumnType
 | 
					 | 
				
			||||||
		for _, columnType := range columnTypes {
 | 
					 | 
				
			||||||
			if columnType.Name() == fieldName {
 | 
					 | 
				
			||||||
				found = columnType
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if found == nil {
 | 
					 | 
				
			||||||
			t.Fatalf("%v: failed to find column type %q", utils.FileWithLineNum(), fieldName)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if actualUnique, ok := found.Unique(); !ok || actualUnique != unique {
 | 
					 | 
				
			||||||
			t.Fatalf("%v: column %q unique should be %v but got %v", utils.FileWithLineNum(), fieldName, unique, actualUnique)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	checkIndex := func(t *testing.T, expected []gorm.Index) {
 | 
					 | 
				
			||||||
		indexes, err := DB.Migrator().GetIndexes(table)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("%v: failed to get indexes, got error: %v", utils.FileWithLineNum(), err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		assert.ElementsMatch(t, expected, indexes)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	uniqueIndex := &migrator.Index{TableName: table, NameValue: DB.Config.NamingStrategy.IndexName(table, "name"), ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}}
 | 
					 | 
				
			||||||
	myIndex := &migrator.Index{TableName: table, NameValue: "my_us8_index", ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}}
 | 
					 | 
				
			||||||
	mulIndex := &migrator.Index{TableName: table, NameValue: "idx_us6_all_names", ColumnList: []string{"name", "nick_name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var checkNotUnique, checkUnique, checkUniqueIndex, checkMyIndex, checkMulIndex func(t *testing.T)
 | 
					 | 
				
			||||||
	// UniqueAffectedByUniqueIndex is true
 | 
					 | 
				
			||||||
	if DB.Dialector.Name() == "mysql" {
 | 
					 | 
				
			||||||
		uniqueConstraintIndex := &migrator.Index{TableName: table, NameValue: DB.Config.NamingStrategy.UniqueName(table, "name"), ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}}
 | 
					 | 
				
			||||||
		checkNotUnique = func(t *testing.T) {
 | 
					 | 
				
			||||||
			checkColumnType(t, "name", false)
 | 
					 | 
				
			||||||
			checkIndex(t, nil)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		checkUnique = func(t *testing.T) {
 | 
					 | 
				
			||||||
			checkColumnType(t, "name", true)
 | 
					 | 
				
			||||||
			checkIndex(t, []gorm.Index{uniqueConstraintIndex})
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		checkUniqueIndex = func(t *testing.T) {
 | 
					 | 
				
			||||||
			checkColumnType(t, "name", true)
 | 
					 | 
				
			||||||
			checkIndex(t, []gorm.Index{uniqueIndex})
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		checkMyIndex = func(t *testing.T) {
 | 
					 | 
				
			||||||
			checkColumnType(t, "name", true)
 | 
					 | 
				
			||||||
			checkIndex(t, []gorm.Index{uniqueConstraintIndex, myIndex})
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		checkMulIndex = func(t *testing.T) {
 | 
					 | 
				
			||||||
			checkColumnType(t, "name", false)
 | 
					 | 
				
			||||||
			checkColumnType(t, "nick_name", false)
 | 
					 | 
				
			||||||
			checkIndex(t, []gorm.Index{mulIndex})
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		checkNotUnique = func(t *testing.T) { checkColumnType(t, "name", false) }
 | 
					 | 
				
			||||||
		checkUnique = func(t *testing.T) { checkColumnType(t, "name", true) }
 | 
					 | 
				
			||||||
		checkUniqueIndex = func(t *testing.T) {
 | 
					 | 
				
			||||||
			checkColumnType(t, "name", false)
 | 
					 | 
				
			||||||
			checkIndex(t, []gorm.Index{uniqueIndex})
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		checkMyIndex = func(t *testing.T) {
 | 
					 | 
				
			||||||
			checkColumnType(t, "name", true)
 | 
					 | 
				
			||||||
			if !DB.Migrator().HasIndex(table, myIndex.Name()) {
 | 
					 | 
				
			||||||
				t.Errorf("%v: should has index %s but not", utils.FileWithLineNum(), myIndex.Name())
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		checkMulIndex = func(t *testing.T) {
 | 
					 | 
				
			||||||
			checkColumnType(t, "name", false)
 | 
					 | 
				
			||||||
			checkColumnType(t, "nick_name", false)
 | 
					 | 
				
			||||||
			if !DB.Migrator().HasIndex(table, mulIndex.Name()) {
 | 
					 | 
				
			||||||
				t.Errorf("%v: should has index %s but not", utils.FileWithLineNum(), mulIndex.Name())
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	tests := []TestCase{
 | 
					 | 
				
			||||||
		{name: "notUnique to notUnique", from: &UniqueStruct1{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique},
 | 
					 | 
				
			||||||
		{name: "notUnique to unique", from: &UniqueStruct1{}, to: &UniqueStruct3{}, checkFunc: checkUnique},
 | 
					 | 
				
			||||||
		{name: "notUnique to uniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex},
 | 
					 | 
				
			||||||
		{name: "notUnique to uniqueAndUniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex},
 | 
					 | 
				
			||||||
		{name: "unique to unique", from: &UniqueStruct3{}, to: &UniqueStruct4{}, checkFunc: checkUnique},
 | 
					 | 
				
			||||||
		{name: "unique to uniqueIndex", from: &UniqueStruct3{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex},
 | 
					 | 
				
			||||||
		{name: "unique to uniqueAndUniqueIndex", from: &UniqueStruct3{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex},
 | 
					 | 
				
			||||||
		{name: "uniqueIndex to uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct6{}, checkFunc: checkUniqueIndex},
 | 
					 | 
				
			||||||
		{name: "uniqueIndex to uniqueAndUniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex},
 | 
					 | 
				
			||||||
		{name: "uniqueIndex to multi uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct7{}, checkFunc: checkMulIndex},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	for _, test := range tests {
 | 
					 | 
				
			||||||
		t.Run(test.name, func(t *testing.T) {
 | 
					 | 
				
			||||||
			if err := DB.Migrator().DropTable(table); err != nil {
 | 
					 | 
				
			||||||
				t.Fatalf("failed to drop table, got error: %v", err)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			if err := DB.Table(table).AutoMigrate(test.from); err != nil {
 | 
					 | 
				
			||||||
				t.Fatalf("failed to migrate table, got error: %v", err)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			if err := DB.Table(table).AutoMigrate(test.to); err != nil {
 | 
					 | 
				
			||||||
				t.Fatalf("failed to migrate table, got error: %v", err)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			test.checkFunc(t)
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if DB.Dialector.Name() != "sqlserver" {
 | 
					 | 
				
			||||||
		// In SQLServer, If an index or constraint depends on the column,
 | 
					 | 
				
			||||||
		// this column will not be able to run ALTER
 | 
					 | 
				
			||||||
		// see https://stackoverflow.com/questions/19460912/the-object-df-is-dependent-on-column-changing-int-to-double/19461205#19461205
 | 
					 | 
				
			||||||
		// may we need to create another PR to fix it, see https://github.com/go-gorm/sqlserver/pull/106
 | 
					 | 
				
			||||||
		tests = []TestCase{
 | 
					 | 
				
			||||||
			{name: "unique to notUnique", from: &UniqueStruct3{}, to: &UniqueStruct1{}, checkFunc: checkNotUnique},
 | 
					 | 
				
			||||||
			{name: "uniqueIndex to notUnique", from: &UniqueStruct5{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique},
 | 
					 | 
				
			||||||
			{name: "uniqueIndex to unique", from: &UniqueStruct5{}, to: &UniqueStruct3{}, checkFunc: checkUnique},
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if DB.Dialector.Name() == "mysql" {
 | 
					 | 
				
			||||||
		compatibilityTests := []TestCase{
 | 
					 | 
				
			||||||
			{name: "oldUnique to notUnique", to: UniqueStruct1{}, checkFunc: checkNotUnique},
 | 
					 | 
				
			||||||
			{name: "oldUnique to unique", to: UniqueStruct3{}, checkFunc: checkUnique},
 | 
					 | 
				
			||||||
			{name: "oldUnique to uniqueIndex", to: UniqueStruct5{}, checkFunc: checkUniqueIndex},
 | 
					 | 
				
			||||||
			{name: "oldUnique to uniqueAndUniqueIndex", to: UniqueStruct8{}, checkFunc: checkMyIndex},
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		for _, test := range compatibilityTests {
 | 
					 | 
				
			||||||
			t.Run(test.name, func(t *testing.T) {
 | 
					 | 
				
			||||||
				if err := DB.Migrator().DropTable(table); err != nil {
 | 
					 | 
				
			||||||
					t.Fatalf("failed to drop table, got error: %v", err)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				if err := DB.Exec("CREATE TABLE ? (`name` varchar(10) UNIQUE)", clause.Table{Name: table}).Error; err != nil {
 | 
					 | 
				
			||||||
					t.Fatalf("failed to create table, got error: %v", err)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				if err := DB.Table(table).AutoMigrate(test.to); err != nil {
 | 
					 | 
				
			||||||
					t.Fatalf("failed to migrate table, got error: %v", err)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				test.checkFunc(t)
 | 
					 | 
				
			||||||
			})
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func testAutoMigrateDecimal(t *testing.T, model1, model2 any) []string {
 | 
					 | 
				
			||||||
	tracer := Tracer{
 | 
					 | 
				
			||||||
		Logger: DB.Config.Logger,
 | 
					 | 
				
			||||||
		Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
 | 
					 | 
				
			||||||
			sql, _ := fc()
 | 
					 | 
				
			||||||
			if strings.HasPrefix(sql, "ALTER TABLE ") {
 | 
					 | 
				
			||||||
				t.Fatalf("shouldn't execute ALTER COLUMN TYPE if decimal is not change: sql: %s", sql)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	session := DB.Session(&gorm.Session{Logger: tracer})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.Migrator().DropTable(model1)
 | 
					 | 
				
			||||||
	var modifySql []string
 | 
					 | 
				
			||||||
	if err := session.AutoMigrate(model1); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to auto migrate, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if err := session.AutoMigrate(model1); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to auto migrate, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	tracer2 := Tracer{
 | 
					 | 
				
			||||||
		Logger: DB.Config.Logger,
 | 
					 | 
				
			||||||
		Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
 | 
					 | 
				
			||||||
			sql, _ := fc()
 | 
					 | 
				
			||||||
			modifySql = append(modifySql, sql)
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	session2 := DB.Session(&gorm.Session{Logger: tracer2})
 | 
					 | 
				
			||||||
	err := session2.Table("migrate_decimal_columns").Migrator().AutoMigrate(model2)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to get column types, got error: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return modifySql
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func decimalColumnsTest[T, T2 any](t *testing.T, expectedSql []string) {
 | 
					 | 
				
			||||||
	var t1 T
 | 
					 | 
				
			||||||
	var t2 T2
 | 
					 | 
				
			||||||
	modSql := testAutoMigrateDecimal(t, t1, t2)
 | 
					 | 
				
			||||||
	var alterSQL []string
 | 
					 | 
				
			||||||
	for _, sql := range modSql {
 | 
					 | 
				
			||||||
		if strings.HasPrefix(sql, "ALTER TABLE ") {
 | 
					 | 
				
			||||||
			alterSQL = append(alterSQL, sql)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(alterSQL) != 3 {
 | 
					 | 
				
			||||||
		t.Fatalf("decimal changed error,expected: %+v,got: %+v.", expectedSql, alterSQL)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	for i := range alterSQL {
 | 
					 | 
				
			||||||
		if alterSQL[i] != expectedSql[i] {
 | 
					 | 
				
			||||||
			t.Fatalf("decimal changed error,expected: %+v,got: %+v.", expectedSql, alterSQL)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestAutoMigrateDecimal(t *testing.T) {
 | 
					 | 
				
			||||||
	if DB.Dialector.Name() == "sqlserver" { // database/sql will replace numeric to decimal. so only support decimal.
 | 
					 | 
				
			||||||
		type MigrateDecimalColumn struct {
 | 
					 | 
				
			||||||
			RecID1 int64 `gorm:"column:recid1;type:decimal(9,0);not null" json:"recid1"`
 | 
					 | 
				
			||||||
			RecID2 int64 `gorm:"column:recid2;type:decimal(8);not null" json:"recid2"`
 | 
					 | 
				
			||||||
			RecID3 int64 `gorm:"column:recid3;type:decimal(8,1);not null" json:"recid3"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		type MigrateDecimalColumn2 struct {
 | 
					 | 
				
			||||||
			RecID1 int64 `gorm:"column:recid1;type:decimal(8);not null" json:"recid1"`
 | 
					 | 
				
			||||||
			RecID2 int64 `gorm:"column:recid2;type:decimal(9,1);not null" json:"recid2"`
 | 
					 | 
				
			||||||
			RecID3 int64 `gorm:"column:recid3;type:decimal(9,2);not null" json:"recid3"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		expectedSql := []string{
 | 
					 | 
				
			||||||
			`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid1" decimal(8) NOT NULL`,
 | 
					 | 
				
			||||||
			`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid2" decimal(9,1) NOT NULL`,
 | 
					 | 
				
			||||||
			`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid3" decimal(9,2) NOT NULL`,
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql)
 | 
					 | 
				
			||||||
	} else if DB.Dialector.Name() == "postgres" || DB.Dialector.Name() == "gaussdb" {
 | 
					 | 
				
			||||||
		type MigrateDecimalColumn struct {
 | 
					 | 
				
			||||||
			RecID1 int64 `gorm:"column:recid1;type:numeric(9,0);not null" json:"recid1"`
 | 
					 | 
				
			||||||
			RecID2 int64 `gorm:"column:recid2;type:numeric(8);not null" json:"recid2"`
 | 
					 | 
				
			||||||
			RecID3 int64 `gorm:"column:recid3;type:numeric(8,1);not null" json:"recid3"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		type MigrateDecimalColumn2 struct {
 | 
					 | 
				
			||||||
			RecID1 int64 `gorm:"column:recid1;type:numeric(8);not null" json:"recid1"`
 | 
					 | 
				
			||||||
			RecID2 int64 `gorm:"column:recid2;type:numeric(9,1);not null" json:"recid2"`
 | 
					 | 
				
			||||||
			RecID3 int64 `gorm:"column:recid3;type:numeric(9,2);not null" json:"recid3"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		expectedSql := []string{
 | 
					 | 
				
			||||||
			`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid1" TYPE numeric(8) USING "recid1"::numeric(8)`,
 | 
					 | 
				
			||||||
			`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid2" TYPE numeric(9,1) USING "recid2"::numeric(9,1)`,
 | 
					 | 
				
			||||||
			`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid3" TYPE numeric(9,2) USING "recid3"::numeric(9,2)`,
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql)
 | 
					 | 
				
			||||||
	} else if DB.Dialector.Name() == "mysql" {
 | 
					 | 
				
			||||||
		type MigrateDecimalColumn struct {
 | 
					 | 
				
			||||||
			RecID1 int64 `gorm:"column:recid1;type:decimal(9,0);not null" json:"recid1"`
 | 
					 | 
				
			||||||
			RecID2 int64 `gorm:"column:recid2;type:decimal(8);not null" json:"recid2"`
 | 
					 | 
				
			||||||
			RecID3 int64 `gorm:"column:recid3;type:decimal(8,1);not null" json:"recid3"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		type MigrateDecimalColumn2 struct {
 | 
					 | 
				
			||||||
			RecID1 int64 `gorm:"column:recid1;type:decimal(8);not null" json:"recid1"`
 | 
					 | 
				
			||||||
			RecID2 int64 `gorm:"column:recid2;type:decimal(9,1);not null" json:"recid2"`
 | 
					 | 
				
			||||||
			RecID3 int64 `gorm:"column:recid3;type:decimal(9,2);not null" json:"recid3"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		expectedSql := []string{
 | 
					 | 
				
			||||||
			"ALTER TABLE `migrate_decimal_columns` MODIFY COLUMN `recid1` decimal(8) NOT NULL",
 | 
					 | 
				
			||||||
			"ALTER TABLE `migrate_decimal_columns` MODIFY COLUMN `recid2` decimal(9,1) NOT NULL",
 | 
					 | 
				
			||||||
			"ALTER TABLE `migrate_decimal_columns` MODIFY COLUMN `recid3` decimal(9,2) NOT NULL",
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -41,7 +41,7 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
 | 
				
			|||||||
		t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
 | 
							t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if name := DB.Dialector.Name(); name == "postgres" || name == "mysql" || name == "gaussdb" {
 | 
						if name := DB.Dialector.Name(); name == "postgres" {
 | 
				
			||||||
		stmt := gorm.Statement{DB: DB}
 | 
							stmt := gorm.Statement{DB: DB}
 | 
				
			||||||
		stmt.Parse(&Blog{})
 | 
							stmt.Parse(&Blog{})
 | 
				
			||||||
		stmt.Schema.LookUpField("ID").Unique = true
 | 
							stmt.Schema.LookUpField("ID").Unique = true
 | 
				
			||||||
@ -142,9 +142,6 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) {
 | 
				
			|||||||
	if name := DB.Dialector.Name(); name == "postgres" {
 | 
						if name := DB.Dialector.Name(); name == "postgres" {
 | 
				
			||||||
		t.Skip("skip postgres due to it only allow unique constraint matching given keys")
 | 
							t.Skip("skip postgres due to it only allow unique constraint matching given keys")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if name := DB.Dialector.Name(); name == "gaussdb" {
 | 
					 | 
				
			||||||
		t.Skip("skip gaussdb due to it only allow unique constraint matching given keys")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags")
 | 
						DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags")
 | 
				
			||||||
	if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil {
 | 
						if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil {
 | 
				
			||||||
@ -267,14 +264,10 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
 | 
				
			|||||||
		t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
 | 
							t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if name := DB.Dialector.Name(); name == "postgres" || name == "mysql" {
 | 
						if name := DB.Dialector.Name(); name == "postgres" {
 | 
				
			||||||
		t.Skip("skip postgres due to it only allow unique constraint matching given keys")
 | 
							t.Skip("skip postgres due to it only allow unique constraint matching given keys")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if name := DB.Dialector.Name(); name == "gaussdb" {
 | 
					 | 
				
			||||||
		t.Skip("skip gaussdb due to it only allow unique constraint matching given keys")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags")
 | 
						DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags")
 | 
				
			||||||
	if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil {
 | 
						if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil {
 | 
				
			||||||
		t.Fatalf("Failed to auto migrate, got error: %v", err)
 | 
							t.Fatalf("Failed to auto migrate, got error: %v", err)
 | 
				
			||||||
@ -339,7 +332,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	DB.Model(&blog2).Association("LocaleTags").Find(&tags)
 | 
						DB.Model(&blog2).Association("LocaleTags").Find(&tags)
 | 
				
			||||||
	if !compareTags(tags, []string{"tag4"}) {
 | 
						if !compareTags(tags, []string{"tag4"}) {
 | 
				
			||||||
		t.Fatalf("Should find 1 tags for EN Blog, but got %v", tags)
 | 
							t.Fatalf("Should find 1 tags  for EN Blog")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Replace
 | 
						// Replace
 | 
				
			||||||
 | 
				
			|||||||
@ -37,7 +37,7 @@ func TestNonStdPrimaryKeyAndDefaultValues(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	animal = Animal{From: "somewhere"}              // No name fields, should be filled with the default value (galeone)
 | 
						animal = Animal{From: "somewhere"}              // No name fields, should be filled with the default value (galeone)
 | 
				
			||||||
	DB.Save(&animal).Update("From", "a nice place") // The name field should be untouched
 | 
						DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched
 | 
				
			||||||
	DB.First(&animal, animal.Counter)
 | 
						DB.First(&animal, animal.Counter)
 | 
				
			||||||
	if animal.Name != "galeone" {
 | 
						if animal.Name != "galeone" {
 | 
				
			||||||
		t.Errorf("Name fields shouldn't be changed if untouched, but got %v", animal.Name)
 | 
							t.Errorf("Name fields shouldn't be changed if untouched, but got %v", animal.Name)
 | 
				
			||||||
 | 
				
			|||||||
@ -696,10 +696,6 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
 | 
				
			|||||||
		t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
 | 
							t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if name := DB.Dialector.Name(); name == "mysql" {
 | 
					 | 
				
			||||||
		t.Skip("skip mysql due to it only allow unique constraint matching given keys")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type (
 | 
						type (
 | 
				
			||||||
		Level1 struct {
 | 
							Level1 struct {
 | 
				
			||||||
			ID           uint   `gorm:"primary_key;"`
 | 
								ID           uint   `gorm:"primary_key;"`
 | 
				
			||||||
 | 
				
			|||||||
@ -1,14 +1,12 @@
 | 
				
			|||||||
package tests_test
 | 
					package tests_test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
					 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
	"sort"
 | 
						"sort"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	"gorm.io/gorm/clause"
 | 
						"gorm.io/gorm/clause"
 | 
				
			||||||
@ -309,189 +307,6 @@ func TestNestedPreloadWithUnscoped(t *testing.T) {
 | 
				
			|||||||
	CheckUserUnscoped(t, *user6, user)
 | 
						CheckUserUnscoped(t, *user6, user)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestNestedPreloadWithNestedJoin(t *testing.T) {
 | 
					 | 
				
			||||||
	type (
 | 
					 | 
				
			||||||
		Preload struct {
 | 
					 | 
				
			||||||
			ID       uint
 | 
					 | 
				
			||||||
			Value    string
 | 
					 | 
				
			||||||
			NestedID uint
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		Join struct {
 | 
					 | 
				
			||||||
			ID       uint
 | 
					 | 
				
			||||||
			Value    string
 | 
					 | 
				
			||||||
			NestedID uint
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		Nested struct {
 | 
					 | 
				
			||||||
			ID       uint
 | 
					 | 
				
			||||||
			Preloads []*Preload
 | 
					 | 
				
			||||||
			Join     Join
 | 
					 | 
				
			||||||
			ValueID  uint
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		Value struct {
 | 
					 | 
				
			||||||
			ID     uint
 | 
					 | 
				
			||||||
			Name   string
 | 
					 | 
				
			||||||
			Nested Nested
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{})
 | 
					 | 
				
			||||||
	DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	value1 := Value{
 | 
					 | 
				
			||||||
		Name: "value",
 | 
					 | 
				
			||||||
		Nested: Nested{
 | 
					 | 
				
			||||||
			Preloads: []*Preload{
 | 
					 | 
				
			||||||
				{Value: "p1"}, {Value: "p2"},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
			Join: Join{Value: "j1"},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	value2 := Value{
 | 
					 | 
				
			||||||
		Name: "value2",
 | 
					 | 
				
			||||||
		Nested: Nested{
 | 
					 | 
				
			||||||
			Preloads: []*Preload{
 | 
					 | 
				
			||||||
				{Value: "p3"}, {Value: "p4"}, {Value: "p5"},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
			Join: Join{Value: "j2"},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	values := []*Value{&value1, &value2}
 | 
					 | 
				
			||||||
	if err := DB.Create(&values).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("failed to create value, got err: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var find1 Value
 | 
					 | 
				
			||||||
	err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1, value1.ID).Error
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("failed to find value, got err: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	AssertEqual(t, find1, value1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var find2 Value
 | 
					 | 
				
			||||||
	// Joins will automatically add Nested queries.
 | 
					 | 
				
			||||||
	err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2, value2.ID).Error
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("failed to find value, got err: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	AssertEqual(t, find2, value2)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var finds []Value
 | 
					 | 
				
			||||||
	err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("failed to find value, got err: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	AssertEqual(t, len(finds), 2)
 | 
					 | 
				
			||||||
	AssertEqual(t, finds[0], value1)
 | 
					 | 
				
			||||||
	AssertEqual(t, finds[1], value2)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestMergeNestedPreloadWithNestedJoin(t *testing.T) {
 | 
					 | 
				
			||||||
	users := []User{
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			Name: "TestMergeNestedPreloadWithNestedJoin-1",
 | 
					 | 
				
			||||||
			Manager: &User{
 | 
					 | 
				
			||||||
				Name: "Alexis Manager",
 | 
					 | 
				
			||||||
				Tools: []Tools{
 | 
					 | 
				
			||||||
					{Name: "Alexis Tool 1"},
 | 
					 | 
				
			||||||
					{Name: "Alexis Tool 2"},
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			Name: "TestMergeNestedPreloadWithNestedJoin-2",
 | 
					 | 
				
			||||||
			Manager: &User{
 | 
					 | 
				
			||||||
				Name: "Jinzhu Manager",
 | 
					 | 
				
			||||||
				Tools: []Tools{
 | 
					 | 
				
			||||||
					{Name: "Jinzhu Tool 1"},
 | 
					 | 
				
			||||||
					{Name: "Jinzhu Tool 2"},
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.Create(&users)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	query := make([]string, 0)
 | 
					 | 
				
			||||||
	sess := DB.Session(&gorm.Session{Logger: Tracer{
 | 
					 | 
				
			||||||
		Logger: DB.Config.Logger,
 | 
					 | 
				
			||||||
		Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
 | 
					 | 
				
			||||||
			sql, _ := fc()
 | 
					 | 
				
			||||||
			query = append(query, sql)
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var result []User
 | 
					 | 
				
			||||||
	err := sess.
 | 
					 | 
				
			||||||
		Joins("Manager").
 | 
					 | 
				
			||||||
		Preload("Manager.Tools").
 | 
					 | 
				
			||||||
		Where("users.name Like ?", "TestMergeNestedPreloadWithNestedJoin%").
 | 
					 | 
				
			||||||
		Find(&result).Error
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to preload and find users: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	AssertEqual(t, result, users)
 | 
					 | 
				
			||||||
	AssertEqual(t, len(query), 2) // Check preload queries are merged
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !regexp.MustCompile(`SELECT \* FROM .*tools.* WHERE .*IN.*`).MatchString(query[0]) {
 | 
					 | 
				
			||||||
		t.Fatalf("Expected first query to preload manager tools, got: %s", query[0])
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestNestedPreloadWithPointerJoin(t *testing.T) {
 | 
					 | 
				
			||||||
	type (
 | 
					 | 
				
			||||||
		Preload struct {
 | 
					 | 
				
			||||||
			ID     uint
 | 
					 | 
				
			||||||
			Value  string
 | 
					 | 
				
			||||||
			JoinID uint
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		Join struct {
 | 
					 | 
				
			||||||
			ID       uint
 | 
					 | 
				
			||||||
			Value    string
 | 
					 | 
				
			||||||
			Preload  Preload
 | 
					 | 
				
			||||||
			NestedID uint
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		Nested struct {
 | 
					 | 
				
			||||||
			ID      uint
 | 
					 | 
				
			||||||
			Join    Join
 | 
					 | 
				
			||||||
			ValueID uint
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		Value struct {
 | 
					 | 
				
			||||||
			ID     uint
 | 
					 | 
				
			||||||
			Name   string
 | 
					 | 
				
			||||||
			Nested *Nested
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{})
 | 
					 | 
				
			||||||
	DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	value := Value{
 | 
					 | 
				
			||||||
		Name: "value",
 | 
					 | 
				
			||||||
		Nested: &Nested{
 | 
					 | 
				
			||||||
			Join: Join{
 | 
					 | 
				
			||||||
				Value: "j1",
 | 
					 | 
				
			||||||
				Preload: Preload{
 | 
					 | 
				
			||||||
					Value: "p1",
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.Create(&value).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("failed to create value, got err: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var find1 Value
 | 
					 | 
				
			||||||
	err := DB.Table("values").Joins("Nested").Joins("Nested.Join").Preload("Nested.Join.Preload").First(&find1).Error
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("failed to find value, got err: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	AssertEqual(t, find1, value)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestEmbedPreload(t *testing.T) {
 | 
					func TestEmbedPreload(t *testing.T) {
 | 
				
			||||||
	type Country struct {
 | 
						type Country struct {
 | 
				
			||||||
		ID   int `gorm:"primaryKey"`
 | 
							ID   int `gorm:"primaryKey"`
 | 
				
			||||||
@ -584,7 +399,7 @@ func TestEmbedPreload(t *testing.T) {
 | 
				
			|||||||
			},
 | 
								},
 | 
				
			||||||
		}, {
 | 
							}, {
 | 
				
			||||||
			name:     "nested address country",
 | 
								name:     "nested address country",
 | 
				
			||||||
			preloads: map[string][]interface{}{"NestedAddress.Country": {}},
 | 
								preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}},
 | 
				
			||||||
			expect: Org{
 | 
								expect: Org{
 | 
				
			||||||
				ID: org.ID,
 | 
									ID: org.ID,
 | 
				
			||||||
				PostalAddress: EmbeddedAddress{
 | 
									PostalAddress: EmbeddedAddress{
 | 
				
			||||||
@ -614,6 +429,7 @@ func TestEmbedPreload(t *testing.T) {
 | 
				
			|||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						DB = DB.Debug()
 | 
				
			||||||
	for _, test := range tests {
 | 
						for _, test := range tests {
 | 
				
			||||||
		t.Run(test.name, func(t *testing.T) {
 | 
							t.Run(test.name, func(t *testing.T) {
 | 
				
			||||||
			actual := Org{}
 | 
								actual := Org{}
 | 
				
			||||||
 | 
				
			|||||||
@ -91,65 +91,6 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
 | 
				
			|||||||
	tx2.Commit()
 | 
						tx2.Commit()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestPreparedStmtLruFromTransaction(t *testing.T) {
 | 
					 | 
				
			||||||
	db, _ := OpenTestConnection(&gorm.Config{PrepareStmt: true, PrepareStmtMaxSize: 10, PrepareStmtTTL: 20 * time.Second})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	tx := db.Begin()
 | 
					 | 
				
			||||||
	defer func() {
 | 
					 | 
				
			||||||
		if r := recover(); r != nil {
 | 
					 | 
				
			||||||
			tx.Rollback()
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}()
 | 
					 | 
				
			||||||
	if err := tx.Error; err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("Failed to start transaction, got error %v\n", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil {
 | 
					 | 
				
			||||||
		tx.Rollback()
 | 
					 | 
				
			||||||
		t.Errorf("Failed to run one transaction, got error %v\n", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil {
 | 
					 | 
				
			||||||
		tx.Rollback()
 | 
					 | 
				
			||||||
		t.Errorf("Failed to run one transaction, got error %v\n", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := tx.Commit().Error; err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("Failed to commit transaction, got error %v\n", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 {
 | 
					 | 
				
			||||||
		t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	tx2 := db.Begin()
 | 
					 | 
				
			||||||
	if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 {
 | 
					 | 
				
			||||||
		t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	tx2.Commit()
 | 
					 | 
				
			||||||
	// Attempt to convert the connection pool of tx to the *gorm.PreparedStmtDB type.
 | 
					 | 
				
			||||||
	// If the conversion is successful, ok will be true and conn will be the converted object;
 | 
					 | 
				
			||||||
	// otherwise, ok will be false and conn will be nil.
 | 
					 | 
				
			||||||
	conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
 | 
					 | 
				
			||||||
	// Get the number of statement keys stored in the PreparedStmtDB.
 | 
					 | 
				
			||||||
	lens := len(conn.Stmts.Keys())
 | 
					 | 
				
			||||||
	// Check if the number of stored statement keys is 0.
 | 
					 | 
				
			||||||
	if lens == 0 {
 | 
					 | 
				
			||||||
		// If the number is 0, it means there are no statements stored in the LRU cache.
 | 
					 | 
				
			||||||
		// The test fails and an error message is output.
 | 
					 | 
				
			||||||
		t.Fatalf("lru should not be empty")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	// Wait for 40 seconds to give the statements in the cache enough time to expire.
 | 
					 | 
				
			||||||
	time.Sleep(time.Second * 40)
 | 
					 | 
				
			||||||
	// Assert whether the connection pool of tx is successfully converted to the *gorm.PreparedStmtDB type.
 | 
					 | 
				
			||||||
	AssertEqual(t, ok, true)
 | 
					 | 
				
			||||||
	// Assert whether the number of statement keys stored in the PreparedStmtDB is 0 after 40 seconds.
 | 
					 | 
				
			||||||
	// If it is not 0, it means the statements in the cache have not expired as expected.
 | 
					 | 
				
			||||||
	AssertEqual(t, len(conn.Stmts.Keys()), 0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestPreparedStmtDeadlock(t *testing.T) {
 | 
					func TestPreparedStmtDeadlock(t *testing.T) {
 | 
				
			||||||
	tx, err := OpenTestConnection(&gorm.Config{})
 | 
						tx, err := OpenTestConnection(&gorm.Config{})
 | 
				
			||||||
	AssertEqual(t, err, nil)
 | 
						AssertEqual(t, err, nil)
 | 
				
			||||||
@ -175,9 +116,9 @@ func TestPreparedStmtDeadlock(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
 | 
						conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
 | 
				
			||||||
	AssertEqual(t, ok, true)
 | 
						AssertEqual(t, ok, true)
 | 
				
			||||||
	AssertEqual(t, len(conn.Stmts.Keys()), 2)
 | 
						AssertEqual(t, len(conn.Stmts), 2)
 | 
				
			||||||
	for _, stmt := range conn.Stmts.Keys() {
 | 
						for _, stmt := range conn.Stmts {
 | 
				
			||||||
		if stmt == "" {
 | 
							if stmt == nil {
 | 
				
			||||||
			t.Fatalf("stmt cannot bee nil")
 | 
								t.Fatalf("stmt cannot bee nil")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -185,6 +126,33 @@ func TestPreparedStmtDeadlock(t *testing.T) {
 | 
				
			|||||||
	AssertEqual(t, sqlDB.Stats().InUse, 0)
 | 
						AssertEqual(t, sqlDB.Stats().InUse, 0)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestPreparedStmtError(t *testing.T) {
 | 
				
			||||||
 | 
						tx, err := OpenTestConnection(&gorm.Config{})
 | 
				
			||||||
 | 
						AssertEqual(t, err, nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sqlDB, _ := tx.DB()
 | 
				
			||||||
 | 
						sqlDB.SetMaxOpenConns(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tx = tx.Session(&gorm.Session{PrepareStmt: true})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						wg := sync.WaitGroup{}
 | 
				
			||||||
 | 
						for i := 0; i < 10; i++ {
 | 
				
			||||||
 | 
							wg.Add(1)
 | 
				
			||||||
 | 
							go func() {
 | 
				
			||||||
 | 
								// err prepare
 | 
				
			||||||
 | 
								tag := Tag{Locale: "zh"}
 | 
				
			||||||
 | 
								tx.Table("users").Find(&tag)
 | 
				
			||||||
 | 
								wg.Done()
 | 
				
			||||||
 | 
							}()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						wg.Wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
 | 
				
			||||||
 | 
						AssertEqual(t, ok, true)
 | 
				
			||||||
 | 
						AssertEqual(t, len(conn.Stmts), 0)
 | 
				
			||||||
 | 
						AssertEqual(t, sqlDB.Stats().InUse, 0)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestPreparedStmtInTransaction(t *testing.T) {
 | 
					func TestPreparedStmtInTransaction(t *testing.T) {
 | 
				
			||||||
	user := User{Name: "jinzhu"}
 | 
						user := User{Name: "jinzhu"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -201,10 +169,10 @@ func TestPreparedStmtInTransaction(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestPreparedStmtClose(t *testing.T) {
 | 
					func TestPreparedStmtReset(t *testing.T) {
 | 
				
			||||||
	tx := DB.Session(&gorm.Session{PrepareStmt: true})
 | 
						tx := DB.Session(&gorm.Session{PrepareStmt: true})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	user := *GetUser("prepared_stmt_close", Config{})
 | 
						user := *GetUser("prepared_stmt_reset", Config{})
 | 
				
			||||||
	tx = tx.Create(&user)
 | 
						tx = tx.Create(&user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
 | 
						pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
 | 
				
			||||||
@ -213,77 +181,16 @@ func TestPreparedStmtClose(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	pdb.Mux.Lock()
 | 
						pdb.Mux.Lock()
 | 
				
			||||||
	if len(pdb.Stmts.Keys()) == 0 {
 | 
						if len(pdb.Stmts) == 0 {
 | 
				
			||||||
		pdb.Mux.Unlock()
 | 
							pdb.Mux.Unlock()
 | 
				
			||||||
		t.Fatalf("prepared stmt can not be empty")
 | 
							t.Fatalf("prepared stmt can not be empty")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	pdb.Mux.Unlock()
 | 
						pdb.Mux.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	pdb.Close()
 | 
						pdb.Reset()
 | 
				
			||||||
	pdb.Mux.Lock()
 | 
						pdb.Mux.Lock()
 | 
				
			||||||
	defer pdb.Mux.Unlock()
 | 
						defer pdb.Mux.Unlock()
 | 
				
			||||||
	if len(pdb.Stmts.Keys()) != 0 {
 | 
						if len(pdb.Stmts) != 0 {
 | 
				
			||||||
		t.Fatalf("prepared stmt should be empty")
 | 
							t.Fatalf("prepared stmt should be empty")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func isUsingClosedConnError(err error) bool {
 | 
					 | 
				
			||||||
	// https://github.com/golang/go/blob/e705a2d16e4ece77e08e80c168382cdb02890f5b/src/database/sql/sql.go#L2717
 | 
					 | 
				
			||||||
	return err.Error() == "sql: statement is closed"
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
 | 
					 | 
				
			||||||
// this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt
 | 
					 | 
				
			||||||
func TestPreparedStmtConcurrentClose(t *testing.T) {
 | 
					 | 
				
			||||||
	name := "prepared_stmt_concurrent_close"
 | 
					 | 
				
			||||||
	user := *GetUser(name, Config{})
 | 
					 | 
				
			||||||
	createTx := DB.Session(&gorm.Session{}).Create(&user)
 | 
					 | 
				
			||||||
	if createTx.Error != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// create a new connection to keep away from other tests
 | 
					 | 
				
			||||||
	tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to open test connection due to %s", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
 | 
					 | 
				
			||||||
	if !ok {
 | 
					 | 
				
			||||||
		t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	loopCount := 100
 | 
					 | 
				
			||||||
	var wg sync.WaitGroup
 | 
					 | 
				
			||||||
	var unexpectedError bool
 | 
					 | 
				
			||||||
	writerFinish := make(chan struct{})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	wg.Add(1)
 | 
					 | 
				
			||||||
	go func(id uint) {
 | 
					 | 
				
			||||||
		defer wg.Done()
 | 
					 | 
				
			||||||
		defer close(writerFinish)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		for j := 0; j < loopCount; j++ {
 | 
					 | 
				
			||||||
			var tmp User
 | 
					 | 
				
			||||||
			err := tx.Session(&gorm.Session{}).First(&tmp, id).Error
 | 
					 | 
				
			||||||
			if err == nil || isUsingClosedConnError(err) {
 | 
					 | 
				
			||||||
				continue
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			t.Errorf("failed to read user of id %d due to %s, there should not be error", id, err)
 | 
					 | 
				
			||||||
			unexpectedError = true
 | 
					 | 
				
			||||||
			break
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}(user.ID)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	wg.Add(1)
 | 
					 | 
				
			||||||
	go func() {
 | 
					 | 
				
			||||||
		defer wg.Done()
 | 
					 | 
				
			||||||
		<-writerFinish
 | 
					 | 
				
			||||||
		pdb.Close()
 | 
					 | 
				
			||||||
	}()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	wg.Wait()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if unexpectedError {
 | 
					 | 
				
			||||||
		t.Fatalf("should is a unexpected error")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -554,16 +554,6 @@ func TestNot(t *testing.T) {
 | 
				
			|||||||
	if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) {
 | 
						if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) {
 | 
				
			||||||
		t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
 | 
							t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					 | 
				
			||||||
	result = dryDB.Not(DB.Where("manager IS NULL").Where("age >= ?", 20)).Find(&User{})
 | 
					 | 
				
			||||||
	if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT \\(manager IS NULL AND age >= .+\\) AND .users.\\..deleted_at. IS NULL").MatchString(result.Statement.SQL.String()) {
 | 
					 | 
				
			||||||
		t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	result = dryDB.Not(DB.Where("manager IS NULL").Or("age >= ?", 20)).Find(&User{})
 | 
					 | 
				
			||||||
	if !regexp.MustCompile(`SELECT \* FROM .*users.* WHERE NOT \(manager IS NULL OR age >= .+\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
 | 
					 | 
				
			||||||
		t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestNotWithAllFields(t *testing.T) {
 | 
					func TestNotWithAllFields(t *testing.T) {
 | 
				
			||||||
@ -632,21 +622,6 @@ func TestOr(t *testing.T) {
 | 
				
			|||||||
		t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
 | 
							t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	sub := dryDB.Clauses(clause.Where{
 | 
					 | 
				
			||||||
		Exprs: []clause.Expression{
 | 
					 | 
				
			||||||
			clause.OrConditions{
 | 
					 | 
				
			||||||
				Exprs: []clause.Expression{
 | 
					 | 
				
			||||||
					clause.Expr{SQL: "role = ?", Vars: []interface{}{"super_admin"}},
 | 
					 | 
				
			||||||
					clause.Expr{SQL: "role = ?", Vars: []interface{}{"admin"}},
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
	result = dryDB.Where(sub).Find(&User{})
 | 
					 | 
				
			||||||
	if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) {
 | 
					 | 
				
			||||||
		t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	result = dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{})
 | 
						result = dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{})
 | 
				
			||||||
	if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) {
 | 
						if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) {
 | 
				
			||||||
		t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
 | 
							t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
 | 
				
			||||||
@ -875,28 +850,6 @@ func TestOmitWithAllFields(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestMapColumns(t *testing.T) {
 | 
					 | 
				
			||||||
	user := User{Name: "MapColumnsUser", Age: 12}
 | 
					 | 
				
			||||||
	DB.Save(&user)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type result struct {
 | 
					 | 
				
			||||||
		Name     string
 | 
					 | 
				
			||||||
		Nickname string
 | 
					 | 
				
			||||||
		Age      uint
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	var res result
 | 
					 | 
				
			||||||
	DB.Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "nickname"}).Scan(&res)
 | 
					 | 
				
			||||||
	if res.Nickname != user.Name {
 | 
					 | 
				
			||||||
		t.Errorf("Expected res.Nickname to be %s, but got %s", user.Name, res.Nickname)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if res.Name != "" {
 | 
					 | 
				
			||||||
		t.Errorf("Expected res.Name to be empty, but got %s", res.Name)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if res.Age != user.Age {
 | 
					 | 
				
			||||||
		t.Errorf("Expected res.Age to be %d, but got %d", user.Age, res.Age)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestPluckWithSelect(t *testing.T) {
 | 
					func TestPluckWithSelect(t *testing.T) {
 | 
				
			||||||
	users := []User{
 | 
						users := []User{
 | 
				
			||||||
		{Name: "pluck_with_select_1", Age: 25},
 | 
							{Name: "pluck_with_select_1", Age: 25},
 | 
				
			||||||
@ -1127,10 +1080,6 @@ func TestSearchWithMap(t *testing.T) {
 | 
				
			|||||||
	DB.First(&user, map[string]interface{}{"name": users[0].Name})
 | 
						DB.First(&user, map[string]interface{}{"name": users[0].Name})
 | 
				
			||||||
	CheckUser(t, user, users[0])
 | 
						CheckUser(t, user, users[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	user = User{}
 | 
					 | 
				
			||||||
	DB.First(&user, map[string]interface{}{"users.name": users[0].Name})
 | 
					 | 
				
			||||||
	CheckUser(t, user, users[0])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	user = User{}
 | 
						user = User{}
 | 
				
			||||||
	DB.Where(map[string]interface{}{"name": users[1].Name}).First(&user)
 | 
						DB.Where(map[string]interface{}{"name": users[1].Name}).First(&user)
 | 
				
			||||||
	CheckUser(t, user, users[1])
 | 
						CheckUser(t, user, users[1])
 | 
				
			||||||
@ -1169,12 +1118,12 @@ func TestSearchWithStruct(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{})
 | 
						result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{})
 | 
				
			||||||
	if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
 | 
						if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
 | 
				
			||||||
		t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
 | 
							t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{})
 | 
						result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{})
 | 
				
			||||||
	if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
 | 
						if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
 | 
				
			||||||
		t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
 | 
							t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1235,6 +1184,7 @@ func TestSubQueryWithRaw(t *testing.T) {
 | 
				
			|||||||
			Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}).
 | 
								Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}).
 | 
				
			||||||
			Group("name"),
 | 
								Group("name"),
 | 
				
			||||||
	).Count(&count).Error
 | 
						).Count(&count).Error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Errorf("Expected to get no errors, but got %v", err)
 | 
							t.Errorf("Expected to get no errors, but got %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -1250,6 +1200,7 @@ func TestSubQueryWithRaw(t *testing.T) {
 | 
				
			|||||||
			Not("age <= ?", 10).Not("name IN (?)", []string{"subquery_raw_1", "subquery_raw_3"}).
 | 
								Not("age <= ?", 10).Not("name IN (?)", []string{"subquery_raw_1", "subquery_raw_3"}).
 | 
				
			||||||
			Group("name"),
 | 
								Group("name"),
 | 
				
			||||||
	).Count(&count).Error
 | 
						).Count(&count).Error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Errorf("Expected to get no errors, but got %v", err)
 | 
							t.Errorf("Expected to get no errors, but got %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -1376,7 +1327,7 @@ func TestQueryResetNullValue(t *testing.T) {
 | 
				
			|||||||
		Number1 int64      `gorm:"default:NULL"`
 | 
							Number1 int64      `gorm:"default:NULL"`
 | 
				
			||||||
		Number2 uint64     `gorm:"default:NULL"`
 | 
							Number2 uint64     `gorm:"default:NULL"`
 | 
				
			||||||
		Number3 float64    `gorm:"default:NULL"`
 | 
							Number3 float64    `gorm:"default:NULL"`
 | 
				
			||||||
		Now     *time.Time `gorm:"default:NULL"`
 | 
							Now     *time.Time `gorm:"defalut:NULL"`
 | 
				
			||||||
		Item1Id string
 | 
							Item1Id string
 | 
				
			||||||
		Item1   *QueryResetItem `gorm:"references:ID"`
 | 
							Item1   *QueryResetItem `gorm:"references:ID"`
 | 
				
			||||||
		Item2Id string
 | 
							Item2Id string
 | 
				
			||||||
@ -1453,22 +1404,3 @@ func TestQueryError(t *testing.T) {
 | 
				
			|||||||
	}, Value: 1}).Scan(&p2).Error
 | 
						}, Value: 1}).Scan(&p2).Error
 | 
				
			||||||
	AssertEqual(t, err, gorm.ErrModelValueRequired)
 | 
						AssertEqual(t, err, gorm.ErrModelValueRequired)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestQueryScanToArray(t *testing.T) {
 | 
					 | 
				
			||||||
	err := DB.Create(&User{Name: "testname1", Age: 10}).Error
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatal(err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	users := [2]*User{{Name: "1"}, {Name: "2"}}
 | 
					 | 
				
			||||||
	err = DB.Model(&User{}).Where("name = ?", "testname1").Find(&users).Error
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatal(err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if users[0] == nil || users[0].Name != "testname1" {
 | 
					 | 
				
			||||||
		t.Error("users[0] not covered")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if users[1] != nil {
 | 
					 | 
				
			||||||
		t.Error("users[1] should be empty")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -5,7 +5,6 @@ import (
 | 
				
			|||||||
	"sort"
 | 
						"sort"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	. "gorm.io/gorm/utils/tests"
 | 
						. "gorm.io/gorm/utils/tests"
 | 
				
			||||||
@ -127,7 +126,7 @@ func TestScanRows(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
 | 
						rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Errorf("No error should happen, got %v", err)
 | 
							t.Errorf("Not error should happen, got %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	type Result struct {
 | 
						type Result struct {
 | 
				
			||||||
@ -149,7 +148,7 @@ func TestScanRows(t *testing.T) {
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
 | 
						if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
 | 
				
			||||||
		t.Errorf("Should find expected results, got %+v", results)
 | 
							t.Errorf("Should find expected results")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var ages int
 | 
						var ages int
 | 
				
			||||||
@ -159,105 +158,7 @@ func TestScanRows(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	var name string
 | 
						var name string
 | 
				
			||||||
	if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name {
 | 
						if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name {
 | 
				
			||||||
		t.Fatalf("failed to scan name, got error %v, name: %v", err, name)
 | 
							t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name)
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestScanRowsNullValuesScanToFieldDefault(t *testing.T) {
 | 
					 | 
				
			||||||
	DB.Save(&User{})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	rows, err := DB.Table("users").
 | 
					 | 
				
			||||||
		Select(`
 | 
					 | 
				
			||||||
			NULL AS bool_field,
 | 
					 | 
				
			||||||
			NULL AS int_field,
 | 
					 | 
				
			||||||
			NULL AS int8_field,
 | 
					 | 
				
			||||||
			NULL AS int16_field,
 | 
					 | 
				
			||||||
			NULL AS int32_field,
 | 
					 | 
				
			||||||
			NULL AS int64_field,
 | 
					 | 
				
			||||||
			NULL AS uint_field,
 | 
					 | 
				
			||||||
			NULL AS uint8_field,
 | 
					 | 
				
			||||||
			NULL AS uint16_field,
 | 
					 | 
				
			||||||
			NULL AS uint32_field,
 | 
					 | 
				
			||||||
			NULL AS uint64_field,
 | 
					 | 
				
			||||||
			NULL AS float32_field,
 | 
					 | 
				
			||||||
			NULL AS float64_field,
 | 
					 | 
				
			||||||
			NULL AS string_field,
 | 
					 | 
				
			||||||
			NULL AS time_field,
 | 
					 | 
				
			||||||
			NULL AS time_ptr_field,
 | 
					 | 
				
			||||||
			NULL AS embedded_int_field,
 | 
					 | 
				
			||||||
			NULL AS nested_embedded_int_field,
 | 
					 | 
				
			||||||
			NULL AS embedded_ptr_int_field
 | 
					 | 
				
			||||||
        `).Rows()
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("No error should happen, got %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type NestedEmbeddedStruct struct {
 | 
					 | 
				
			||||||
		NestedEmbeddedIntField            int
 | 
					 | 
				
			||||||
		NestedEmbeddedIntFieldWithDefault int `gorm:"default:2"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type EmbeddedStruct struct {
 | 
					 | 
				
			||||||
		EmbeddedIntField     int
 | 
					 | 
				
			||||||
		NestedEmbeddedStruct `gorm:"embedded"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type EmbeddedPtrStruct struct {
 | 
					 | 
				
			||||||
		EmbeddedPtrIntField   int
 | 
					 | 
				
			||||||
		*NestedEmbeddedStruct `gorm:"embedded"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type Result struct {
 | 
					 | 
				
			||||||
		BoolField          bool
 | 
					 | 
				
			||||||
		IntField           int
 | 
					 | 
				
			||||||
		Int8Field          int8
 | 
					 | 
				
			||||||
		Int16Field         int16
 | 
					 | 
				
			||||||
		Int32Field         int32
 | 
					 | 
				
			||||||
		Int64Field         int64
 | 
					 | 
				
			||||||
		UIntField          uint
 | 
					 | 
				
			||||||
		UInt8Field         uint8
 | 
					 | 
				
			||||||
		UInt16Field        uint16
 | 
					 | 
				
			||||||
		UInt32Field        uint32
 | 
					 | 
				
			||||||
		UInt64Field        uint64
 | 
					 | 
				
			||||||
		Float32Field       float32
 | 
					 | 
				
			||||||
		Float64Field       float64
 | 
					 | 
				
			||||||
		StringField        string
 | 
					 | 
				
			||||||
		TimeField          time.Time
 | 
					 | 
				
			||||||
		TimePtrField       *time.Time
 | 
					 | 
				
			||||||
		EmbeddedStruct     `gorm:"embedded"`
 | 
					 | 
				
			||||||
		*EmbeddedPtrStruct `gorm:"embedded"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	currTime := time.Now()
 | 
					 | 
				
			||||||
	reusedVar := Result{
 | 
					 | 
				
			||||||
		BoolField:         true,
 | 
					 | 
				
			||||||
		IntField:          1,
 | 
					 | 
				
			||||||
		Int8Field:         1,
 | 
					 | 
				
			||||||
		Int16Field:        1,
 | 
					 | 
				
			||||||
		Int32Field:        1,
 | 
					 | 
				
			||||||
		Int64Field:        1,
 | 
					 | 
				
			||||||
		UIntField:         1,
 | 
					 | 
				
			||||||
		UInt8Field:        1,
 | 
					 | 
				
			||||||
		UInt16Field:       1,
 | 
					 | 
				
			||||||
		UInt32Field:       1,
 | 
					 | 
				
			||||||
		UInt64Field:       1,
 | 
					 | 
				
			||||||
		Float32Field:      1.1,
 | 
					 | 
				
			||||||
		Float64Field:      1.1,
 | 
					 | 
				
			||||||
		StringField:       "hello",
 | 
					 | 
				
			||||||
		TimeField:         currTime,
 | 
					 | 
				
			||||||
		TimePtrField:      &currTime,
 | 
					 | 
				
			||||||
		EmbeddedStruct:    EmbeddedStruct{EmbeddedIntField: 1, NestedEmbeddedStruct: NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}},
 | 
					 | 
				
			||||||
		EmbeddedPtrStruct: &EmbeddedPtrStruct{EmbeddedPtrIntField: 1, NestedEmbeddedStruct: &NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for rows.Next() {
 | 
					 | 
				
			||||||
		if err := DB.ScanRows(rows, &reusedVar); err != nil {
 | 
					 | 
				
			||||||
			t.Errorf("should get no error, but got %v", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !reflect.DeepEqual(reusedVar, Result{}) {
 | 
					 | 
				
			||||||
		t.Errorf("Should find zero values in struct fields, got %+v\n", reusedVar)
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -84,9 +84,7 @@ func TestComplexScopes(t *testing.T) {
 | 
				
			|||||||
			queryFn: func(tx *gorm.DB) *gorm.DB {
 | 
								queryFn: func(tx *gorm.DB) *gorm.DB {
 | 
				
			||||||
				return tx.Scopes(
 | 
									return tx.Scopes(
 | 
				
			||||||
					func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
 | 
										func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
 | 
				
			||||||
					func(d *gorm.DB) *gorm.DB {
 | 
										func(d *gorm.DB) *gorm.DB { return d.Where(d.Or("b = 2").Or("c = 3")) },
 | 
				
			||||||
						return d.Where(DB.Or("b = 2").Or("c = 3"))
 | 
					 | 
				
			||||||
					},
 | 
					 | 
				
			||||||
				).Find(&Language{})
 | 
									).Find(&Language{})
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`,
 | 
								expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`,
 | 
				
			||||||
@ -95,9 +93,7 @@ func TestComplexScopes(t *testing.T) {
 | 
				
			|||||||
			queryFn: func(tx *gorm.DB) *gorm.DB {
 | 
								queryFn: func(tx *gorm.DB) *gorm.DB {
 | 
				
			||||||
				return tx.Where("z = 0").Scopes(
 | 
									return tx.Where("z = 0").Scopes(
 | 
				
			||||||
					func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
 | 
										func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
 | 
				
			||||||
					func(d *gorm.DB) *gorm.DB {
 | 
										func(d *gorm.DB) *gorm.DB { return d.Or(d.Where("b = 2").Or("c = 3")) },
 | 
				
			||||||
						return d.Or(DB.Where("b = 2").Or("c = 3"))
 | 
					 | 
				
			||||||
					},
 | 
					 | 
				
			||||||
				).Find(&Language{})
 | 
									).Find(&Language{})
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`,
 | 
								expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`,
 | 
				
			||||||
@ -108,7 +104,7 @@ func TestComplexScopes(t *testing.T) {
 | 
				
			|||||||
					func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) },
 | 
										func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) },
 | 
				
			||||||
					func(d *gorm.DB) *gorm.DB {
 | 
										func(d *gorm.DB) *gorm.DB {
 | 
				
			||||||
						return d.
 | 
											return d.
 | 
				
			||||||
							Or(DB.Scopes(
 | 
												Or(d.Scopes(
 | 
				
			||||||
								func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
 | 
													func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
 | 
				
			||||||
								func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") },
 | 
													func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") },
 | 
				
			||||||
							)).
 | 
												)).
 | 
				
			||||||
 | 
				
			|||||||
@ -45,7 +45,7 @@ type SerializerPostgresStruct struct {
 | 
				
			|||||||
func (*SerializerPostgresStruct) TableName() string { return "serializer_structs" }
 | 
					func (*SerializerPostgresStruct) TableName() string { return "serializer_structs" }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func adaptorSerializerModel(s *SerializerStruct) interface{} {
 | 
					func adaptorSerializerModel(s *SerializerStruct) interface{} {
 | 
				
			||||||
	if DB.Dialector.Name() == "postgres" || DB.Dialector.Name() == "gaussdb" {
 | 
						if DB.Dialector.Name() == "postgres" {
 | 
				
			||||||
		sps := SerializerPostgresStruct(*s)
 | 
							sps := SerializerPostgresStruct(*s)
 | 
				
			||||||
		return &sps
 | 
							return &sps
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -388,7 +388,7 @@ func TestToSQL(t *testing.T) {
 | 
				
			|||||||
	sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
 | 
						sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
 | 
				
			||||||
		return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}).Limit(10).Offset(5).Order("name ASC").First(&User{})
 | 
							return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}).Limit(10).Offset(5).Order("name ASC").First(&User{})
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	assertEqualSQL(t, `SELECT * FROM "users" WHERE ("users"."name" = 'foo' AND "users"."age" = 20) AND "users"."deleted_at" IS NULL ORDER BY name ASC,"users"."id" LIMIT 1 OFFSET 5`, sql)
 | 
						assertEqualSQL(t, `SELECT * FROM "users" WHERE "users"."name" = 'foo' AND "users"."age" = 20 AND "users"."deleted_at" IS NULL ORDER BY name ASC,"users"."id" LIMIT 1 OFFSET 5`, sql)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// last and unscoped
 | 
						// last and unscoped
 | 
				
			||||||
	sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
 | 
						sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
 | 
				
			||||||
@ -487,7 +487,7 @@ func replaceQuoteInSQL(sql string) string {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// convert dialect special quote into double quote
 | 
						// convert dialect special quote into double quote
 | 
				
			||||||
	switch DB.Dialector.Name() {
 | 
						switch DB.Dialector.Name() {
 | 
				
			||||||
	case "postgres", "gaussdb":
 | 
						case "postgres":
 | 
				
			||||||
		sql = strings.ReplaceAll(sql, `"`, `"`)
 | 
							sql = strings.ReplaceAll(sql, `"`, `"`)
 | 
				
			||||||
	case "mysql", "sqlite":
 | 
						case "mysql", "sqlite":
 | 
				
			||||||
		sql = strings.ReplaceAll(sql, "`", `"`)
 | 
							sql = strings.ReplaceAll(sql, "`", `"`)
 | 
				
			||||||
 | 
				
			|||||||
@ -2,11 +2,8 @@ package tests_test
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
	"sync"
 | 
					 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gorm.io/driver/gaussdb"
 | 
					 | 
				
			||||||
	"gorm.io/driver/postgres"
 | 
					 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	"gorm.io/gorm/schema"
 | 
						"gorm.io/gorm/schema"
 | 
				
			||||||
	"gorm.io/gorm/utils/tests"
 | 
						"gorm.io/gorm/utils/tests"
 | 
				
			||||||
@ -175,164 +172,3 @@ func TestTableWithNamer(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("Table with namer, got %v", sql)
 | 
							t.Errorf("Table with namer, got %v", sql)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestPostgresTableWithIdentifierLength(t *testing.T) {
 | 
					 | 
				
			||||||
	if DB.Dialector.Name() != "postgres" {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type LongString struct {
 | 
					 | 
				
			||||||
		ThisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString string `gorm:"unique"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	t.Run("default", func(t *testing.T) {
 | 
					 | 
				
			||||||
		db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{})
 | 
					 | 
				
			||||||
		user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("failed to parse user unique, got error %v", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		constraints := user.ParseUniqueConstraints()
 | 
					 | 
				
			||||||
		if len(constraints) != 1 {
 | 
					 | 
				
			||||||
			t.Fatalf("failed to find unique constraint, got %v", constraints)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		for key := range constraints {
 | 
					 | 
				
			||||||
			if len(key) != 63 {
 | 
					 | 
				
			||||||
				t.Errorf("failed to find unique constraint, got %v", constraints)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	t.Run("naming strategy", func(t *testing.T) {
 | 
					 | 
				
			||||||
		db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{
 | 
					 | 
				
			||||||
			NamingStrategy: schema.NamingStrategy{},
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("failed to parse user unique, got error %v", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		constraints := user.ParseUniqueConstraints()
 | 
					 | 
				
			||||||
		if len(constraints) != 1 {
 | 
					 | 
				
			||||||
			t.Fatalf("failed to find unique constraint, got %v", constraints)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		for key := range constraints {
 | 
					 | 
				
			||||||
			if len(key) != 63 {
 | 
					 | 
				
			||||||
				t.Errorf("failed to find unique constraint, got %v", constraints)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	t.Run("namer", func(t *testing.T) {
 | 
					 | 
				
			||||||
		uname := "custom_unique_name"
 | 
					 | 
				
			||||||
		db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{
 | 
					 | 
				
			||||||
			NamingStrategy: mockUniqueNamingStrategy{
 | 
					 | 
				
			||||||
				UName: uname,
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("failed to parse user unique, got error %v", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		constraints := user.ParseUniqueConstraints()
 | 
					 | 
				
			||||||
		if len(constraints) != 1 {
 | 
					 | 
				
			||||||
			t.Fatalf("failed to find unique constraint, got %v", constraints)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		for key := range constraints {
 | 
					 | 
				
			||||||
			if key != uname {
 | 
					 | 
				
			||||||
				t.Errorf("failed to find unique constraint, got %v", constraints)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestGaussDBTableWithIdentifierLength(t *testing.T) {
 | 
					 | 
				
			||||||
	if DB.Dialector.Name() != "gaussdb" {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	type LongString struct {
 | 
					 | 
				
			||||||
		ThisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString string `gorm:"unique"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	t.Run("default", func(t *testing.T) {
 | 
					 | 
				
			||||||
		db, _ := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{})
 | 
					 | 
				
			||||||
		user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("failed to parse user unique, got error %v", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		constraints := user.ParseUniqueConstraints()
 | 
					 | 
				
			||||||
		if len(constraints) != 1 {
 | 
					 | 
				
			||||||
			t.Fatalf("failed to find unique constraint, got %v", constraints)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		for key := range constraints {
 | 
					 | 
				
			||||||
			if len(key) != 63 {
 | 
					 | 
				
			||||||
				t.Errorf("failed to find unique constraint, got %v", constraints)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	t.Run("naming strategy", func(t *testing.T) {
 | 
					 | 
				
			||||||
		db, _ := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{
 | 
					 | 
				
			||||||
			NamingStrategy: schema.NamingStrategy{},
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("failed to parse user unique, got error %v", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		constraints := user.ParseUniqueConstraints()
 | 
					 | 
				
			||||||
		if len(constraints) != 1 {
 | 
					 | 
				
			||||||
			t.Fatalf("failed to find unique constraint, got %v", constraints)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		for key := range constraints {
 | 
					 | 
				
			||||||
			if len(key) != 63 {
 | 
					 | 
				
			||||||
				t.Errorf("failed to find unique constraint, got %v", constraints)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	t.Run("namer", func(t *testing.T) {
 | 
					 | 
				
			||||||
		uname := "custom_unique_name"
 | 
					 | 
				
			||||||
		db, _ := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{
 | 
					 | 
				
			||||||
			NamingStrategy: mockUniqueNamingStrategy{
 | 
					 | 
				
			||||||
				UName: uname,
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("failed to parse user unique, got error %v", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		constraints := user.ParseUniqueConstraints()
 | 
					 | 
				
			||||||
		if len(constraints) != 1 {
 | 
					 | 
				
			||||||
			t.Fatalf("failed to find unique constraint, got %v", constraints)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		for key := range constraints {
 | 
					 | 
				
			||||||
			if key != uname {
 | 
					 | 
				
			||||||
				t.Errorf("failed to find unique constraint, got %v", constraints)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type mockUniqueNamingStrategy struct {
 | 
					 | 
				
			||||||
	UName string
 | 
					 | 
				
			||||||
	schema.NamingStrategy
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (a mockUniqueNamingStrategy) UniqueName(table, column string) string {
 | 
					 | 
				
			||||||
	return a.UName
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,6 @@
 | 
				
			|||||||
#!/bin/bash -e
 | 
					#!/bin/bash -e
 | 
				
			||||||
 | 
					
 | 
				
			||||||
dialects=("sqlite" "mysql" "postgres" "gaussdb" "sqlserver" "tidb")
 | 
					dialects=("sqlite" "mysql" "postgres" "sqlserver" "tidb")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if [[ $(pwd) == *"gorm/tests"* ]]; then
 | 
					if [[ $(pwd) == *"gorm/tests"* ]]; then
 | 
				
			||||||
  cd ..
 | 
					  cd ..
 | 
				
			||||||
@ -16,22 +16,21 @@ then
 | 
				
			|||||||
fi
 | 
					fi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# SqlServer for Mac M1
 | 
					# SqlServer for Mac M1
 | 
				
			||||||
if [[ -z $GITHUB_ACTION && -d tests ]]; then
 | 
					if [[ -z $GITHUB_ACTION ]]; then
 | 
				
			||||||
  cd tests
 | 
					  if [ -d tests ]
 | 
				
			||||||
  if [[ $(uname -a) == *" arm64" ]]; then
 | 
					  then
 | 
				
			||||||
    MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker compose up -d --wait
 | 
					    cd tests
 | 
				
			||||||
    go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest || true
 | 
					    if [[ $(uname -a) == *" arm64" ]]; then
 | 
				
			||||||
    for query in \
 | 
					      MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker-compose start || true
 | 
				
			||||||
      "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" \
 | 
					      go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest || true
 | 
				
			||||||
      "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" \
 | 
					      SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" > /dev/null || true
 | 
				
			||||||
      "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];"
 | 
					      SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" > /dev/null || true
 | 
				
			||||||
    do
 | 
					      SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" > /dev/null || true
 | 
				
			||||||
      SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "$query" > /dev/null || true
 | 
					    else
 | 
				
			||||||
    done
 | 
					      docker-compose start
 | 
				
			||||||
  else
 | 
					    fi
 | 
				
			||||||
    MSSQL_IMAGE=mcr.microsoft.com/mssql/server docker compose up -d --wait
 | 
					    cd ..
 | 
				
			||||||
  fi
 | 
					  fi
 | 
				
			||||||
  cd ..
 | 
					 | 
				
			||||||
fi
 | 
					fi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -1,4 +1,3 @@
 | 
				
			|||||||
//go:debug x509negativeserial=1
 | 
					 | 
				
			||||||
package tests_test
 | 
					package tests_test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
@ -8,7 +7,6 @@ import (
 | 
				
			|||||||
	"path/filepath"
 | 
						"path/filepath"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gorm.io/driver/gaussdb"
 | 
					 | 
				
			||||||
	"gorm.io/driver/mysql"
 | 
						"gorm.io/driver/mysql"
 | 
				
			||||||
	"gorm.io/driver/postgres"
 | 
						"gorm.io/driver/postgres"
 | 
				
			||||||
	"gorm.io/driver/sqlite"
 | 
						"gorm.io/driver/sqlite"
 | 
				
			||||||
@ -22,8 +20,7 @@ var DB *gorm.DB
 | 
				
			|||||||
var (
 | 
					var (
 | 
				
			||||||
	mysqlDSN     = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
 | 
						mysqlDSN     = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
 | 
				
			||||||
	postgresDSN  = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai"
 | 
						postgresDSN  = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai"
 | 
				
			||||||
	gaussdbDSN   = "user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai"
 | 
						sqlserverDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
 | 
				
			||||||
	sqlserverDSN = "sqlserver://sa:LoremIpsum86@localhost:9930?database=master"
 | 
					 | 
				
			||||||
	tidbDSN      = "root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local"
 | 
						tidbDSN      = "root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -46,6 +43,9 @@ func init() {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		RunMigrations()
 | 
							RunMigrations()
 | 
				
			||||||
 | 
							if DB.Dialector.Name() == "sqlite" {
 | 
				
			||||||
 | 
								DB.Exec("PRAGMA foreign_keys = ON")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -67,15 +67,6 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
 | 
				
			|||||||
			DSN:                  dbDSN,
 | 
								DSN:                  dbDSN,
 | 
				
			||||||
			PreferSimpleProtocol: true,
 | 
								PreferSimpleProtocol: true,
 | 
				
			||||||
		}), cfg)
 | 
							}), cfg)
 | 
				
			||||||
	case "gaussdb":
 | 
					 | 
				
			||||||
		log.Println("testing gaussdb...")
 | 
					 | 
				
			||||||
		if dbDSN == "" {
 | 
					 | 
				
			||||||
			dbDSN = gaussdbDSN
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		db, err = gorm.Open(gaussdb.New(gaussdb.Config{
 | 
					 | 
				
			||||||
			DSN:                  dbDSN,
 | 
					 | 
				
			||||||
			PreferSimpleProtocol: true,
 | 
					 | 
				
			||||||
		}), cfg)
 | 
					 | 
				
			||||||
	case "sqlserver":
 | 
						case "sqlserver":
 | 
				
			||||||
		// go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest
 | 
							// go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest
 | 
				
			||||||
		// SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930
 | 
							// SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930
 | 
				
			||||||
@ -98,10 +89,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
 | 
				
			|||||||
		db, err = gorm.Open(mysql.Open(dbDSN), cfg)
 | 
							db, err = gorm.Open(mysql.Open(dbDSN), cfg)
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		log.Println("testing sqlite3...")
 | 
							log.Println("testing sqlite3...")
 | 
				
			||||||
		db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), cfg)
 | 
							db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg)
 | 
				
			||||||
		if err == nil {
 | 
					 | 
				
			||||||
			db.Exec("PRAGMA foreign_keys = ON")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@ -119,7 +107,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func RunMigrations() {
 | 
					func RunMigrations() {
 | 
				
			||||||
	var err error
 | 
						var err error
 | 
				
			||||||
	allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Tools{}}
 | 
						allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}}
 | 
				
			||||||
	rand.Seed(time.Now().UnixNano())
 | 
						rand.Seed(time.Now().UnixNano())
 | 
				
			||||||
	rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
 | 
						rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -4,7 +4,6 @@ import (
 | 
				
			|||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	. "gorm.io/gorm/utils/tests"
 | 
						. "gorm.io/gorm/utils/tests"
 | 
				
			||||||
@ -68,7 +67,7 @@ func TestTransaction(t *testing.T) {
 | 
				
			|||||||
				return tx5.First(&User{}, "name = ?", "transaction-2").Error
 | 
									return tx5.First(&User{}, "name = ?", "transaction-2").Error
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
		}); err != nil {
 | 
							}); err != nil {
 | 
				
			||||||
			t.Fatalf("prepare statement and nested transaction coexist" + err.Error())
 | 
								t.Fatalf("prepare statement and nested transcation coexist" + err.Error())
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -298,74 +297,6 @@ func TestNestedTransactionWithBlock(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestDeeplyNestedTransactionWithBlockAndWrappedCallback(t *testing.T) {
 | 
					 | 
				
			||||||
	transaction := func(ctx context.Context, db *gorm.DB, callback func(ctx context.Context, db *gorm.DB) error) error {
 | 
					 | 
				
			||||||
		return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
					 | 
				
			||||||
			return callback(ctx, tx)
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	var (
 | 
					 | 
				
			||||||
		user  = *GetUser("transaction-nested", Config{})
 | 
					 | 
				
			||||||
		user1 = *GetUser("transaction-nested-1", Config{})
 | 
					 | 
				
			||||||
		user2 = *GetUser("transaction-nested-2", Config{})
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := transaction(context.Background(), DB, func(ctx context.Context, tx *gorm.DB) error {
 | 
					 | 
				
			||||||
		tx.Create(&user)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("Should find saved record")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if err := transaction(ctx, tx, func(ctx context.Context, tx1 *gorm.DB) error {
 | 
					 | 
				
			||||||
			tx1.Create(&user1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil {
 | 
					 | 
				
			||||||
				t.Fatalf("Should find saved record")
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if err := transaction(ctx, tx1, func(ctx context.Context, tx2 *gorm.DB) error {
 | 
					 | 
				
			||||||
				tx2.Create(&user2)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil {
 | 
					 | 
				
			||||||
					t.Fatalf("Should find saved record")
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				return errors.New("inner rollback")
 | 
					 | 
				
			||||||
			}); err == nil {
 | 
					 | 
				
			||||||
				t.Fatalf("nested transaction has no error")
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			return errors.New("rollback")
 | 
					 | 
				
			||||||
		}); err == nil {
 | 
					 | 
				
			||||||
			t.Fatalf("nested transaction should returns error")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil {
 | 
					 | 
				
			||||||
			t.Fatalf("Should not find rollbacked record")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil {
 | 
					 | 
				
			||||||
			t.Fatalf("Should find saved record")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}); err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("no error should return, but got %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Should find saved record")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Should not find rollbacked parent record")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("Should not find rollbacked nested record")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestDisabledNestedTransaction(t *testing.T) {
 | 
					func TestDisabledNestedTransaction(t *testing.T) {
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		user  = *GetUser("transaction-nested", Config{})
 | 
							user  = *GetUser("transaction-nested", Config{})
 | 
				
			||||||
@ -460,6 +391,7 @@ func TestTransactionWithHooks(t *testing.T) {
 | 
				
			|||||||
			return tx2.Scan(&User{}).Error
 | 
								return tx2.Scan(&User{}).Error
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Error(err)
 | 
							t.Error(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -473,20 +405,8 @@ func TestTransactionWithHooks(t *testing.T) {
 | 
				
			|||||||
			return tx3.Where("user_id", user.ID).Delete(&Account{}).Error
 | 
								return tx3.Where("user_id", user.ID).Delete(&Account{}).Error
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Error(err)
 | 
							t.Error(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestTransactionWithDefaultTimeout(t *testing.T) {
 | 
					 | 
				
			||||||
	db, err := OpenTestConnection(&gorm.Config{DefaultTransactionTimeout: 2 * time.Second})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		t.Fatalf("failed to connect database, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	tx := db.Begin()
 | 
					 | 
				
			||||||
	time.Sleep(3 * time.Second)
 | 
					 | 
				
			||||||
	if err = tx.Find(&User{}).Error; err == nil {
 | 
					 | 
				
			||||||
		t.Errorf("should return error when transaction timeout, got error %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user