diff --git a/.codeclimate.yml b/.codeclimate.yml deleted file mode 100644 index 51aba50..0000000 --- a/.codeclimate.yml +++ /dev/null @@ -1,11 +0,0 @@ ---- -engines: - gofmt: - enabled: true - govet: - enabled: true - golint: - enabled: true -ratings: - paths: - - "**.go" diff --git a/.gitignore b/.gitignore index 01dc5ce..117f92f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ documents +coverage.txt _book diff --git a/README.md b/README.md index 0c5c7ea..85588a7 100644 --- a/README.md +++ b/README.md @@ -1,40 +1,5 @@ # GORM -The fantastic ORM library for Golang, aims to be developer friendly. +GORM V2 moved to https://github.com/go-gorm/gorm -[![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm) -[![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b) -[![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -[![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm) -[![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm) -[![MIT license](http://img.shields.io/badge/license-MIT-brightgreen.svg)](http://opensource.org/licenses/MIT) -[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) - -## Overview - -* Full-Featured ORM (almost) -* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism) -* Hooks (Before/After Create/Save/Update/Delete/Find) -* Preloading (eager loading) -* Transactions -* Composite Primary Key -* SQL Builder -* Auto Migrations -* Logger -* Extendable, write Plugins based on GORM callbacks -* Every feature comes with tests -* Developer Friendly - -## Getting Started - -* GORM Guides [http://gorm.io](http://gorm.io) - -## Contributing - -[You can help to deliver a better GORM, check out things you can do](http://gorm.io/contribute.html) - -## License - -© Jinzhu, 2013~time.Now - -Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License) +GORM V1 Doc https://v1.gorm.io/ diff --git a/association.go b/association.go index 8c6d986..a73344f 100644 --- a/association.go +++ b/association.go @@ -267,15 +267,16 @@ query = scope.DB() ) - if relationship.Kind == "many_to_many" { + switch relationship.Kind { + case "many_to_many": query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { + case "has_many", "has_one": primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) query = query.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)..., ) - } else if relationship.Kind == "belongs_to" { + case "belongs_to": primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) query = query.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), @@ -367,6 +368,7 @@ return association } +// setErr set error when the error is not nil. And return Association. func (association *Association) setErr(err error) *Association { if err != nil { association.Error = err diff --git a/callback.go b/callback.go index a438214..1f0e3c7 100644 --- a/callback.go +++ b/callback.go @@ -1,9 +1,9 @@ package gorm -import "log" +import "fmt" // DefaultCallback default callbacks defined by gorm -var DefaultCallback = &Callback{} +var DefaultCallback = &Callback{logger: nopLogger{}} // Callback is a struct that contains all CRUD callbacks // Field `creates` contains callbacks will be call when creating object @@ -13,6 +13,7 @@ // Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... // Field `processors` contains all callback processors, will be used to generate above callbacks in order type Callback struct { + logger logger creates []*func(scope *Scope) updates []*func(scope *Scope) deletes []*func(scope *Scope) @@ -23,6 +24,7 @@ // CallbackProcessor contains callback informations type CallbackProcessor struct { + logger logger name string // current callback's name before string // register current callback before a callback after string // register current callback after a callback @@ -33,8 +35,9 @@ parent *Callback } -func (c *Callback) clone() *Callback { +func (c *Callback) clone(logger logger) *Callback { return &Callback{ + logger: logger, creates: c.creates, updates: c.updates, deletes: c.deletes, @@ -53,28 +56,28 @@ // scope.Err(errors.New("error")) // }) func (c *Callback) Create() *CallbackProcessor { - return &CallbackProcessor{kind: "create", parent: c} + return &CallbackProcessor{logger: c.logger, kind: "create", parent: c} } // Update could be used to register callbacks for updating object, refer `Create` for usage func (c *Callback) Update() *CallbackProcessor { - return &CallbackProcessor{kind: "update", parent: c} + return &CallbackProcessor{logger: c.logger, kind: "update", parent: c} } // Delete could be used to register callbacks for deleting object, refer `Create` for usage func (c *Callback) Delete() *CallbackProcessor { - return &CallbackProcessor{kind: "delete", parent: c} + return &CallbackProcessor{logger: c.logger, kind: "delete", parent: c} } // Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`... // Refer `Create` for usage func (c *Callback) Query() *CallbackProcessor { - return &CallbackProcessor{kind: "query", parent: c} + return &CallbackProcessor{logger: c.logger, kind: "query", parent: c} } // RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage func (c *Callback) RowQuery() *CallbackProcessor { - return &CallbackProcessor{kind: "row_query", parent: c} + return &CallbackProcessor{logger: c.logger, kind: "row_query", parent: c} } // After insert a new callback after callback `callbackName`, refer `Callbacks.Create` @@ -93,11 +96,12 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { if cp.kind == "row_query" { if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { - log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName) + cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName)) cp.before = "gorm:row_query" } } + cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum())) cp.name = callbackName cp.processor = &callback cp.parent.processors = append(cp.parent.processors, cp) @@ -107,7 +111,7 @@ // Remove a registered callback // db.Callback().Create().Remove("gorm:update_time_stamp_when_create") func (cp *CallbackProcessor) Remove(callbackName string) { - log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()) + cp.logger.Print("info", fmt.Sprintf("[info] removing callback `%v` from %v", callbackName, fileWithLineNum())) cp.name = callbackName cp.remove = true cp.parent.processors = append(cp.parent.processors, cp) @@ -116,11 +120,11 @@ // Replace a registered callback with new callback // db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) { -// scope.SetColumn("Created", now) -// scope.SetColumn("Updated", now) +// scope.SetColumn("CreatedAt", now) +// scope.SetColumn("UpdatedAt", now) // }) func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()) + cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum())) cp.name = callbackName cp.processor = &callback cp.replace = true @@ -132,11 +136,15 @@ // db.Callback().Create().Get("gorm:create") func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { for _, p := range cp.parent.processors { - if p.name == callbackName && p.kind == cp.kind && !cp.remove { - return *p.processor - } - } - return nil + if p.name == callbackName && p.kind == cp.kind { + if p.remove { + callback = nil + } else { + callback = *p.processor + } + } + } + return } // getRIndex get right index from string slice @@ -159,7 +167,7 @@ for _, cp := range cps { // show warning message the callback name already exists if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { - log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) + cp.logger.Print("warning", fmt.Sprintf("[warning] duplicated callback `%v` from %v", cp.name, fileWithLineNum())) } allNames = append(allNames, cp.name) } diff --git a/callback_create.go b/callback_create.go index e7fe6f8..59840f8 100644 --- a/callback_create.go +++ b/callback_create.go @@ -31,7 +31,7 @@ // updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating func updateTimeStampForCreateCallback(scope *Scope) { if !scope.HasError() { - now := NowFunc() + now := scope.db.nowFunc() if createdAtField, ok := scope.FieldByName("CreatedAt"); ok { if createdAtField.IsBlank { @@ -59,7 +59,7 @@ for _, field := range scope.Fields() { if scope.changeableField(field) { - if field.IsNormal { + if field.IsNormal && !field.IsIgnored { if field.IsBlank && field.HasDefaultValue { blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) @@ -83,21 +83,33 @@ quotedTableName = scope.QuotedTableName() primaryField = scope.PrimaryField() extraOption string + insertModifier string ) if str, ok := scope.Get("gorm:insert_option"); ok { extraOption = fmt.Sprint(str) } + if str, ok := scope.Get("gorm:insert_modifier"); ok { + insertModifier = strings.ToUpper(fmt.Sprint(str)) + if insertModifier == "INTO" { + insertModifier = "" + } + } if primaryField != nil { returningColumn = scope.Quote(primaryField.DBName) } - lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) + lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns) + var lastInsertIDReturningSuffix string + if lastInsertIDOutputInterstitial == "" { + lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) + } if len(columns) == 0 { scope.Raw(fmt.Sprintf( - "INSERT INTO %v %v%v%v", + "INSERT%v INTO %v %v%v%v", + addExtraSpaceIfExist(insertModifier), quotedTableName, scope.Dialect().DefaultValueStr(), addExtraSpaceIfExist(extraOption), @@ -105,17 +117,19 @@ )) } else { scope.Raw(fmt.Sprintf( - "INSERT INTO %v (%v) VALUES (%v)%v%v", + "INSERT%v INTO %v (%v)%v VALUES (%v)%v%v", + addExtraSpaceIfExist(insertModifier), scope.QuotedTableName(), strings.Join(columns, ","), + addExtraSpaceIfExist(lastInsertIDOutputInterstitial), strings.Join(placeholders, ","), addExtraSpaceIfExist(extraOption), addExtraSpaceIfExist(lastInsertIDReturningSuffix), )) } - // execute create sql - if lastInsertIDReturningSuffix == "" || primaryField == nil { + // execute create sql: no primaryField + if primaryField == nil { if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { // set rows affected count scope.db.RowsAffected, _ = result.RowsAffected() @@ -127,28 +141,54 @@ } } } + return + } + + // execute create sql: lastInsertID implemention for majority of dialects + if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" { + if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + // set rows affected count + scope.db.RowsAffected, _ = result.RowsAffected() + + // set primary value to primary field + if primaryField != nil && primaryField.IsBlank { + if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { + scope.Err(primaryField.Set(primaryValue)) + } + } + } + return + } + + // execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql) + if primaryField.Field.CanAddr() { + if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { + primaryField.IsBlank = false + scope.db.RowsAffected = 1 + } } else { - if primaryField.Field.CanAddr() { - if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { - primaryField.IsBlank = false - scope.db.RowsAffected = 1 - } - } else { - scope.Err(ErrUnaddressable) - } - } + scope.Err(ErrUnaddressable) + } + return } } // forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object func forceReloadAfterCreateCallback(scope *Scope) { if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { + var shouldScan bool db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string)) for _, field := range scope.Fields() { if field.IsPrimaryKey && !field.IsBlank { db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface()) - } - } + shouldScan = true + } + } + + if !shouldScan { + return + } + db.Scan(scope.Value) } } diff --git a/callback_delete.go b/callback_delete.go index 73d9088..48b97ac 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -17,7 +17,7 @@ // beforeDeleteCallback will invoke `BeforeDelete` method before deleting func beforeDeleteCallback(scope *Scope) { if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("Missing WHERE clause while deleting")) + scope.Err(errors.New("missing WHERE clause while deleting")) return } if !scope.HasError() { @@ -40,7 +40,7 @@ "UPDATE %v SET %v=%v%v%v", scope.QuotedTableName(), scope.Quote(deletedAtField.DBName), - scope.AddToVars(NowFunc()), + scope.AddToVars(scope.db.nowFunc()), addExtraSpaceIfExist(scope.CombinedConditionSql()), addExtraSpaceIfExist(extraOption), )).Exec() diff --git a/callback_query.go b/callback_query.go index ba10cc7..f756271 100644 --- a/callback_query.go +++ b/callback_query.go @@ -16,6 +16,11 @@ // queryCallback used to query data from database func queryCallback(scope *Scope) { if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { + return + } + + //we are only preloading relations, dont touch base model + if _, skip := scope.InstanceGet("gorm:only_preload"); skip { return } @@ -55,8 +60,9 @@ if !scope.HasError() { scope.db.RowsAffected = 0 - if str, ok := scope.Get("gorm:query_option"); ok { - scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) + + if str, ok := scope.Get("gorm:query_hint"); ok { + scope.SQL = fmt.Sprint(str) + scope.SQL } if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { diff --git a/callback_query_preload.go b/callback_query_preload.go index 30f6b58..a936180 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -14,8 +14,14 @@ return } - if _, ok := scope.Get("gorm:auto_preload"); ok { - autoPreload(scope) + if ap, ok := scope.Get("gorm:auto_preload"); ok { + // If gorm:auto_preload IS NOT a bool then auto preload. + // Else if it IS a bool, use the value + if apb, ok := ap.(bool); !ok { + autoPreload(scope) + } else if apb { + autoPreload(scope) + } } if scope.Search.preload == nil || scope.HasError() { @@ -94,7 +100,7 @@ continue } - if val, ok := field.TagSettings["PRELOAD"]; ok { + if val, ok := field.TagSettingsGet("PRELOAD"); ok { if preload, err := strconv.ParseBool(val); err != nil { scope.Err(errors.New("invalid preload option")) return @@ -155,14 +161,17 @@ ) if indirectScopeValue.Kind() == reflect.Slice { + foreignValuesToResults := make(map[string]reflect.Value) + for i := 0; i < resultsValue.Len(); i++ { + result := resultsValue.Index(i) + foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames)) + foreignValuesToResults[foreignValues] = result + } for j := 0; j < indirectScopeValue.Len(); j++ { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { - indirectValue.FieldByName(field.Name).Set(result) - break - } + indirectValue := indirect(indirectScopeValue.Index(j)) + valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames)) + if result, found := foreignValuesToResults[valueString]; found { + indirectValue.FieldByName(field.Name).Set(result) } } } else { @@ -249,13 +258,21 @@ indirectScopeValue = scope.IndirectValue() ) + foreignFieldToObjects := make(map[string][]*reflect.Value) + if indirectScopeValue.Kind() == reflect.Slice { + for j := 0; j < indirectScopeValue.Len(); j++ { + object := indirect(indirectScopeValue.Index(j)) + valueString := toString(getValueFromFields(object, relation.ForeignFieldNames)) + foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object) + } + } + for i := 0; i < resultsValue.Len(); i++ { result := resultsValue.Index(i) if indirectScopeValue.Kind() == reflect.Slice { - value := getValueFromFields(result, relation.AssociationForeignFieldNames) - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { + valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames)) + if objects, found := foreignFieldToObjects[valueString]; found { + for _, object := range objects { object.FieldByName(field.Name).Set(result) } } @@ -374,14 +391,20 @@ key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) } - for source, link := range linkHash { - for i, field := range fieldsSourceMap[source] { + + for source, fields := range fieldsSourceMap { + for _, f := range fields { //If not 0 this means Value is a pointer and we already added preloaded models to it - if fieldsSourceMap[source][i].Len() != 0 { + if f.Len() != 0 { continue } - field.Set(reflect.Append(fieldsSourceMap[source][i], link...)) - } - - } -} + + v := reflect.MakeSlice(f.Type(), 0, 0) + if len(linkHash[source]) > 0 { + v = reflect.Append(f, linkHash[source]...) + } + + f.Set(v) + } + } +} diff --git a/callback_row_query.go b/callback_row_query.go index c2ff4a0..43b21f8 100644 --- a/callback_row_query.go +++ b/callback_row_query.go @@ -1,6 +1,9 @@ package gorm -import "database/sql" +import ( + "database/sql" + "fmt" +) // Define callbacks for row query func init() { @@ -21,6 +24,10 @@ if result, ok := scope.InstanceGet("row_query_result"); ok { scope.prepareQuerySQL() + if str, ok := scope.Get("gorm:query_hint"); ok { + scope.SQL = fmt.Sprint(str) + scope.SQL + } + if rowResult, ok := result.(*RowQueryResult); ok { rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) } else if rowsResult, ok := result.(*RowsQueryResult); ok { diff --git a/callback_save.go b/callback_save.go index ef26714..3b4e058 100644 --- a/callback_save.go +++ b/callback_save.go @@ -21,9 +21,7 @@ if v, ok := value.(string); ok { v = strings.ToLower(v) - if v == "false" || v != "skip" { - return false - } + return v == "true" } return true @@ -36,26 +34,28 @@ if value, ok := scope.Get("gorm:save_associations"); ok { autoUpdate = checkTruth(value) autoCreate = autoUpdate - } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok { + saveReference = autoUpdate + } else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok { autoUpdate = checkTruth(value) autoCreate = autoUpdate + saveReference = autoUpdate } if value, ok := scope.Get("gorm:association_autoupdate"); ok { autoUpdate = checkTruth(value) - } else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok { + } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok { autoUpdate = checkTruth(value) } if value, ok := scope.Get("gorm:association_autocreate"); ok { autoCreate = checkTruth(value) - } else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok { + } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok { autoCreate = checkTruth(value) } if value, ok := scope.Get("gorm:association_save_reference"); ok { saveReference = checkTruth(value) - } else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { + } else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok { saveReference = checkTruth(value) } } diff --git a/callback_system_test.go b/callback_system_test.go index 13ca3f4..2482eda 100644 --- a/callback_system_test.go +++ b/callback_system_test.go @@ -23,7 +23,7 @@ func afterCreate2(s *Scope) {} func TestRegisterCallback(t *testing.T) { - var callback = &Callback{} + var callback = &Callback{logger: defaultLogger} callback.Create().Register("before_create1", beforeCreate1) callback.Create().Register("before_create2", beforeCreate2) @@ -37,7 +37,7 @@ } func TestRegisterCallbackWithOrder(t *testing.T) { - var callback1 = &Callback{} + var callback1 = &Callback{logger: defaultLogger} callback1.Create().Register("before_create1", beforeCreate1) callback1.Create().Register("create", create) callback1.Create().Register("after_create1", afterCreate1) @@ -46,7 +46,7 @@ t.Errorf("register callback with order") } - var callback2 = &Callback{} + var callback2 = &Callback{logger: defaultLogger} callback2.Update().Register("create", create) callback2.Update().Before("create").Register("before_create1", beforeCreate1) @@ -60,7 +60,7 @@ } func TestRegisterCallbackWithComplexOrder(t *testing.T) { - var callback1 = &Callback{} + var callback1 = &Callback{logger: defaultLogger} callback1.Query().Before("after_create1").After("before_create1").Register("create", create) callback1.Query().Register("before_create1", beforeCreate1) @@ -70,7 +70,7 @@ t.Errorf("register callback with order") } - var callback2 = &Callback{} + var callback2 = &Callback{logger: defaultLogger} callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) callback2.Delete().Before("create").Register("before_create1", beforeCreate1) @@ -86,7 +86,7 @@ func replaceCreate(s *Scope) {} func TestReplaceCallback(t *testing.T) { - var callback = &Callback{} + var callback = &Callback{logger: defaultLogger} callback.Create().Before("after_create1").After("before_create1").Register("create", create) callback.Create().Register("before_create1", beforeCreate1) @@ -99,7 +99,7 @@ } func TestRemoveCallback(t *testing.T) { - var callback = &Callback{} + var callback = &Callback{logger: defaultLogger} callback.Create().Before("after_create1").After("before_create1").Register("create", create) callback.Create().Register("before_create1", beforeCreate1) diff --git a/callback_update.go b/callback_update.go index 373bd72..699e534 100644 --- a/callback_update.go +++ b/callback_update.go @@ -34,7 +34,7 @@ // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating func beforeUpdateCallback(scope *Scope) { if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("Missing WHERE clause while updating")) + scope.Err(errors.New("missing WHERE clause while updating")) return } if _, ok := scope.Get("gorm:update_column"); !ok { @@ -50,7 +50,7 @@ // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating func updateTimeStampForUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { - scope.SetColumn("UpdatedAt", NowFunc()) + scope.SetColumn("UpdatedAt", scope.db.nowFunc()) } } @@ -75,8 +75,10 @@ } else { for _, field := range scope.Fields() { if scope.changeableField(field) { - if !field.IsPrimaryKey && field.IsNormal { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + if !field.IsPrimaryKey && field.IsNormal && (field.Name != "CreatedAt" || !field.IsBlank) { + if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { for _, foreignKey := range relationship.ForeignDBNames { if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { diff --git a/callbacks_test.go b/callbacks_test.go index a58913d..bebd0e3 100644 --- a/callbacks_test.go +++ b/callbacks_test.go @@ -2,11 +2,10 @@ import ( "errors" - - "github.com/jinzhu/gorm" - "reflect" "testing" + + "github.com/jinzhu/gorm" ) func (s *Product) BeforeCreate() (err error) { @@ -175,3 +174,76 @@ t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") } } + +func TestGetCallback(t *testing.T) { + scope := DB.NewScope(nil) + + if DB.Callback().Create().Get("gorm:test_callback") != nil { + t.Errorf("`gorm:test_callback` should be nil") + } + + DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) }) + callback := DB.Callback().Create().Get("gorm:test_callback") + if callback == nil { + t.Errorf("`gorm:test_callback` should be non-nil") + } + callback(scope) + if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 { + t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok) + } + + DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) }) + callback = DB.Callback().Create().Get("gorm:test_callback") + if callback == nil { + t.Errorf("`gorm:test_callback` should be non-nil") + } + callback(scope) + if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 { + t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok) + } + + DB.Callback().Create().Remove("gorm:test_callback") + if DB.Callback().Create().Get("gorm:test_callback") != nil { + t.Errorf("`gorm:test_callback` should be nil") + } + + DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) }) + callback = DB.Callback().Create().Get("gorm:test_callback") + if callback == nil { + t.Errorf("`gorm:test_callback` should be non-nil") + } + callback(scope) + if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 { + t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok) + } +} + +func TestUseDefaultCallback(t *testing.T) { + createCallbackName := "gorm:test_use_default_callback_for_create" + gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) { + // nop + }) + if gorm.DefaultCallback.Create().Get(createCallbackName) == nil { + t.Errorf("`%s` expected non-nil, but got nil", createCallbackName) + } + gorm.DefaultCallback.Create().Remove(createCallbackName) + if gorm.DefaultCallback.Create().Get(createCallbackName) != nil { + t.Errorf("`%s` expected nil, but got non-nil", createCallbackName) + } + + updateCallbackName := "gorm:test_use_default_callback_for_update" + scopeValueName := "gorm:test_use_default_callback_for_update_value" + gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) { + scope.Set(scopeValueName, 1) + }) + gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) { + scope.Set(scopeValueName, 2) + }) + + scope := DB.NewScope(nil) + callback := gorm.DefaultCallback.Update().Get(updateCallbackName) + callback(scope) + if v, ok := scope.Get(scopeValueName); !ok || v != 2 { + t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok) + } +} diff --git a/create_test.go b/create_test.go index 9256064..1f46696 100644 --- a/create_test.go +++ b/create_test.go @@ -3,10 +3,13 @@ import ( "os" "reflect" + "strings" "testing" "time" "github.com/jinzhu/now" + + "github.com/jinzhu/gorm" ) func TestCreate(t *testing.T) { @@ -98,6 +101,46 @@ if newUser.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { t.Errorf("UpdatedAt should not be changed") + } +} + +func TestCreateWithNowFuncOverride(t *testing.T) { + user1 := User{Name: "CreateUserTimestampOverride"} + + timeA := now.MustParse("2016-01-01") + + // do DB.New() because we don't want this test to affect other tests + db1 := DB.New() + // set the override to use static timeA + db1.SetNowFuncOverride(func() time.Time { + return timeA + }) + // call .New again to check the override is carried over as well during clone + db1 = db1.New() + + db1.Save(&user1) + + if user1.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { + t.Errorf("CreatedAt be using the nowFuncOverride") + } + if user1.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { + t.Errorf("UpdatedAt be using the nowFuncOverride") + } + + // now create another user with a fresh DB.Now() that doesn't have the nowFuncOverride set + // to make sure that setting it only affected the above instance + + user2 := User{Name: "CreateUserTimestampOverrideNoMore"} + + db2 := DB.New() + + db2.Save(&user2) + + if user2.CreatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) { + t.Errorf("CreatedAt no longer be using the nowFuncOverride") + } + if user2.UpdatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) { + t.Errorf("UpdatedAt no longer be using the nowFuncOverride") } } @@ -229,3 +272,42 @@ t.Errorf("Should not create omitted relationships") } } + +func TestCreateIgnore(t *testing.T) { + float := 35.03554004971999 + now := time.Now() + user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float} + + if !DB.NewRecord(user) || !DB.NewRecord(&user) { + t.Error("User should be new record before create") + } + + if count := DB.Create(&user).RowsAffected; count != 1 { + t.Error("There should be one record be affected when create record") + } + if DB.Dialect().GetName() == "mysql" && DB.Set("gorm:insert_modifier", "IGNORE").Create(&user).Error != nil { + t.Error("Should ignore duplicate user insert by insert modifier:IGNORE ") + } +} + +func TestFixFullTableScanWhenInsertIgnore(t *testing.T) { + pandaYuanYuan := Panda{Number: 200408301001} + + if !DB.NewRecord(pandaYuanYuan) || !DB.NewRecord(&pandaYuanYuan) { + t.Error("Panda should be new record before create") + } + + if count := DB.Create(&pandaYuanYuan).RowsAffected; count != 1 { + t.Error("There should be one record be affected when create record") + } + + DB.Callback().Query().Register("gorm:fix_full_table_scan", func(scope *gorm.Scope) { + if strings.Contains(scope.SQL, "SELECT") && strings.Contains(scope.SQL, "pandas") && len(scope.SQLVars) == 0 { + t.Error("Should skip force reload when ignore duplicate panda insert") + } + }) + + if DB.Dialect().GetName() == "mysql" && DB.Set("gorm:insert_modifier", "IGNORE").Create(&pandaYuanYuan).Error != nil { + t.Error("Should ignore duplicate panda insert by insert modifier:IGNORE ") + } +} \ No newline at end of file diff --git a/customize_column_test.go b/customize_column_test.go index 5e19d6f..c236ac2 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -289,6 +289,9 @@ func TestSelfReferencingMany2ManyColumn(t *testing.T) { DB.DropTable(&SelfReferencingUser{}, "UserFriends") DB.AutoMigrate(&SelfReferencingUser{}) + if !DB.HasTable("UserFriends") { + t.Errorf("auto migrate error, table UserFriends should be created") + } friend1 := SelfReferencingUser{Name: "friend1_m2m"} if err := DB.Create(&friend1).Error; err != nil { @@ -311,6 +314,14 @@ if DB.Model(&user).Association("Friends").Count() != 2 { t.Errorf("Should find created friends correctly") + } + + var count int + if err := DB.Table("UserFriends").Count(&count).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + if count == 0 { + t.Errorf("table UserFriends should have records") } var newUser = SelfReferencingUser{} diff --git a/dialect.go b/dialect.go index 5f6439c..c742efc 100644 --- a/dialect.go +++ b/dialect.go @@ -37,9 +37,11 @@ ModifyColumn(tableName string, columnName string, typ string) error // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case - LimitAndOffsetSQL(limit, offset interface{}) string + LimitAndOffsetSQL(limit, offset interface{}) (string, error) // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` SelectFromDummyTable() string + // LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT` + LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` LastInsertIDReturningSuffix(tableName, columnName string) string // DefaultValueStr @@ -47,6 +49,9 @@ // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference BuildKeyName(kind, tableName string, fields ...string) string + + // NormalizeIndexAndColumn returns valid index name and column name depending on each dialect + NormalizeIndexAndColumn(indexName, columnName string) (string, string) // CurrentDatabase return current database name CurrentDatabase() string @@ -72,12 +77,18 @@ dialectsMap[name] = dialect } +// GetDialect gets the dialect for the specified dialect name +func GetDialect(name string) (dialect Dialect, ok bool) { + dialect, ok = dialectsMap[name] + return +} + // ParseFieldStructForDialect get field's sql data type var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { // Get redirected field type var ( reflectType = field.Struct.Type - dataType = field.TagSettings["TYPE"] + dataType, _ = field.TagSettingsGet("TYPE") ) for reflectType.Kind() == reflect.Ptr { @@ -106,16 +117,22 @@ } // Default Size - if num, ok := field.TagSettings["SIZE"]; ok { + if num, ok := field.TagSettingsGet("SIZE"); ok { size, _ = strconv.Atoi(num) } else { size = 255 } // Default type from tag setting - additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] - if value, ok := field.TagSettings["DEFAULT"]; ok { + notNull, _ := field.TagSettingsGet("NOT NULL") + unique, _ := field.TagSettingsGet("UNIQUE") + additionalType = notNull + " " + unique + if value, ok := field.TagSettingsGet("DEFAULT"); ok { additionalType = additionalType + " DEFAULT " + value + } + + if value, ok := field.TagSettingsGet("COMMENT"); ok && dialect.GetName() != "sqlite3" { + additionalType = additionalType + " COMMENT " + value } return fieldValue, dataType, size, strings.TrimSpace(additionalType) diff --git a/dialect_common.go b/dialect_common.go index b9f0c7d..d549510 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -8,6 +8,8 @@ "strings" "time" ) + +var keyNameRegex = regexp.MustCompile("[^a-zA-Z0-9]+") // DefaultForeignKeyNamer contains the default foreign key name generator method type DefaultForeignKeyNamer struct { @@ -39,7 +41,7 @@ } func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { - if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { return strings.ToLower(value) != "false" } return field.IsPrimaryKey @@ -137,14 +139,19 @@ return } -func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { +// LimitAndOffsetSQL return generated SQL with Limit and Offset +func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { + if parsedLimit, err := s.parseInt(limit); err != nil { + return "", err + } else if parsedLimit >= 0 { sql += fmt.Sprintf(" LIMIT %d", parsedLimit) } } if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { + if parsedOffset, err := s.parseInt(offset); err != nil { + return "", err + } else if parsedOffset >= 0 { sql += fmt.Sprintf(" OFFSET %d", parsedOffset) } } @@ -152,6 +159,10 @@ } func (commonDialect) SelectFromDummyTable() string { + return "" +} + +func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { return "" } @@ -166,8 +177,17 @@ // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) - keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_") + keyName = keyNameRegex.ReplaceAllString(keyName, "_") return keyName +} + +// NormalizeIndexAndColumn returns argument's index name and column name without doing anything +func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { + return indexName, columnName +} + +func (commonDialect) parseInt(value interface{}) (int64, error) { + return strconv.ParseInt(fmt.Sprint(value), 0, 0) } // IsByteArrayOrSlice returns true of the reflected value is an array or slice diff --git a/dialect_mysql.go b/dialect_mysql.go index b162bad..b4467ff 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -2,15 +2,17 @@ import ( "crypto/sha1" + "database/sql" "fmt" "reflect" "regexp" - "strconv" "strings" "time" "unicode/utf8" ) +var mysqlIndexRegex = regexp.MustCompile(`^(.+)\((\d+)\)$`) + type mysql struct { commonDialect } @@ -33,9 +35,9 @@ // MySQL allows only one auto increment column per table, and it must // be a KEY column. - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey { - delete(field.TagSettings, "AUTO_INCREMENT") + if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { + if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey { + field.TagSettingsDelete("AUTO_INCREMENT") } } @@ -45,42 +47,42 @@ sqlType = "boolean" case reflect.Int8: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "tinyint AUTO_INCREMENT" } else { sqlType = "tinyint" } case reflect.Int, reflect.Int16, reflect.Int32: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "int AUTO_INCREMENT" } else { sqlType = "int" } case reflect.Uint8: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "tinyint unsigned AUTO_INCREMENT" } else { sqlType = "tinyint unsigned" } case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "int unsigned AUTO_INCREMENT" } else { sqlType = "int unsigned" } case reflect.Int64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigint AUTO_INCREMENT" } else { sqlType = "bigint" } case reflect.Uint64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigint unsigned AUTO_INCREMENT" } else { sqlType = "bigint unsigned" @@ -96,14 +98,14 @@ case reflect.Struct: if _, ok := dataValue.Interface().(time.Time); ok { precision := "" - if p, ok := field.TagSettings["PRECISION"]; ok { + if p, ok := field.TagSettingsGet("PRECISION"); ok { precision = fmt.Sprintf("(%s)", p) } - if _, ok := field.TagSettings["NOT NULL"]; ok { - sqlType = fmt.Sprintf("timestamp%v", precision) + if _, ok := field.TagSettings["NOT NULL"]; ok || field.IsPrimaryKey { + sqlType = fmt.Sprintf("DATETIME%v", precision) } else { - sqlType = fmt.Sprintf("timestamp%v NULL", precision) + sqlType = fmt.Sprintf("DATETIME%v NULL", precision) } } default: @@ -118,7 +120,7 @@ } if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String())) + panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name)) } if strings.TrimSpace(additionalType) == "" { @@ -137,13 +139,21 @@ return err } -func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { +func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { + parsedLimit, err := s.parseInt(limit) + if err != nil { + return "", err + } + if parsedLimit >= 0 { sql += fmt.Sprintf(" LIMIT %d", parsedLimit) if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { + parsedOffset, err := s.parseInt(offset) + if err != nil { + return "", err + } + if parsedOffset >= 0 { sql += fmt.Sprintf(" OFFSET %d", parsedOffset) } } @@ -157,6 +167,40 @@ currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) return count > 0 +} + +func (s mysql) HasTable(tableName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + var name string + // allow mysql database name with '-' character + if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { + if err == sql.ErrNoRows { + return false + } + panic(err) + } else { + return true + } +} + +func (s mysql) HasIndex(tableName string, indexName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil { + panic(err) + } else { + defer rows.Close() + return rows.Next() + } +} + +func (s mysql) HasColumn(tableName string, columnName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil { + panic(err) + } else { + defer rows.Close() + return rows.Next() + } } func (s mysql) CurrentDatabase() (name string) { @@ -178,7 +222,7 @@ bs := h.Sum(nil) // sha1 is 40 characters, keep first 24 characters of destination - destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_")) + destRunes := []rune(keyNameRegex.ReplaceAllString(fields[0], "_")) if len(destRunes) > 24 { destRunes = destRunes[:24] } @@ -186,6 +230,17 @@ return fmt.Sprintf("%s%x", string(destRunes), bs) } +// NormalizeIndexAndColumn returns index name and column name for specify an index prefix length if needed +func (mysql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { + submatch := mysqlIndexRegex.FindStringSubmatch(indexName) + if len(submatch) != 3 { + return indexName, columnName + } + indexName = submatch[1] + columnName = fmt.Sprintf("%s(%s)", columnName, submatch[2]) + return indexName, columnName +} + func (mysql) DefaultValueStr() string { return "VALUES()" } diff --git a/dialect_postgres.go b/dialect_postgres.go index c44c6a5..d2df313 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -34,14 +34,14 @@ sqlType = "boolean" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "serial" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint32, reflect.Uint64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigserial" } else { sqlType = "bigint" @@ -49,7 +49,7 @@ case reflect.Float32, reflect.Float64: sqlType = "numeric" case reflect.String: - if _, ok := field.TagSettings["SIZE"]; !ok { + if _, ok := field.TagSettingsGet("SIZE"); !ok { size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different } @@ -120,6 +120,10 @@ return } +func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string { + return "" +} + func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { return fmt.Sprintf("RETURNING %v.%v", tableName, key) } diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index f26f6be..5f96c36 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -29,14 +29,14 @@ sqlType = "bool" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "integer primary key autoincrement" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "integer primary key autoincrement" } else { sqlType = "bigint" diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index e060646..a516ed4 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -1,12 +1,16 @@ package mssql import ( + "database/sql/driver" + "encoding/json" + "errors" "fmt" "reflect" "strconv" "strings" "time" + // Importing mssql driver package only in dialect file, otherwide not needed _ "github.com/denisenkom/go-mssqldb" "github.com/jinzhu/gorm" ) @@ -14,7 +18,7 @@ func setIdentityInsert(scope *gorm.Scope) { if scope.Dialect().GetName() == "mssql" { for _, field := range scope.PrimaryFields() { - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank { + if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank { scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) scope.InstanceSet("mssql:identity_insert_on", true) } @@ -66,14 +70,14 @@ sqlType = "bit" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "int IDENTITY(1,1)" } else { sqlType = "int" } case reflect.Int64, reflect.Uint64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigint IDENTITY(1,1)" } else { sqlType = "bigint" @@ -112,7 +116,7 @@ } func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { - if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { return value != "FALSE" } return field.IsPrimaryKey @@ -130,7 +134,14 @@ } func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { - return false + var count int + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow(`SELECT count(*) + FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id + inner join information_schema.tables as I on I.TABLE_NAME = T.name + WHERE F.name = ? + AND T.Name = ? AND I.TABLE_CATALOG = ?;`, foreignKeyName, tableName, currentDatabase).Scan(&count) + return count > 0 } func (s mssql) HasTable(tableName string) bool { @@ -157,14 +168,22 @@ return } -func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { +func parseInt(value interface{}) (int64, error) { + return strconv.ParseInt(fmt.Sprint(value), 0, 0) +} + +func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { + if parsedOffset, err := parseInt(offset); err != nil { + return "", err + } else if parsedOffset >= 0 { sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset) } } if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { + if parsedLimit, err := parseInt(limit); err != nil { + return "", err + } else if parsedLimit >= 0 { if sql == "" { // add default zero offset sql += " OFFSET 0 ROWS" @@ -179,12 +198,26 @@ return "" } +func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { + if len(columns) == 0 { + // No OUTPUT to query + return "" + } + return fmt.Sprintf("OUTPUT Inserted.%v", columnName) +} + func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { - return "" + // https://stackoverflow.com/questions/5228780/how-to-get-last-inserted-id + return "; SELECT SCOPE_IDENTITY()" } func (mssql) DefaultValueStr() string { return "DEFAULT VALUES" +} + +// NormalizeIndexAndColumn returns argument's index name and column name without doing anything +func (mssql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { + return indexName, columnName } func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { @@ -194,3 +227,27 @@ } return dialect.CurrentDatabase(), tableName } + +// JSON type to support easy handling of JSON data in character table fields +// using golang json.RawMessage for deferred decoding/encoding +type JSON struct { + json.RawMessage +} + +// Value get value of JSON +func (j JSON) Value() (driver.Value, error) { + if len(j.RawMessage) == 0 { + return nil, nil + } + return j.MarshalJSON() +} + +// Scan scan value into JSON +func (j *JSON) Scan(value interface{}) error { + str, ok := value.(string) + if !ok { + return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value)) + } + bytes := []byte(str) + return json.Unmarshal(bytes, j) +} diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 1d0dcb6..e6c088b 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -4,11 +4,12 @@ "database/sql" "database/sql/driver" - _ "github.com/lib/pq" - "github.com/lib/pq/hstore" "encoding/json" "errors" "fmt" + + _ "github.com/lib/pq" + "github.com/lib/pq/hstore" ) type Hstore map[string]*string diff --git a/errors.go b/errors.go index da2cf13..d5ef8d5 100644 --- a/errors.go +++ b/errors.go @@ -6,11 +6,11 @@ ) var ( - // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct + // ErrRecordNotFound returns a "record not found error". Occurs only when attempting to query the database with a struct; querying with a slice won't return this error ErrRecordNotFound = errors.New("record not found") - // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL + // ErrInvalidSQL occurs when you attempt a query with invalid SQL ErrInvalidSQL = errors.New("invalid SQL") - // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` + // ErrInvalidTransaction occurs when you are trying to `Commit` or `Rollback` ErrInvalidTransaction = errors.New("no valid transaction") // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` ErrCantStartTransaction = errors.New("can't start transaction") @@ -21,7 +21,7 @@ // Errors contains all happened errors type Errors []error -// IsRecordNotFoundError returns current error has record not found error or not +// IsRecordNotFoundError returns true if error contains a RecordNotFound error func IsRecordNotFoundError(err error) bool { if errs, ok := err.(Errors); ok { for _, err := range errs { @@ -33,12 +33,12 @@ return err == ErrRecordNotFound } -// GetErrors gets all happened errors +// GetErrors gets all errors that have occurred and returns a slice of errors (Error type) func (errs Errors) GetErrors() []error { return errs } -// Add adds an error +// Add adds an error to a given slice of errors func (errs Errors) Add(newErrors ...error) Errors { for _, err := range newErrors { if err == nil { @@ -62,7 +62,7 @@ return errs } -// Error format happened errors +// Error takes a slice of all errors that have occurred and returns it as a formatted string func (errs Errors) Error() string { var errors = []string{} for _, e := range errs { diff --git a/field.go b/field.go index 11c410b..acd06e2 100644 --- a/field.go +++ b/field.go @@ -2,6 +2,7 @@ import ( "database/sql" + "database/sql/driver" "errors" "fmt" "reflect" @@ -44,7 +45,14 @@ if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { fieldValue.Set(reflectValue.Convert(fieldValue.Type())) } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - err = scanner.Scan(reflectValue.Interface()) + v := reflectValue.Interface() + if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = scanner.Scan(v) + } + } else { + err = scanner.Scan(v) + } } else { err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) } diff --git a/field_test.go b/field_test.go index 30e9a77..715661f 100644 --- a/field_test.go +++ b/field_test.go @@ -1,6 +1,9 @@ package gorm_test import ( + "database/sql/driver" + "encoding/hex" + "fmt" "testing" "github.com/jinzhu/gorm" @@ -43,7 +46,85 @@ if field, ok := scope.FieldByName("embedded_name"); !ok { t.Errorf("should find embedded field") - } else if _, ok := field.TagSettings["NOT NULL"]; !ok { + } else if _, ok := field.TagSettingsGet("NOT NULL"); !ok { t.Errorf("should find embedded field's tag settings") } } + +type UUID [16]byte + +type NullUUID struct { + UUID + Valid bool +} + +func FromString(input string) (u UUID) { + src := []byte(input) + return FromBytes(src) +} + +func FromBytes(src []byte) (u UUID) { + dst := u[:] + hex.Decode(dst[0:4], src[0:8]) + hex.Decode(dst[4:6], src[9:13]) + hex.Decode(dst[6:8], src[14:18]) + hex.Decode(dst[8:10], src[19:23]) + hex.Decode(dst[10:], src[24:]) + return +} + +func (u UUID) String() string { + buf := make([]byte, 36) + src := u[:] + hex.Encode(buf[0:8], src[0:4]) + buf[8] = '-' + hex.Encode(buf[9:13], src[4:6]) + buf[13] = '-' + hex.Encode(buf[14:18], src[6:8]) + buf[18] = '-' + hex.Encode(buf[19:23], src[8:10]) + buf[23] = '-' + hex.Encode(buf[24:], src[10:]) + return string(buf) +} + +func (u UUID) Value() (driver.Value, error) { + return u.String(), nil +} + +func (u *UUID) Scan(src interface{}) error { + switch src := src.(type) { + case UUID: // support gorm convert from UUID to NullUUID + *u = src + return nil + case []byte: + *u = FromBytes(src) + return nil + case string: + *u = FromString(src) + return nil + } + return fmt.Errorf("uuid: cannot convert %T to UUID", src) +} + +func (u *NullUUID) Scan(src interface{}) error { + u.Valid = true + return u.UUID.Scan(src) +} + +func TestFieldSet(t *testing.T) { + type TestFieldSetNullUUID struct { + NullUUID NullUUID + } + scope := DB.NewScope(&TestFieldSetNullUUID{}) + field := scope.Fields()[0] + err := field.Set(FromString("3034d44a-da03-11e8-b366-4a00070b9f00")) + if err != nil { + t.Fatal(err) + } + if id, ok := field.Field.Addr().Interface().(*NullUUID); !ok { + t.Fatal() + } else if !id.Valid || id.UUID.String() != "3034d44a-da03-11e8-b366-4a00070b9f00" { + t.Fatal(id) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..64a42a4 --- /dev/null +++ b/go.mod @@ -0,0 +1,14 @@ +module github.com/jinzhu/gorm + +go 1.12 + +require ( + github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd + github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 + github.com/go-sql-driver/mysql v1.5.0 + github.com/jinzhu/inflection v1.0.0 + github.com/jinzhu/now v1.0.1 + github.com/lib/pq v1.1.1 + github.com/mattn/go-sqlite3 v1.14.0 + golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..d30630b --- /dev/null +++ b/go.sum @@ -0,0 +1,33 @@ +github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= +github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y= +github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd h1:83Wprp6ROGeiHFAP8WJdI2RoxALQYgdllERc3N5N2DM= +github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= +github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y= +github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= +github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M= +github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4= +github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/mattn/go-sqlite3 v1.14.0 h1:mLyGNKR8+Vv9CAU7PphKa2hkEqxxhn8i32J6FPj1/QA= +github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus= +github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw= +github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= +golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd h1:GGJVjV8waZKRHrgwvtH66z9ZGVurTD1MT0n1Bb+q4aM= +golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/interface.go b/interface.go index 55128f7..fe64923 100644 --- a/interface.go +++ b/interface.go @@ -1,6 +1,9 @@ package gorm -import "database/sql" +import ( + "context" + "database/sql" +) // SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. type SQLCommon interface { @@ -12,6 +15,7 @@ type sqlDb interface { Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } type sqlTx interface { diff --git a/logger.go b/logger.go index 4324a2e..88e167d 100644 --- a/logger.go +++ b/logger.go @@ -39,6 +39,15 @@ messages = []interface{}{source, currentTime} + if len(values) == 2 { + //remove the line break + currentTime = currentTime[1:] + //remove the brackets + source = fmt.Sprintf("\033[35m%v\033[0m", values[1]) + + messages = []interface{}{currentTime, source} + } + if level == "sql" { // duration messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) @@ -49,7 +58,11 @@ if indirectValue.IsValid() { value = indirectValue.Interface() if t, ok := value.(time.Time); ok { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05"))) + if t.IsZero() { + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00")) + } else { + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05"))) + } } else if b, ok := value.([]byte); ok { if str := string(b); isPrintable(str) { formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str)) @@ -63,7 +76,12 @@ formattedValues = append(formattedValues, "NULL") } } else { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) + switch value.(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool: + formattedValues = append(formattedValues, fmt.Sprintf("%v", value)) + default: + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) + } } } else { formattedValues = append(formattedValues, "NULL") @@ -117,3 +135,7 @@ func (logger Logger) Print(values ...interface{}) { logger.Println(LogFormatter(values...)...) } + +type nopLogger struct{} + +func (nopLogger) Print(values ...interface{}) {} diff --git a/main.go b/main.go index c26e05c..0c247b6 100644 --- a/main.go +++ b/main.go @@ -1,16 +1,19 @@ package gorm import ( + "context" "database/sql" "errors" "fmt" "reflect" "strings" + "sync" "time" ) // DB contains information for current db connection type DB struct { + sync.RWMutex Value interface{} Error error RowsAffected int64 @@ -18,17 +21,28 @@ // single db db SQLCommon blockGlobalUpdate bool - logMode int + logMode logModeValue logger logger search *search - values map[string]interface{} + values sync.Map // global db parent *DB callbacks *Callback dialect Dialect singularTable bool -} + + // function to be used to override the creating of a new timestamp + nowFuncOverride func() time.Time +} + +type logModeValue int + +const ( + defaultLogMode logModeValue = iota + noLogMode + detailedLogMode +) // Open initialize a new db connection, need to import driver first, e.g: // @@ -48,6 +62,7 @@ } var source string var dbSQL SQLCommon + var ownDbSQL bool switch value := args[0].(type) { case string: @@ -59,15 +74,21 @@ source = args[1].(string) } dbSQL, err = sql.Open(driver, source) + ownDbSQL = true case SQLCommon: dbSQL = value + ownDbSQL = false + default: + return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) } db = &DB{ db: dbSQL, logger: defaultLogger, - values: map[string]interface{}{}, - callbacks: DefaultCallback, + + // Create a clone of the default logger to avoid mutating a shared object when + // multiple gorm connections are created simultaneously. + callbacks: DefaultCallback.clone(defaultLogger), dialect: newDialect(dialect, dbSQL), } db.parent = db @@ -76,7 +97,7 @@ } // Send a ping to make sure the database connection is alive. if d, ok := dbSQL.(*sql.DB); ok { - if err = d.Ping(); err != nil { + if err = d.Ping(); err != nil && ownDbSQL { d.Close() } } @@ -106,7 +127,10 @@ // DB get `*sql.DB` from current connection // If the underlying database connection is not a *sql.DB, returns nil func (s *DB) DB() *sql.DB { - db, _ := s.db.(*sql.DB) + db, ok := s.db.(*sql.DB) + if !ok { + panic("can't support full GORM on currently status, maybe this is a TX instance.") + } return db } @@ -117,14 +141,14 @@ // Dialect get dialect func (s *DB) Dialect() Dialect { - return s.parent.dialect + return s.dialect } // Callback return `Callbacks` container, you could add/change/delete callbacks with it // db.Callback().Create().Register("update_created_at", updateCreated) // Refer https://jinzhu.github.io/gorm/development.html#callbacks func (s *DB) Callback() *Callback { - s.parent.callbacks = s.parent.callbacks.clone() + s.parent.callbacks = s.parent.callbacks.clone(s.logger) return s.parent.callbacks } @@ -136,11 +160,27 @@ // LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs func (s *DB) LogMode(enable bool) *DB { if enable { - s.logMode = 2 - } else { - s.logMode = 1 + s.logMode = detailedLogMode + } else { + s.logMode = noLogMode } return s +} + +// SetNowFuncOverride set the function to be used when creating a new timestamp +func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB { + s.nowFuncOverride = nowFuncOverride + return s +} + +// Get a new timestamp, using the provided nowFuncOverride on the DB instance if set, +// otherwise defaults to the global NowFunc() +func (s *DB) nowFunc() time.Time { + if s.nowFuncOverride != nil { + return s.nowFuncOverride() + } + + return NowFunc() } // BlockGlobalUpdate if true, generates an error on update/delete without where clause. @@ -157,7 +197,8 @@ // SingularTable use singular table by default func (s *DB) SingularTable(enable bool) { - modelStructsMap = newModelStructsMap() + s.parent.Lock() + defer s.parent.Unlock() s.parent.singularTable = enable } @@ -165,11 +206,17 @@ func (s *DB) NewScope(value interface{}) *Scope { dbClone := s.clone() dbClone.Value = value - return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} -} - -// QueryExpr returns the query as expr object -func (s *DB) QueryExpr() *expr { + scope := &Scope{db: dbClone, Value: value} + if s.search != nil { + scope.Search = s.search.clone() + } else { + scope.Search = &search{} + } + return scope +} + +// QueryExpr returns the query as SqlExpr object +func (s *DB) QueryExpr() *SqlExpr { scope := s.NewScope(s.Value) scope.InstanceSet("skip_bindvar", true) scope.prepareQuerySQL() @@ -178,7 +225,7 @@ } // SubQuery returns the query as sub query -func (s *DB) SubQuery() *expr { +func (s *DB) SubQuery() *SqlExpr { scope := s.NewScope(s.Value) scope.InstanceSet("skip_bindvar", true) scope.prepareQuerySQL() @@ -285,6 +332,7 @@ func (s *DB) First(out interface{}, where ...interface{}) *DB { newScope := s.NewScope(out) newScope.Search.Limit(1) + return newScope.Set("gorm:order_by_primary_key", "ASC"). inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } @@ -307,6 +355,11 @@ // Find find records that match given conditions func (s *DB) Find(out interface{}, where ...interface{}) *DB { return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db +} + +//Preloads preloads relations, don`t touch out +func (s *DB) Preloads(out interface{}) *DB { + return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db } // Scan scan value to a struct @@ -387,6 +440,7 @@ } // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update +// WARNING when update with struct, GORM will not update fields that with zero value func (s *DB) Update(attrs ...interface{}) *DB { return s.Updates(toSearchableMap(attrs...), true) } @@ -419,7 +473,7 @@ if !scope.PrimaryKeyZero() { newDB := scope.callCallbacks(s.parent.callbacks.updates).db if newDB.Error == nil && newDB.RowsAffected == 0 { - return s.New().FirstOrCreate(value) + return s.New().Table(scope.TableName()).FirstOrCreate(value) } return newDB } @@ -433,6 +487,7 @@ } // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition +// WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time func (s *DB) Delete(value interface{}, where ...interface{}) *DB { return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db } @@ -476,12 +531,46 @@ return s.clone().LogMode(true) } -// Begin begin a transaction +// Transaction start a transaction as a block, +// return error will rollback, otherwise to commit. +func (s *DB) Transaction(fc func(tx *DB) error) (err error) { + + if _, ok := s.db.(*sql.Tx); ok { + return fc(s) + } + + panicked := true + tx := s.Begin() + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + tx.Rollback() + } + }() + + err = fc(tx) + + if err == nil { + err = tx.Commit().Error + } + + panicked = false + return +} + +// Begin begins a transaction func (s *DB) Begin() *DB { + return s.BeginTx(context.Background(), &sql.TxOptions{}) +} + +// BeginTx begins a transaction with options +func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB { c := s.clone() if db, ok := c.db.(sqlDb); ok && db != nil { - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, opts) c.db = interface{}(tx).(SQLCommon) + + c.dialect.SetDB(c.db) c.AddError(err) } else { c.AddError(ErrCantStartTransaction) @@ -491,7 +580,8 @@ // Commit commit a transaction func (s *DB) Commit() *DB { - if db, ok := s.db.(sqlTx); ok && db != nil { + var emptySQLTx *sql.Tx + if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { s.AddError(db.Commit()) } else { s.AddError(ErrInvalidTransaction) @@ -501,8 +591,28 @@ // Rollback rollback a transaction func (s *DB) Rollback() *DB { - if db, ok := s.db.(sqlTx); ok && db != nil { - s.AddError(db.Rollback()) + var emptySQLTx *sql.Tx + if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { + if err := db.Rollback(); err != nil && err != sql.ErrTxDone { + s.AddError(err) + } + } else { + s.AddError(ErrInvalidTransaction) + } + return s +} + +// RollbackUnlessCommitted rollback a transaction if it has not yet been +// committed. +func (s *DB) RollbackUnlessCommitted() *DB { + var emptySQLTx *sql.Tx + if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { + err := db.Rollback() + // Ignore the error indicating that the transaction has already + // been committed. + if err != sql.ErrTxDone { + s.AddError(err) + } } else { s.AddError(ErrInvalidTransaction) } @@ -670,13 +780,13 @@ // InstantSet instant set setting, will affect current db func (s *DB) InstantSet(name string, value interface{}) *DB { - s.values[name] = value + s.values.Store(name, value) return s } // Get get setting by name func (s *DB) Get(name string) (value interface{}, ok bool) { - value, ok = s.values[name] + value, ok = s.values.Load(name) return } @@ -685,7 +795,7 @@ scope := s.NewScope(source) for _, field := range scope.GetModelStruct().StructFields { if field.Name == column || field.DBName == column { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { source := (&Scope{Value: source}).GetModelStruct().ModelType destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType handler.Setup(field.Relationship, many2many, source, destination) @@ -702,8 +812,8 @@ func (s *DB) AddError(err error) error { if err != nil { if err != ErrRecordNotFound { - if s.logMode == 0 { - go s.print(fileWithLineNum(), err) + if s.logMode == defaultLogMode { + go s.print("error", fileWithLineNum(), err) } else { s.log(err) } @@ -740,15 +850,17 @@ parent: s.parent, logger: s.logger, logMode: s.logMode, - values: map[string]interface{}{}, Value: s.Value, Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, - } - - for key, value := range s.values { - db.values[key] = value - } + dialect: newDialect(s.dialect.GetName(), s.db), + nowFuncOverride: s.nowFuncOverride, + } + + s.values.Range(func(k, v interface{}) bool { + db.values.Store(k, v) + return true + }) if s.search == nil { db.search = &search{limit: -1, offset: -1} @@ -765,13 +877,13 @@ } func (s *DB) log(v ...interface{}) { - if s != nil && s.logMode == 2 { + if s != nil && s.logMode == detailedLogMode { s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) } } func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { - if s.logMode == 2 { + if s.logMode == detailedLogMode { s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) } } diff --git a/main_test.go b/main_test.go index 66c46af..7901e91 100644 --- a/main_test.go +++ b/main_test.go @@ -1,13 +1,23 @@ package gorm_test +// Run tests +// $ docker-compose up +// $ ./test_all.sh + import ( + "context" "database/sql" "database/sql/driver" + "errors" "fmt" "os" "path/filepath" "reflect" + "regexp" + "sort" "strconv" + "strings" + "sync" "testing" "time" @@ -47,7 +57,7 @@ case "postgres": fmt.Println("testing postgres...") if dbDSN == "" { - dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" + dbDSN = "user=gorm password=gorm dbname=gorm port=9920 sslmode=disable" } db, err = gorm.Open("postgres", dbDSN) case "mssql": @@ -79,6 +89,22 @@ return } +func TestOpen_ReturnsError_WithBadArgs(t *testing.T) { + stringRef := "foo" + testCases := []interface{}{42, time.Now(), &stringRef} + for _, tc := range testCases { + t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) { + _, err := gorm.Open("postgresql", tc) + if err == nil { + t.Error("Should got error with invalid database source") + } + if !strings.HasPrefix(err.Error(), "invalid database source:") { + t.Errorf("Should got error starting with \"invalid database source:\", but got %q", err.Error()) + } + }) + } +} + func TestStringPrimaryKey(t *testing.T) { type UUIDStruct struct { ID string `gorm:"primary_key"` @@ -157,6 +183,15 @@ DB.Table("deleted_users").Find(&deletedUsers) if len(deletedUsers) != 1 { t.Errorf("Query from specified table") + } + + var user User + DB.Table("deleted_users").First(&user, "name = ?", "DeletedUser") + + user.Age = 20 + DB.Table("deleted_users").Save(&user) + if DB.Table("deleted_users").First(&user, "name = ? AND age = ?", "DeletedUser", 20).RecordNotFound() { + t.Errorf("Failed to found updated user") } DB.Save(getPreparedUser("normal_user", "reset_table")) @@ -257,6 +292,30 @@ if DB.NewScope([]Cart{}).TableName() != "shopping_cart" { t.Errorf("[]Cart's singular table name should be shopping_cart") } + DB.SingularTable(false) +} + +func TestTableNameConcurrently(t *testing.T) { + DB := DB.Model("") + if DB.NewScope(Order{}).TableName() != "orders" { + t.Errorf("Order's table name should be orders") + } + + var wg sync.WaitGroup + wg.Add(10) + + for i := 1; i <= 10; i++ { + go func(db *gorm.DB) { + DB.SingularTable(true) + wg.Done() + }(DB) + } + wg.Wait() + + if DB.NewScope(Order{}).TableName() != "order" { + t.Errorf("Order's singular table name should be order") + } + DB.SingularTable(false) } @@ -377,6 +436,160 @@ if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { t.Errorf("Should be able to find committed record") } + + tx3 := DB.Begin() + u3 := User{Name: "transcation-3"} + if err := tx3.Save(&u3).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx3.First(&User{}, "name = ?", "transcation-3").Error; err != nil { + t.Errorf("Should find saved record") + } + + tx3.RollbackUnlessCommitted() + + if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { + t.Errorf("Should not find record after rollback") + } + + tx4 := DB.Begin() + u4 := User{Name: "transcation-4"} + if err := tx4.Save(&u4).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx4.First(&User{}, "name = ?", "transcation-4").Error; err != nil { + t.Errorf("Should find saved record") + } + + tx4.Commit() + + tx4.RollbackUnlessCommitted() + + if err := DB.First(&User{}, "name = ?", "transcation-4").Error; err != nil { + t.Errorf("Should be able to find committed record") + } +} + +func assertPanic(t *testing.T, f func()) { + defer func() { + if r := recover(); r == nil { + t.Errorf("The code did not panic") + } + }() + f() +} + +func TestTransactionWithBlock(t *testing.T) { + // rollback + err := DB.Transaction(func(tx *gorm.DB) error { + u := User{Name: "transcation"} + if err := tx.Save(&u).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { + t.Errorf("Should find saved record") + } + + return errors.New("the error message") + }) + + if err.Error() != "the error message" { + t.Errorf("Transaction return error will equal the block returns error") + } + + if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { + t.Errorf("Should not find record after rollback") + } + + // commit + DB.Transaction(func(tx *gorm.DB) error { + u2 := User{Name: "transcation-2"} + if err := tx.Save(&u2).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + t.Errorf("Should find saved record") + } + return nil + }) + + if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + t.Errorf("Should be able to find committed record") + } + + // panic will rollback + assertPanic(t, func() { + DB.Transaction(func(tx *gorm.DB) error { + u3 := User{Name: "transcation-3"} + if err := tx.Save(&u3).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", "transcation-3").Error; err != nil { + t.Errorf("Should find saved record") + } + + panic("force panic") + }) + }) + + if err := DB.First(&User{}, "name = ?", "transcation-3").Error; err == nil { + t.Errorf("Should not find record after panic rollback") + } +} + +func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) { + tx := DB.Begin() + u := User{Name: "transcation"} + if err := tx.Save(&u).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.Commit().Error; err != nil { + t.Errorf("Commit should not raise error") + } + + if err := tx.Rollback().Error; err != nil { + t.Errorf("Rollback should not raise error") + } +} + +func TestTransactionReadonly(t *testing.T) { + dialect := os.Getenv("GORM_DIALECT") + if dialect == "" { + dialect = "sqlite" + } + switch dialect { + case "mssql", "sqlite": + t.Skipf("%s does not support readonly transactions\n", dialect) + } + + tx := DB.Begin() + u := User{Name: "transcation"} + if err := tx.Save(&u).Error; err != nil { + t.Errorf("No error should raise") + } + tx.Commit() + + tx = DB.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) + if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { + t.Errorf("Should find saved record") + } + + if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil { + t.Errorf("Should return the underlying sql.Tx") + } + + u = User{Name: "transcation-2"} + if err := tx.Save(&u).Error; err == nil { + t.Errorf("Error should have been raised in a readonly transaction") + } + + tx.Rollback() } func TestRow(t *testing.T) { @@ -564,6 +777,60 @@ } } +type JoinedIds struct { + UserID int64 `gorm:"column:id"` + BillingAddressID int64 `gorm:"column:id"` + EmailID int64 `gorm:"column:id"` +} + +func TestScanIdenticalColumnNames(t *testing.T) { + var user = User{ + Name: "joinsIds", + Email: "joinIds@example.com", + BillingAddress: Address{ + Address1: "One Park Place", + }, + Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, + } + DB.Save(&user) + + var users []JoinedIds + DB.Select("users.id, addresses.id, emails.id").Table("users"). + Joins("left join addresses on users.billing_address_id = addresses.id"). + Joins("left join emails on emails.user_id = users.id"). + Where("name = ?", "joinsIds").Scan(&users) + + if len(users) != 2 { + t.Fatal("should find two rows using left join") + } + + if user.Id != users[0].UserID { + t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[0].UserID) + } + if user.Id != users[1].UserID { + t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[1].UserID) + } + + if user.BillingAddressID.Int64 != users[0].BillingAddressID { + t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID) + } + if user.BillingAddressID.Int64 != users[1].BillingAddressID { + t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID) + } + + if users[0].EmailID == users[1].EmailID { + t.Errorf("Email ids should be unique. Got %d and %d", users[0].EmailID, users[1].EmailID) + } + + if int64(user.Emails[0].Id) != users[0].EmailID && int64(user.Emails[1].Id) != users[0].EmailID { + t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[0].EmailID) + } + + if int64(user.Emails[0].Id) != users[1].EmailID && int64(user.Emails[1].Id) != users[1].EmailID { + t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[1].EmailID) + } +} + func TestJoinsWithSelect(t *testing.T) { type result struct { Name string @@ -578,6 +845,11 @@ var results []result DB.Table("users").Select("name, emails.email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results) + + sort.Slice(results, func(i, j int) bool { + return strings.Compare(results[i].Email, results[j].Email) < 0 + }) + if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" { t.Errorf("Should find all two emails with Join select") } @@ -862,6 +1134,94 @@ } } +func TestSaveAssociations(t *testing.T) { + db := DB.New() + deltaAddressCount := 0 + if err := db.Model(&Address{}).Count(&deltaAddressCount).Error; err != nil { + t.Errorf("failed to fetch address count") + t.FailNow() + } + + placeAddress := &Address{ + Address1: "somewhere on earth", + } + ownerAddress1 := &Address{ + Address1: "near place address", + } + ownerAddress2 := &Address{ + Address1: "address2", + } + db.Create(placeAddress) + + addressCountShouldBe := func(t *testing.T, expectedCount int) { + countFromDB := 0 + t.Helper() + err := db.Model(&Address{}).Count(&countFromDB).Error + if err != nil { + t.Error("failed to fetch address count") + } + if countFromDB != expectedCount { + t.Errorf("address count mismatch: %d", countFromDB) + } + } + addressCountShouldBe(t, deltaAddressCount+1) + + // owner address should be created, place address should be reused + place1 := &Place{ + PlaceAddressID: placeAddress.ID, + PlaceAddress: placeAddress, + OwnerAddress: ownerAddress1, + } + err := db.Create(place1).Error + if err != nil { + t.Errorf("failed to store place: %s", err.Error()) + } + addressCountShouldBe(t, deltaAddressCount+2) + + // owner address should be created again, place address should be reused + place2 := &Place{ + PlaceAddressID: placeAddress.ID, + PlaceAddress: &Address{ + ID: 777, + Address1: "address1", + }, + OwnerAddress: ownerAddress2, + OwnerAddressID: 778, + } + err = db.Create(place2).Error + if err != nil { + t.Errorf("failed to store place: %s", err.Error()) + } + addressCountShouldBe(t, deltaAddressCount+3) + + count := 0 + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + OwnerAddressID: ownerAddress1.ID, + }).Count(&count) + if count != 1 { + t.Errorf("only one instance of (%d, %d) should be available, found: %d", + placeAddress.ID, ownerAddress1.ID, count) + } + + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + OwnerAddressID: ownerAddress2.ID, + }).Count(&count) + if count != 1 { + t.Errorf("only one instance of (%d, %d) should be available, found: %d", + placeAddress.ID, ownerAddress2.ID, count) + } + + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + }).Count(&count) + if count != 2 { + t.Errorf("two instances of (%d) should be available, found: %d", + placeAddress.ID, count) + } +} + func TestBlockGlobalUpdate(t *testing.T) { db := DB.New() db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"}) @@ -898,6 +1258,176 @@ if err != nil { t.Error("Unexpected error on conditional delete") } +} + +func TestCountWithHaving(t *testing.T) { + db := DB.New() + db.Delete(User{}) + defer db.Delete(User{}) + + DB.Create(getPreparedUser("user1", "pluck_user")) + DB.Create(getPreparedUser("user2", "pluck_user")) + user3 := getPreparedUser("user3", "pluck_user") + user3.Languages = []Language{} + DB.Create(user3) + + var count int + err := db.Model(User{}).Select("users.id"). + Joins("LEFT JOIN user_languages ON user_languages.user_id = users.id"). + Joins("LEFT JOIN languages ON user_languages.language_id = languages.id"). + Group("users.id").Having("COUNT(languages.id) > 1").Count(&count).Error + + if err != nil { + t.Error("Unexpected error on query count with having") + } + + if count != 2 { + t.Error("Unexpected result on query count with having") + } +} + +func TestPluck(t *testing.T) { + db := DB.New() + db.Delete(User{}) + defer db.Delete(User{}) + + DB.Create(&User{Id: 1, Name: "user1"}) + DB.Create(&User{Id: 2, Name: "user2"}) + DB.Create(&User{Id: 3, Name: "user3"}) + + var ids []int64 + err := db.Model(User{}).Order("id").Pluck("id", &ids).Error + + if err != nil { + t.Error("Unexpected error on pluck") + } + + if len(ids) != 3 || ids[0] != 1 || ids[1] != 2 || ids[2] != 3 { + t.Error("Unexpected result on pluck") + } + + err = db.Model(User{}).Order("id").Pluck("id", &ids).Error + + if err != nil { + t.Error("Unexpected error on pluck again") + } + + if len(ids) != 3 || ids[0] != 1 || ids[1] != 2 || ids[2] != 3 { + t.Error("Unexpected result on pluck again") + } +} + +func TestCountWithQueryOption(t *testing.T) { + db := DB.New() + db.Delete(User{}) + defer db.Delete(User{}) + + DB.Create(&User{Name: "user1"}) + DB.Create(&User{Name: "user2"}) + DB.Create(&User{Name: "user3"}) + + var count int + err := db.Model(User{}).Select("users.id"). + Set("gorm:query_option", "WHERE users.name='user2'"). + Count(&count).Error + + if err != nil { + t.Error("Unexpected error on query count with query_option") + } + + if count != 1 { + t.Error("Unexpected result on query count with query_option") + } +} + +func TestSubQueryWithQueryOption(t *testing.T) { + db := DB.New() + + subQuery := db.Model(User{}).Select("users.id"). + Set("gorm:query_option", "WHERE users.name='user2'"). + SubQuery() + + matched, _ := regexp.MatchString( + `^&{.+\s+WHERE users\.name='user2'.*\s\[]}$`, fmt.Sprint(subQuery)) + if !matched { + t.Error("Unexpected result of SubQuery with query_option") + } +} + +func TestQueryExprWithQueryOption(t *testing.T) { + db := DB.New() + + queryExpr := db.Model(User{}).Select("users.id"). + Set("gorm:query_option", "WHERE users.name='user2'"). + QueryExpr() + + matched, _ := regexp.MatchString( + `^&{.+\s+WHERE users\.name='user2'.*\s\[]}$`, fmt.Sprint(queryExpr)) + if !matched { + t.Error("Unexpected result of QueryExpr with query_option") + } +} + +func TestQueryHint1(t *testing.T) { + db := DB.New() + + _, err := db.Model(User{}).Raw("select 1").Rows() + + if err != nil { + t.Error("Unexpected error on query count with query_option") + } +} + +func TestQueryHint2(t *testing.T) { + type TestStruct struct { + ID string `gorm:"primary_key"` + Name string + } + DB.DropTable(&TestStruct{}) + DB.AutoMigrate(&TestStruct{}) + + data := TestStruct{ID: "uuid", Name: "hello"} + if err := DB.Set("gorm:query_hint", "/*master*/").Save(&data).Error; err != nil { + t.Error("Unexpected error on query count with query_option") + } +} + +func TestFloatColumnPrecision(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" && dialect != "sqlite" { + t.Skip() + } + + type FloatTest struct { + ID string `gorm:"primary_key"` + FloatValue float64 `gorm:"column:float_value" sql:"type:float(255,5);"` + } + DB.DropTable(&FloatTest{}) + DB.AutoMigrate(&FloatTest{}) + + data := FloatTest{ID: "uuid", FloatValue: 112.57315} + if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.FloatValue != 112.57315 { + t.Errorf("Float value should not lose precision") + } +} + +func TestWhereUpdates(t *testing.T) { + type OwnerEntity struct { + gorm.Model + OwnerID uint + OwnerType string + } + + type SomeEntity struct { + gorm.Model + Name string + OwnerEntity OwnerEntity `gorm:"polymorphic:Owner"` + } + + DB.DropTable(&SomeEntity{}) + DB.AutoMigrate(&SomeEntity{}) + + a := SomeEntity{Name: "test"} + DB.Model(&a).Where(a).Updates(SomeEntity{Name: "test2"}) } func BenchmarkGorm(b *testing.B) { diff --git a/migration_test.go b/migration_test.go index 7c69448..063c6f6 100644 --- a/migration_test.go +++ b/migration_test.go @@ -118,6 +118,14 @@ Owner *User `sql:"-"` } +type Place struct { + Id int64 + PlaceAddressID int + PlaceAddress *Address `gorm:"save_associations:false"` + OwnerAddressID int + OwnerAddress *Address `gorm:"save_associations:true"` +} + type EncryptedData []byte func (data *EncryptedData) Scan(value interface{}) error { @@ -275,6 +283,11 @@ } } +type Panda struct { + Number int64 `gorm:"unique_index:number"` + Name string `gorm:"column:name;type:varchar(255);default:null"` +} + func runMigration() { if err := DB.DropTableIfExists(&User{}).Error; err != nil { fmt.Printf("Got error when try to delete table users, %+v\n", err) @@ -284,7 +297,7 @@ DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}} + values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}, &Panda{}} for _, value := range values { DB.DropTable(value) } @@ -398,6 +411,53 @@ } } +func TestCreateAndAutomigrateTransaction(t *testing.T) { + tx := DB.Begin() + + func() { + type Bar struct { + ID uint + } + DB.DropTableIfExists(&Bar{}) + + if ok := DB.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } + + if ok := tx.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } + }() + + func() { + type Bar struct { + Name string + } + err := tx.CreateTable(&Bar{}).Error + + if err != nil { + t.Errorf("Should have been able to create the table, but couldn't: %s", err) + } + + if ok := tx.HasTable(&Bar{}); !ok { + t.Errorf("The transaction should be able to see the table") + } + }() + + func() { + type Bar struct { + Stuff string + } + + err := tx.AutoMigrate(&Bar{}).Error + if err != nil { + t.Errorf("Should have been able to alter the table, but couldn't") + } + }() + + tx.Rollback() +} + type MultipleIndexes struct { ID int64 UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` @@ -483,3 +543,42 @@ t.Errorf("No error should happen when ModifyColumn, but got %v", err) } } + +func TestIndexWithPrefixLength(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" { + t.Skip("Skipping this because only mysql support setting an index prefix length") + } + + type IndexWithPrefix struct { + gorm.Model + Name string + Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` + } + type IndexesWithPrefix struct { + gorm.Model + Name string + Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` + Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` + } + type IndexesWithPrefixAndWithoutPrefix struct { + gorm.Model + Name string `gorm:"index:idx_index_with_prefixes_length"` + Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` + } + tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}} + for _, table := range tables { + scope := DB.NewScope(table) + tableName := scope.TableName() + t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) { + if err := DB.DropTableIfExists(table).Error; err != nil { + t.Errorf("Failed to drop %s table: %v", tableName, err) + } + if err := DB.CreateTable(table).Error; err != nil { + t.Errorf("Failed to create %s table: %v", tableName, err) + } + if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") { + t.Errorf("Failed to create %s table index:", tableName) + } + }) + } +} diff --git a/model_struct.go b/model_struct.go index f571e2e..57dbec3 100644 --- a/model_struct.go +++ b/model_struct.go @@ -17,48 +17,38 @@ return defaultTableName } -type safeModelStructsMap struct { - m map[reflect.Type]*ModelStruct - l *sync.RWMutex -} - -func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) { +// lock for mutating global cached model metadata +var structsLock sync.Mutex + +// global cache of model metadata +var modelStructsMap sync.Map + +// ModelStruct model definition +type ModelStruct struct { + PrimaryFields []*StructField + StructFields []*StructField + ModelType reflect.Type + + defaultTableName string + l sync.Mutex +} + +// TableName returns model's table name +func (s *ModelStruct) TableName(db *DB) string { s.l.Lock() defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newModelStructsMap() *safeModelStructsMap { - return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)} -} - -var modelStructsMap = newModelStructsMap() - -// ModelStruct model definition -type ModelStruct struct { - PrimaryFields []*StructField - StructFields []*StructField - ModelType reflect.Type - defaultTableName string -} - -// TableName get model's table name -func (s *ModelStruct) TableName(db *DB) string { + if s.defaultTableName == "" && db != nil && s.ModelType != nil { // Set default table name if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { s.defaultTableName = tabler.TableName() } else { - tableName := ToDBName(s.ModelType.Name()) - if db == nil || !db.parent.singularTable { + tableName := ToTableName(s.ModelType.Name()) + db.parent.RLock() + if db == nil || (db.parent != nil && !db.parent.singularTable) { tableName = inflection.Plural(tableName) } + db.parent.RUnlock() s.defaultTableName = tableName } } @@ -81,30 +71,57 @@ Struct reflect.StructField IsForeignKey bool Relationship *Relationship -} - -func (structField *StructField) clone() *StructField { + + tagSettingsLock sync.RWMutex +} + +// TagSettingsSet Sets a tag in the tag settings map +func (sf *StructField) TagSettingsSet(key, val string) { + sf.tagSettingsLock.Lock() + defer sf.tagSettingsLock.Unlock() + sf.TagSettings[key] = val +} + +// TagSettingsGet returns a tag from the tag settings +func (sf *StructField) TagSettingsGet(key string) (string, bool) { + sf.tagSettingsLock.RLock() + defer sf.tagSettingsLock.RUnlock() + val, ok := sf.TagSettings[key] + return val, ok +} + +// TagSettingsDelete deletes a tag +func (sf *StructField) TagSettingsDelete(key string) { + sf.tagSettingsLock.Lock() + defer sf.tagSettingsLock.Unlock() + delete(sf.TagSettings, key) +} + +func (sf *StructField) clone() *StructField { clone := &StructField{ - DBName: structField.DBName, - Name: structField.Name, - Names: structField.Names, - IsPrimaryKey: structField.IsPrimaryKey, - IsNormal: structField.IsNormal, - IsIgnored: structField.IsIgnored, - IsScanner: structField.IsScanner, - HasDefaultValue: structField.HasDefaultValue, - Tag: structField.Tag, + DBName: sf.DBName, + Name: sf.Name, + Names: sf.Names, + IsPrimaryKey: sf.IsPrimaryKey, + IsNormal: sf.IsNormal, + IsIgnored: sf.IsIgnored, + IsScanner: sf.IsScanner, + HasDefaultValue: sf.HasDefaultValue, + Tag: sf.Tag, TagSettings: map[string]string{}, - Struct: structField.Struct, - IsForeignKey: structField.IsForeignKey, - } - - if structField.Relationship != nil { - relationship := *structField.Relationship + Struct: sf.Struct, + IsForeignKey: sf.IsForeignKey, + } + + if sf.Relationship != nil { + relationship := *sf.Relationship clone.Relationship = &relationship } - for key, value := range structField.TagSettings { + // copy the struct field tagSettings, they should be read-locked while they are copied + sf.tagSettingsLock.Lock() + defer sf.tagSettingsLock.Unlock() + for key, value := range sf.TagSettings { clone.TagSettings[key] = value } @@ -126,7 +143,7 @@ func getForeignField(column string, fields []*StructField) *StructField { for _, field := range fields { - if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) { + if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) { return field } } @@ -135,6 +152,10 @@ // GetModelStruct get value's model struct, relationships based on struct and tag definition func (scope *Scope) GetModelStruct() *ModelStruct { + return scope.getModelStruct(scope, make([]*StructField, 0)) +} + +func (scope *Scope) getModelStruct(rootScope *Scope, allFields []*StructField) *ModelStruct { var modelStruct ModelStruct // Scope value can't be nil if scope.Value == nil { @@ -152,8 +173,19 @@ } // Get Cached model struct - if value := modelStructsMap.Get(reflectType); value != nil { - return value + isSingularTable := false + if scope.db != nil && scope.db.parent != nil { + scope.db.parent.RLock() + isSingularTable = scope.db.parent.singularTable + scope.db.parent.RUnlock() + } + + hashKey := struct { + singularTable bool + reflectType reflect.Type + }{isSingularTable, reflectType} + if value, ok := modelStructsMap.Load(hashKey); ok && value != nil { + return value.(*ModelStruct) } modelStruct.ModelType = reflectType @@ -170,19 +202,19 @@ } // is ignored field - if _, ok := field.TagSettings["-"]; ok { + if _, ok := field.TagSettingsGet("-"); ok { field.IsIgnored = true } else { - if _, ok := field.TagSettings["PRIMARY_KEY"]; ok { + if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok { field.IsPrimaryKey = true modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) } - if _, ok := field.TagSettings["DEFAULT"]; ok { + if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey { field.HasDefaultValue = true } - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsPrimaryKey { + if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey { field.HasDefaultValue = true } @@ -198,8 +230,8 @@ if indirectType.Kind() == reflect.Struct { for i := 0; i < indirectType.NumField(); i++ { for key, value := range parseTagSetting(indirectType.Field(i).Tag) { - if _, ok := field.TagSettings[key]; !ok { - field.TagSettings[key] = value + if _, ok := field.TagSettingsGet(key); !ok { + field.TagSettingsSet(key, value) } } } @@ -207,17 +239,17 @@ } else if _, isTime := fieldValue.(*time.Time); isTime { // is time field.IsNormal = true - } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { + } else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous { // is embedded struct - for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields { + for _, subField := range scope.New(fieldValue).getModelStruct(rootScope, allFields).StructFields { subField = subField.clone() subField.Names = append([]string{fieldStruct.Name}, subField.Names...) - if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok { + if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok { subField.DBName = prefix + subField.DBName } if subField.IsPrimaryKey { - if _, ok := subField.TagSettings["PRIMARY_KEY"]; ok { + if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok { modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) } else { subField.IsPrimaryKey = false @@ -233,6 +265,7 @@ } modelStruct.StructFields = append(modelStruct.StructFields, subField) + allFields = append(allFields, subField) } continue } else { @@ -248,13 +281,13 @@ elemType = field.Struct.Type ) - if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { foreignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { associationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { associationForeignKeys = strings.Split(foreignKey, ",") } @@ -263,13 +296,13 @@ } if elemType.Kind() == reflect.Struct { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { relationship.Kind = "many_to_many" { // Foreign Keys for Source joinTableDBNames := []string{} - if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" { joinTableDBNames = strings.Split(foreignKey, ",") } @@ -290,7 +323,7 @@ // if defined join table's foreign key relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) } else { - defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName + defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) } } @@ -300,7 +333,7 @@ { // Foreign Keys for Association (Destination) associationJoinTableDBNames := []string{} - if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" { associationJoinTableDBNames = strings.Split(foreignKey, ",") } @@ -321,7 +354,7 @@ relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) } else { // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) } } @@ -338,7 +371,7 @@ var toFields = toScope.GetStructFields() relationship.Kind = "has_many" - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { // Dog has many toys, tag polymorphic is Owner, then associationType is Owner // Toy use OwnerID, OwnerType ('dogs') as foreign key if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { @@ -346,7 +379,7 @@ relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName // if Dog has multiple set of toys set name of the set (instead of default 'dogs') - if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { + if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { relationship.PolymorphicValue = value } else { relationship.PolymorphicValue = scope.TableName() @@ -366,7 +399,7 @@ } else { // generate foreign keys from defined association foreign keys for _, scopeFieldName := range associationForeignKeys { - if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil { + if foreignField := getForeignField(scopeFieldName, allFields); foreignField != nil { foreignKeys = append(foreignKeys, associationType+foreignField.Name) associationForeignKeys = append(associationForeignKeys, foreignField.Name) } @@ -378,13 +411,13 @@ for _, foreignKey := range foreignKeys { if strings.HasPrefix(foreignKey, associationType) { associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil { associationForeignKeys = append(associationForeignKeys, associationForeignKey) } } } if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} + associationForeignKeys = []string{rootScope.PrimaryKey()} } } else if len(foreignKeys) != len(associationForeignKeys) { scope.Err(errors.New("invalid foreign keys, should have same length")) @@ -394,9 +427,13 @@ for idx, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { - // source foreign keys + if associationField := getForeignField(associationForeignKeys[idx], allFields); associationField != nil { + // mark field as foreignkey, use global lock to avoid race + structsLock.Lock() foreignField.IsForeignKey = true + structsLock.Unlock() + + // association foreign keys relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) @@ -428,17 +465,17 @@ tagAssociationForeignKeys []string ) - if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { tagForeignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { tagAssociationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { tagAssociationForeignKeys = strings.Split(foreignKey, ",") } - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { // Cat has one toy, tag polymorphic is Owner, then associationType is Owner // Toy use OwnerID, OwnerType ('cats') as foreign key if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { @@ -446,7 +483,7 @@ relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName // if Cat has several different types of toys set name for each (instead of default 'cats') - if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { + if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { relationship.PolymorphicValue = value } else { relationship.PolymorphicValue = scope.TableName() @@ -470,7 +507,7 @@ } else { // generate foreign keys form association foreign keys for _, associationForeignKey := range tagAssociationForeignKeys { - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil { foreignKeys = append(foreignKeys, associationType+foreignField.Name) associationForeignKeys = append(associationForeignKeys, foreignField.Name) } @@ -482,13 +519,13 @@ for _, foreignKey := range foreignKeys { if strings.HasPrefix(foreignKey, associationType) { associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil { associationForeignKeys = append(associationForeignKeys, associationForeignKey) } } } if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} + associationForeignKeys = []string{rootScope.PrimaryKey()} } } else if len(foreignKeys) != len(associationForeignKeys) { scope.Err(errors.New("invalid foreign keys, should have same length")) @@ -498,9 +535,13 @@ for idx, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { + if scopeField := getForeignField(associationForeignKeys[idx], allFields); scopeField != nil { + // mark field as foreignkey, use global lock to avoid race + structsLock.Lock() foreignField.IsForeignKey = true - // source foreign keys + structsLock.Unlock() + + // association foreign keys relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) @@ -558,7 +599,10 @@ for idx, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { + // mark field as foreignkey, use global lock to avoid race + structsLock.Lock() foreignField.IsForeignKey = true + structsLock.Unlock() // association foreign keys relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) @@ -584,13 +628,14 @@ } // Even it is ignored, also possible to decode db value into the field - if value, ok := field.TagSettings["COLUMN"]; ok { + if value, ok := field.TagSettingsGet("COLUMN"); ok { field.DBName = value } else { - field.DBName = ToDBName(fieldStruct.Name) + field.DBName = ToColumnName(fieldStruct.Name) } modelStruct.StructFields = append(modelStruct.StructFields, field) + allFields = append(allFields, field) } } @@ -601,7 +646,7 @@ } } - modelStructsMap.Set(reflectType, &modelStruct) + modelStructsMap.Store(hashKey, &modelStruct) return &modelStruct } @@ -614,6 +659,9 @@ func parseTagSetting(tags reflect.StructTag) map[string]string { setting := map[string]string{} for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { + if str == "" { + continue + } tags := strings.Split(str, ";") for _, value := range tags { v := strings.Split(value, ":") diff --git a/model_struct_test.go b/model_struct_test.go new file mode 100644 index 0000000..7543ccf --- /dev/null +++ b/model_struct_test.go @@ -0,0 +1,140 @@ +package gorm_test + +import ( + "sync" + "testing" + + "github.com/jinzhu/gorm" +) + +type ModelA struct { + gorm.Model + Name string + + ModelCs []ModelC `gorm:"foreignkey:OtherAID"` +} + +type ModelB struct { + gorm.Model + Name string + + ModelCs []ModelC `gorm:"foreignkey:OtherBID"` +} + +type ModelC struct { + gorm.Model + Name string + + OtherAID uint64 + OtherA *ModelA `gorm:"foreignkey:OtherAID"` + OtherBID uint64 + OtherB *ModelB `gorm:"foreignkey:OtherBID"` +} + +type RequestModel struct { + Name string + Children []ChildModel `gorm:"foreignkey:ParentID"` +} + +type ChildModel struct { + ID string + ParentID string + Name string +} + +type ResponseModel struct { + gorm.Model + RequestModel +} + +// This test will try to cause a race condition on the model's foreignkey metadata +func TestModelStructRaceSameModel(t *testing.T) { + // use a WaitGroup to execute as much in-sync as possible + // it's more likely to hit a race condition than without + n := 32 + start := sync.WaitGroup{} + start.Add(n) + + // use another WaitGroup to know when the test is done + done := sync.WaitGroup{} + done.Add(n) + + for i := 0; i < n; i++ { + go func() { + start.Wait() + + // call GetStructFields, this had a race condition before we fixed it + DB.NewScope(&ModelA{}).GetStructFields() + + done.Done() + }() + + start.Done() + } + + done.Wait() +} + +// This test will try to cause a race condition on the model's foreignkey metadata +func TestModelStructRaceDifferentModel(t *testing.T) { + // use a WaitGroup to execute as much in-sync as possible + // it's more likely to hit a race condition than without + n := 32 + start := sync.WaitGroup{} + start.Add(n) + + // use another WaitGroup to know when the test is done + done := sync.WaitGroup{} + done.Add(n) + + for i := 0; i < n; i++ { + i := i + go func() { + start.Wait() + + // call GetStructFields, this had a race condition before we fixed it + if i%2 == 0 { + DB.NewScope(&ModelA{}).GetStructFields() + } else { + DB.NewScope(&ModelB{}).GetStructFields() + } + + done.Done() + }() + + start.Done() + } + + done.Wait() +} + +func TestModelStructEmbeddedHasMany(t *testing.T) { + fields := DB.NewScope(&ResponseModel{}).GetStructFields() + + var childrenField *gorm.StructField + + for i := 0; i < len(fields); i++ { + field := fields[i] + + if field != nil && field.Name == "Children" { + childrenField = field + } + } + + if childrenField == nil { + t.Error("childrenField should not be nil") + return + } + + if childrenField.Relationship == nil { + t.Error("childrenField.Relation should not be nil") + return + } + + expected := "has_many" + actual := childrenField.Relationship.Kind + + if actual != expected { + t.Errorf("childrenField.Relationship.Kind should be %v, but was %v", expected, actual) + } +} diff --git a/naming.go b/naming.go new file mode 100644 index 0000000..6b0a4fd --- /dev/null +++ b/naming.go @@ -0,0 +1,124 @@ +package gorm + +import ( + "bytes" + "strings" +) + +// Namer is a function type which is given a string and return a string +type Namer func(string) string + +// NamingStrategy represents naming strategies +type NamingStrategy struct { + DB Namer + Table Namer + Column Namer +} + +// TheNamingStrategy is being initialized with defaultNamingStrategy +var TheNamingStrategy = &NamingStrategy{ + DB: defaultNamer, + Table: defaultNamer, + Column: defaultNamer, +} + +// AddNamingStrategy sets the naming strategy +func AddNamingStrategy(ns *NamingStrategy) { + if ns.DB == nil { + ns.DB = defaultNamer + } + if ns.Table == nil { + ns.Table = defaultNamer + } + if ns.Column == nil { + ns.Column = defaultNamer + } + TheNamingStrategy = ns +} + +// DBName alters the given name by DB +func (ns *NamingStrategy) DBName(name string) string { + return ns.DB(name) +} + +// TableName alters the given name by Table +func (ns *NamingStrategy) TableName(name string) string { + return ns.Table(name) +} + +// ColumnName alters the given name by Column +func (ns *NamingStrategy) ColumnName(name string) string { + return ns.Column(name) +} + +// ToDBName convert string to db name +func ToDBName(name string) string { + return TheNamingStrategy.DBName(name) +} + +// ToTableName convert string to table name +func ToTableName(name string) string { + return TheNamingStrategy.TableName(name) +} + +// ToColumnName convert string to db name +func ToColumnName(name string) string { + return TheNamingStrategy.ColumnName(name) +} + +var smap = newSafeMap() + +func defaultNamer(name string) string { + const ( + lower = false + upper = true + ) + + if v := smap.Get(name); v != "" { + return v + } + + if name == "" { + return "" + } + + var ( + value = commonInitialismsReplacer.Replace(name) + buf = bytes.NewBufferString("") + lastCase, currCase, nextCase, nextNumber bool + ) + + for i, v := range value[:len(value)-1] { + nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z') + nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9') + + if i > 0 { + if currCase == upper { + if lastCase == upper && (nextCase == upper || nextNumber == upper) { + buf.WriteRune(v) + } else { + if value[i-1] != '_' && value[i+1] != '_' { + buf.WriteRune('_') + } + buf.WriteRune(v) + } + } else { + buf.WriteRune(v) + if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { + buf.WriteRune('_') + } + } + } else { + currCase = upper + buf.WriteRune(v) + } + lastCase = currCase + currCase = nextCase + } + + buf.WriteByte(value[len(value)-1]) + + s := strings.ToLower(buf.String()) + smap.Set(name, s) + return s +} diff --git a/naming_test.go b/naming_test.go new file mode 100644 index 0000000..0c6f771 --- /dev/null +++ b/naming_test.go @@ -0,0 +1,69 @@ +package gorm_test + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +func TestTheNamingStrategy(t *testing.T) { + + cases := []struct { + name string + namer gorm.Namer + expected string + }{ + {name: "auth", expected: "auth", namer: gorm.TheNamingStrategy.DB}, + {name: "userRestrictions", expected: "user_restrictions", namer: gorm.TheNamingStrategy.Table}, + {name: "clientID", expected: "client_id", namer: gorm.TheNamingStrategy.Column}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := c.namer(c.name) + if result != c.expected { + t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) + } + }) + } + +} + +func TestNamingStrategy(t *testing.T) { + + dbNameNS := func(name string) string { + return "db_" + name + } + tableNameNS := func(name string) string { + return "tbl_" + name + } + columnNameNS := func(name string) string { + return "col_" + name + } + + ns := &gorm.NamingStrategy{ + DB: dbNameNS, + Table: tableNameNS, + Column: columnNameNS, + } + + cases := []struct { + name string + namer gorm.Namer + expected string + }{ + {name: "auth", expected: "db_auth", namer: ns.DB}, + {name: "user", expected: "tbl_user", namer: ns.Table}, + {name: "password", expected: "col_password", namer: ns.Column}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := c.namer(c.name) + if result != c.expected { + t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) + } + }) + } + +} diff --git a/preload_test.go b/preload_test.go index 311ad0b..dd29fb5 100644 --- a/preload_test.go +++ b/preload_test.go @@ -123,6 +123,31 @@ } } +func TestAutoPreloadFalseDoesntPreload(t *testing.T) { + user1 := getPreloadUser("auto_user1") + DB.Save(user1) + + preloadDB := DB.Set("gorm:auto_preload", false).Where("role = ?", "Preload") + var user User + preloadDB.Find(&user) + + if user.BillingAddress.Address1 != "" { + t.Error("AutoPreload was set to fasle, but still fetched data") + } + + user2 := getPreloadUser("auto_user2") + DB.Save(user2) + + var users []User + preloadDB.Find(&users) + + for _, user := range users { + if user.BillingAddress.Address1 != "" { + t.Error("AutoPreload was set to fasle, but still fetched data") + } + } +} + func TestNestedPreload1(t *testing.T) { type ( Level1 struct { @@ -746,6 +771,7 @@ levelB3 := &LevelB3{ Value: "bar", LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)}, + LevelB2s: []*LevelB2{}, } if err := DB.Create(levelB3).Error; err != nil { t.Error(err) @@ -1651,7 +1677,7 @@ lvl := Level1{ Name: "l1", Level2s: []Level2{ - Level2{Name: "l2-1"}, Level2{Name: "l2-2"}, + {Name: "l2-1"}, {Name: "l2-2"}, }, } DB.Save(&lvl) diff --git a/query_test.go b/query_test.go index fac7d4d..b40538e 100644 --- a/query_test.go +++ b/query_test.go @@ -38,6 +38,24 @@ if user.Email != "" { t.Errorf("User's Email should be blank as no one set it") + } +} + +func TestQueryWithAssociation(t *testing.T) { + user := &User{Name: "user1", Emails: []Email{{Email: "user1@example.com"}}, Company: Company{Name: "company"}} + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create user: %v", err) + } + + user.CreatedAt = time.Time{} + user.UpdatedAt = time.Time{} + if err := DB.Where(&user).First(&User{}).Error; err != nil { + t.Errorf("search with struct with association should returns no error, but got %v", err) + } + + if err := DB.Where(user).First(&User{}).Error; err != nil { + t.Errorf("search with struct with association should returns no error, but got %v", err) } } @@ -181,17 +199,17 @@ scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users) if len(users) != 2 { - t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) + t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users)) } scopedb.Where("birthday > ?", "2002-10-10").Find(&users) if len(users) != 2 { - t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users)) + t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users)) } scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) if len(users) != 1 { - t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) + t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) } DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) @@ -457,6 +475,74 @@ } } +func TestLimitAndOffsetSQL(t *testing.T) { + user1 := User{Name: "TestLimitAndOffsetSQL1", Age: 10} + user2 := User{Name: "TestLimitAndOffsetSQL2", Age: 20} + user3 := User{Name: "TestLimitAndOffsetSQL3", Age: 30} + user4 := User{Name: "TestLimitAndOffsetSQL4", Age: 40} + user5 := User{Name: "TestLimitAndOffsetSQL5", Age: 50} + if err := DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5).Error; err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + limit, offset interface{} + users []*User + ok bool + }{ + { + name: "OK", + limit: float64(2), + offset: float64(2), + users: []*User{ + &User{Name: "TestLimitAndOffsetSQL3", Age: 30}, + &User{Name: "TestLimitAndOffsetSQL2", Age: 20}, + }, + ok: true, + }, + { + name: "Limit parse error", + limit: float64(1000000), // 1e+06 + offset: float64(2), + ok: false, + }, + { + name: "Offset parse error", + limit: float64(2), + offset: float64(1000000), // 1e+06 + ok: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var users []*User + err := DB.Where("name LIKE ?", "TestLimitAndOffsetSQL%").Order("age desc").Limit(tt.limit).Offset(tt.offset).Find(&users).Error + if tt.ok { + if err != nil { + t.Errorf("error expected nil, but got %v", err) + } + if len(users) != len(tt.users) { + t.Errorf("users length expected %d, but got %d", len(tt.users), len(users)) + } + for i := range tt.users { + if users[i].Name != tt.users[i].Name { + t.Errorf("users[%d] name expected %s, but got %s", i, tt.users[i].Name, users[i].Name) + } + if users[i].Age != tt.users[i].Age { + t.Errorf("users[%d] age expected %d, but got %d", i, tt.users[i].Age, users[i].Age) + } + } + } else { + if err == nil { + t.Error("error expected not nil, but got nil") + } + } + }) + } +} + func TestOr(t *testing.T) { user1 := User{Name: "OrUser1", Age: 1} user2 := User{Name: "OrUser2", Age: 10} @@ -532,28 +618,28 @@ DB.Table("users").Where("name = ?", "user3").Count(&name3Count) DB.Not("name", "user3").Find(&users4) if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not("name = ?", "user3").Find(&users4) if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not("name <> ?", "user3").Find(&users4) if len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(User{Name: "user3"}).Find(&users5) if len(users1)-len(users5) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6) if len(users1)-len(users6) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) @@ -563,14 +649,14 @@ DB.Not("name", []string{"user3"}).Find(&users8) if len(users1)-len(users8) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } var name2Count int64 DB.Table("users").Where("name = ?", "user2").Count(&name2Count) DB.Not("name", []string{"user3", "user2"}).Find(&users9) if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } } diff --git a/scope.go b/scope.go index 150ac71..ea12ee2 100644 --- a/scope.go +++ b/scope.go @@ -63,12 +63,12 @@ // Dialect get dialect func (scope *Scope) Dialect() Dialect { - return scope.db.parent.dialect + return scope.db.dialect } // Quote used to quote string to escape them for database func (scope *Scope) Quote(str string) string { - if strings.Index(str, ".") != -1 { + if strings.Contains(str, ".") { newStrs := []string{} for _, str := range strings.Split(str, ".") { newStrs = append(newStrs, scope.Dialect().Quote(str)) @@ -134,7 +134,7 @@ // FieldByName find `gorm.Field` with field name or db name func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { var ( - dbName = ToDBName(name) + dbName = ToColumnName(name) mostMatchedField *Field ) @@ -225,7 +225,7 @@ updateAttrs[field.DBName] = value return field.Set(value) } - if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) { + if !field.IsIgnored && ((field.DBName == dbName) || (field.Name == name && mostMatchedField == nil)) { mostMatchedField = field } } @@ -257,7 +257,7 @@ func (scope *Scope) AddToVars(value interface{}) string { _, skipBindVar := scope.InstanceGet("skip_bindvar") - if expr, ok := value.(*expr); ok { + if expr, ok := value.(*SqlExpr); ok { exp := expr.expr for _, arg := range expr.args { if skipBindVar { @@ -330,7 +330,7 @@ // QuotedTableName return quoted table name func (scope *Scope) QuotedTableName() (name string) { if scope.Search != nil && len(scope.Search.tableName) > 0 { - if strings.Index(scope.Search.tableName, " ") != -1 { + if strings.Contains(scope.Search.tableName, " ") { return scope.Search.tableName } return scope.Quote(scope.Search.tableName) @@ -402,7 +402,7 @@ // Begin start a transaction func (scope *Scope) Begin() *Scope { if db, ok := scope.SQLDB().(sqlDb); ok { - if tx, err := db.Begin(); err == nil { + if tx, err := db.Begin(); scope.Err(err) == nil { scope.db.db = interface{}(tx).(SQLCommon) scope.InstanceSet("gorm:started_transaction", true) } @@ -486,8 +486,10 @@ values[index] = &ignored selectFields = fields + offset := 0 if idx, ok := selectedColumnsMap[column]; ok { - selectFields = selectFields[idx+1:] + offset = idx + 1 + selectFields = selectFields[offset:] } for fieldIndex, field := range selectFields { @@ -501,7 +503,7 @@ resetFields[index] = field } - selectedColumnsMap[column] = fieldIndex + selectedColumnsMap[column] = offset + fieldIndex if field.IsNormal { break @@ -586,10 +588,10 @@ scope.Err(fmt.Errorf("invalid query condition: %v", value)) return } - + scopeQuotedTableName := newScope.QuotedTableName() for _, field := range newScope.Fields() { - if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) + if !field.IsIgnored && !field.IsBlank && field.Relationship == nil { + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ") @@ -692,12 +694,12 @@ buff := bytes.NewBuffer([]byte{}) i := 0 - for pos := range str { + for pos, char := range str { if str[pos] == '?' { buff.WriteString(replacements[i]) i++ } else { - buff.WriteByte(str[pos]) + buff.WriteRune(char) } } @@ -783,7 +785,7 @@ for _, order := range scope.Search.orders { if str, ok := order.(string); ok { orders = append(orders, scope.quoteIfPossible(str)) - } else if expr, ok := order.(*expr); ok { + } else if expr, ok := order.(*SqlExpr); ok { exp := expr.expr for _, arg := range expr.args { exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) @@ -795,7 +797,9 @@ } func (scope *Scope) limitAndOffsetSQL() string { - return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) + sql, err := scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) + scope.Err(err) + return sql } func (scope *Scope) groupSQL() string { @@ -837,12 +841,16 @@ } func (scope *Scope) prepareQuerySQL() { + var sql string if scope.Search.raw { - scope.Raw(scope.CombinedConditionSql()) + sql = scope.CombinedConditionSql() } else { - scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())) - } - return + sql = fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()) + } + if str, ok := scope.Get("gorm:query_option"); ok { + sql += addExtraSpaceIfExist(fmt.Sprint(str)) + } + scope.Raw(sql) } func (scope *Scope) inlineCondition(values ...interface{}) *Scope { @@ -853,6 +861,14 @@ } func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { + defer func() { + if err := recover(); err != nil { + if db, ok := scope.db.db.(sqlTx); ok { + db.Rollback() + } + panic(err) + } + }() for _, f := range funcs { (*f)(scope) if scope.skipLeft { @@ -862,7 +878,7 @@ return scope } -func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string]interface{} { +func convertInterfaceToMap(values interface{}, withIgnoredField bool, db *DB) map[string]interface{} { var attrs = map[string]interface{}{} switch value := values.(type) { @@ -870,7 +886,7 @@ return value case []interface{}: for _, v := range value { - for key, value := range convertInterfaceToMap(v, withIgnoredField) { + for key, value := range convertInterfaceToMap(v, withIgnoredField, db) { attrs[key] = value } } @@ -880,10 +896,10 @@ switch reflectValue.Kind() { case reflect.Map: for _, key := range reflectValue.MapKeys() { - attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() + attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() } default: - for _, field := range (&Scope{Value: values}).Fields() { + for _, field := range (&Scope{Value: values, db: db}).Fields() { if !field.IsBlank && (withIgnoredField || !field.IsIgnored) { attrs[field.DBName] = field.Field.Interface() } @@ -895,27 +911,31 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) { if scope.IndirectValue().Kind() != reflect.Struct { - return convertInterfaceToMap(value, false), true + return convertInterfaceToMap(value, false, scope.db), true } results = map[string]interface{}{} - for key, value := range convertInterfaceToMap(value, true) { - if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { - if _, ok := value.(*expr); ok { - hasUpdate = true - results[field.DBName] = value - } else { - err := field.Set(value) - if field.IsNormal { + for key, value := range convertInterfaceToMap(value, true, scope.db) { + if field, ok := scope.FieldByName(key); ok { + if scope.changeableField(field) { + if _, ok := value.(*SqlExpr); ok { hasUpdate = true - if err == ErrUnaddressable { - results[field.DBName] = value - } else { - results[field.DBName] = field.Field.Interface() + results[field.DBName] = value + } else { + err := field.Set(value) + if field.IsNormal && !field.IsIgnored { + hasUpdate = true + if err == ErrUnaddressable { + results[field.DBName] = value + } else { + results[field.DBName] = field.Field.Interface() + } } } } + } else { + results[key] = value } } return @@ -972,6 +992,10 @@ if dest.Kind() != reflect.Slice { scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) return scope + } + + if dest.Len() > 0 { + dest.Set(reflect.Zero(dest.Type())) } if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) { @@ -997,8 +1021,15 @@ func (scope *Scope) count(value interface{}) *Scope { if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { if len(scope.Search.group) != 0 { - scope.Search.Select("count(*) FROM ( SELECT count(*) as name ") - scope.Search.group += " ) AS count_table" + if len(scope.Search.havingConditions) != 0 { + scope.prepareQuerySQL() + scope.Search = &search{} + scope.Search.Select("count(*)") + scope.Search.Table(fmt.Sprintf("( %s ) AS count_table", scope.SQL)) + } else { + scope.Search.Select("count(*) FROM ( SELECT count(*) as name ") + scope.Search.group += " ) AS count_table" + } } else { scope.Search.Select("count(*)") } @@ -1113,8 +1144,8 @@ if field, ok := scope.FieldByName(fieldName); ok { foreignKeyStruct := field.clone() foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" - delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT") + foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") + foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) } @@ -1124,8 +1155,8 @@ if field, ok := toScope.FieldByName(fieldName); ok { foreignKeyStruct := field.clone() foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" - delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT") + foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") + foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) } @@ -1173,7 +1204,7 @@ } func (scope *Scope) dropTable() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE %v%s", scope.QuotedTableName(), scope.getTableOptions())).Exec() + scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() return scope } @@ -1215,12 +1246,18 @@ } func (scope *Scope) removeForeignKey(field string, dest string) { - keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest) - + keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return } - var query = `ALTER TABLE %s DROP CONSTRAINT %s;` + var mysql mysql + var query string + if scope.Dialect().GetName() == mysql.GetName() { + query = `ALTER TABLE %s DROP FOREIGN KEY %s;` + } else { + query = `ALTER TABLE %s DROP CONSTRAINT %s;` + } + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() } @@ -1254,25 +1291,27 @@ var uniqueIndexes = map[string][]string{} for _, field := range scope.GetStructFields() { - if name, ok := field.TagSettings["INDEX"]; ok { + if name, ok := field.TagSettingsGet("INDEX"); ok { names := strings.Split(name, ",") for _, name := range names { if name == "INDEX" || name == "" { name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName) } - indexes[name] = append(indexes[name], field.DBName) - } - } - - if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok { + name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName) + indexes[name] = append(indexes[name], column) + } + } + + if name, ok := field.TagSettingsGet("UNIQUE_INDEX"); ok { names := strings.Split(name, ",") for _, name := range names { if name == "UNIQUE_INDEX" || name == "" { name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName) } - uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) + name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName) + uniqueIndexes[name] = append(uniqueIndexes[name], column) } } } @@ -1293,6 +1332,7 @@ } func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { + resultMap := make(map[string][]interface{}) for _, value := range values { indirectValue := indirect(reflect.ValueOf(value)) @@ -1311,7 +1351,10 @@ } if hasValue { - results = append(results, result) + h := fmt.Sprint(result...) + if _, exist := resultMap[h]; !exist { + resultMap[h] = result + } } } case reflect.Struct: @@ -1326,11 +1369,16 @@ } if hasValue { - results = append(results, result) - } - } - } - + h := fmt.Sprint(result...) + if _, exist := resultMap[h]; !exist { + resultMap[h] = result + } + } + } + } + for _, v := range resultMap { + results = append(results, v) + } return } diff --git a/scope_test.go b/scope_test.go index 3018f35..f7f1ed0 100644 --- a/scope_test.go +++ b/scope_test.go @@ -78,3 +78,16 @@ t.Errorf("The error should be returned from Valuer, but get %v", err) } } + +func TestDropTableWithTableOptions(t *testing.T) { + type UserWithOptions struct { + gorm.Model + } + DB.AutoMigrate(&UserWithOptions{}) + + DB = DB.Set("gorm:table_options", "CHARSET=utf8") + err := DB.DropTable(&UserWithOptions{}).Error + if err != nil { + t.Errorf("Table must be dropped, got error %s", err) + } +} diff --git a/search.go b/search.go index 9013859..52ae2ef 100644 --- a/search.go +++ b/search.go @@ -32,7 +32,57 @@ } func (s *search) clone() *search { - clone := *s + clone := search{ + db: s.db, + whereConditions: make([]map[string]interface{}, len(s.whereConditions)), + orConditions: make([]map[string]interface{}, len(s.orConditions)), + notConditions: make([]map[string]interface{}, len(s.notConditions)), + havingConditions: make([]map[string]interface{}, len(s.havingConditions)), + joinConditions: make([]map[string]interface{}, len(s.joinConditions)), + initAttrs: make([]interface{}, len(s.initAttrs)), + assignAttrs: make([]interface{}, len(s.assignAttrs)), + selects: s.selects, + omits: make([]string, len(s.omits)), + orders: make([]interface{}, len(s.orders)), + preload: make([]searchPreload, len(s.preload)), + offset: s.offset, + limit: s.limit, + group: s.group, + tableName: s.tableName, + raw: s.raw, + Unscoped: s.Unscoped, + ignoreOrderQuery: s.ignoreOrderQuery, + } + for i, value := range s.whereConditions { + clone.whereConditions[i] = value + } + for i, value := range s.orConditions { + clone.orConditions[i] = value + } + for i, value := range s.notConditions { + clone.notConditions[i] = value + } + for i, value := range s.havingConditions { + clone.havingConditions[i] = value + } + for i, value := range s.joinConditions { + clone.joinConditions[i] = value + } + for i, value := range s.initAttrs { + clone.initAttrs[i] = value + } + for i, value := range s.assignAttrs { + clone.assignAttrs[i] = value + } + for i, value := range s.omits { + clone.omits[i] = value + } + for i, value := range s.orders { + clone.orders[i] = value + } + for i, value := range s.preload { + clone.preload[i] = value + } return &clone } @@ -98,7 +148,7 @@ } func (s *search) Having(query interface{}, values ...interface{}) *search { - if val, ok := query.(*expr); ok { + if val, ok := query.(*SqlExpr); ok { s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args}) } else { s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) diff --git a/search_test.go b/search_test.go index 4db7ab6..bf8e84f 100644 --- a/search_test.go +++ b/search_test.go @@ -1,6 +1,7 @@ package gorm import ( + "fmt" "reflect" "testing" ) @@ -28,3 +29,24 @@ t.Errorf("selectStr should be copied") } } + +func TestWhereCloneCorruption(t *testing.T) { + for whereCount := 1; whereCount <= 8; whereCount++ { + t.Run(fmt.Sprintf("w=%d", whereCount), func(t *testing.T) { + s := new(search) + for w := 0; w < whereCount; w++ { + s = s.clone().Where(fmt.Sprintf("w%d = ?", w), fmt.Sprintf("value%d", w)) + } + if len(s.whereConditions) != whereCount { + t.Errorf("s: where count should be %d", whereCount) + } + + q1 := s.clone().Where("finalThing = ?", "THING1") + q2 := s.clone().Where("finalThing = ?", "THING2") + + if reflect.DeepEqual(q1.whereConditions, q2.whereConditions) { + t.Errorf("Where conditions should be different") + } + }) + } +} diff --git a/utils.go b/utils.go index dfaae93..d2ae946 100644 --- a/utils.go +++ b/utils.go @@ -1,7 +1,6 @@ package gorm import ( - "bytes" "database/sql/driver" "fmt" "reflect" @@ -26,8 +25,8 @@ var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialismsReplacer *strings.Replacer -var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) -var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`) +var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`) +var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`) func init() { var commonInitialismsForReplacer []string @@ -58,74 +57,16 @@ return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} } -var smap = newSafeMap() - -type strCase bool - -const ( - lower strCase = false - upper strCase = true -) - -// ToDBName convert string to db name -func ToDBName(name string) string { - if v := smap.Get(name); v != "" { - return v - } - - if name == "" { - return "" - } - - var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase strCase - ) - - for i, v := range value[:len(value)-1] { - nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') - if i > 0 { - if currCase == upper { - if lastCase == upper && nextCase == upper { - buf.WriteRune(v) - } else { - if value[i-1] != '_' && value[i+1] != '_' { - buf.WriteRune('_') - } - buf.WriteRune(v) - } - } else { - buf.WriteRune(v) - if i == len(value)-2 && nextCase == upper { - buf.WriteRune('_') - } - } - } else { - currCase = upper - buf.WriteRune(v) - } - lastCase = currCase - currCase = nextCase - } - - buf.WriteByte(value[len(value)-1]) - - s := strings.ToLower(buf.String()) - smap.Set(name, s) - return s -} - // SQL expression -type expr struct { +type SqlExpr struct { expr string args []interface{} } // Expr generate raw SQL expression, for example: // DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) -func Expr(expression string, args ...interface{}) *expr { - return &expr{expr: expression, args: args} +func Expr(expression string, args ...interface{}) *SqlExpr { + return &SqlExpr{expr: expression, args: args} } func indirect(reflectValue reflect.Value) reflect.Value { @@ -265,7 +206,7 @@ // as FieldByName could panic if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { for _, fieldName := range fieldNames { - if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { + if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() { result := fieldValue.Interface() if r, ok := result.(driver.Valuer); ok { result, _ = r.Value() diff --git a/utils_test.go b/utils_test.go deleted file mode 100644 index 152296d..0000000 --- a/utils_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package gorm_test - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestToDBNameGenerateFriendlyName(t *testing.T) { - var maps = map[string]string{ - "": "", - "X": "x", - "ThisIsATest": "this_is_a_test", - "PFAndESI": "pf_and_esi", - "AbcAndJkl": "abc_and_jkl", - "EmployeeID": "employee_id", - "SKU_ID": "sku_id", - "FieldX": "field_x", - "HTTPAndSMTP": "http_and_smtp", - "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", - "UUID": "uuid", - "HTTPURL": "http_url", - "HTTP_URL": "http_url", - "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", - } - - for key, value := range maps { - if gorm.ToDBName(key) != value { - t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key)) - } - } -} diff --git a/wercker.yml b/wercker.yml index 0c3e73e..1de947b 100644 --- a/wercker.yml +++ b/wercker.yml @@ -4,6 +4,13 @@ services: - name: mariadb id: mariadb:latest + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql + id: mysql:latest env: MYSQL_DATABASE: gorm MYSQL_USER: gorm @@ -18,13 +25,6 @@ MYSQL_RANDOM_ROOT_PASSWORD: "yes" - name: mysql56 id: mysql:5.6 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql55 - id: mysql:5.5 env: MYSQL_DATABASE: gorm MYSQL_USER: gorm @@ -83,7 +83,7 @@ code: | cd $WERCKER_SOURCE_DIR go version - go get -t ./... + go get -t -v ./... # Build the project - script: @@ -95,54 +95,55 @@ - script: name: test sqlite code: | - go test ./... + go test -race -v ./... - script: name: test mariadb code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./... + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... + + - script: + name: test mysql + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - script: name: test mysql5.7 code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test ./... + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - script: name: test mysql5.6 code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test ./... - - - script: - name: test mysql5.5 - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test ./... + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - script: name: test postgres code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - script: name: test postgres96 code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - script: name: test postgres95 code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - script: name: test postgres94 code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - script: name: test postgres93 code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - script: - name: test mssql + name: codecov code: | - GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test ./... + go test -race -coverprofile=coverage.txt -covermode=atomic ./... + bash <(curl -s https://codecov.io/bash)