Codebase list golang-github-jinzhu-gorm / d22a1d9
Import Upstream version 1.9.6 Shengjing Zhu 1 year, 3 months ago
46 changed file(s) with 2246 addition(s) and 510 deletion(s). Raw diff Collapse all Expand all
+0
-11
.codeclimate.yml less more
0 ---
1 engines:
2 gofmt:
3 enabled: true
4 govet:
5 enabled: true
6 golint:
7 enabled: true
8 ratings:
9 paths:
10 - "**.go"
00 documents
1 coverage.txt
12 _book
00 # GORM
11
2 The fantastic ORM library for Golang, aims to be developer friendly.
2 GORM V2 moved to https://github.com/go-gorm/gorm
33
4 [![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm)
5 [![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b)
6 [![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
7 [![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm)
8 [![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm)
9 [![MIT license](http://img.shields.io/badge/license-MIT-brightgreen.svg)](http://opensource.org/licenses/MIT)
10 [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm)
11
12 ## Overview
13
14 * Full-Featured ORM (almost)
15 * Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism)
16 * Hooks (Before/After Create/Save/Update/Delete/Find)
17 * Preloading (eager loading)
18 * Transactions
19 * Composite Primary Key
20 * SQL Builder
21 * Auto Migrations
22 * Logger
23 * Extendable, write Plugins based on GORM callbacks
24 * Every feature comes with tests
25 * Developer Friendly
26
27 ## Getting Started
28
29 * GORM Guides [http://gorm.io](http://gorm.io)
30
31 ## Contributing
32
33 [You can help to deliver a better GORM, check out things you can do](http://gorm.io/contribute.html)
34
35 ## License
36
37 © Jinzhu, 2013~time.Now
38
39 Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License)
4 GORM V1 Doc https://v1.gorm.io/
266266 query = scope.DB()
267267 )
268268
269 if relationship.Kind == "many_to_many" {
269 switch relationship.Kind {
270 case "many_to_many":
270271 query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
271 } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
272 case "has_many", "has_one":
272273 primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
273274 query = query.Where(
274275 fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
275276 toQueryValues(primaryKeys)...,
276277 )
277 } else if relationship.Kind == "belongs_to" {
278 case "belongs_to":
278279 primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
279280 query = query.Where(
280281 fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
366367 return association
367368 }
368369
370 // setErr set error when the error is not nil. And return Association.
369371 func (association *Association) setErr(err error) *Association {
370372 if err != nil {
371373 association.Error = err
00 package gorm
11
2 import "log"
2 import "fmt"
33
44 // DefaultCallback default callbacks defined by gorm
5 var DefaultCallback = &Callback{}
5 var DefaultCallback = &Callback{logger: nopLogger{}}
66
77 // Callback is a struct that contains all CRUD callbacks
88 // Field `creates` contains callbacks will be call when creating object
1212 // Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
1313 // Field `processors` contains all callback processors, will be used to generate above callbacks in order
1414 type Callback struct {
15 logger logger
1516 creates []*func(scope *Scope)
1617 updates []*func(scope *Scope)
1718 deletes []*func(scope *Scope)
2223
2324 // CallbackProcessor contains callback informations
2425 type CallbackProcessor struct {
26 logger logger
2527 name string // current callback's name
2628 before string // register current callback before a callback
2729 after string // register current callback after a callback
3234 parent *Callback
3335 }
3436
35 func (c *Callback) clone() *Callback {
37 func (c *Callback) clone(logger logger) *Callback {
3638 return &Callback{
39 logger: logger,
3740 creates: c.creates,
3841 updates: c.updates,
3942 deletes: c.deletes,
5255 // scope.Err(errors.New("error"))
5356 // })
5457 func (c *Callback) Create() *CallbackProcessor {
55 return &CallbackProcessor{kind: "create", parent: c}
58 return &CallbackProcessor{logger: c.logger, kind: "create", parent: c}
5659 }
5760
5861 // Update could be used to register callbacks for updating object, refer `Create` for usage
5962 func (c *Callback) Update() *CallbackProcessor {
60 return &CallbackProcessor{kind: "update", parent: c}
63 return &CallbackProcessor{logger: c.logger, kind: "update", parent: c}
6164 }
6265
6366 // Delete could be used to register callbacks for deleting object, refer `Create` for usage
6467 func (c *Callback) Delete() *CallbackProcessor {
65 return &CallbackProcessor{kind: "delete", parent: c}
68 return &CallbackProcessor{logger: c.logger, kind: "delete", parent: c}
6669 }
6770
6871 // Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
6972 // Refer `Create` for usage
7073 func (c *Callback) Query() *CallbackProcessor {
71 return &CallbackProcessor{kind: "query", parent: c}
74 return &CallbackProcessor{logger: c.logger, kind: "query", parent: c}
7275 }
7376
7477 // RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
7578 func (c *Callback) RowQuery() *CallbackProcessor {
76 return &CallbackProcessor{kind: "row_query", parent: c}
79 return &CallbackProcessor{logger: c.logger, kind: "row_query", parent: c}
7780 }
7881
7982 // After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
9295 func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
9396 if cp.kind == "row_query" {
9497 if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
95 log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
98 cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
9699 cp.before = "gorm:row_query"
97100 }
98101 }
99102
103 cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
100104 cp.name = callbackName
101105 cp.processor = &callback
102106 cp.parent.processors = append(cp.parent.processors, cp)
106110 // Remove a registered callback
107111 // db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
108112 func (cp *CallbackProcessor) Remove(callbackName string) {
109 log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
113 cp.logger.Print("info", fmt.Sprintf("[info] removing callback `%v` from %v", callbackName, fileWithLineNum()))
110114 cp.name = callbackName
111115 cp.remove = true
112116 cp.parent.processors = append(cp.parent.processors, cp)
115119
116120 // Replace a registered callback with new callback
117121 // db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
118 // scope.SetColumn("Created", now)
119 // scope.SetColumn("Updated", now)
122 // scope.SetColumn("CreatedAt", now)
123 // scope.SetColumn("UpdatedAt", now)
120124 // })
121125 func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
122 log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
126 cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
123127 cp.name = callbackName
124128 cp.processor = &callback
125129 cp.replace = true
131135 // db.Callback().Create().Get("gorm:create")
132136 func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
133137 for _, p := range cp.parent.processors {
134 if p.name == callbackName && p.kind == cp.kind && !cp.remove {
135 return *p.processor
136 }
137 }
138 return nil
138 if p.name == callbackName && p.kind == cp.kind {
139 if p.remove {
140 callback = nil
141 } else {
142 callback = *p.processor
143 }
144 }
145 }
146 return
139147 }
140148
141149 // getRIndex get right index from string slice
158166 for _, cp := range cps {
159167 // show warning message the callback name already exists
160168 if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
161 log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
169 cp.logger.Print("warning", fmt.Sprintf("[warning] duplicated callback `%v` from %v", cp.name, fileWithLineNum()))
162170 }
163171 allNames = append(allNames, cp.name)
164172 }
3030 // updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
3131 func updateTimeStampForCreateCallback(scope *Scope) {
3232 if !scope.HasError() {
33 now := NowFunc()
33 now := scope.db.nowFunc()
3434
3535 if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
3636 if createdAtField.IsBlank {
5858
5959 for _, field := range scope.Fields() {
6060 if scope.changeableField(field) {
61 if field.IsNormal {
61 if field.IsNormal && !field.IsIgnored {
6262 if field.IsBlank && field.HasDefaultValue {
6363 blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
6464 scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
8282 quotedTableName = scope.QuotedTableName()
8383 primaryField = scope.PrimaryField()
8484 extraOption string
85 insertModifier string
8586 )
8687
8788 if str, ok := scope.Get("gorm:insert_option"); ok {
8889 extraOption = fmt.Sprint(str)
90 }
91 if str, ok := scope.Get("gorm:insert_modifier"); ok {
92 insertModifier = strings.ToUpper(fmt.Sprint(str))
93 if insertModifier == "INTO" {
94 insertModifier = ""
95 }
8996 }
9097
9198 if primaryField != nil {
9299 returningColumn = scope.Quote(primaryField.DBName)
93100 }
94101
95 lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
102 lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns)
103 var lastInsertIDReturningSuffix string
104 if lastInsertIDOutputInterstitial == "" {
105 lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
106 }
96107
97108 if len(columns) == 0 {
98109 scope.Raw(fmt.Sprintf(
99 "INSERT INTO %v %v%v%v",
110 "INSERT%v INTO %v %v%v%v",
111 addExtraSpaceIfExist(insertModifier),
100112 quotedTableName,
101113 scope.Dialect().DefaultValueStr(),
102114 addExtraSpaceIfExist(extraOption),
104116 ))
105117 } else {
106118 scope.Raw(fmt.Sprintf(
107 "INSERT INTO %v (%v) VALUES (%v)%v%v",
119 "INSERT%v INTO %v (%v)%v VALUES (%v)%v%v",
120 addExtraSpaceIfExist(insertModifier),
108121 scope.QuotedTableName(),
109122 strings.Join(columns, ","),
123 addExtraSpaceIfExist(lastInsertIDOutputInterstitial),
110124 strings.Join(placeholders, ","),
111125 addExtraSpaceIfExist(extraOption),
112126 addExtraSpaceIfExist(lastInsertIDReturningSuffix),
113127 ))
114128 }
115129
116 // execute create sql
117 if lastInsertIDReturningSuffix == "" || primaryField == nil {
130 // execute create sql: no primaryField
131 if primaryField == nil {
118132 if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
119133 // set rows affected count
120134 scope.db.RowsAffected, _ = result.RowsAffected()
126140 }
127141 }
128142 }
143 return
144 }
145
146 // execute create sql: lastInsertID implemention for majority of dialects
147 if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
148 if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
149 // set rows affected count
150 scope.db.RowsAffected, _ = result.RowsAffected()
151
152 // set primary value to primary field
153 if primaryField != nil && primaryField.IsBlank {
154 if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
155 scope.Err(primaryField.Set(primaryValue))
156 }
157 }
158 }
159 return
160 }
161
162 // execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
163 if primaryField.Field.CanAddr() {
164 if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
165 primaryField.IsBlank = false
166 scope.db.RowsAffected = 1
167 }
129168 } else {
130 if primaryField.Field.CanAddr() {
131 if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
132 primaryField.IsBlank = false
133 scope.db.RowsAffected = 1
134 }
135 } else {
136 scope.Err(ErrUnaddressable)
137 }
169 scope.Err(ErrUnaddressable)
138170 }
171 return
139172 }
140173 }
141174
1616 // beforeDeleteCallback will invoke `BeforeDelete` method before deleting
1717 func beforeDeleteCallback(scope *Scope) {
1818 if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
19 scope.Err(errors.New("Missing WHERE clause while deleting"))
19 scope.Err(errors.New("missing WHERE clause while deleting"))
2020 return
2121 }
2222 if !scope.HasError() {
3939 "UPDATE %v SET %v=%v%v%v",
4040 scope.QuotedTableName(),
4141 scope.Quote(deletedAtField.DBName),
42 scope.AddToVars(NowFunc()),
42 scope.AddToVars(scope.db.nowFunc()),
4343 addExtraSpaceIfExist(scope.CombinedConditionSql()),
4444 addExtraSpaceIfExist(extraOption),
4545 )).Exec()
1515 // queryCallback used to query data from database
1616 func queryCallback(scope *Scope) {
1717 if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
18 return
19 }
20
21 //we are only preloading relations, dont touch base model
22 if _, skip := scope.InstanceGet("gorm:only_preload"); skip {
1823 return
1924 }
2025
5459
5560 if !scope.HasError() {
5661 scope.db.RowsAffected = 0
62
63 if str, ok := scope.Get("gorm:query_hint"); ok {
64 scope.SQL = fmt.Sprint(str) + scope.SQL
65 }
66
5767 if str, ok := scope.Get("gorm:query_option"); ok {
5868 scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
5969 }
1313 return
1414 }
1515
16 if _, ok := scope.Get("gorm:auto_preload"); ok {
17 autoPreload(scope)
16 if ap, ok := scope.Get("gorm:auto_preload"); ok {
17 // If gorm:auto_preload IS NOT a bool then auto preload.
18 // Else if it IS a bool, use the value
19 if apb, ok := ap.(bool); !ok {
20 autoPreload(scope)
21 } else if apb {
22 autoPreload(scope)
23 }
1824 }
1925
2026 if scope.Search.preload == nil || scope.HasError() {
9399 continue
94100 }
95101
96 if val, ok := field.TagSettings["PRELOAD"]; ok {
102 if val, ok := field.TagSettingsGet("PRELOAD"); ok {
97103 if preload, err := strconv.ParseBool(val); err != nil {
98104 scope.Err(errors.New("invalid preload option"))
99105 return
154160 )
155161
156162 if indirectScopeValue.Kind() == reflect.Slice {
163 foreignValuesToResults := make(map[string]reflect.Value)
164 for i := 0; i < resultsValue.Len(); i++ {
165 result := resultsValue.Index(i)
166 foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames))
167 foreignValuesToResults[foreignValues] = result
168 }
157169 for j := 0; j < indirectScopeValue.Len(); j++ {
158 for i := 0; i < resultsValue.Len(); i++ {
159 result := resultsValue.Index(i)
160 foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
161 if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
162 indirectValue.FieldByName(field.Name).Set(result)
163 break
164 }
170 indirectValue := indirect(indirectScopeValue.Index(j))
171 valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames))
172 if result, found := foreignValuesToResults[valueString]; found {
173 indirectValue.FieldByName(field.Name).Set(result)
165174 }
166175 }
167176 } else {
248257 indirectScopeValue = scope.IndirectValue()
249258 )
250259
260 foreignFieldToObjects := make(map[string][]*reflect.Value)
261 if indirectScopeValue.Kind() == reflect.Slice {
262 for j := 0; j < indirectScopeValue.Len(); j++ {
263 object := indirect(indirectScopeValue.Index(j))
264 valueString := toString(getValueFromFields(object, relation.ForeignFieldNames))
265 foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object)
266 }
267 }
268
251269 for i := 0; i < resultsValue.Len(); i++ {
252270 result := resultsValue.Index(i)
253271 if indirectScopeValue.Kind() == reflect.Slice {
254 value := getValueFromFields(result, relation.AssociationForeignFieldNames)
255 for j := 0; j < indirectScopeValue.Len(); j++ {
256 object := indirect(indirectScopeValue.Index(j))
257 if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
272 valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames))
273 if objects, found := foreignFieldToObjects[valueString]; found {
274 for _, object := range objects {
258275 object.FieldByName(field.Name).Set(result)
259276 }
260277 }
373390 key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
374391 fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
375392 }
376 for source, link := range linkHash {
377 for i, field := range fieldsSourceMap[source] {
393
394 for source, fields := range fieldsSourceMap {
395 for _, f := range fields {
378396 //If not 0 this means Value is a pointer and we already added preloaded models to it
379 if fieldsSourceMap[source][i].Len() != 0 {
397 if f.Len() != 0 {
380398 continue
381399 }
382 field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
383 }
384
385 }
386 }
400
401 v := reflect.MakeSlice(f.Type(), 0, 0)
402 if len(linkHash[source]) > 0 {
403 v = reflect.Append(f, linkHash[source]...)
404 }
405
406 f.Set(v)
407 }
408 }
409 }
00 package gorm
11
2 import "database/sql"
2 import (
3 "database/sql"
4 "fmt"
5 )
36
47 // Define callbacks for row query
58 func init() {
2023 if result, ok := scope.InstanceGet("row_query_result"); ok {
2124 scope.prepareQuerySQL()
2225
26 if str, ok := scope.Get("gorm:query_hint"); ok {
27 scope.SQL = fmt.Sprint(str) + scope.SQL
28 }
29
30 if str, ok := scope.Get("gorm:query_option"); ok {
31 scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
32 }
33
2334 if rowResult, ok := result.(*RowQueryResult); ok {
2435 rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
2536 } else if rowsResult, ok := result.(*RowsQueryResult); ok {
2020
2121 if v, ok := value.(string); ok {
2222 v = strings.ToLower(v)
23 if v == "false" || v != "skip" {
24 return false
25 }
23 return v == "true"
2624 }
2725
2826 return true
3533 if value, ok := scope.Get("gorm:save_associations"); ok {
3634 autoUpdate = checkTruth(value)
3735 autoCreate = autoUpdate
38 } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok {
36 saveReference = autoUpdate
37 } else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok {
3938 autoUpdate = checkTruth(value)
4039 autoCreate = autoUpdate
40 saveReference = autoUpdate
4141 }
4242
4343 if value, ok := scope.Get("gorm:association_autoupdate"); ok {
4444 autoUpdate = checkTruth(value)
45 } else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok {
45 } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok {
4646 autoUpdate = checkTruth(value)
4747 }
4848
4949 if value, ok := scope.Get("gorm:association_autocreate"); ok {
5050 autoCreate = checkTruth(value)
51 } else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok {
51 } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok {
5252 autoCreate = checkTruth(value)
5353 }
5454
5555 if value, ok := scope.Get("gorm:association_save_reference"); ok {
5656 saveReference = checkTruth(value)
57 } else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok {
57 } else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok {
5858 saveReference = checkTruth(value)
5959 }
6060 }
2222 func afterCreate2(s *Scope) {}
2323
2424 func TestRegisterCallback(t *testing.T) {
25 var callback = &Callback{}
25 var callback = &Callback{logger: defaultLogger}
2626
2727 callback.Create().Register("before_create1", beforeCreate1)
2828 callback.Create().Register("before_create2", beforeCreate2)
3636 }
3737
3838 func TestRegisterCallbackWithOrder(t *testing.T) {
39 var callback1 = &Callback{}
39 var callback1 = &Callback{logger: defaultLogger}
4040 callback1.Create().Register("before_create1", beforeCreate1)
4141 callback1.Create().Register("create", create)
4242 callback1.Create().Register("after_create1", afterCreate1)
4545 t.Errorf("register callback with order")
4646 }
4747
48 var callback2 = &Callback{}
48 var callback2 = &Callback{logger: defaultLogger}
4949
5050 callback2.Update().Register("create", create)
5151 callback2.Update().Before("create").Register("before_create1", beforeCreate1)
5959 }
6060
6161 func TestRegisterCallbackWithComplexOrder(t *testing.T) {
62 var callback1 = &Callback{}
62 var callback1 = &Callback{logger: defaultLogger}
6363
6464 callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
6565 callback1.Query().Register("before_create1", beforeCreate1)
6969 t.Errorf("register callback with order")
7070 }
7171
72 var callback2 = &Callback{}
72 var callback2 = &Callback{logger: defaultLogger}
7373
7474 callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
7575 callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
8585 func replaceCreate(s *Scope) {}
8686
8787 func TestReplaceCallback(t *testing.T) {
88 var callback = &Callback{}
88 var callback = &Callback{logger: defaultLogger}
8989
9090 callback.Create().Before("after_create1").After("before_create1").Register("create", create)
9191 callback.Create().Register("before_create1", beforeCreate1)
9898 }
9999
100100 func TestRemoveCallback(t *testing.T) {
101 var callback = &Callback{}
101 var callback = &Callback{logger: defaultLogger}
102102
103103 callback.Create().Before("after_create1").After("before_create1").Register("create", create)
104104 callback.Create().Register("before_create1", beforeCreate1)
3333 // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
3434 func beforeUpdateCallback(scope *Scope) {
3535 if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
36 scope.Err(errors.New("Missing WHERE clause while updating"))
36 scope.Err(errors.New("missing WHERE clause while updating"))
3737 return
3838 }
3939 if _, ok := scope.Get("gorm:update_column"); !ok {
4949 // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
5050 func updateTimeStampForUpdateCallback(scope *Scope) {
5151 if _, ok := scope.Get("gorm:update_column"); !ok {
52 scope.SetColumn("UpdatedAt", NowFunc())
52 scope.SetColumn("UpdatedAt", scope.db.nowFunc())
5353 }
5454 }
5555
7474 } else {
7575 for _, field := range scope.Fields() {
7676 if scope.changeableField(field) {
77 if !field.IsPrimaryKey && field.IsNormal {
78 sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
77 if !field.IsPrimaryKey && field.IsNormal && (field.Name != "CreatedAt" || !field.IsBlank) {
78 if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue {
79 sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
80 }
7981 } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
8082 for _, foreignKey := range relationship.ForeignDBNames {
8183 if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
11
22 import (
33 "errors"
4
5 "github.com/jinzhu/gorm"
6
74 "reflect"
85 "testing"
6
7 "github.com/jinzhu/gorm"
98 )
109
1110 func (s *Product) BeforeCreate() (err error) {
174173 t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
175174 }
176175 }
176
177 func TestGetCallback(t *testing.T) {
178 scope := DB.NewScope(nil)
179
180 if DB.Callback().Create().Get("gorm:test_callback") != nil {
181 t.Errorf("`gorm:test_callback` should be nil")
182 }
183
184 DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
185 callback := DB.Callback().Create().Get("gorm:test_callback")
186 if callback == nil {
187 t.Errorf("`gorm:test_callback` should be non-nil")
188 }
189 callback(scope)
190 if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
191 t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
192 }
193
194 DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
195 callback = DB.Callback().Create().Get("gorm:test_callback")
196 if callback == nil {
197 t.Errorf("`gorm:test_callback` should be non-nil")
198 }
199 callback(scope)
200 if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
201 t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
202 }
203
204 DB.Callback().Create().Remove("gorm:test_callback")
205 if DB.Callback().Create().Get("gorm:test_callback") != nil {
206 t.Errorf("`gorm:test_callback` should be nil")
207 }
208
209 DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
210 callback = DB.Callback().Create().Get("gorm:test_callback")
211 if callback == nil {
212 t.Errorf("`gorm:test_callback` should be non-nil")
213 }
214 callback(scope)
215 if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
216 t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
217 }
218 }
219
220 func TestUseDefaultCallback(t *testing.T) {
221 createCallbackName := "gorm:test_use_default_callback_for_create"
222 gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) {
223 // nop
224 })
225 if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
226 t.Errorf("`%s` expected non-nil, but got nil", createCallbackName)
227 }
228 gorm.DefaultCallback.Create().Remove(createCallbackName)
229 if gorm.DefaultCallback.Create().Get(createCallbackName) != nil {
230 t.Errorf("`%s` expected nil, but got non-nil", createCallbackName)
231 }
232
233 updateCallbackName := "gorm:test_use_default_callback_for_update"
234 scopeValueName := "gorm:test_use_default_callback_for_update_value"
235 gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) {
236 scope.Set(scopeValueName, 1)
237 })
238 gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) {
239 scope.Set(scopeValueName, 2)
240 })
241
242 scope := DB.NewScope(nil)
243 callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
244 callback(scope)
245 if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
246 t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
247 }
248 }
100100 }
101101 }
102102
103 func TestCreateWithNowFuncOverride(t *testing.T) {
104 user1 := User{Name: "CreateUserTimestampOverride"}
105
106 timeA := now.MustParse("2016-01-01")
107
108 // do DB.New() because we don't want this test to affect other tests
109 db1 := DB.New()
110 // set the override to use static timeA
111 db1.SetNowFuncOverride(func() time.Time {
112 return timeA
113 })
114 // call .New again to check the override is carried over as well during clone
115 db1 = db1.New()
116
117 db1.Save(&user1)
118
119 if user1.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
120 t.Errorf("CreatedAt be using the nowFuncOverride")
121 }
122 if user1.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
123 t.Errorf("UpdatedAt be using the nowFuncOverride")
124 }
125
126 // now create another user with a fresh DB.Now() that doesn't have the nowFuncOverride set
127 // to make sure that setting it only affected the above instance
128
129 user2 := User{Name: "CreateUserTimestampOverrideNoMore"}
130
131 db2 := DB.New()
132
133 db2.Save(&user2)
134
135 if user2.CreatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) {
136 t.Errorf("CreatedAt no longer be using the nowFuncOverride")
137 }
138 if user2.UpdatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) {
139 t.Errorf("UpdatedAt no longer be using the nowFuncOverride")
140 }
141 }
142
103143 type AutoIncrementUser struct {
104144 User
105145 Sequence uint `gorm:"AUTO_INCREMENT"`
228268 t.Errorf("Should not create omitted relationships")
229269 }
230270 }
271
272 func TestCreateIgnore(t *testing.T) {
273 float := 35.03554004971999
274 now := time.Now()
275 user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float}
276
277 if !DB.NewRecord(user) || !DB.NewRecord(&user) {
278 t.Error("User should be new record before create")
279 }
280
281 if count := DB.Create(&user).RowsAffected; count != 1 {
282 t.Error("There should be one record be affected when create record")
283 }
284 if DB.Dialect().GetName() == "mysql" && DB.Set("gorm:insert_modifier", "IGNORE").Create(&user).Error != nil {
285 t.Error("Should ignore duplicate user insert by insert modifier:IGNORE ")
286 }
287 }
288288 func TestSelfReferencingMany2ManyColumn(t *testing.T) {
289289 DB.DropTable(&SelfReferencingUser{}, "UserFriends")
290290 DB.AutoMigrate(&SelfReferencingUser{})
291 if !DB.HasTable("UserFriends") {
292 t.Errorf("auto migrate error, table UserFriends should be created")
293 }
291294
292295 friend1 := SelfReferencingUser{Name: "friend1_m2m"}
293296 if err := DB.Create(&friend1).Error; err != nil {
310313
311314 if DB.Model(&user).Association("Friends").Count() != 2 {
312315 t.Errorf("Should find created friends correctly")
316 }
317
318 var count int
319 if err := DB.Table("UserFriends").Count(&count).Error; err != nil {
320 t.Errorf("no error should happen, but got %v", err)
321 }
322 if count == 0 {
323 t.Errorf("table UserFriends should have records")
313324 }
314325
315326 var newUser = SelfReferencingUser{}
3636 ModifyColumn(tableName string, columnName string, typ string) error
3737
3838 // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
39 LimitAndOffsetSQL(limit, offset interface{}) string
39 LimitAndOffsetSQL(limit, offset interface{}) (string, error)
4040 // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
4141 SelectFromDummyTable() string
42 // LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
43 LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
4244 // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
4345 LastInsertIDReturningSuffix(tableName, columnName string) string
4446 // DefaultValueStr
4648
4749 // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
4850 BuildKeyName(kind, tableName string, fields ...string) string
51
52 // NormalizeIndexAndColumn returns valid index name and column name depending on each dialect
53 NormalizeIndexAndColumn(indexName, columnName string) (string, string)
4954
5055 // CurrentDatabase return current database name
5156 CurrentDatabase() string
7176 dialectsMap[name] = dialect
7277 }
7378
79 // GetDialect gets the dialect for the specified dialect name
80 func GetDialect(name string) (dialect Dialect, ok bool) {
81 dialect, ok = dialectsMap[name]
82 return
83 }
84
7485 // ParseFieldStructForDialect get field's sql data type
7586 var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
7687 // Get redirected field type
7788 var (
7889 reflectType = field.Struct.Type
79 dataType = field.TagSettings["TYPE"]
90 dataType, _ = field.TagSettingsGet("TYPE")
8091 )
8192
8293 for reflectType.Kind() == reflect.Ptr {
105116 }
106117
107118 // Default Size
108 if num, ok := field.TagSettings["SIZE"]; ok {
119 if num, ok := field.TagSettingsGet("SIZE"); ok {
109120 size, _ = strconv.Atoi(num)
110121 } else {
111122 size = 255
112123 }
113124
114125 // Default type from tag setting
115 additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
116 if value, ok := field.TagSettings["DEFAULT"]; ok {
126 notNull, _ := field.TagSettingsGet("NOT NULL")
127 unique, _ := field.TagSettingsGet("UNIQUE")
128 additionalType = notNull + " " + unique
129 if value, ok := field.TagSettingsGet("DEFAULT"); ok {
117130 additionalType = additionalType + " DEFAULT " + value
131 }
132
133 if value, ok := field.TagSettingsGet("COMMENT"); ok {
134 additionalType = additionalType + " COMMENT " + value
118135 }
119136
120137 return fieldValue, dataType, size, strings.TrimSpace(additionalType)
77 "strings"
88 "time"
99 )
10
11 var keyNameRegex = regexp.MustCompile("[^a-zA-Z0-9]+")
1012
1113 // DefaultForeignKeyNamer contains the default foreign key name generator method
1214 type DefaultForeignKeyNamer struct {
3840 }
3941
4042 func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool {
41 if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
43 if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
4244 return strings.ToLower(value) != "false"
4345 }
4446 return field.IsPrimaryKey
136138 return
137139 }
138140
139 func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
141 // LimitAndOffsetSQL return generated SQL with Limit and Offset
142 func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
140143 if limit != nil {
141 if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
144 if parsedLimit, err := s.parseInt(limit); err != nil {
145 return "", err
146 } else if parsedLimit >= 0 {
142147 sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
143148 }
144149 }
145150 if offset != nil {
146 if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
151 if parsedOffset, err := s.parseInt(offset); err != nil {
152 return "", err
153 } else if parsedOffset >= 0 {
147154 sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
148155 }
149156 }
151158 }
152159
153160 func (commonDialect) SelectFromDummyTable() string {
161 return ""
162 }
163
164 func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
154165 return ""
155166 }
156167
165176 // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
166177 func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
167178 keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
168 keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_")
179 keyName = keyNameRegex.ReplaceAllString(keyName, "_")
169180 return keyName
181 }
182
183 // NormalizeIndexAndColumn returns argument's index name and column name without doing anything
184 func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
185 return indexName, columnName
186 }
187
188 func (commonDialect) parseInt(value interface{}) (int64, error) {
189 return strconv.ParseInt(fmt.Sprint(value), 0, 0)
170190 }
171191
172192 // IsByteArrayOrSlice returns true of the reflected value is an array or slice
11
22 import (
33 "crypto/sha1"
4 "database/sql"
45 "fmt"
56 "reflect"
67 "regexp"
7 "strconv"
88 "strings"
99 "time"
1010 "unicode/utf8"
1111 )
1212
13 var mysqlIndexRegex = regexp.MustCompile(`^(.+)\((\d+)\)$`)
14
1315 type mysql struct {
1416 commonDialect
1517 }
3234
3335 // MySQL allows only one auto increment column per table, and it must
3436 // be a KEY column.
35 if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
36 if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey {
37 delete(field.TagSettings, "AUTO_INCREMENT")
37 if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
38 if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey {
39 field.TagSettingsDelete("AUTO_INCREMENT")
3840 }
3941 }
4042
4446 sqlType = "boolean"
4547 case reflect.Int8:
4648 if s.fieldCanAutoIncrement(field) {
47 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
49 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
4850 sqlType = "tinyint AUTO_INCREMENT"
4951 } else {
5052 sqlType = "tinyint"
5153 }
5254 case reflect.Int, reflect.Int16, reflect.Int32:
5355 if s.fieldCanAutoIncrement(field) {
54 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
56 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
5557 sqlType = "int AUTO_INCREMENT"
5658 } else {
5759 sqlType = "int"
5860 }
5961 case reflect.Uint8:
6062 if s.fieldCanAutoIncrement(field) {
61 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
63 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
6264 sqlType = "tinyint unsigned AUTO_INCREMENT"
6365 } else {
6466 sqlType = "tinyint unsigned"
6567 }
6668 case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
6769 if s.fieldCanAutoIncrement(field) {
68 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
70 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
6971 sqlType = "int unsigned AUTO_INCREMENT"
7072 } else {
7173 sqlType = "int unsigned"
7274 }
7375 case reflect.Int64:
7476 if s.fieldCanAutoIncrement(field) {
75 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
77 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
7678 sqlType = "bigint AUTO_INCREMENT"
7779 } else {
7880 sqlType = "bigint"
7981 }
8082 case reflect.Uint64:
8183 if s.fieldCanAutoIncrement(field) {
82 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
84 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
8385 sqlType = "bigint unsigned AUTO_INCREMENT"
8486 } else {
8587 sqlType = "bigint unsigned"
9597 case reflect.Struct:
9698 if _, ok := dataValue.Interface().(time.Time); ok {
9799 precision := ""
98 if p, ok := field.TagSettings["PRECISION"]; ok {
100 if p, ok := field.TagSettingsGet("PRECISION"); ok {
99101 precision = fmt.Sprintf("(%s)", p)
100102 }
101103
102 if _, ok := field.TagSettings["NOT NULL"]; ok {
103 sqlType = fmt.Sprintf("timestamp%v", precision)
104 if _, ok := field.TagSettings["NOT NULL"]; ok || field.IsPrimaryKey {
105 sqlType = fmt.Sprintf("DATETIME%v", precision)
104106 } else {
105 sqlType = fmt.Sprintf("timestamp%v NULL", precision)
107 sqlType = fmt.Sprintf("DATETIME%v NULL", precision)
106108 }
107109 }
108110 default:
117119 }
118120
119121 if sqlType == "" {
120 panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
122 panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name))
121123 }
122124
123125 if strings.TrimSpace(additionalType) == "" {
136138 return err
137139 }
138140
139 func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
141 func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
140142 if limit != nil {
141 if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
143 parsedLimit, err := s.parseInt(limit)
144 if err != nil {
145 return "", err
146 }
147 if parsedLimit >= 0 {
142148 sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
143149
144150 if offset != nil {
145 if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
151 parsedOffset, err := s.parseInt(offset)
152 if err != nil {
153 return "", err
154 }
155 if parsedOffset >= 0 {
146156 sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
147157 }
148158 }
156166 currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
157167 s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
158168 return count > 0
169 }
170
171 func (s mysql) HasTable(tableName string) bool {
172 currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
173 var name string
174 // allow mysql database name with '-' character
175 if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil {
176 if err == sql.ErrNoRows {
177 return false
178 }
179 panic(err)
180 } else {
181 return true
182 }
183 }
184
185 func (s mysql) HasIndex(tableName string, indexName string) bool {
186 currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
187 if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil {
188 panic(err)
189 } else {
190 defer rows.Close()
191 return rows.Next()
192 }
193 }
194
195 func (s mysql) HasColumn(tableName string, columnName string) bool {
196 currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
197 if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil {
198 panic(err)
199 } else {
200 defer rows.Close()
201 return rows.Next()
202 }
159203 }
160204
161205 func (s mysql) CurrentDatabase() (name string) {
177221 bs := h.Sum(nil)
178222
179223 // sha1 is 40 characters, keep first 24 characters of destination
180 destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_"))
224 destRunes := []rune(keyNameRegex.ReplaceAllString(fields[0], "_"))
181225 if len(destRunes) > 24 {
182226 destRunes = destRunes[:24]
183227 }
185229 return fmt.Sprintf("%s%x", string(destRunes), bs)
186230 }
187231
232 // NormalizeIndexAndColumn returns index name and column name for specify an index prefix length if needed
233 func (mysql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
234 submatch := mysqlIndexRegex.FindStringSubmatch(indexName)
235 if len(submatch) != 3 {
236 return indexName, columnName
237 }
238 indexName = submatch[1]
239 columnName = fmt.Sprintf("%s(%s)", columnName, submatch[2])
240 return indexName, columnName
241 }
242
188243 func (mysql) DefaultValueStr() string {
189244 return "VALUES()"
190245 }
3333 sqlType = "boolean"
3434 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
3535 if s.fieldCanAutoIncrement(field) {
36 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
36 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
3737 sqlType = "serial"
3838 } else {
3939 sqlType = "integer"
4040 }
4141 case reflect.Int64, reflect.Uint32, reflect.Uint64:
4242 if s.fieldCanAutoIncrement(field) {
43 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
43 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
4444 sqlType = "bigserial"
4545 } else {
4646 sqlType = "bigint"
4848 case reflect.Float32, reflect.Float64:
4949 sqlType = "numeric"
5050 case reflect.String:
51 if _, ok := field.TagSettings["SIZE"]; !ok {
51 if _, ok := field.TagSettingsGet("SIZE"); !ok {
5252 size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
5353 }
5454
119119 return
120120 }
121121
122 func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string {
123 return ""
124 }
125
122126 func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
123127 return fmt.Sprintf("RETURNING %v.%v", tableName, key)
124128 }
2828 sqlType = "bool"
2929 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
3030 if s.fieldCanAutoIncrement(field) {
31 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
31 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
3232 sqlType = "integer primary key autoincrement"
3333 } else {
3434 sqlType = "integer"
3535 }
3636 case reflect.Int64, reflect.Uint64:
3737 if s.fieldCanAutoIncrement(field) {
38 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
38 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
3939 sqlType = "integer primary key autoincrement"
4040 } else {
4141 sqlType = "bigint"
00 package mssql
11
22 import (
3 "database/sql/driver"
4 "encoding/json"
5 "errors"
36 "fmt"
47 "reflect"
58 "strconv"
69 "strings"
710 "time"
811
12 // Importing mssql driver package only in dialect file, otherwide not needed
913 _ "github.com/denisenkom/go-mssqldb"
1014 "github.com/jinzhu/gorm"
1115 )
1317 func setIdentityInsert(scope *gorm.Scope) {
1418 if scope.Dialect().GetName() == "mssql" {
1519 for _, field := range scope.PrimaryFields() {
16 if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank {
20 if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank {
1721 scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
1822 scope.InstanceSet("mssql:identity_insert_on", true)
1923 }
6569 sqlType = "bit"
6670 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
6771 if s.fieldCanAutoIncrement(field) {
68 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
72 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
6973 sqlType = "int IDENTITY(1,1)"
7074 } else {
7175 sqlType = "int"
7276 }
7377 case reflect.Int64, reflect.Uint64:
7478 if s.fieldCanAutoIncrement(field) {
75 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
79 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
7680 sqlType = "bigint IDENTITY(1,1)"
7781 } else {
7882 sqlType = "bigint"
111115 }
112116
113117 func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
114 if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
118 if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
115119 return value != "FALSE"
116120 }
117121 return field.IsPrimaryKey
129133 }
130134
131135 func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
132 return false
136 var count int
137 currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
138 s.db.QueryRow(`SELECT count(*)
139 FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id
140 inner join information_schema.tables as I on I.TABLE_NAME = T.name
141 WHERE F.name = ?
142 AND T.Name = ? AND I.TABLE_CATALOG = ?;`, foreignKeyName, tableName, currentDatabase).Scan(&count)
143 return count > 0
133144 }
134145
135146 func (s mssql) HasTable(tableName string) bool {
156167 return
157168 }
158169
159 func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
170 func parseInt(value interface{}) (int64, error) {
171 return strconv.ParseInt(fmt.Sprint(value), 0, 0)
172 }
173
174 func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
160175 if offset != nil {
161 if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
176 if parsedOffset, err := parseInt(offset); err != nil {
177 return "", err
178 } else if parsedOffset >= 0 {
162179 sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset)
163180 }
164181 }
165182 if limit != nil {
166 if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
183 if parsedLimit, err := parseInt(limit); err != nil {
184 return "", err
185 } else if parsedLimit >= 0 {
167186 if sql == "" {
168187 // add default zero offset
169188 sql += " OFFSET 0 ROWS"
178197 return ""
179198 }
180199
200 func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
201 if len(columns) == 0 {
202 // No OUTPUT to query
203 return ""
204 }
205 return fmt.Sprintf("OUTPUT Inserted.%v", columnName)
206 }
207
181208 func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
182 return ""
209 // https://stackoverflow.com/questions/5228780/how-to-get-last-inserted-id
210 return "; SELECT SCOPE_IDENTITY()"
183211 }
184212
185213 func (mssql) DefaultValueStr() string {
186214 return "DEFAULT VALUES"
215 }
216
217 // NormalizeIndexAndColumn returns argument's index name and column name without doing anything
218 func (mssql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
219 return indexName, columnName
187220 }
188221
189222 func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
193226 }
194227 return dialect.CurrentDatabase(), tableName
195228 }
229
230 // JSON type to support easy handling of JSON data in character table fields
231 // using golang json.RawMessage for deferred decoding/encoding
232 type JSON struct {
233 json.RawMessage
234 }
235
236 // Value get value of JSON
237 func (j JSON) Value() (driver.Value, error) {
238 if len(j.RawMessage) == 0 {
239 return nil, nil
240 }
241 return j.MarshalJSON()
242 }
243
244 // Scan scan value into JSON
245 func (j *JSON) Scan(value interface{}) error {
246 str, ok := value.(string)
247 if !ok {
248 return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value))
249 }
250 bytes := []byte(str)
251 return json.Unmarshal(bytes, j)
252 }
33 "database/sql"
44 "database/sql/driver"
55
6 _ "github.com/lib/pq"
7 "github.com/lib/pq/hstore"
86 "encoding/json"
97 "errors"
108 "fmt"
9
10 _ "github.com/lib/pq"
11 "github.com/lib/pq/hstore"
1112 )
1213
1314 type Hstore map[string]*string
55 )
66
77 var (
8 // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct
8 // ErrRecordNotFound returns a "record not found error". Occurs only when attempting to query the database with a struct; querying with a slice won't return this error
99 ErrRecordNotFound = errors.New("record not found")
10 // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
10 // ErrInvalidSQL occurs when you attempt a query with invalid SQL
1111 ErrInvalidSQL = errors.New("invalid SQL")
12 // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
12 // ErrInvalidTransaction occurs when you are trying to `Commit` or `Rollback`
1313 ErrInvalidTransaction = errors.New("no valid transaction")
1414 // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
1515 ErrCantStartTransaction = errors.New("can't start transaction")
2020 // Errors contains all happened errors
2121 type Errors []error
2222
23 // IsRecordNotFoundError returns current error has record not found error or not
23 // IsRecordNotFoundError returns true if error contains a RecordNotFound error
2424 func IsRecordNotFoundError(err error) bool {
2525 if errs, ok := err.(Errors); ok {
2626 for _, err := range errs {
3232 return err == ErrRecordNotFound
3333 }
3434
35 // GetErrors gets all happened errors
35 // GetErrors gets all errors that have occurred and returns a slice of errors (Error type)
3636 func (errs Errors) GetErrors() []error {
3737 return errs
3838 }
3939
40 // Add adds an error
40 // Add adds an error to a given slice of errors
4141 func (errs Errors) Add(newErrors ...error) Errors {
4242 for _, err := range newErrors {
4343 if err == nil {
6161 return errs
6262 }
6363
64 // Error format happened errors
64 // Error takes a slice of all errors that have occurred and returns it as a formatted string
6565 func (errs Errors) Error() string {
6666 var errors = []string{}
6767 for _, e := range errs {
11
22 import (
33 "database/sql"
4 "database/sql/driver"
45 "errors"
56 "fmt"
67 "reflect"
4344 if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
4445 fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
4546 } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
46 err = scanner.Scan(reflectValue.Interface())
47 v := reflectValue.Interface()
48 if valuer, ok := v.(driver.Valuer); ok {
49 if v, err = valuer.Value(); err == nil {
50 err = scanner.Scan(v)
51 }
52 } else {
53 err = scanner.Scan(v)
54 }
4755 } else {
4856 err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
4957 }
00 package gorm_test
11
22 import (
3 "database/sql/driver"
4 "encoding/hex"
5 "fmt"
36 "testing"
47
58 "github.com/jinzhu/gorm"
4245
4346 if field, ok := scope.FieldByName("embedded_name"); !ok {
4447 t.Errorf("should find embedded field")
45 } else if _, ok := field.TagSettings["NOT NULL"]; !ok {
48 } else if _, ok := field.TagSettingsGet("NOT NULL"); !ok {
4649 t.Errorf("should find embedded field's tag settings")
4750 }
4851 }
52
53 type UUID [16]byte
54
55 type NullUUID struct {
56 UUID
57 Valid bool
58 }
59
60 func FromString(input string) (u UUID) {
61 src := []byte(input)
62 return FromBytes(src)
63 }
64
65 func FromBytes(src []byte) (u UUID) {
66 dst := u[:]
67 hex.Decode(dst[0:4], src[0:8])
68 hex.Decode(dst[4:6], src[9:13])
69 hex.Decode(dst[6:8], src[14:18])
70 hex.Decode(dst[8:10], src[19:23])
71 hex.Decode(dst[10:], src[24:])
72 return
73 }
74
75 func (u UUID) String() string {
76 buf := make([]byte, 36)
77 src := u[:]
78 hex.Encode(buf[0:8], src[0:4])
79 buf[8] = '-'
80 hex.Encode(buf[9:13], src[4:6])
81 buf[13] = '-'
82 hex.Encode(buf[14:18], src[6:8])
83 buf[18] = '-'
84 hex.Encode(buf[19:23], src[8:10])
85 buf[23] = '-'
86 hex.Encode(buf[24:], src[10:])
87 return string(buf)
88 }
89
90 func (u UUID) Value() (driver.Value, error) {
91 return u.String(), nil
92 }
93
94 func (u *UUID) Scan(src interface{}) error {
95 switch src := src.(type) {
96 case UUID: // support gorm convert from UUID to NullUUID
97 *u = src
98 return nil
99 case []byte:
100 *u = FromBytes(src)
101 return nil
102 case string:
103 *u = FromString(src)
104 return nil
105 }
106 return fmt.Errorf("uuid: cannot convert %T to UUID", src)
107 }
108
109 func (u *NullUUID) Scan(src interface{}) error {
110 u.Valid = true
111 return u.UUID.Scan(src)
112 }
113
114 func TestFieldSet(t *testing.T) {
115 type TestFieldSetNullUUID struct {
116 NullUUID NullUUID
117 }
118 scope := DB.NewScope(&TestFieldSetNullUUID{})
119 field := scope.Fields()[0]
120 err := field.Set(FromString("3034d44a-da03-11e8-b366-4a00070b9f00"))
121 if err != nil {
122 t.Fatal(err)
123 }
124 if id, ok := field.Field.Addr().Interface().(*NullUUID); !ok {
125 t.Fatal()
126 } else if !id.Valid || id.UUID.String() != "3034d44a-da03-11e8-b366-4a00070b9f00" {
127 t.Fatal(id)
128 }
129 }
0 module github.com/jinzhu/gorm
1
2 go 1.12
3
4 require (
5 github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd
6 github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5
7 github.com/go-sql-driver/mysql v1.5.0
8 github.com/jinzhu/inflection v1.0.0
9 github.com/jinzhu/now v1.0.1
10 github.com/lib/pq v1.1.1
11 github.com/mattn/go-sqlite3 v1.14.0
12 golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd // indirect
13 )
0 github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc=
1 github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y=
2 github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd h1:83Wprp6ROGeiHFAP8WJdI2RoxALQYgdllERc3N5N2DM=
3 github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
4 github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y=
5 github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0=
6 github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
7 github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
8 github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY=
9 github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
10 github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
11 github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
12 github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M=
13 github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
14 github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4=
15 github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
16 github.com/mattn/go-sqlite3 v1.14.0 h1:mLyGNKR8+Vv9CAU7PphKa2hkEqxxhn8i32J6FPj1/QA=
17 github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus=
18 github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw=
19 github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
20 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
21 golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI=
22 golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
23 golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd h1:GGJVjV8waZKRHrgwvtH66z9ZGVurTD1MT0n1Bb+q4aM=
24 golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
25 golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
26 golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
27 golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
28 golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
29 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
30 golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
31 golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
32 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
00 package gorm
11
2 import "database/sql"
2 import (
3 "context"
4 "database/sql"
5 )
36
47 // SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
58 type SQLCommon interface {
1114
1215 type sqlDb interface {
1316 Begin() (*sql.Tx, error)
17 BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
1418 }
1519
1620 type sqlTx interface {
3838
3939 messages = []interface{}{source, currentTime}
4040
41 if len(values) == 2 {
42 //remove the line break
43 currentTime = currentTime[1:]
44 //remove the brackets
45 source = fmt.Sprintf("\033[35m%v\033[0m", values[1])
46
47 messages = []interface{}{currentTime, source}
48 }
49
4150 if level == "sql" {
4251 // duration
4352 messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
4857 if indirectValue.IsValid() {
4958 value = indirectValue.Interface()
5059 if t, ok := value.(time.Time); ok {
51 formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05")))
60 if t.IsZero() {
61 formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00"))
62 } else {
63 formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05")))
64 }
5265 } else if b, ok := value.([]byte); ok {
5366 if str := string(b); isPrintable(str) {
5467 formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
6275 formattedValues = append(formattedValues, "NULL")
6376 }
6477 } else {
65 formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
78 switch value.(type) {
79 case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
80 formattedValues = append(formattedValues, fmt.Sprintf("%v", value))
81 default:
82 formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
83 }
6684 }
6785 } else {
6886 formattedValues = append(formattedValues, "NULL")
116134 func (logger Logger) Print(values ...interface{}) {
117135 logger.Println(LogFormatter(values...)...)
118136 }
137
138 type nopLogger struct{}
139
140 func (nopLogger) Print(values ...interface{}) {}
00 package gorm
11
22 import (
3 "context"
34 "database/sql"
45 "errors"
56 "fmt"
67 "reflect"
78 "strings"
9 "sync"
810 "time"
911 )
1012
1113 // DB contains information for current db connection
1214 type DB struct {
15 sync.RWMutex
1316 Value interface{}
1417 Error error
1518 RowsAffected int64
1720 // single db
1821 db SQLCommon
1922 blockGlobalUpdate bool
20 logMode int
23 logMode logModeValue
2124 logger logger
2225 search *search
23 values map[string]interface{}
26 values sync.Map
2427
2528 // global db
2629 parent *DB
2730 callbacks *Callback
2831 dialect Dialect
2932 singularTable bool
30 }
33
34 // function to be used to override the creating of a new timestamp
35 nowFuncOverride func() time.Time
36 }
37
38 type logModeValue int
39
40 const (
41 defaultLogMode logModeValue = iota
42 noLogMode
43 detailedLogMode
44 )
3145
3246 // Open initialize a new db connection, need to import driver first, e.g:
3347 //
4761 }
4862 var source string
4963 var dbSQL SQLCommon
64 var ownDbSQL bool
5065
5166 switch value := args[0].(type) {
5267 case string:
5873 source = args[1].(string)
5974 }
6075 dbSQL, err = sql.Open(driver, source)
76 ownDbSQL = true
6177 case SQLCommon:
6278 dbSQL = value
79 ownDbSQL = false
80 default:
81 return nil, fmt.Errorf("invalid database source: %v is not a valid type", value)
6382 }
6483
6584 db = &DB{
6685 db: dbSQL,
6786 logger: defaultLogger,
68 values: map[string]interface{}{},
6987 callbacks: DefaultCallback,
7088 dialect: newDialect(dialect, dbSQL),
7189 }
7593 }
7694 // Send a ping to make sure the database connection is alive.
7795 if d, ok := dbSQL.(*sql.DB); ok {
78 if err = d.Ping(); err != nil {
96 if err = d.Ping(); err != nil && ownDbSQL {
7997 d.Close()
8098 }
8199 }
105123 // DB get `*sql.DB` from current connection
106124 // If the underlying database connection is not a *sql.DB, returns nil
107125 func (s *DB) DB() *sql.DB {
108 db, _ := s.db.(*sql.DB)
126 db, ok := s.db.(*sql.DB)
127 if !ok {
128 panic("can't support full GORM on currently status, maybe this is a TX instance.")
129 }
109130 return db
110131 }
111132
116137
117138 // Dialect get dialect
118139 func (s *DB) Dialect() Dialect {
119 return s.parent.dialect
140 return s.dialect
120141 }
121142
122143 // Callback return `Callbacks` container, you could add/change/delete callbacks with it
123144 // db.Callback().Create().Register("update_created_at", updateCreated)
124145 // Refer https://jinzhu.github.io/gorm/development.html#callbacks
125146 func (s *DB) Callback() *Callback {
126 s.parent.callbacks = s.parent.callbacks.clone()
147 s.parent.callbacks = s.parent.callbacks.clone(s.logger)
127148 return s.parent.callbacks
128149 }
129150
135156 // LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
136157 func (s *DB) LogMode(enable bool) *DB {
137158 if enable {
138 s.logMode = 2
139 } else {
140 s.logMode = 1
159 s.logMode = detailedLogMode
160 } else {
161 s.logMode = noLogMode
141162 }
142163 return s
164 }
165
166 // SetNowFuncOverride set the function to be used when creating a new timestamp
167 func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB {
168 s.nowFuncOverride = nowFuncOverride
169 return s
170 }
171
172 // Get a new timestamp, using the provided nowFuncOverride on the DB instance if set,
173 // otherwise defaults to the global NowFunc()
174 func (s *DB) nowFunc() time.Time {
175 if s.nowFuncOverride != nil {
176 return s.nowFuncOverride()
177 }
178
179 return NowFunc()
143180 }
144181
145182 // BlockGlobalUpdate if true, generates an error on update/delete without where clause.
156193
157194 // SingularTable use singular table by default
158195 func (s *DB) SingularTable(enable bool) {
159 modelStructsMap = newModelStructsMap()
196 s.parent.Lock()
197 defer s.parent.Unlock()
160198 s.parent.singularTable = enable
161199 }
162200
164202 func (s *DB) NewScope(value interface{}) *Scope {
165203 dbClone := s.clone()
166204 dbClone.Value = value
167 return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
168 }
169
170 // QueryExpr returns the query as expr object
171 func (s *DB) QueryExpr() *expr {
205 scope := &Scope{db: dbClone, Value: value}
206 if s.search != nil {
207 scope.Search = s.search.clone()
208 } else {
209 scope.Search = &search{}
210 }
211 return scope
212 }
213
214 // QueryExpr returns the query as SqlExpr object
215 func (s *DB) QueryExpr() *SqlExpr {
172216 scope := s.NewScope(s.Value)
173217 scope.InstanceSet("skip_bindvar", true)
174218 scope.prepareQuerySQL()
177221 }
178222
179223 // SubQuery returns the query as sub query
180 func (s *DB) SubQuery() *expr {
224 func (s *DB) SubQuery() *SqlExpr {
181225 scope := s.NewScope(s.Value)
182226 scope.InstanceSet("skip_bindvar", true)
183227 scope.prepareQuerySQL()
284328 func (s *DB) First(out interface{}, where ...interface{}) *DB {
285329 newScope := s.NewScope(out)
286330 newScope.Search.Limit(1)
331
287332 return newScope.Set("gorm:order_by_primary_key", "ASC").
288333 inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
289334 }
306351 // Find find records that match given conditions
307352 func (s *DB) Find(out interface{}, where ...interface{}) *DB {
308353 return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
354 }
355
356 //Preloads preloads relations, don`t touch out
357 func (s *DB) Preloads(out interface{}) *DB {
358 return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db
309359 }
310360
311361 // Scan scan value to a struct
386436 }
387437
388438 // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
439 // WARNING when update with struct, GORM will not update fields that with zero value
389440 func (s *DB) Update(attrs ...interface{}) *DB {
390441 return s.Updates(toSearchableMap(attrs...), true)
391442 }
418469 if !scope.PrimaryKeyZero() {
419470 newDB := scope.callCallbacks(s.parent.callbacks.updates).db
420471 if newDB.Error == nil && newDB.RowsAffected == 0 {
421 return s.New().FirstOrCreate(value)
472 return s.New().Table(scope.TableName()).FirstOrCreate(value)
422473 }
423474 return newDB
424475 }
432483 }
433484
434485 // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
486 // WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time
435487 func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
436488 return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
437489 }
475527 return s.clone().LogMode(true)
476528 }
477529
478 // Begin begin a transaction
530 // Transaction start a transaction as a block,
531 // return error will rollback, otherwise to commit.
532 func (s *DB) Transaction(fc func(tx *DB) error) (err error) {
533
534 if _, ok := s.db.(*sql.Tx); ok {
535 return fc(s)
536 }
537
538 panicked := true
539 tx := s.Begin()
540 defer func() {
541 // Make sure to rollback when panic, Block error or Commit error
542 if panicked || err != nil {
543 tx.Rollback()
544 }
545 }()
546
547 err = fc(tx)
548
549 if err == nil {
550 err = tx.Commit().Error
551 }
552
553 panicked = false
554 return
555 }
556
557 // Begin begins a transaction
479558 func (s *DB) Begin() *DB {
559 return s.BeginTx(context.Background(), &sql.TxOptions{})
560 }
561
562 // BeginTx begins a transaction with options
563 func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
480564 c := s.clone()
481565 if db, ok := c.db.(sqlDb); ok && db != nil {
482 tx, err := db.Begin()
566 tx, err := db.BeginTx(ctx, opts)
483567 c.db = interface{}(tx).(SQLCommon)
568
569 c.dialect.SetDB(c.db)
484570 c.AddError(err)
485571 } else {
486572 c.AddError(ErrCantStartTransaction)
490576
491577 // Commit commit a transaction
492578 func (s *DB) Commit() *DB {
493 if db, ok := s.db.(sqlTx); ok && db != nil {
579 var emptySQLTx *sql.Tx
580 if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
494581 s.AddError(db.Commit())
495582 } else {
496583 s.AddError(ErrInvalidTransaction)
500587
501588 // Rollback rollback a transaction
502589 func (s *DB) Rollback() *DB {
503 if db, ok := s.db.(sqlTx); ok && db != nil {
504 s.AddError(db.Rollback())
590 var emptySQLTx *sql.Tx
591 if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
592 if err := db.Rollback(); err != nil && err != sql.ErrTxDone {
593 s.AddError(err)
594 }
595 } else {
596 s.AddError(ErrInvalidTransaction)
597 }
598 return s
599 }
600
601 // RollbackUnlessCommitted rollback a transaction if it has not yet been
602 // committed.
603 func (s *DB) RollbackUnlessCommitted() *DB {
604 var emptySQLTx *sql.Tx
605 if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
606 err := db.Rollback()
607 // Ignore the error indicating that the transaction has already
608 // been committed.
609 if err != sql.ErrTxDone {
610 s.AddError(err)
611 }
505612 } else {
506613 s.AddError(ErrInvalidTransaction)
507614 }
669776
670777 // InstantSet instant set setting, will affect current db
671778 func (s *DB) InstantSet(name string, value interface{}) *DB {
672 s.values[name] = value
779 s.values.Store(name, value)
673780 return s
674781 }
675782
676783 // Get get setting by name
677784 func (s *DB) Get(name string) (value interface{}, ok bool) {
678 value, ok = s.values[name]
785 value, ok = s.values.Load(name)
679786 return
680787 }
681788
684791 scope := s.NewScope(source)
685792 for _, field := range scope.GetModelStruct().StructFields {
686793 if field.Name == column || field.DBName == column {
687 if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
794 if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" {
688795 source := (&Scope{Value: source}).GetModelStruct().ModelType
689796 destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
690797 handler.Setup(field.Relationship, many2many, source, destination)
701808 func (s *DB) AddError(err error) error {
702809 if err != nil {
703810 if err != ErrRecordNotFound {
704 if s.logMode == 0 {
705 go s.print(fileWithLineNum(), err)
811 if s.logMode == defaultLogMode {
812 go s.print("error", fileWithLineNum(), err)
706813 } else {
707814 s.log(err)
708815 }
739846 parent: s.parent,
740847 logger: s.logger,
741848 logMode: s.logMode,
742 values: map[string]interface{}{},
743849 Value: s.Value,
744850 Error: s.Error,
745851 blockGlobalUpdate: s.blockGlobalUpdate,
746 }
747
748 for key, value := range s.values {
749 db.values[key] = value
750 }
852 dialect: newDialect(s.dialect.GetName(), s.db),
853 nowFuncOverride: s.nowFuncOverride,
854 }
855
856 s.values.Range(func(k, v interface{}) bool {
857 db.values.Store(k, v)
858 return true
859 })
751860
752861 if s.search == nil {
753862 db.search = &search{limit: -1, offset: -1}
764873 }
765874
766875 func (s *DB) log(v ...interface{}) {
767 if s != nil && s.logMode == 2 {
876 if s != nil && s.logMode == detailedLogMode {
768877 s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
769878 }
770879 }
771880
772881 func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
773 if s.logMode == 2 {
882 if s.logMode == detailedLogMode {
774883 s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected)
775884 }
776885 }
00 package gorm_test
11
2 // Run tests
3 // $ docker-compose up
4 // $ ./test_all.sh
5
26 import (
7 "context"
38 "database/sql"
49 "database/sql/driver"
10 "errors"
511 "fmt"
612 "os"
713 "path/filepath"
814 "reflect"
15 "sort"
916 "strconv"
17 "strings"
18 "sync"
1019 "testing"
1120 "time"
1221
4655 case "postgres":
4756 fmt.Println("testing postgres...")
4857 if dbDSN == "" {
49 dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
58 dbDSN = "user=gorm password=gorm dbname=gorm port=9920 sslmode=disable"
5059 }
5160 db, err = gorm.Open("postgres", dbDSN)
5261 case "mssql":
7887 return
7988 }
8089
90 func TestOpen_ReturnsError_WithBadArgs(t *testing.T) {
91 stringRef := "foo"
92 testCases := []interface{}{42, time.Now(), &stringRef}
93 for _, tc := range testCases {
94 t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) {
95 _, err := gorm.Open("postgresql", tc)
96 if err == nil {
97 t.Error("Should got error with invalid database source")
98 }
99 if !strings.HasPrefix(err.Error(), "invalid database source:") {
100 t.Errorf("Should got error starting with \"invalid database source:\", but got %q", err.Error())
101 }
102 })
103 }
104 }
105
81106 func TestStringPrimaryKey(t *testing.T) {
82107 type UUIDStruct struct {
83108 ID string `gorm:"primary_key"`
156181 DB.Table("deleted_users").Find(&deletedUsers)
157182 if len(deletedUsers) != 1 {
158183 t.Errorf("Query from specified table")
184 }
185
186 var user User
187 DB.Table("deleted_users").First(&user, "name = ?", "DeletedUser")
188
189 user.Age = 20
190 DB.Table("deleted_users").Save(&user)
191 if DB.Table("deleted_users").First(&user, "name = ? AND age = ?", "DeletedUser", 20).RecordNotFound() {
192 t.Errorf("Failed to found updated user")
159193 }
160194
161195 DB.Save(getPreparedUser("normal_user", "reset_table"))
256290 if DB.NewScope([]Cart{}).TableName() != "shopping_cart" {
257291 t.Errorf("[]Cart's singular table name should be shopping_cart")
258292 }
293 DB.SingularTable(false)
294 }
295
296 func TestTableNameConcurrently(t *testing.T) {
297 DB := DB.Model("")
298 if DB.NewScope(Order{}).TableName() != "orders" {
299 t.Errorf("Order's table name should be orders")
300 }
301
302 var wg sync.WaitGroup
303 wg.Add(10)
304
305 for i := 1; i <= 10; i++ {
306 go func(db *gorm.DB) {
307 DB.SingularTable(true)
308 wg.Done()
309 }(DB)
310 }
311 wg.Wait()
312
313 if DB.NewScope(Order{}).TableName() != "order" {
314 t.Errorf("Order's singular table name should be order")
315 }
316
259317 DB.SingularTable(false)
260318 }
261319
376434 if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
377435 t.Errorf("Should be able to find committed record")
378436 }
437
438 tx3 := DB.Begin()
439 u3 := User{Name: "transcation-3"}
440 if err := tx3.Save(&u3).Error; err != nil {
441 t.Errorf("No error should raise")
442 }
443
444 if err := tx3.First(&User{}, "name = ?", "transcation-3").Error; err != nil {
445 t.Errorf("Should find saved record")
446 }
447
448 tx3.RollbackUnlessCommitted()
449
450 if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {
451 t.Errorf("Should not find record after rollback")
452 }
453
454 tx4 := DB.Begin()
455 u4 := User{Name: "transcation-4"}
456 if err := tx4.Save(&u4).Error; err != nil {
457 t.Errorf("No error should raise")
458 }
459
460 if err := tx4.First(&User{}, "name = ?", "transcation-4").Error; err != nil {
461 t.Errorf("Should find saved record")
462 }
463
464 tx4.Commit()
465
466 tx4.RollbackUnlessCommitted()
467
468 if err := DB.First(&User{}, "name = ?", "transcation-4").Error; err != nil {
469 t.Errorf("Should be able to find committed record")
470 }
471 }
472
473 func assertPanic(t *testing.T, f func()) {
474 defer func() {
475 if r := recover(); r == nil {
476 t.Errorf("The code did not panic")
477 }
478 }()
479 f()
480 }
481
482 func TestTransactionWithBlock(t *testing.T) {
483 // rollback
484 err := DB.Transaction(func(tx *gorm.DB) error {
485 u := User{Name: "transcation"}
486 if err := tx.Save(&u).Error; err != nil {
487 t.Errorf("No error should raise")
488 }
489
490 if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
491 t.Errorf("Should find saved record")
492 }
493
494 return errors.New("the error message")
495 })
496
497 if err.Error() != "the error message" {
498 t.Errorf("Transaction return error will equal the block returns error")
499 }
500
501 if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil {
502 t.Errorf("Should not find record after rollback")
503 }
504
505 // commit
506 DB.Transaction(func(tx *gorm.DB) error {
507 u2 := User{Name: "transcation-2"}
508 if err := tx.Save(&u2).Error; err != nil {
509 t.Errorf("No error should raise")
510 }
511
512 if err := tx.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
513 t.Errorf("Should find saved record")
514 }
515 return nil
516 })
517
518 if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
519 t.Errorf("Should be able to find committed record")
520 }
521
522 // panic will rollback
523 assertPanic(t, func() {
524 DB.Transaction(func(tx *gorm.DB) error {
525 u3 := User{Name: "transcation-3"}
526 if err := tx.Save(&u3).Error; err != nil {
527 t.Errorf("No error should raise")
528 }
529
530 if err := tx.First(&User{}, "name = ?", "transcation-3").Error; err != nil {
531 t.Errorf("Should find saved record")
532 }
533
534 panic("force panic")
535 })
536 })
537
538 if err := DB.First(&User{}, "name = ?", "transcation-3").Error; err == nil {
539 t.Errorf("Should not find record after panic rollback")
540 }
541 }
542
543 func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) {
544 tx := DB.Begin()
545 u := User{Name: "transcation"}
546 if err := tx.Save(&u).Error; err != nil {
547 t.Errorf("No error should raise")
548 }
549
550 if err := tx.Commit().Error; err != nil {
551 t.Errorf("Commit should not raise error")
552 }
553
554 if err := tx.Rollback().Error; err != nil {
555 t.Errorf("Rollback should not raise error")
556 }
557 }
558
559 func TestTransactionReadonly(t *testing.T) {
560 dialect := os.Getenv("GORM_DIALECT")
561 if dialect == "" {
562 dialect = "sqlite"
563 }
564 switch dialect {
565 case "mssql", "sqlite":
566 t.Skipf("%s does not support readonly transactions\n", dialect)
567 }
568
569 tx := DB.Begin()
570 u := User{Name: "transcation"}
571 if err := tx.Save(&u).Error; err != nil {
572 t.Errorf("No error should raise")
573 }
574 tx.Commit()
575
576 tx = DB.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
577 if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
578 t.Errorf("Should find saved record")
579 }
580
581 if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil {
582 t.Errorf("Should return the underlying sql.Tx")
583 }
584
585 u = User{Name: "transcation-2"}
586 if err := tx.Save(&u).Error; err == nil {
587 t.Errorf("Error should have been raised in a readonly transaction")
588 }
589
590 tx.Rollback()
379591 }
380592
381593 func TestRow(t *testing.T) {
563775 }
564776 }
565777
778 type JoinedIds struct {
779 UserID int64 `gorm:"column:id"`
780 BillingAddressID int64 `gorm:"column:id"`
781 EmailID int64 `gorm:"column:id"`
782 }
783
784 func TestScanIdenticalColumnNames(t *testing.T) {
785 var user = User{
786 Name: "joinsIds",
787 Email: "joinIds@example.com",
788 BillingAddress: Address{
789 Address1: "One Park Place",
790 },
791 Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
792 }
793 DB.Save(&user)
794
795 var users []JoinedIds
796 DB.Select("users.id, addresses.id, emails.id").Table("users").
797 Joins("left join addresses on users.billing_address_id = addresses.id").
798 Joins("left join emails on emails.user_id = users.id").
799 Where("name = ?", "joinsIds").Scan(&users)
800
801 if len(users) != 2 {
802 t.Fatal("should find two rows using left join")
803 }
804
805 if user.Id != users[0].UserID {
806 t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[0].UserID)
807 }
808 if user.Id != users[1].UserID {
809 t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[1].UserID)
810 }
811
812 if user.BillingAddressID.Int64 != users[0].BillingAddressID {
813 t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID)
814 }
815 if user.BillingAddressID.Int64 != users[1].BillingAddressID {
816 t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID)
817 }
818
819 if users[0].EmailID == users[1].EmailID {
820 t.Errorf("Email ids should be unique. Got %d and %d", users[0].EmailID, users[1].EmailID)
821 }
822
823 if int64(user.Emails[0].Id) != users[0].EmailID && int64(user.Emails[1].Id) != users[0].EmailID {
824 t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[0].EmailID)
825 }
826
827 if int64(user.Emails[0].Id) != users[1].EmailID && int64(user.Emails[1].Id) != users[1].EmailID {
828 t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[1].EmailID)
829 }
830 }
831
566832 func TestJoinsWithSelect(t *testing.T) {
567833 type result struct {
568834 Name string
577843
578844 var results []result
579845 DB.Table("users").Select("name, emails.email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results)
846
847 sort.Slice(results, func(i, j int) bool {
848 return strings.Compare(results[i].Email, results[j].Email) < 0
849 })
850
580851 if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" {
581852 t.Errorf("Should find all two emails with Join select")
582853 }
8611132 }
8621133 }
8631134
1135 func TestSaveAssociations(t *testing.T) {
1136 db := DB.New()
1137 deltaAddressCount := 0
1138 if err := db.Model(&Address{}).Count(&deltaAddressCount).Error; err != nil {
1139 t.Errorf("failed to fetch address count")
1140 t.FailNow()
1141 }
1142
1143 placeAddress := &Address{
1144 Address1: "somewhere on earth",
1145 }
1146 ownerAddress1 := &Address{
1147 Address1: "near place address",
1148 }
1149 ownerAddress2 := &Address{
1150 Address1: "address2",
1151 }
1152 db.Create(placeAddress)
1153
1154 addressCountShouldBe := func(t *testing.T, expectedCount int) {
1155 countFromDB := 0
1156 t.Helper()
1157 err := db.Model(&Address{}).Count(&countFromDB).Error
1158 if err != nil {
1159 t.Error("failed to fetch address count")
1160 }
1161 if countFromDB != expectedCount {
1162 t.Errorf("address count mismatch: %d", countFromDB)
1163 }
1164 }
1165 addressCountShouldBe(t, deltaAddressCount+1)
1166
1167 // owner address should be created, place address should be reused
1168 place1 := &Place{
1169 PlaceAddressID: placeAddress.ID,
1170 PlaceAddress: placeAddress,
1171 OwnerAddress: ownerAddress1,
1172 }
1173 err := db.Create(place1).Error
1174 if err != nil {
1175 t.Errorf("failed to store place: %s", err.Error())
1176 }
1177 addressCountShouldBe(t, deltaAddressCount+2)
1178
1179 // owner address should be created again, place address should be reused
1180 place2 := &Place{
1181 PlaceAddressID: placeAddress.ID,
1182 PlaceAddress: &Address{
1183 ID: 777,
1184 Address1: "address1",
1185 },
1186 OwnerAddress: ownerAddress2,
1187 OwnerAddressID: 778,
1188 }
1189 err = db.Create(place2).Error
1190 if err != nil {
1191 t.Errorf("failed to store place: %s", err.Error())
1192 }
1193 addressCountShouldBe(t, deltaAddressCount+3)
1194
1195 count := 0
1196 db.Model(&Place{}).Where(&Place{
1197 PlaceAddressID: placeAddress.ID,
1198 OwnerAddressID: ownerAddress1.ID,
1199 }).Count(&count)
1200 if count != 1 {
1201 t.Errorf("only one instance of (%d, %d) should be available, found: %d",
1202 placeAddress.ID, ownerAddress1.ID, count)
1203 }
1204
1205 db.Model(&Place{}).Where(&Place{
1206 PlaceAddressID: placeAddress.ID,
1207 OwnerAddressID: ownerAddress2.ID,
1208 }).Count(&count)
1209 if count != 1 {
1210 t.Errorf("only one instance of (%d, %d) should be available, found: %d",
1211 placeAddress.ID, ownerAddress2.ID, count)
1212 }
1213
1214 db.Model(&Place{}).Where(&Place{
1215 PlaceAddressID: placeAddress.ID,
1216 }).Count(&count)
1217 if count != 2 {
1218 t.Errorf("two instances of (%d) should be available, found: %d",
1219 placeAddress.ID, count)
1220 }
1221 }
1222
8641223 func TestBlockGlobalUpdate(t *testing.T) {
8651224 db := DB.New()
8661225 db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
8971256 if err != nil {
8981257 t.Error("Unexpected error on conditional delete")
8991258 }
1259 }
1260
1261 func TestCountWithHaving(t *testing.T) {
1262 db := DB.New()
1263 db.Delete(User{})
1264 defer db.Delete(User{})
1265
1266 DB.Create(getPreparedUser("user1", "pluck_user"))
1267 DB.Create(getPreparedUser("user2", "pluck_user"))
1268 user3 := getPreparedUser("user3", "pluck_user")
1269 user3.Languages = []Language{}
1270 DB.Create(user3)
1271
1272 var count int
1273 err := db.Model(User{}).Select("users.id").
1274 Joins("LEFT JOIN user_languages ON user_languages.user_id = users.id").
1275 Joins("LEFT JOIN languages ON user_languages.language_id = languages.id").
1276 Group("users.id").Having("COUNT(languages.id) > 1").Count(&count).Error
1277
1278 if err != nil {
1279 t.Error("Unexpected error on query count with having")
1280 }
1281
1282 if count != 2 {
1283 t.Error("Unexpected result on query count with having")
1284 }
1285 }
1286
1287 func TestPluck(t *testing.T) {
1288 db := DB.New()
1289 db.Delete(User{})
1290 defer db.Delete(User{})
1291
1292 DB.Create(&User{Id: 1, Name: "user1"})
1293 DB.Create(&User{Id: 2, Name: "user2"})
1294 DB.Create(&User{Id: 3, Name: "user3"})
1295
1296 var ids []int64
1297 err := db.Model(User{}).Order("id").Pluck("id", &ids).Error
1298
1299 if err != nil {
1300 t.Error("Unexpected error on pluck")
1301 }
1302
1303 if len(ids) != 3 || ids[0] != 1 || ids[1] != 2 || ids[2] != 3 {
1304 t.Error("Unexpected result on pluck")
1305 }
1306
1307 err = db.Model(User{}).Order("id").Pluck("id", &ids).Error
1308
1309 if err != nil {
1310 t.Error("Unexpected error on pluck again")
1311 }
1312
1313 if len(ids) != 3 || ids[0] != 1 || ids[1] != 2 || ids[2] != 3 {
1314 t.Error("Unexpected result on pluck again")
1315 }
1316 }
1317
1318 func TestCountWithQueryOption(t *testing.T) {
1319 db := DB.New()
1320 db.Delete(User{})
1321 defer db.Delete(User{})
1322
1323 DB.Create(&User{Name: "user1"})
1324 DB.Create(&User{Name: "user2"})
1325 DB.Create(&User{Name: "user3"})
1326
1327 var count int
1328 err := db.Model(User{}).Select("users.id").
1329 Set("gorm:query_option", "WHERE users.name='user2'").
1330 Count(&count).Error
1331
1332 if err != nil {
1333 t.Error("Unexpected error on query count with query_option")
1334 }
1335
1336 if count != 1 {
1337 t.Error("Unexpected result on query count with query_option")
1338 }
1339 }
1340
1341 func TestQueryHint1(t *testing.T) {
1342 db := DB.New()
1343
1344 _, err := db.Model(User{}).Raw("select 1").Rows()
1345
1346 if err != nil {
1347 t.Error("Unexpected error on query count with query_option")
1348 }
1349 }
1350
1351 func TestQueryHint2(t *testing.T) {
1352 type TestStruct struct {
1353 ID string `gorm:"primary_key"`
1354 Name string
1355 }
1356 DB.DropTable(&TestStruct{})
1357 DB.AutoMigrate(&TestStruct{})
1358
1359 data := TestStruct{ID: "uuid", Name: "hello"}
1360 if err := DB.Set("gorm:query_hint", "/*master*/").Save(&data).Error; err != nil {
1361 t.Error("Unexpected error on query count with query_option")
1362 }
1363 }
1364
1365 func TestFloatColumnPrecision(t *testing.T) {
1366 if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" && dialect != "sqlite" {
1367 t.Skip()
1368 }
1369
1370 type FloatTest struct {
1371 ID string `gorm:"primary_key"`
1372 FloatValue float64 `gorm:"column:float_value" sql:"type:float(255,5);"`
1373 }
1374 DB.DropTable(&FloatTest{})
1375 DB.AutoMigrate(&FloatTest{})
1376
1377 data := FloatTest{ID: "uuid", FloatValue: 112.57315}
1378 if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.FloatValue != 112.57315 {
1379 t.Errorf("Float value should not lose precision")
1380 }
1381 }
1382
1383 func TestWhereUpdates(t *testing.T) {
1384 type OwnerEntity struct {
1385 gorm.Model
1386 OwnerID uint
1387 OwnerType string
1388 }
1389
1390 type SomeEntity struct {
1391 gorm.Model
1392 Name string
1393 OwnerEntity OwnerEntity `gorm:"polymorphic:Owner"`
1394 }
1395
1396 DB.DropTable(&SomeEntity{})
1397 DB.AutoMigrate(&SomeEntity{})
1398
1399 a := SomeEntity{Name: "test"}
1400 DB.Model(&a).Where(a).Updates(SomeEntity{Name: "test2"})
9001401 }
9011402
9021403 func BenchmarkGorm(b *testing.B) {
117117 Owner *User `sql:"-"`
118118 }
119119
120 type Place struct {
121 Id int64
122 PlaceAddressID int
123 PlaceAddress *Address `gorm:"save_associations:false"`
124 OwnerAddressID int
125 OwnerAddress *Address `gorm:"save_associations:true"`
126 }
127
120128 type EncryptedData []byte
121129
122130 func (data *EncryptedData) Scan(value interface{}) error {
283291 DB.Exec(fmt.Sprintf("drop table %v;", table))
284292 }
285293
286 values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}}
294 values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}}
287295 for _, value := range values {
288296 DB.DropTable(value)
289297 }
397405 }
398406 }
399407
408 func TestCreateAndAutomigrateTransaction(t *testing.T) {
409 tx := DB.Begin()
410
411 func() {
412 type Bar struct {
413 ID uint
414 }
415 DB.DropTableIfExists(&Bar{})
416
417 if ok := DB.HasTable("bars"); ok {
418 t.Errorf("Table should not exist, but does")
419 }
420
421 if ok := tx.HasTable("bars"); ok {
422 t.Errorf("Table should not exist, but does")
423 }
424 }()
425
426 func() {
427 type Bar struct {
428 Name string
429 }
430 err := tx.CreateTable(&Bar{}).Error
431
432 if err != nil {
433 t.Errorf("Should have been able to create the table, but couldn't: %s", err)
434 }
435
436 if ok := tx.HasTable(&Bar{}); !ok {
437 t.Errorf("The transaction should be able to see the table")
438 }
439 }()
440
441 func() {
442 type Bar struct {
443 Stuff string
444 }
445
446 err := tx.AutoMigrate(&Bar{}).Error
447 if err != nil {
448 t.Errorf("Should have been able to alter the table, but couldn't")
449 }
450 }()
451
452 tx.Rollback()
453 }
454
400455 type MultipleIndexes struct {
401456 ID int64
402457 UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`
482537 t.Errorf("No error should happen when ModifyColumn, but got %v", err)
483538 }
484539 }
540
541 func TestIndexWithPrefixLength(t *testing.T) {
542 if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" {
543 t.Skip("Skipping this because only mysql support setting an index prefix length")
544 }
545
546 type IndexWithPrefix struct {
547 gorm.Model
548 Name string
549 Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
550 }
551 type IndexesWithPrefix struct {
552 gorm.Model
553 Name string
554 Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
555 Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
556 }
557 type IndexesWithPrefixAndWithoutPrefix struct {
558 gorm.Model
559 Name string `gorm:"index:idx_index_with_prefixes_length"`
560 Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
561 }
562 tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}}
563 for _, table := range tables {
564 scope := DB.NewScope(table)
565 tableName := scope.TableName()
566 t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) {
567 if err := DB.DropTableIfExists(table).Error; err != nil {
568 t.Errorf("Failed to drop %s table: %v", tableName, err)
569 }
570 if err := DB.CreateTable(table).Error; err != nil {
571 t.Errorf("Failed to create %s table: %v", tableName, err)
572 }
573 if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") {
574 t.Errorf("Failed to create %s table index:", tableName)
575 }
576 })
577 }
578 }
1616 return defaultTableName
1717 }
1818
19 type safeModelStructsMap struct {
20 m map[reflect.Type]*ModelStruct
21 l *sync.RWMutex
22 }
23
24 func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) {
19 // lock for mutating global cached model metadata
20 var structsLock sync.Mutex
21
22 // global cache of model metadata
23 var modelStructsMap sync.Map
24
25 // ModelStruct model definition
26 type ModelStruct struct {
27 PrimaryFields []*StructField
28 StructFields []*StructField
29 ModelType reflect.Type
30
31 defaultTableName string
32 l sync.Mutex
33 }
34
35 // TableName returns model's table name
36 func (s *ModelStruct) TableName(db *DB) string {
2537 s.l.Lock()
2638 defer s.l.Unlock()
27 s.m[key] = value
28 }
29
30 func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct {
31 s.l.RLock()
32 defer s.l.RUnlock()
33 return s.m[key]
34 }
35
36 func newModelStructsMap() *safeModelStructsMap {
37 return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)}
38 }
39
40 var modelStructsMap = newModelStructsMap()
41
42 // ModelStruct model definition
43 type ModelStruct struct {
44 PrimaryFields []*StructField
45 StructFields []*StructField
46 ModelType reflect.Type
47 defaultTableName string
48 }
49
50 // TableName get model's table name
51 func (s *ModelStruct) TableName(db *DB) string {
39
5240 if s.defaultTableName == "" && db != nil && s.ModelType != nil {
5341 // Set default table name
5442 if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok {
5543 s.defaultTableName = tabler.TableName()
5644 } else {
57 tableName := ToDBName(s.ModelType.Name())
58 if db == nil || !db.parent.singularTable {
45 tableName := ToTableName(s.ModelType.Name())
46 db.parent.RLock()
47 if db == nil || (db.parent != nil && !db.parent.singularTable) {
5948 tableName = inflection.Plural(tableName)
6049 }
50 db.parent.RUnlock()
6151 s.defaultTableName = tableName
6252 }
6353 }
8070 Struct reflect.StructField
8171 IsForeignKey bool
8272 Relationship *Relationship
83 }
84
85 func (structField *StructField) clone() *StructField {
73
74 tagSettingsLock sync.RWMutex
75 }
76
77 // TagSettingsSet Sets a tag in the tag settings map
78 func (sf *StructField) TagSettingsSet(key, val string) {
79 sf.tagSettingsLock.Lock()
80 defer sf.tagSettingsLock.Unlock()
81 sf.TagSettings[key] = val
82 }
83
84 // TagSettingsGet returns a tag from the tag settings
85 func (sf *StructField) TagSettingsGet(key string) (string, bool) {
86 sf.tagSettingsLock.RLock()
87 defer sf.tagSettingsLock.RUnlock()
88 val, ok := sf.TagSettings[key]
89 return val, ok
90 }
91
92 // TagSettingsDelete deletes a tag
93 func (sf *StructField) TagSettingsDelete(key string) {
94 sf.tagSettingsLock.Lock()
95 defer sf.tagSettingsLock.Unlock()
96 delete(sf.TagSettings, key)
97 }
98
99 func (sf *StructField) clone() *StructField {
86100 clone := &StructField{
87 DBName: structField.DBName,
88 Name: structField.Name,
89 Names: structField.Names,
90 IsPrimaryKey: structField.IsPrimaryKey,
91 IsNormal: structField.IsNormal,
92 IsIgnored: structField.IsIgnored,
93 IsScanner: structField.IsScanner,
94 HasDefaultValue: structField.HasDefaultValue,
95 Tag: structField.Tag,
101 DBName: sf.DBName,
102 Name: sf.Name,
103 Names: sf.Names,
104 IsPrimaryKey: sf.IsPrimaryKey,
105 IsNormal: sf.IsNormal,
106 IsIgnored: sf.IsIgnored,
107 IsScanner: sf.IsScanner,
108 HasDefaultValue: sf.HasDefaultValue,
109 Tag: sf.Tag,
96110 TagSettings: map[string]string{},
97 Struct: structField.Struct,
98 IsForeignKey: structField.IsForeignKey,
99 }
100
101 if structField.Relationship != nil {
102 relationship := *structField.Relationship
111 Struct: sf.Struct,
112 IsForeignKey: sf.IsForeignKey,
113 }
114
115 if sf.Relationship != nil {
116 relationship := *sf.Relationship
103117 clone.Relationship = &relationship
104118 }
105119
106 for key, value := range structField.TagSettings {
120 // copy the struct field tagSettings, they should be read-locked while they are copied
121 sf.tagSettingsLock.Lock()
122 defer sf.tagSettingsLock.Unlock()
123 for key, value := range sf.TagSettings {
107124 clone.TagSettings[key] = value
108125 }
109126
125142
126143 func getForeignField(column string, fields []*StructField) *StructField {
127144 for _, field := range fields {
128 if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) {
145 if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) {
129146 return field
130147 }
131148 }
134151
135152 // GetModelStruct get value's model struct, relationships based on struct and tag definition
136153 func (scope *Scope) GetModelStruct() *ModelStruct {
154 return scope.getModelStruct(scope, make([]*StructField, 0))
155 }
156
157 func (scope *Scope) getModelStruct(rootScope *Scope, allFields []*StructField) *ModelStruct {
137158 var modelStruct ModelStruct
138159 // Scope value can't be nil
139160 if scope.Value == nil {
151172 }
152173
153174 // Get Cached model struct
154 if value := modelStructsMap.Get(reflectType); value != nil {
155 return value
175 isSingularTable := false
176 if scope.db != nil && scope.db.parent != nil {
177 scope.db.parent.RLock()
178 isSingularTable = scope.db.parent.singularTable
179 scope.db.parent.RUnlock()
180 }
181
182 hashKey := struct {
183 singularTable bool
184 reflectType reflect.Type
185 }{isSingularTable, reflectType}
186 if value, ok := modelStructsMap.Load(hashKey); ok && value != nil {
187 return value.(*ModelStruct)
156188 }
157189
158190 modelStruct.ModelType = reflectType
169201 }
170202
171203 // is ignored field
172 if _, ok := field.TagSettings["-"]; ok {
204 if _, ok := field.TagSettingsGet("-"); ok {
173205 field.IsIgnored = true
174206 } else {
175 if _, ok := field.TagSettings["PRIMARY_KEY"]; ok {
207 if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok {
176208 field.IsPrimaryKey = true
177209 modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
178210 }
179211
180 if _, ok := field.TagSettings["DEFAULT"]; ok {
212 if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey {
181213 field.HasDefaultValue = true
182214 }
183215
184 if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsPrimaryKey {
216 if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey {
185217 field.HasDefaultValue = true
186218 }
187219
197229 if indirectType.Kind() == reflect.Struct {
198230 for i := 0; i < indirectType.NumField(); i++ {
199231 for key, value := range parseTagSetting(indirectType.Field(i).Tag) {
200 if _, ok := field.TagSettings[key]; !ok {
201 field.TagSettings[key] = value
232 if _, ok := field.TagSettingsGet(key); !ok {
233 field.TagSettingsSet(key, value)
202234 }
203235 }
204236 }
206238 } else if _, isTime := fieldValue.(*time.Time); isTime {
207239 // is time
208240 field.IsNormal = true
209 } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
241 } else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous {
210242 // is embedded struct
211 for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields {
243 for _, subField := range scope.New(fieldValue).getModelStruct(rootScope, allFields).StructFields {
212244 subField = subField.clone()
213245 subField.Names = append([]string{fieldStruct.Name}, subField.Names...)
214 if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok {
246 if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok {
215247 subField.DBName = prefix + subField.DBName
216248 }
217249
218250 if subField.IsPrimaryKey {
219 if _, ok := subField.TagSettings["PRIMARY_KEY"]; ok {
251 if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok {
220252 modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField)
221253 } else {
222254 subField.IsPrimaryKey = false
232264 }
233265
234266 modelStruct.StructFields = append(modelStruct.StructFields, subField)
267 allFields = append(allFields, subField)
235268 }
236269 continue
237270 } else {
247280 elemType = field.Struct.Type
248281 )
249282
250 if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
283 if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" {
251284 foreignKeys = strings.Split(foreignKey, ",")
252285 }
253286
254 if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
287 if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" {
255288 associationForeignKeys = strings.Split(foreignKey, ",")
256 } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
289 } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" {
257290 associationForeignKeys = strings.Split(foreignKey, ",")
258291 }
259292
262295 }
263296
264297 if elemType.Kind() == reflect.Struct {
265 if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
298 if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" {
266299 relationship.Kind = "many_to_many"
267300
268301 { // Foreign Keys for Source
269302 joinTableDBNames := []string{}
270303
271 if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
304 if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" {
272305 joinTableDBNames = strings.Split(foreignKey, ",")
273306 }
274307
289322 // if defined join table's foreign key
290323 relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
291324 } else {
292 defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
325 defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName
293326 relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
294327 }
295328 }
299332 { // Foreign Keys for Association (Destination)
300333 associationJoinTableDBNames := []string{}
301334
302 if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
335 if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" {
303336 associationJoinTableDBNames = strings.Split(foreignKey, ",")
304337 }
305338
320353 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
321354 } else {
322355 // join table foreign keys for association
323 joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
356 joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName
324357 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
325358 }
326359 }
337370 var toFields = toScope.GetStructFields()
338371 relationship.Kind = "has_many"
339372
340 if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
373 if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" {
341374 // Dog has many toys, tag polymorphic is Owner, then associationType is Owner
342375 // Toy use OwnerID, OwnerType ('dogs') as foreign key
343376 if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
345378 relationship.PolymorphicType = polymorphicType.Name
346379 relationship.PolymorphicDBName = polymorphicType.DBName
347380 // if Dog has multiple set of toys set name of the set (instead of default 'dogs')
348 if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
381 if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok {
349382 relationship.PolymorphicValue = value
350383 } else {
351384 relationship.PolymorphicValue = scope.TableName()
365398 } else {
366399 // generate foreign keys from defined association foreign keys
367400 for _, scopeFieldName := range associationForeignKeys {
368 if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil {
401 if foreignField := getForeignField(scopeFieldName, allFields); foreignField != nil {
369402 foreignKeys = append(foreignKeys, associationType+foreignField.Name)
370403 associationForeignKeys = append(associationForeignKeys, foreignField.Name)
371404 }
377410 for _, foreignKey := range foreignKeys {
378411 if strings.HasPrefix(foreignKey, associationType) {
379412 associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
380 if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
413 if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil {
381414 associationForeignKeys = append(associationForeignKeys, associationForeignKey)
382415 }
383416 }
384417 }
385418 if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
386 associationForeignKeys = []string{scope.PrimaryKey()}
419 associationForeignKeys = []string{rootScope.PrimaryKey()}
387420 }
388421 } else if len(foreignKeys) != len(associationForeignKeys) {
389422 scope.Err(errors.New("invalid foreign keys, should have same length"))
393426
394427 for idx, foreignKey := range foreignKeys {
395428 if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
396 if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil {
397 // source foreign keys
429 if associationField := getForeignField(associationForeignKeys[idx], allFields); associationField != nil {
430 // mark field as foreignkey, use global lock to avoid race
431 structsLock.Lock()
398432 foreignField.IsForeignKey = true
433 structsLock.Unlock()
434
435 // association foreign keys
399436 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
400437 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)
401438
427464 tagAssociationForeignKeys []string
428465 )
429466
430 if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
467 if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" {
431468 tagForeignKeys = strings.Split(foreignKey, ",")
432469 }
433470
434 if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
471 if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" {
435472 tagAssociationForeignKeys = strings.Split(foreignKey, ",")
436 } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
473 } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" {
437474 tagAssociationForeignKeys = strings.Split(foreignKey, ",")
438475 }
439476
440 if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
477 if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" {
441478 // Cat has one toy, tag polymorphic is Owner, then associationType is Owner
442479 // Toy use OwnerID, OwnerType ('cats') as foreign key
443480 if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
445482 relationship.PolymorphicType = polymorphicType.Name
446483 relationship.PolymorphicDBName = polymorphicType.DBName
447484 // if Cat has several different types of toys set name for each (instead of default 'cats')
448 if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
485 if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok {
449486 relationship.PolymorphicValue = value
450487 } else {
451488 relationship.PolymorphicValue = scope.TableName()
469506 } else {
470507 // generate foreign keys form association foreign keys
471508 for _, associationForeignKey := range tagAssociationForeignKeys {
472 if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
509 if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil {
473510 foreignKeys = append(foreignKeys, associationType+foreignField.Name)
474511 associationForeignKeys = append(associationForeignKeys, foreignField.Name)
475512 }
481518 for _, foreignKey := range foreignKeys {
482519 if strings.HasPrefix(foreignKey, associationType) {
483520 associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
484 if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
521 if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil {
485522 associationForeignKeys = append(associationForeignKeys, associationForeignKey)
486523 }
487524 }
488525 }
489526 if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
490 associationForeignKeys = []string{scope.PrimaryKey()}
527 associationForeignKeys = []string{rootScope.PrimaryKey()}
491528 }
492529 } else if len(foreignKeys) != len(associationForeignKeys) {
493530 scope.Err(errors.New("invalid foreign keys, should have same length"))
497534
498535 for idx, foreignKey := range foreignKeys {
499536 if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
500 if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil {
537 if scopeField := getForeignField(associationForeignKeys[idx], allFields); scopeField != nil {
538 // mark field as foreignkey, use global lock to avoid race
539 structsLock.Lock()
501540 foreignField.IsForeignKey = true
502 // source foreign keys
541 structsLock.Unlock()
542
543 // association foreign keys
503544 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name)
504545 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName)
505546
557598 for idx, foreignKey := range foreignKeys {
558599 if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
559600 if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil {
601 // mark field as foreignkey, use global lock to avoid race
602 structsLock.Lock()
560603 foreignField.IsForeignKey = true
604 structsLock.Unlock()
561605
562606 // association foreign keys
563607 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
583627 }
584628
585629 // Even it is ignored, also possible to decode db value into the field
586 if value, ok := field.TagSettings["COLUMN"]; ok {
630 if value, ok := field.TagSettingsGet("COLUMN"); ok {
587631 field.DBName = value
588632 } else {
589 field.DBName = ToDBName(fieldStruct.Name)
633 field.DBName = ToColumnName(fieldStruct.Name)
590634 }
591635
592636 modelStruct.StructFields = append(modelStruct.StructFields, field)
637 allFields = append(allFields, field)
593638 }
594639 }
595640
600645 }
601646 }
602647
603 modelStructsMap.Set(reflectType, &modelStruct)
648 modelStructsMap.Store(hashKey, &modelStruct)
604649
605650 return &modelStruct
606651 }
613658 func parseTagSetting(tags reflect.StructTag) map[string]string {
614659 setting := map[string]string{}
615660 for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
661 if str == "" {
662 continue
663 }
616664 tags := strings.Split(str, ";")
617665 for _, value := range tags {
618666 v := strings.Split(value, ":")
0 package gorm_test
1
2 import (
3 "sync"
4 "testing"
5
6 "github.com/jinzhu/gorm"
7 )
8
9 type ModelA struct {
10 gorm.Model
11 Name string
12
13 ModelCs []ModelC `gorm:"foreignkey:OtherAID"`
14 }
15
16 type ModelB struct {
17 gorm.Model
18 Name string
19
20 ModelCs []ModelC `gorm:"foreignkey:OtherBID"`
21 }
22
23 type ModelC struct {
24 gorm.Model
25 Name string
26
27 OtherAID uint64
28 OtherA *ModelA `gorm:"foreignkey:OtherAID"`
29 OtherBID uint64
30 OtherB *ModelB `gorm:"foreignkey:OtherBID"`
31 }
32
33 type RequestModel struct {
34 Name string
35 Children []ChildModel `gorm:"foreignkey:ParentID"`
36 }
37
38 type ChildModel struct {
39 ID string
40 ParentID string
41 Name string
42 }
43
44 type ResponseModel struct {
45 gorm.Model
46 RequestModel
47 }
48
49 // This test will try to cause a race condition on the model's foreignkey metadata
50 func TestModelStructRaceSameModel(t *testing.T) {
51 // use a WaitGroup to execute as much in-sync as possible
52 // it's more likely to hit a race condition than without
53 n := 32
54 start := sync.WaitGroup{}
55 start.Add(n)
56
57 // use another WaitGroup to know when the test is done
58 done := sync.WaitGroup{}
59 done.Add(n)
60
61 for i := 0; i < n; i++ {
62 go func() {
63 start.Wait()
64
65 // call GetStructFields, this had a race condition before we fixed it
66 DB.NewScope(&ModelA{}).GetStructFields()
67
68 done.Done()
69 }()
70
71 start.Done()
72 }
73
74 done.Wait()
75 }
76
77 // This test will try to cause a race condition on the model's foreignkey metadata
78 func TestModelStructRaceDifferentModel(t *testing.T) {
79 // use a WaitGroup to execute as much in-sync as possible
80 // it's more likely to hit a race condition than without
81 n := 32
82 start := sync.WaitGroup{}
83 start.Add(n)
84
85 // use another WaitGroup to know when the test is done
86 done := sync.WaitGroup{}
87 done.Add(n)
88
89 for i := 0; i < n; i++ {
90 i := i
91 go func() {
92 start.Wait()
93
94 // call GetStructFields, this had a race condition before we fixed it
95 if i%2 == 0 {
96 DB.NewScope(&ModelA{}).GetStructFields()
97 } else {
98 DB.NewScope(&ModelB{}).GetStructFields()
99 }
100
101 done.Done()
102 }()
103
104 start.Done()
105 }
106
107 done.Wait()
108 }
109
110 func TestModelStructEmbeddedHasMany(t *testing.T) {
111 fields := DB.NewScope(&ResponseModel{}).GetStructFields()
112
113 var childrenField *gorm.StructField
114
115 for i := 0; i < len(fields); i++ {
116 field := fields[i]
117
118 if field != nil && field.Name == "Children" {
119 childrenField = field
120 }
121 }
122
123 if childrenField == nil {
124 t.Error("childrenField should not be nil")
125 return
126 }
127
128 if childrenField.Relationship == nil {
129 t.Error("childrenField.Relation should not be nil")
130 return
131 }
132
133 expected := "has_many"
134 actual := childrenField.Relationship.Kind
135
136 if actual != expected {
137 t.Errorf("childrenField.Relationship.Kind should be %v, but was %v", expected, actual)
138 }
139 }
0 package gorm
1
2 import (
3 "bytes"
4 "strings"
5 )
6
7 // Namer is a function type which is given a string and return a string
8 type Namer func(string) string
9
10 // NamingStrategy represents naming strategies
11 type NamingStrategy struct {
12 DB Namer
13 Table Namer
14 Column Namer
15 }
16
17 // TheNamingStrategy is being initialized with defaultNamingStrategy
18 var TheNamingStrategy = &NamingStrategy{
19 DB: defaultNamer,
20 Table: defaultNamer,
21 Column: defaultNamer,
22 }
23
24 // AddNamingStrategy sets the naming strategy
25 func AddNamingStrategy(ns *NamingStrategy) {
26 if ns.DB == nil {
27 ns.DB = defaultNamer
28 }
29 if ns.Table == nil {
30 ns.Table = defaultNamer
31 }
32 if ns.Column == nil {
33 ns.Column = defaultNamer
34 }
35 TheNamingStrategy = ns
36 }
37
38 // DBName alters the given name by DB
39 func (ns *NamingStrategy) DBName(name string) string {
40 return ns.DB(name)
41 }
42
43 // TableName alters the given name by Table
44 func (ns *NamingStrategy) TableName(name string) string {
45 return ns.Table(name)
46 }
47
48 // ColumnName alters the given name by Column
49 func (ns *NamingStrategy) ColumnName(name string) string {
50 return ns.Column(name)
51 }
52
53 // ToDBName convert string to db name
54 func ToDBName(name string) string {
55 return TheNamingStrategy.DBName(name)
56 }
57
58 // ToTableName convert string to table name
59 func ToTableName(name string) string {
60 return TheNamingStrategy.TableName(name)
61 }
62
63 // ToColumnName convert string to db name
64 func ToColumnName(name string) string {
65 return TheNamingStrategy.ColumnName(name)
66 }
67
68 var smap = newSafeMap()
69
70 func defaultNamer(name string) string {
71 const (
72 lower = false
73 upper = true
74 )
75
76 if v := smap.Get(name); v != "" {
77 return v
78 }
79
80 if name == "" {
81 return ""
82 }
83
84 var (
85 value = commonInitialismsReplacer.Replace(name)
86 buf = bytes.NewBufferString("")
87 lastCase, currCase, nextCase, nextNumber bool
88 )
89
90 for i, v := range value[:len(value)-1] {
91 nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z')
92 nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9')
93
94 if i > 0 {
95 if currCase == upper {
96 if lastCase == upper && (nextCase == upper || nextNumber == upper) {
97 buf.WriteRune(v)
98 } else {
99 if value[i-1] != '_' && value[i+1] != '_' {
100 buf.WriteRune('_')
101 }
102 buf.WriteRune(v)
103 }
104 } else {
105 buf.WriteRune(v)
106 if i == len(value)-2 && (nextCase == upper && nextNumber == lower) {
107 buf.WriteRune('_')
108 }
109 }
110 } else {
111 currCase = upper
112 buf.WriteRune(v)
113 }
114 lastCase = currCase
115 currCase = nextCase
116 }
117
118 buf.WriteByte(value[len(value)-1])
119
120 s := strings.ToLower(buf.String())
121 smap.Set(name, s)
122 return s
123 }
0 package gorm_test
1
2 import (
3 "testing"
4
5 "github.com/jinzhu/gorm"
6 )
7
8 func TestTheNamingStrategy(t *testing.T) {
9
10 cases := []struct {
11 name string
12 namer gorm.Namer
13 expected string
14 }{
15 {name: "auth", expected: "auth", namer: gorm.TheNamingStrategy.DB},
16 {name: "userRestrictions", expected: "user_restrictions", namer: gorm.TheNamingStrategy.Table},
17 {name: "clientID", expected: "client_id", namer: gorm.TheNamingStrategy.Column},
18 }
19
20 for _, c := range cases {
21 t.Run(c.name, func(t *testing.T) {
22 result := c.namer(c.name)
23 if result != c.expected {
24 t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result)
25 }
26 })
27 }
28
29 }
30
31 func TestNamingStrategy(t *testing.T) {
32
33 dbNameNS := func(name string) string {
34 return "db_" + name
35 }
36 tableNameNS := func(name string) string {
37 return "tbl_" + name
38 }
39 columnNameNS := func(name string) string {
40 return "col_" + name
41 }
42
43 ns := &gorm.NamingStrategy{
44 DB: dbNameNS,
45 Table: tableNameNS,
46 Column: columnNameNS,
47 }
48
49 cases := []struct {
50 name string
51 namer gorm.Namer
52 expected string
53 }{
54 {name: "auth", expected: "db_auth", namer: ns.DB},
55 {name: "user", expected: "tbl_user", namer: ns.Table},
56 {name: "password", expected: "col_password", namer: ns.Column},
57 }
58
59 for _, c := range cases {
60 t.Run(c.name, func(t *testing.T) {
61 result := c.namer(c.name)
62 if result != c.expected {
63 t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result)
64 }
65 })
66 }
67
68 }
122122 }
123123 }
124124
125 func TestAutoPreloadFalseDoesntPreload(t *testing.T) {
126 user1 := getPreloadUser("auto_user1")
127 DB.Save(user1)
128
129 preloadDB := DB.Set("gorm:auto_preload", false).Where("role = ?", "Preload")
130 var user User
131 preloadDB.Find(&user)
132
133 if user.BillingAddress.Address1 != "" {
134 t.Error("AutoPreload was set to fasle, but still fetched data")
135 }
136
137 user2 := getPreloadUser("auto_user2")
138 DB.Save(user2)
139
140 var users []User
141 preloadDB.Find(&users)
142
143 for _, user := range users {
144 if user.BillingAddress.Address1 != "" {
145 t.Error("AutoPreload was set to fasle, but still fetched data")
146 }
147 }
148 }
149
125150 func TestNestedPreload1(t *testing.T) {
126151 type (
127152 Level1 struct {
745770 levelB3 := &LevelB3{
746771 Value: "bar",
747772 LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)},
773 LevelB2s: []*LevelB2{},
748774 }
749775 if err := DB.Create(levelB3).Error; err != nil {
750776 t.Error(err)
16501676 lvl := Level1{
16511677 Name: "l1",
16521678 Level2s: []Level2{
1653 Level2{Name: "l2-1"}, Level2{Name: "l2-2"},
1679 {Name: "l2-1"}, {Name: "l2-2"},
16541680 },
16551681 }
16561682 DB.Save(&lvl)
3737
3838 if user.Email != "" {
3939 t.Errorf("User's Email should be blank as no one set it")
40 }
41 }
42
43 func TestQueryWithAssociation(t *testing.T) {
44 user := &User{Name: "user1", Emails: []Email{{Email: "user1@example.com"}}, Company: Company{Name: "company"}}
45
46 if err := DB.Create(&user).Error; err != nil {
47 t.Fatalf("errors happened when create user: %v", err)
48 }
49
50 user.CreatedAt = time.Time{}
51 user.UpdatedAt = time.Time{}
52 if err := DB.Where(&user).First(&User{}).Error; err != nil {
53 t.Errorf("search with struct with association should returns no error, but got %v", err)
54 }
55
56 if err := DB.Where(user).First(&User{}).Error; err != nil {
57 t.Errorf("search with struct with association should returns no error, but got %v", err)
4058 }
4159 }
4260
180198
181199 scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users)
182200 if len(users) != 2 {
183 t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users))
201 t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users))
184202 }
185203
186204 scopedb.Where("birthday > ?", "2002-10-10").Find(&users)
187205 if len(users) != 2 {
188 t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users))
206 t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users))
189207 }
190208
191209 scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users)
192210 if len(users) != 1 {
193 t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
211 t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
194212 }
195213
196214 DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users)
456474 }
457475 }
458476
477 func TestLimitAndOffsetSQL(t *testing.T) {
478 user1 := User{Name: "TestLimitAndOffsetSQL1", Age: 10}
479 user2 := User{Name: "TestLimitAndOffsetSQL2", Age: 20}
480 user3 := User{Name: "TestLimitAndOffsetSQL3", Age: 30}
481 user4 := User{Name: "TestLimitAndOffsetSQL4", Age: 40}
482 user5 := User{Name: "TestLimitAndOffsetSQL5", Age: 50}
483 if err := DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5).Error; err != nil {
484 t.Fatal(err)
485 }
486
487 tests := []struct {
488 name string
489 limit, offset interface{}
490 users []*User
491 ok bool
492 }{
493 {
494 name: "OK",
495 limit: float64(2),
496 offset: float64(2),
497 users: []*User{
498 &User{Name: "TestLimitAndOffsetSQL3", Age: 30},
499 &User{Name: "TestLimitAndOffsetSQL2", Age: 20},
500 },
501 ok: true,
502 },
503 {
504 name: "Limit parse error",
505 limit: float64(1000000), // 1e+06
506 offset: float64(2),
507 ok: false,
508 },
509 {
510 name: "Offset parse error",
511 limit: float64(2),
512 offset: float64(1000000), // 1e+06
513 ok: false,
514 },
515 }
516
517 for _, tt := range tests {
518 t.Run(tt.name, func(t *testing.T) {
519 var users []*User
520 err := DB.Where("name LIKE ?", "TestLimitAndOffsetSQL%").Order("age desc").Limit(tt.limit).Offset(tt.offset).Find(&users).Error
521 if tt.ok {
522 if err != nil {
523 t.Errorf("error expected nil, but got %v", err)
524 }
525 if len(users) != len(tt.users) {
526 t.Errorf("users length expected %d, but got %d", len(tt.users), len(users))
527 }
528 for i := range tt.users {
529 if users[i].Name != tt.users[i].Name {
530 t.Errorf("users[%d] name expected %s, but got %s", i, tt.users[i].Name, users[i].Name)
531 }
532 if users[i].Age != tt.users[i].Age {
533 t.Errorf("users[%d] age expected %d, but got %d", i, tt.users[i].Age, users[i].Age)
534 }
535 }
536 } else {
537 if err == nil {
538 t.Error("error expected not nil, but got nil")
539 }
540 }
541 })
542 }
543 }
544
459545 func TestOr(t *testing.T) {
460546 user1 := User{Name: "OrUser1", Age: 1}
461547 user2 := User{Name: "OrUser2", Age: 10}
531617 DB.Table("users").Where("name = ?", "user3").Count(&name3Count)
532618 DB.Not("name", "user3").Find(&users4)
533619 if len(users1)-len(users4) != int(name3Count) {
534 t.Errorf("Should find all users's name not equal 3")
620 t.Errorf("Should find all users' name not equal 3")
535621 }
536622
537623 DB.Not("name = ?", "user3").Find(&users4)
538624 if len(users1)-len(users4) != int(name3Count) {
539 t.Errorf("Should find all users's name not equal 3")
625 t.Errorf("Should find all users' name not equal 3")
540626 }
541627
542628 DB.Not("name <> ?", "user3").Find(&users4)
543629 if len(users4) != int(name3Count) {
544 t.Errorf("Should find all users's name not equal 3")
630 t.Errorf("Should find all users' name not equal 3")
545631 }
546632
547633 DB.Not(User{Name: "user3"}).Find(&users5)
548634
549635 if len(users1)-len(users5) != int(name3Count) {
550 t.Errorf("Should find all users's name not equal 3")
636 t.Errorf("Should find all users' name not equal 3")
551637 }
552638
553639 DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6)
554640 if len(users1)-len(users6) != int(name3Count) {
555 t.Errorf("Should find all users's name not equal 3")
641 t.Errorf("Should find all users' name not equal 3")
556642 }
557643
558644 DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7)
562648
563649 DB.Not("name", []string{"user3"}).Find(&users8)
564650 if len(users1)-len(users8) != int(name3Count) {
565 t.Errorf("Should find all users's name not equal 3")
651 t.Errorf("Should find all users' name not equal 3")
566652 }
567653
568654 var name2Count int64
569655 DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
570656 DB.Not("name", []string{"user3", "user2"}).Find(&users9)
571657 if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) {
572 t.Errorf("Should find all users's name not equal 3")
658 t.Errorf("Should find all users' name not equal 3")
573659 }
574660 }
575661
6262
6363 // Dialect get dialect
6464 func (scope *Scope) Dialect() Dialect {
65 return scope.db.parent.dialect
65 return scope.db.dialect
6666 }
6767
6868 // Quote used to quote string to escape them for database
6969 func (scope *Scope) Quote(str string) string {
70 if strings.Index(str, ".") != -1 {
70 if strings.Contains(str, ".") {
7171 newStrs := []string{}
7272 for _, str := range strings.Split(str, ".") {
7373 newStrs = append(newStrs, scope.Dialect().Quote(str))
133133 // FieldByName find `gorm.Field` with field name or db name
134134 func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
135135 var (
136 dbName = ToDBName(name)
136 dbName = ToColumnName(name)
137137 mostMatchedField *Field
138138 )
139139
224224 updateAttrs[field.DBName] = value
225225 return field.Set(value)
226226 }
227 if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) {
227 if !field.IsIgnored && ((field.DBName == dbName) || (field.Name == name && mostMatchedField == nil)) {
228228 mostMatchedField = field
229229 }
230230 }
256256 func (scope *Scope) AddToVars(value interface{}) string {
257257 _, skipBindVar := scope.InstanceGet("skip_bindvar")
258258
259 if expr, ok := value.(*expr); ok {
259 if expr, ok := value.(*SqlExpr); ok {
260260 exp := expr.expr
261261 for _, arg := range expr.args {
262262 if skipBindVar {
329329 // QuotedTableName return quoted table name
330330 func (scope *Scope) QuotedTableName() (name string) {
331331 if scope.Search != nil && len(scope.Search.tableName) > 0 {
332 if strings.Index(scope.Search.tableName, " ") != -1 {
332 if strings.Contains(scope.Search.tableName, " ") {
333333 return scope.Search.tableName
334334 }
335335 return scope.Quote(scope.Search.tableName)
401401 // Begin start a transaction
402402 func (scope *Scope) Begin() *Scope {
403403 if db, ok := scope.SQLDB().(sqlDb); ok {
404 if tx, err := db.Begin(); err == nil {
404 if tx, err := db.Begin(); scope.Err(err) == nil {
405405 scope.db.db = interface{}(tx).(SQLCommon)
406406 scope.InstanceSet("gorm:started_transaction", true)
407407 }
485485 values[index] = &ignored
486486
487487 selectFields = fields
488 offset := 0
488489 if idx, ok := selectedColumnsMap[column]; ok {
489 selectFields = selectFields[idx+1:]
490 offset = idx + 1
491 selectFields = selectFields[offset:]
490492 }
491493
492494 for fieldIndex, field := range selectFields {
500502 resetFields[index] = field
501503 }
502504
503 selectedColumnsMap[column] = fieldIndex
505 selectedColumnsMap[column] = offset + fieldIndex
504506
505507 if field.IsNormal {
506508 break
585587 scope.Err(fmt.Errorf("invalid query condition: %v", value))
586588 return
587589 }
588
590 scopeQuotedTableName := newScope.QuotedTableName()
589591 for _, field := range newScope.Fields() {
590 if !field.IsIgnored && !field.IsBlank {
591 sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
592 if !field.IsIgnored && !field.IsBlank && field.Relationship == nil {
593 sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
592594 }
593595 }
594596 return strings.Join(sqls, " AND ")
691693
692694 buff := bytes.NewBuffer([]byte{})
693695 i := 0
694 for pos := range str {
696 for pos, char := range str {
695697 if str[pos] == '?' {
696698 buff.WriteString(replacements[i])
697699 i++
698700 } else {
699 buff.WriteByte(str[pos])
701 buff.WriteRune(char)
700702 }
701703 }
702704
782784 for _, order := range scope.Search.orders {
783785 if str, ok := order.(string); ok {
784786 orders = append(orders, scope.quoteIfPossible(str))
785 } else if expr, ok := order.(*expr); ok {
787 } else if expr, ok := order.(*SqlExpr); ok {
786788 exp := expr.expr
787789 for _, arg := range expr.args {
788790 exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
794796 }
795797
796798 func (scope *Scope) limitAndOffsetSQL() string {
797 return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
799 sql, err := scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
800 scope.Err(err)
801 return sql
798802 }
799803
800804 func (scope *Scope) groupSQL() string {
852856 }
853857
854858 func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
859 defer func() {
860 if err := recover(); err != nil {
861 if db, ok := scope.db.db.(sqlTx); ok {
862 db.Rollback()
863 }
864 panic(err)
865 }
866 }()
855867 for _, f := range funcs {
856868 (*f)(scope)
857869 if scope.skipLeft {
861873 return scope
862874 }
863875
864 func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string]interface{} {
876 func convertInterfaceToMap(values interface{}, withIgnoredField bool, db *DB) map[string]interface{} {
865877 var attrs = map[string]interface{}{}
866878
867879 switch value := values.(type) {
869881 return value
870882 case []interface{}:
871883 for _, v := range value {
872 for key, value := range convertInterfaceToMap(v, withIgnoredField) {
884 for key, value := range convertInterfaceToMap(v, withIgnoredField, db) {
873885 attrs[key] = value
874886 }
875887 }
879891 switch reflectValue.Kind() {
880892 case reflect.Map:
881893 for _, key := range reflectValue.MapKeys() {
882 attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
894 attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
883895 }
884896 default:
885 for _, field := range (&Scope{Value: values}).Fields() {
897 for _, field := range (&Scope{Value: values, db: db}).Fields() {
886898 if !field.IsBlank && (withIgnoredField || !field.IsIgnored) {
887899 attrs[field.DBName] = field.Field.Interface()
888900 }
894906
895907 func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) {
896908 if scope.IndirectValue().Kind() != reflect.Struct {
897 return convertInterfaceToMap(value, false), true
909 return convertInterfaceToMap(value, false, scope.db), true
898910 }
899911
900912 results = map[string]interface{}{}
901913
902 for key, value := range convertInterfaceToMap(value, true) {
903 if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
904 if _, ok := value.(*expr); ok {
905 hasUpdate = true
906 results[field.DBName] = value
907 } else {
908 err := field.Set(value)
909 if field.IsNormal {
914 for key, value := range convertInterfaceToMap(value, true, scope.db) {
915 if field, ok := scope.FieldByName(key); ok {
916 if scope.changeableField(field) {
917 if _, ok := value.(*SqlExpr); ok {
910918 hasUpdate = true
911 if err == ErrUnaddressable {
912 results[field.DBName] = value
913 } else {
914 results[field.DBName] = field.Field.Interface()
919 results[field.DBName] = value
920 } else {
921 err := field.Set(value)
922 if field.IsNormal && !field.IsIgnored {
923 hasUpdate = true
924 if err == ErrUnaddressable {
925 results[field.DBName] = value
926 } else {
927 results[field.DBName] = field.Field.Interface()
928 }
915929 }
916930 }
917931 }
932 } else {
933 results[key] = value
918934 }
919935 }
920936 return
971987 if dest.Kind() != reflect.Slice {
972988 scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
973989 return scope
990 }
991
992 if dest.Len() > 0 {
993 dest.Set(reflect.Zero(dest.Type()))
974994 }
975995
976996 if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) {
9961016 func (scope *Scope) count(value interface{}) *Scope {
9971017 if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) {
9981018 if len(scope.Search.group) != 0 {
999 scope.Search.Select("count(*) FROM ( SELECT count(*) as name ")
1000 scope.Search.group += " ) AS count_table"
1019 if len(scope.Search.havingConditions) != 0 {
1020 scope.prepareQuerySQL()
1021 scope.Search = &search{}
1022 scope.Search.Select("count(*)")
1023 scope.Search.Table(fmt.Sprintf("( %s ) AS count_table", scope.SQL))
1024 } else {
1025 scope.Search.Select("count(*) FROM ( SELECT count(*) as name ")
1026 scope.Search.group += " ) AS count_table"
1027 }
10011028 } else {
10021029 scope.Search.Select("count(*)")
10031030 }
11121139 if field, ok := scope.FieldByName(fieldName); ok {
11131140 foreignKeyStruct := field.clone()
11141141 foreignKeyStruct.IsPrimaryKey = false
1115 foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
1116 delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT")
1142 foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true")
1143 foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT")
11171144 sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
11181145 primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
11191146 }
11231150 if field, ok := toScope.FieldByName(fieldName); ok {
11241151 foreignKeyStruct := field.clone()
11251152 foreignKeyStruct.IsPrimaryKey = false
1126 foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
1127 delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT")
1153 foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true")
1154 foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT")
11281155 sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
11291156 primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
11301157 }
11721199 }
11731200
11741201 func (scope *Scope) dropTable() *Scope {
1175 scope.Raw(fmt.Sprintf("DROP TABLE %v%s", scope.QuotedTableName(), scope.getTableOptions())).Exec()
1202 scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
11761203 return scope
11771204 }
11781205
12141241 }
12151242
12161243 func (scope *Scope) removeForeignKey(field string, dest string) {
1217 keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest)
1218
1244 keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
12191245 if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
12201246 return
12211247 }
1222 var query = `ALTER TABLE %s DROP CONSTRAINT %s;`
1248 var mysql mysql
1249 var query string
1250 if scope.Dialect().GetName() == mysql.GetName() {
1251 query = `ALTER TABLE %s DROP FOREIGN KEY %s;`
1252 } else {
1253 query = `ALTER TABLE %s DROP CONSTRAINT %s;`
1254 }
1255
12231256 scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
12241257 }
12251258
12531286 var uniqueIndexes = map[string][]string{}
12541287
12551288 for _, field := range scope.GetStructFields() {
1256 if name, ok := field.TagSettings["INDEX"]; ok {
1289 if name, ok := field.TagSettingsGet("INDEX"); ok {
12571290 names := strings.Split(name, ",")
12581291
12591292 for _, name := range names {
12601293 if name == "INDEX" || name == "" {
12611294 name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName)
12621295 }
1263 indexes[name] = append(indexes[name], field.DBName)
1264 }
1265 }
1266
1267 if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok {
1296 name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName)
1297 indexes[name] = append(indexes[name], column)
1298 }
1299 }
1300
1301 if name, ok := field.TagSettingsGet("UNIQUE_INDEX"); ok {
12681302 names := strings.Split(name, ",")
12691303
12701304 for _, name := range names {
12711305 if name == "UNIQUE_INDEX" || name == "" {
12721306 name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName)
12731307 }
1274 uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
1308 name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName)
1309 uniqueIndexes[name] = append(uniqueIndexes[name], column)
12751310 }
12761311 }
12771312 }
12921327 }
12931328
12941329 func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
1330 resultMap := make(map[string][]interface{})
12951331 for _, value := range values {
12961332 indirectValue := indirect(reflect.ValueOf(value))
12971333
13101346 }
13111347
13121348 if hasValue {
1313 results = append(results, result)
1349 h := fmt.Sprint(result...)
1350 if _, exist := resultMap[h]; !exist {
1351 resultMap[h] = result
1352 }
13141353 }
13151354 }
13161355 case reflect.Struct:
13251364 }
13261365
13271366 if hasValue {
1328 results = append(results, result)
1329 }
1330 }
1331 }
1332
1367 h := fmt.Sprint(result...)
1368 if _, exist := resultMap[h]; !exist {
1369 resultMap[h] = result
1370 }
1371 }
1372 }
1373 }
1374 for _, v := range resultMap {
1375 results = append(results, v)
1376 }
13331377 return
13341378 }
13351379
7777 t.Errorf("The error should be returned from Valuer, but get %v", err)
7878 }
7979 }
80
81 func TestDropTableWithTableOptions(t *testing.T) {
82 type UserWithOptions struct {
83 gorm.Model
84 }
85 DB.AutoMigrate(&UserWithOptions{})
86
87 DB = DB.Set("gorm:table_options", "CHARSET=utf8")
88 err := DB.DropTable(&UserWithOptions{}).Error
89 if err != nil {
90 t.Errorf("Table must be dropped, got error %s", err)
91 }
92 }
3131 }
3232
3333 func (s *search) clone() *search {
34 clone := *s
34 clone := search{
35 db: s.db,
36 whereConditions: make([]map[string]interface{}, len(s.whereConditions)),
37 orConditions: make([]map[string]interface{}, len(s.orConditions)),
38 notConditions: make([]map[string]interface{}, len(s.notConditions)),
39 havingConditions: make([]map[string]interface{}, len(s.havingConditions)),
40 joinConditions: make([]map[string]interface{}, len(s.joinConditions)),
41 initAttrs: make([]interface{}, len(s.initAttrs)),
42 assignAttrs: make([]interface{}, len(s.assignAttrs)),
43 selects: s.selects,
44 omits: make([]string, len(s.omits)),
45 orders: make([]interface{}, len(s.orders)),
46 preload: make([]searchPreload, len(s.preload)),
47 offset: s.offset,
48 limit: s.limit,
49 group: s.group,
50 tableName: s.tableName,
51 raw: s.raw,
52 Unscoped: s.Unscoped,
53 ignoreOrderQuery: s.ignoreOrderQuery,
54 }
55 for i, value := range s.whereConditions {
56 clone.whereConditions[i] = value
57 }
58 for i, value := range s.orConditions {
59 clone.orConditions[i] = value
60 }
61 for i, value := range s.notConditions {
62 clone.notConditions[i] = value
63 }
64 for i, value := range s.havingConditions {
65 clone.havingConditions[i] = value
66 }
67 for i, value := range s.joinConditions {
68 clone.joinConditions[i] = value
69 }
70 for i, value := range s.initAttrs {
71 clone.initAttrs[i] = value
72 }
73 for i, value := range s.assignAttrs {
74 clone.assignAttrs[i] = value
75 }
76 for i, value := range s.omits {
77 clone.omits[i] = value
78 }
79 for i, value := range s.orders {
80 clone.orders[i] = value
81 }
82 for i, value := range s.preload {
83 clone.preload[i] = value
84 }
3585 return &clone
3686 }
3787
97147 }
98148
99149 func (s *search) Having(query interface{}, values ...interface{}) *search {
100 if val, ok := query.(*expr); ok {
150 if val, ok := query.(*SqlExpr); ok {
101151 s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args})
102152 } else {
103153 s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values})
00 package gorm
11
22 import (
3 "fmt"
34 "reflect"
45 "testing"
56 )
2728 t.Errorf("selectStr should be copied")
2829 }
2930 }
31
32 func TestWhereCloneCorruption(t *testing.T) {
33 for whereCount := 1; whereCount <= 8; whereCount++ {
34 t.Run(fmt.Sprintf("w=%d", whereCount), func(t *testing.T) {
35 s := new(search)
36 for w := 0; w < whereCount; w++ {
37 s = s.clone().Where(fmt.Sprintf("w%d = ?", w), fmt.Sprintf("value%d", w))
38 }
39 if len(s.whereConditions) != whereCount {
40 t.Errorf("s: where count should be %d", whereCount)
41 }
42
43 q1 := s.clone().Where("finalThing = ?", "THING1")
44 q2 := s.clone().Where("finalThing = ?", "THING2")
45
46 if reflect.DeepEqual(q1.whereConditions, q2.whereConditions) {
47 t.Errorf("Where conditions should be different")
48 }
49 })
50 }
51 }
00 package gorm
11
22 import (
3 "bytes"
43 "database/sql/driver"
54 "fmt"
65 "reflect"
2524 var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
2625 var commonInitialismsReplacer *strings.Replacer
2726
28 var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)
29 var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`)
27 var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`)
28 var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`)
3029
3130 func init() {
3231 var commonInitialismsForReplacer []string
5756 return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
5857 }
5958
60 var smap = newSafeMap()
61
62 type strCase bool
63
64 const (
65 lower strCase = false
66 upper strCase = true
67 )
68
69 // ToDBName convert string to db name
70 func ToDBName(name string) string {
71 if v := smap.Get(name); v != "" {
72 return v
73 }
74
75 if name == "" {
76 return ""
77 }
78
79 var (
80 value = commonInitialismsReplacer.Replace(name)
81 buf = bytes.NewBufferString("")
82 lastCase, currCase, nextCase strCase
83 )
84
85 for i, v := range value[:len(value)-1] {
86 nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z')
87 if i > 0 {
88 if currCase == upper {
89 if lastCase == upper && nextCase == upper {
90 buf.WriteRune(v)
91 } else {
92 if value[i-1] != '_' && value[i+1] != '_' {
93 buf.WriteRune('_')
94 }
95 buf.WriteRune(v)
96 }
97 } else {
98 buf.WriteRune(v)
99 if i == len(value)-2 && nextCase == upper {
100 buf.WriteRune('_')
101 }
102 }
103 } else {
104 currCase = upper
105 buf.WriteRune(v)
106 }
107 lastCase = currCase
108 currCase = nextCase
109 }
110
111 buf.WriteByte(value[len(value)-1])
112
113 s := strings.ToLower(buf.String())
114 smap.Set(name, s)
115 return s
116 }
117
11859 // SQL expression
119 type expr struct {
60 type SqlExpr struct {
12061 expr string
12162 args []interface{}
12263 }
12364
12465 // Expr generate raw SQL expression, for example:
12566 // DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100))
126 func Expr(expression string, args ...interface{}) *expr {
127 return &expr{expr: expression, args: args}
67 func Expr(expression string, args ...interface{}) *SqlExpr {
68 return &SqlExpr{expr: expression, args: args}
12869 }
12970
13071 func indirect(reflectValue reflect.Value) reflect.Value {
264205 // as FieldByName could panic
265206 if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
266207 for _, fieldName := range fieldNames {
267 if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() {
208 if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() {
268209 result := fieldValue.Interface()
269210 if r, ok := result.(driver.Valuer); ok {
270211 result, _ = r.Value()
+0
-32
utils_test.go less more
0 package gorm_test
1
2 import (
3 "testing"
4
5 "github.com/jinzhu/gorm"
6 )
7
8 func TestToDBNameGenerateFriendlyName(t *testing.T) {
9 var maps = map[string]string{
10 "": "",
11 "X": "x",
12 "ThisIsATest": "this_is_a_test",
13 "PFAndESI": "pf_and_esi",
14 "AbcAndJkl": "abc_and_jkl",
15 "EmployeeID": "employee_id",
16 "SKU_ID": "sku_id",
17 "FieldX": "field_x",
18 "HTTPAndSMTP": "http_and_smtp",
19 "HTTPServerHandlerForURLID": "http_server_handler_for_url_id",
20 "UUID": "uuid",
21 "HTTPURL": "http_url",
22 "HTTP_URL": "http_url",
23 "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
24 }
25
26 for key, value := range maps {
27 if gorm.ToDBName(key) != value {
28 t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key))
29 }
30 }
31 }
33 services:
44 - name: mariadb
55 id: mariadb:latest
6 env:
7 MYSQL_DATABASE: gorm
8 MYSQL_USER: gorm
9 MYSQL_PASSWORD: gorm
10 MYSQL_RANDOM_ROOT_PASSWORD: "yes"
11 - name: mysql
12 id: mysql:latest
613 env:
714 MYSQL_DATABASE: gorm
815 MYSQL_USER: gorm
1724 MYSQL_RANDOM_ROOT_PASSWORD: "yes"
1825 - name: mysql56
1926 id: mysql:5.6
20 env:
21 MYSQL_DATABASE: gorm
22 MYSQL_USER: gorm
23 MYSQL_PASSWORD: gorm
24 MYSQL_RANDOM_ROOT_PASSWORD: "yes"
25 - name: mysql55
26 id: mysql:5.5
2727 env:
2828 MYSQL_DATABASE: gorm
2929 MYSQL_USER: gorm
8282 code: |
8383 cd $WERCKER_SOURCE_DIR
8484 go version
85 go get -t ./...
85 go get -t -v ./...
8686
8787 # Build the project
8888 - script:
9494 - script:
9595 name: test sqlite
9696 code: |
97 go test ./...
97 go test -race -v ./...
9898
9999 - script:
100100 name: test mariadb
101101 code: |
102 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./...
102 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
103
104 - script:
105 name: test mysql
106 code: |
107 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
103108
104109 - script:
105110 name: test mysql5.7
106111 code: |
107 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test ./...
112 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
108113
109114 - script:
110115 name: test mysql5.6
111116 code: |
112 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test ./...
113
114 - script:
115 name: test mysql5.5
116 code: |
117 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test ./...
117 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
118118
119119 - script:
120120 name: test postgres
121121 code: |
122 GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
122 GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
123123
124124 - script:
125125 name: test postgres96
126126 code: |
127 GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
127 GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
128128
129129 - script:
130130 name: test postgres95
131131 code: |
132 GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
132 GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
133133
134134 - script:
135135 name: test postgres94
136136 code: |
137 GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
137 GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
138138
139139 - script:
140140 name: test postgres93
141141 code: |
142 GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
142 GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
143143
144144 - script:
145 name: test mssql
145 name: codecov
146146 code: |
147 GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test ./...
147 go test -race -coverprofile=coverage.txt -covermode=atomic ./...
148 bash <(curl -s https://codecov.io/bash)