diff --git a/finisher_api.go b/finisher_api.go index ec5b6946..bc6b9324 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -139,8 +139,14 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1) if len(conds) > 0 { - if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: exprs}) + pkValue, ok := conds[0].(string) + if len(conds) == 1 && ok { + cond := []clause.Expression{clause.IN{Column: clause.PrimaryColumn, Values: []interface{}{pkValue}}} + tx.Statement.AddClause(clause.Where{Exprs: cond}) + } else { + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } } } tx.Statement.RaiseErrorOnNotFound = true @@ -155,8 +161,14 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { Desc: true, }) if len(conds) > 0 { - if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: exprs}) + pkValue, ok := conds[0].(string) + if len(conds) == 1 && ok { + cond := []clause.Expression{clause.IN{Column: clause.PrimaryColumn, Values: []interface{}{pkValue}}} + tx.Statement.AddClause(clause.Where{Exprs: cond}) + } else { + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } } } tx.Statement.RaiseErrorOnNotFound = true