Codebase list golang-github-jinzhu-gorm / 00fe5d3
Import upstream version 1.9.16+git20211120.1.5c235b7 Debian Janitor 2 years ago
46 changed file(s) with 2321 addition(s) and 520 deletion(s). Raw diff Collapse all Expand all
+0
-11
.codeclimate.yml less more
0 ---
1 engines:
2 gofmt:
3 enabled: true
4 govet:
5 enabled: true
6 golint:
7 enabled: true
8 ratings:
9 paths:
10 - "**.go"
00 documents
1 coverage.txt
12 _book
00 # GORM
11
2 The fantastic ORM library for Golang, aims to be developer friendly.
2 GORM V2 moved to https://github.com/go-gorm/gorm
33
4 [![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm)
5 [![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b)
6 [![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
7 [![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm)
8 [![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm)
9 [![MIT license](http://img.shields.io/badge/license-MIT-brightgreen.svg)](http://opensource.org/licenses/MIT)
10 [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm)
11
12 ## Overview
13
14 * Full-Featured ORM (almost)
15 * Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism)
16 * Hooks (Before/After Create/Save/Update/Delete/Find)
17 * Preloading (eager loading)
18 * Transactions
19 * Composite Primary Key
20 * SQL Builder
21 * Auto Migrations
22 * Logger
23 * Extendable, write Plugins based on GORM callbacks
24 * Every feature comes with tests
25 * Developer Friendly
26
27 ## Getting Started
28
29 * GORM Guides [http://gorm.io](http://gorm.io)
30
31 ## Contributing
32
33 [You can help to deliver a better GORM, check out things you can do](http://gorm.io/contribute.html)
34
35 ## License
36
37 © Jinzhu, 2013~time.Now
38
39 Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License)
4 GORM V1 Doc https://v1.gorm.io/
266266 query = scope.DB()
267267 )
268268
269 if relationship.Kind == "many_to_many" {
269 switch relationship.Kind {
270 case "many_to_many":
270271 query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
271 } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
272 case "has_many", "has_one":
272273 primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
273274 query = query.Where(
274275 fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
275276 toQueryValues(primaryKeys)...,
276277 )
277 } else if relationship.Kind == "belongs_to" {
278 case "belongs_to":
278279 primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
279280 query = query.Where(
280281 fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
366367 return association
367368 }
368369
370 // setErr set error when the error is not nil. And return Association.
369371 func (association *Association) setErr(err error) *Association {
370372 if err != nil {
371373 association.Error = err
00 package gorm
11
2 import "log"
2 import "fmt"
33
44 // DefaultCallback default callbacks defined by gorm
5 var DefaultCallback = &Callback{}
5 var DefaultCallback = &Callback{logger: nopLogger{}}
66
77 // Callback is a struct that contains all CRUD callbacks
88 // Field `creates` contains callbacks will be call when creating object
1212 // Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
1313 // Field `processors` contains all callback processors, will be used to generate above callbacks in order
1414 type Callback struct {
15 logger logger
1516 creates []*func(scope *Scope)
1617 updates []*func(scope *Scope)
1718 deletes []*func(scope *Scope)
2223
2324 // CallbackProcessor contains callback informations
2425 type CallbackProcessor struct {
26 logger logger
2527 name string // current callback's name
2628 before string // register current callback before a callback
2729 after string // register current callback after a callback
3234 parent *Callback
3335 }
3436
35 func (c *Callback) clone() *Callback {
37 func (c *Callback) clone(logger logger) *Callback {
3638 return &Callback{
39 logger: logger,
3740 creates: c.creates,
3841 updates: c.updates,
3942 deletes: c.deletes,
5255 // scope.Err(errors.New("error"))
5356 // })
5457 func (c *Callback) Create() *CallbackProcessor {
55 return &CallbackProcessor{kind: "create", parent: c}
58 return &CallbackProcessor{logger: c.logger, kind: "create", parent: c}
5659 }
5760
5861 // Update could be used to register callbacks for updating object, refer `Create` for usage
5962 func (c *Callback) Update() *CallbackProcessor {
60 return &CallbackProcessor{kind: "update", parent: c}
63 return &CallbackProcessor{logger: c.logger, kind: "update", parent: c}
6164 }
6265
6366 // Delete could be used to register callbacks for deleting object, refer `Create` for usage
6467 func (c *Callback) Delete() *CallbackProcessor {
65 return &CallbackProcessor{kind: "delete", parent: c}
68 return &CallbackProcessor{logger: c.logger, kind: "delete", parent: c}
6669 }
6770
6871 // Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
6972 // Refer `Create` for usage
7073 func (c *Callback) Query() *CallbackProcessor {
71 return &CallbackProcessor{kind: "query", parent: c}
74 return &CallbackProcessor{logger: c.logger, kind: "query", parent: c}
7275 }
7376
7477 // RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
7578 func (c *Callback) RowQuery() *CallbackProcessor {
76 return &CallbackProcessor{kind: "row_query", parent: c}
79 return &CallbackProcessor{logger: c.logger, kind: "row_query", parent: c}
7780 }
7881
7982 // After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
9295 func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
9396 if cp.kind == "row_query" {
9497 if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
95 log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
98 cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
9699 cp.before = "gorm:row_query"
97100 }
98101 }
99102
103 cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
100104 cp.name = callbackName
101105 cp.processor = &callback
102106 cp.parent.processors = append(cp.parent.processors, cp)
106110 // Remove a registered callback
107111 // db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
108112 func (cp *CallbackProcessor) Remove(callbackName string) {
109 log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
113 cp.logger.Print("info", fmt.Sprintf("[info] removing callback `%v` from %v", callbackName, fileWithLineNum()))
110114 cp.name = callbackName
111115 cp.remove = true
112116 cp.parent.processors = append(cp.parent.processors, cp)
115119
116120 // Replace a registered callback with new callback
117121 // db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
118 // scope.SetColumn("Created", now)
119 // scope.SetColumn("Updated", now)
122 // scope.SetColumn("CreatedAt", now)
123 // scope.SetColumn("UpdatedAt", now)
120124 // })
121125 func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
122 log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
126 cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
123127 cp.name = callbackName
124128 cp.processor = &callback
125129 cp.replace = true
131135 // db.Callback().Create().Get("gorm:create")
132136 func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
133137 for _, p := range cp.parent.processors {
134 if p.name == callbackName && p.kind == cp.kind && !cp.remove {
135 return *p.processor
136 }
137 }
138 return nil
138 if p.name == callbackName && p.kind == cp.kind {
139 if p.remove {
140 callback = nil
141 } else {
142 callback = *p.processor
143 }
144 }
145 }
146 return
139147 }
140148
141149 // getRIndex get right index from string slice
158166 for _, cp := range cps {
159167 // show warning message the callback name already exists
160168 if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
161 log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
169 cp.logger.Print("warning", fmt.Sprintf("[warning] duplicated callback `%v` from %v", cp.name, fileWithLineNum()))
162170 }
163171 allNames = append(allNames, cp.name)
164172 }
3030 // updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
3131 func updateTimeStampForCreateCallback(scope *Scope) {
3232 if !scope.HasError() {
33 now := NowFunc()
33 now := scope.db.nowFunc()
3434
3535 if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
3636 if createdAtField.IsBlank {
5858
5959 for _, field := range scope.Fields() {
6060 if scope.changeableField(field) {
61 if field.IsNormal {
61 if field.IsNormal && !field.IsIgnored {
6262 if field.IsBlank && field.HasDefaultValue {
6363 blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
6464 scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
8282 quotedTableName = scope.QuotedTableName()
8383 primaryField = scope.PrimaryField()
8484 extraOption string
85 insertModifier string
8586 )
8687
8788 if str, ok := scope.Get("gorm:insert_option"); ok {
8889 extraOption = fmt.Sprint(str)
8990 }
91 if str, ok := scope.Get("gorm:insert_modifier"); ok {
92 insertModifier = strings.ToUpper(fmt.Sprint(str))
93 if insertModifier == "INTO" {
94 insertModifier = ""
95 }
96 }
9097
9198 if primaryField != nil {
9299 returningColumn = scope.Quote(primaryField.DBName)
93100 }
94101
95 lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
102 lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns)
103 var lastInsertIDReturningSuffix string
104 if lastInsertIDOutputInterstitial == "" {
105 lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
106 }
96107
97108 if len(columns) == 0 {
98109 scope.Raw(fmt.Sprintf(
99 "INSERT INTO %v %v%v%v",
110 "INSERT%v INTO %v %v%v%v",
111 addExtraSpaceIfExist(insertModifier),
100112 quotedTableName,
101113 scope.Dialect().DefaultValueStr(),
102114 addExtraSpaceIfExist(extraOption),
104116 ))
105117 } else {
106118 scope.Raw(fmt.Sprintf(
107 "INSERT INTO %v (%v) VALUES (%v)%v%v",
119 "INSERT%v INTO %v (%v)%v VALUES (%v)%v%v",
120 addExtraSpaceIfExist(insertModifier),
108121 scope.QuotedTableName(),
109122 strings.Join(columns, ","),
123 addExtraSpaceIfExist(lastInsertIDOutputInterstitial),
110124 strings.Join(placeholders, ","),
111125 addExtraSpaceIfExist(extraOption),
112126 addExtraSpaceIfExist(lastInsertIDReturningSuffix),
113127 ))
114128 }
115129
116 // execute create sql
117 if lastInsertIDReturningSuffix == "" || primaryField == nil {
130 // execute create sql: no primaryField
131 if primaryField == nil {
118132 if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
119133 // set rows affected count
120134 scope.db.RowsAffected, _ = result.RowsAffected()
126140 }
127141 }
128142 }
143 return
144 }
145
146 // execute create sql: lastInsertID implemention for majority of dialects
147 if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
148 if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
149 // set rows affected count
150 scope.db.RowsAffected, _ = result.RowsAffected()
151
152 // set primary value to primary field
153 if primaryField != nil && primaryField.IsBlank {
154 if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
155 scope.Err(primaryField.Set(primaryValue))
156 }
157 }
158 }
159 return
160 }
161
162 // execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
163 if primaryField.Field.CanAddr() {
164 if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
165 primaryField.IsBlank = false
166 scope.db.RowsAffected = 1
167 }
129168 } else {
130 if primaryField.Field.CanAddr() {
131 if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
132 primaryField.IsBlank = false
133 scope.db.RowsAffected = 1
134 }
135 } else {
136 scope.Err(ErrUnaddressable)
137 }
138 }
169 scope.Err(ErrUnaddressable)
170 }
171 return
139172 }
140173 }
141174
142175 // forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
143176 func forceReloadAfterCreateCallback(scope *Scope) {
144177 if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
178 var shouldScan bool
145179 db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
146180 for _, field := range scope.Fields() {
147181 if field.IsPrimaryKey && !field.IsBlank {
148182 db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface())
149 }
150 }
183 shouldScan = true
184 }
185 }
186
187 if !shouldScan {
188 return
189 }
190
151191 db.Scan(scope.Value)
152192 }
153193 }
1616 // beforeDeleteCallback will invoke `BeforeDelete` method before deleting
1717 func beforeDeleteCallback(scope *Scope) {
1818 if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
19 scope.Err(errors.New("Missing WHERE clause while deleting"))
19 scope.Err(errors.New("missing WHERE clause while deleting"))
2020 return
2121 }
2222 if !scope.HasError() {
3939 "UPDATE %v SET %v=%v%v%v",
4040 scope.QuotedTableName(),
4141 scope.Quote(deletedAtField.DBName),
42 scope.AddToVars(NowFunc()),
42 scope.AddToVars(scope.db.nowFunc()),
4343 addExtraSpaceIfExist(scope.CombinedConditionSql()),
4444 addExtraSpaceIfExist(extraOption),
4545 )).Exec()
1515 // queryCallback used to query data from database
1616 func queryCallback(scope *Scope) {
1717 if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
18 return
19 }
20
21 //we are only preloading relations, dont touch base model
22 if _, skip := scope.InstanceGet("gorm:only_preload"); skip {
1823 return
1924 }
2025
5459
5560 if !scope.HasError() {
5661 scope.db.RowsAffected = 0
57 if str, ok := scope.Get("gorm:query_option"); ok {
58 scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
62
63 if str, ok := scope.Get("gorm:query_hint"); ok {
64 scope.SQL = fmt.Sprint(str) + scope.SQL
5965 }
6066
6167 if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
1313 return
1414 }
1515
16 if _, ok := scope.Get("gorm:auto_preload"); ok {
17 autoPreload(scope)
16 if ap, ok := scope.Get("gorm:auto_preload"); ok {
17 // If gorm:auto_preload IS NOT a bool then auto preload.
18 // Else if it IS a bool, use the value
19 if apb, ok := ap.(bool); !ok {
20 autoPreload(scope)
21 } else if apb {
22 autoPreload(scope)
23 }
1824 }
1925
2026 if scope.Search.preload == nil || scope.HasError() {
9399 continue
94100 }
95101
96 if val, ok := field.TagSettings["PRELOAD"]; ok {
102 if val, ok := field.TagSettingsGet("PRELOAD"); ok {
97103 if preload, err := strconv.ParseBool(val); err != nil {
98104 scope.Err(errors.New("invalid preload option"))
99105 return
154160 )
155161
156162 if indirectScopeValue.Kind() == reflect.Slice {
163 foreignValuesToResults := make(map[string]reflect.Value)
164 for i := 0; i < resultsValue.Len(); i++ {
165 result := resultsValue.Index(i)
166 foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames))
167 foreignValuesToResults[foreignValues] = result
168 }
157169 for j := 0; j < indirectScopeValue.Len(); j++ {
158 for i := 0; i < resultsValue.Len(); i++ {
159 result := resultsValue.Index(i)
160 foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
161 if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
162 indirectValue.FieldByName(field.Name).Set(result)
163 break
164 }
170 indirectValue := indirect(indirectScopeValue.Index(j))
171 valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames))
172 if result, found := foreignValuesToResults[valueString]; found {
173 indirectValue.FieldByName(field.Name).Set(result)
165174 }
166175 }
167176 } else {
248257 indirectScopeValue = scope.IndirectValue()
249258 )
250259
260 foreignFieldToObjects := make(map[string][]*reflect.Value)
261 if indirectScopeValue.Kind() == reflect.Slice {
262 for j := 0; j < indirectScopeValue.Len(); j++ {
263 object := indirect(indirectScopeValue.Index(j))
264 valueString := toString(getValueFromFields(object, relation.ForeignFieldNames))
265 foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object)
266 }
267 }
268
251269 for i := 0; i < resultsValue.Len(); i++ {
252270 result := resultsValue.Index(i)
253271 if indirectScopeValue.Kind() == reflect.Slice {
254 value := getValueFromFields(result, relation.AssociationForeignFieldNames)
255 for j := 0; j < indirectScopeValue.Len(); j++ {
256 object := indirect(indirectScopeValue.Index(j))
257 if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
272 valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames))
273 if objects, found := foreignFieldToObjects[valueString]; found {
274 for _, object := range objects {
258275 object.FieldByName(field.Name).Set(result)
259276 }
260277 }
373390 key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
374391 fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
375392 }
376 for source, link := range linkHash {
377 for i, field := range fieldsSourceMap[source] {
393
394 for source, fields := range fieldsSourceMap {
395 for _, f := range fields {
378396 //If not 0 this means Value is a pointer and we already added preloaded models to it
379 if fieldsSourceMap[source][i].Len() != 0 {
397 if f.Len() != 0 {
380398 continue
381399 }
382 field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
383 }
384
385 }
386 }
400
401 v := reflect.MakeSlice(f.Type(), 0, 0)
402 if len(linkHash[source]) > 0 {
403 v = reflect.Append(f, linkHash[source]...)
404 }
405
406 f.Set(v)
407 }
408 }
409 }
00 package gorm
11
2 import "database/sql"
2 import (
3 "database/sql"
4 "fmt"
5 )
36
47 // Define callbacks for row query
58 func init() {
2023 if result, ok := scope.InstanceGet("row_query_result"); ok {
2124 scope.prepareQuerySQL()
2225
26 if str, ok := scope.Get("gorm:query_hint"); ok {
27 scope.SQL = fmt.Sprint(str) + scope.SQL
28 }
29
2330 if rowResult, ok := result.(*RowQueryResult); ok {
2431 rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
2532 } else if rowsResult, ok := result.(*RowsQueryResult); ok {
2020
2121 if v, ok := value.(string); ok {
2222 v = strings.ToLower(v)
23 if v == "false" || v != "skip" {
24 return false
25 }
23 return v == "true"
2624 }
2725
2826 return true
3533 if value, ok := scope.Get("gorm:save_associations"); ok {
3634 autoUpdate = checkTruth(value)
3735 autoCreate = autoUpdate
38 } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok {
36 saveReference = autoUpdate
37 } else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok {
3938 autoUpdate = checkTruth(value)
4039 autoCreate = autoUpdate
40 saveReference = autoUpdate
4141 }
4242
4343 if value, ok := scope.Get("gorm:association_autoupdate"); ok {
4444 autoUpdate = checkTruth(value)
45 } else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok {
45 } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok {
4646 autoUpdate = checkTruth(value)
4747 }
4848
4949 if value, ok := scope.Get("gorm:association_autocreate"); ok {
5050 autoCreate = checkTruth(value)
51 } else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok {
51 } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok {
5252 autoCreate = checkTruth(value)
5353 }
5454
5555 if value, ok := scope.Get("gorm:association_save_reference"); ok {
5656 saveReference = checkTruth(value)
57 } else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok {
57 } else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok {
5858 saveReference = checkTruth(value)
5959 }
6060 }
2222 func afterCreate2(s *Scope) {}
2323
2424 func TestRegisterCallback(t *testing.T) {
25 var callback = &Callback{}
25 var callback = &Callback{logger: defaultLogger}
2626
2727 callback.Create().Register("before_create1", beforeCreate1)
2828 callback.Create().Register("before_create2", beforeCreate2)
3636 }
3737
3838 func TestRegisterCallbackWithOrder(t *testing.T) {
39 var callback1 = &Callback{}
39 var callback1 = &Callback{logger: defaultLogger}
4040 callback1.Create().Register("before_create1", beforeCreate1)
4141 callback1.Create().Register("create", create)
4242 callback1.Create().Register("after_create1", afterCreate1)
4545 t.Errorf("register callback with order")
4646 }
4747
48 var callback2 = &Callback{}
48 var callback2 = &Callback{logger: defaultLogger}
4949
5050 callback2.Update().Register("create", create)
5151 callback2.Update().Before("create").Register("before_create1", beforeCreate1)
5959 }
6060
6161 func TestRegisterCallbackWithComplexOrder(t *testing.T) {
62 var callback1 = &Callback{}
62 var callback1 = &Callback{logger: defaultLogger}
6363
6464 callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
6565 callback1.Query().Register("before_create1", beforeCreate1)
6969 t.Errorf("register callback with order")
7070 }
7171
72 var callback2 = &Callback{}
72 var callback2 = &Callback{logger: defaultLogger}
7373
7474 callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
7575 callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
8585 func replaceCreate(s *Scope) {}
8686
8787 func TestReplaceCallback(t *testing.T) {
88 var callback = &Callback{}
88 var callback = &Callback{logger: defaultLogger}
8989
9090 callback.Create().Before("after_create1").After("before_create1").Register("create", create)
9191 callback.Create().Register("before_create1", beforeCreate1)
9898 }
9999
100100 func TestRemoveCallback(t *testing.T) {
101 var callback = &Callback{}
101 var callback = &Callback{logger: defaultLogger}
102102
103103 callback.Create().Before("after_create1").After("before_create1").Register("create", create)
104104 callback.Create().Register("before_create1", beforeCreate1)
3333 // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
3434 func beforeUpdateCallback(scope *Scope) {
3535 if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
36 scope.Err(errors.New("Missing WHERE clause while updating"))
36 scope.Err(errors.New("missing WHERE clause while updating"))
3737 return
3838 }
3939 if _, ok := scope.Get("gorm:update_column"); !ok {
4949 // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
5050 func updateTimeStampForUpdateCallback(scope *Scope) {
5151 if _, ok := scope.Get("gorm:update_column"); !ok {
52 scope.SetColumn("UpdatedAt", NowFunc())
52 scope.SetColumn("UpdatedAt", scope.db.nowFunc())
5353 }
5454 }
5555
7474 } else {
7575 for _, field := range scope.Fields() {
7676 if scope.changeableField(field) {
77 if !field.IsPrimaryKey && field.IsNormal {
78 sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
77 if !field.IsPrimaryKey && field.IsNormal && (field.Name != "CreatedAt" || !field.IsBlank) {
78 if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue {
79 sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
80 }
7981 } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
8082 for _, foreignKey := range relationship.ForeignDBNames {
8183 if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
11
22 import (
33 "errors"
4
5 "github.com/jinzhu/gorm"
6
74 "reflect"
85 "testing"
6
7 "github.com/jinzhu/gorm"
98 )
109
1110 func (s *Product) BeforeCreate() (err error) {
174173 t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
175174 }
176175 }
176
177 func TestGetCallback(t *testing.T) {
178 scope := DB.NewScope(nil)
179
180 if DB.Callback().Create().Get("gorm:test_callback") != nil {
181 t.Errorf("`gorm:test_callback` should be nil")
182 }
183
184 DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
185 callback := DB.Callback().Create().Get("gorm:test_callback")
186 if callback == nil {
187 t.Errorf("`gorm:test_callback` should be non-nil")
188 }
189 callback(scope)
190 if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
191 t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
192 }
193
194 DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
195 callback = DB.Callback().Create().Get("gorm:test_callback")
196 if callback == nil {
197 t.Errorf("`gorm:test_callback` should be non-nil")
198 }
199 callback(scope)
200 if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
201 t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
202 }
203
204 DB.Callback().Create().Remove("gorm:test_callback")
205 if DB.Callback().Create().Get("gorm:test_callback") != nil {
206 t.Errorf("`gorm:test_callback` should be nil")
207 }
208
209 DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
210 callback = DB.Callback().Create().Get("gorm:test_callback")
211 if callback == nil {
212 t.Errorf("`gorm:test_callback` should be non-nil")
213 }
214 callback(scope)
215 if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
216 t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
217 }
218 }
219
220 func TestUseDefaultCallback(t *testing.T) {
221 createCallbackName := "gorm:test_use_default_callback_for_create"
222 gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) {
223 // nop
224 })
225 if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
226 t.Errorf("`%s` expected non-nil, but got nil", createCallbackName)
227 }
228 gorm.DefaultCallback.Create().Remove(createCallbackName)
229 if gorm.DefaultCallback.Create().Get(createCallbackName) != nil {
230 t.Errorf("`%s` expected nil, but got non-nil", createCallbackName)
231 }
232
233 updateCallbackName := "gorm:test_use_default_callback_for_update"
234 scopeValueName := "gorm:test_use_default_callback_for_update_value"
235 gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) {
236 scope.Set(scopeValueName, 1)
237 })
238 gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) {
239 scope.Set(scopeValueName, 2)
240 })
241
242 scope := DB.NewScope(nil)
243 callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
244 callback(scope)
245 if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
246 t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
247 }
248 }
22 import (
33 "os"
44 "reflect"
5 "strings"
56 "testing"
67 "time"
78
89 "github.com/jinzhu/now"
10
11 "github.com/jinzhu/gorm"
912 )
1013
1114 func TestCreate(t *testing.T) {
97100
98101 if newUser.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
99102 t.Errorf("UpdatedAt should not be changed")
103 }
104 }
105
106 func TestCreateWithNowFuncOverride(t *testing.T) {
107 user1 := User{Name: "CreateUserTimestampOverride"}
108
109 timeA := now.MustParse("2016-01-01")
110
111 // do DB.New() because we don't want this test to affect other tests
112 db1 := DB.New()
113 // set the override to use static timeA
114 db1.SetNowFuncOverride(func() time.Time {
115 return timeA
116 })
117 // call .New again to check the override is carried over as well during clone
118 db1 = db1.New()
119
120 db1.Save(&user1)
121
122 if user1.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
123 t.Errorf("CreatedAt be using the nowFuncOverride")
124 }
125 if user1.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
126 t.Errorf("UpdatedAt be using the nowFuncOverride")
127 }
128
129 // now create another user with a fresh DB.Now() that doesn't have the nowFuncOverride set
130 // to make sure that setting it only affected the above instance
131
132 user2 := User{Name: "CreateUserTimestampOverrideNoMore"}
133
134 db2 := DB.New()
135
136 db2.Save(&user2)
137
138 if user2.CreatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) {
139 t.Errorf("CreatedAt no longer be using the nowFuncOverride")
140 }
141 if user2.UpdatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) {
142 t.Errorf("UpdatedAt no longer be using the nowFuncOverride")
100143 }
101144 }
102145
228271 t.Errorf("Should not create omitted relationships")
229272 }
230273 }
274
275 func TestCreateIgnore(t *testing.T) {
276 float := 35.03554004971999
277 now := time.Now()
278 user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float}
279
280 if !DB.NewRecord(user) || !DB.NewRecord(&user) {
281 t.Error("User should be new record before create")
282 }
283
284 if count := DB.Create(&user).RowsAffected; count != 1 {
285 t.Error("There should be one record be affected when create record")
286 }
287 if DB.Dialect().GetName() == "mysql" && DB.Set("gorm:insert_modifier", "IGNORE").Create(&user).Error != nil {
288 t.Error("Should ignore duplicate user insert by insert modifier:IGNORE ")
289 }
290 }
291
292 func TestFixFullTableScanWhenInsertIgnore(t *testing.T) {
293 pandaYuanYuan := Panda{Number: 200408301001}
294
295 if !DB.NewRecord(pandaYuanYuan) || !DB.NewRecord(&pandaYuanYuan) {
296 t.Error("Panda should be new record before create")
297 }
298
299 if count := DB.Create(&pandaYuanYuan).RowsAffected; count != 1 {
300 t.Error("There should be one record be affected when create record")
301 }
302
303 DB.Callback().Query().Register("gorm:fix_full_table_scan", func(scope *gorm.Scope) {
304 if strings.Contains(scope.SQL, "SELECT") && strings.Contains(scope.SQL, "pandas") && len(scope.SQLVars) == 0 {
305 t.Error("Should skip force reload when ignore duplicate panda insert")
306 }
307 })
308
309 if DB.Dialect().GetName() == "mysql" && DB.Set("gorm:insert_modifier", "IGNORE").Create(&pandaYuanYuan).Error != nil {
310 t.Error("Should ignore duplicate panda insert by insert modifier:IGNORE ")
311 }
312 }
288288 func TestSelfReferencingMany2ManyColumn(t *testing.T) {
289289 DB.DropTable(&SelfReferencingUser{}, "UserFriends")
290290 DB.AutoMigrate(&SelfReferencingUser{})
291 if !DB.HasTable("UserFriends") {
292 t.Errorf("auto migrate error, table UserFriends should be created")
293 }
291294
292295 friend1 := SelfReferencingUser{Name: "friend1_m2m"}
293296 if err := DB.Create(&friend1).Error; err != nil {
310313
311314 if DB.Model(&user).Association("Friends").Count() != 2 {
312315 t.Errorf("Should find created friends correctly")
316 }
317
318 var count int
319 if err := DB.Table("UserFriends").Count(&count).Error; err != nil {
320 t.Errorf("no error should happen, but got %v", err)
321 }
322 if count == 0 {
323 t.Errorf("table UserFriends should have records")
313324 }
314325
315326 var newUser = SelfReferencingUser{}
3636 ModifyColumn(tableName string, columnName string, typ string) error
3737
3838 // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
39 LimitAndOffsetSQL(limit, offset interface{}) string
39 LimitAndOffsetSQL(limit, offset interface{}) (string, error)
4040 // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
4141 SelectFromDummyTable() string
42 // LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
43 LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
4244 // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
4345 LastInsertIDReturningSuffix(tableName, columnName string) string
4446 // DefaultValueStr
4648
4749 // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
4850 BuildKeyName(kind, tableName string, fields ...string) string
51
52 // NormalizeIndexAndColumn returns valid index name and column name depending on each dialect
53 NormalizeIndexAndColumn(indexName, columnName string) (string, string)
4954
5055 // CurrentDatabase return current database name
5156 CurrentDatabase() string
7176 dialectsMap[name] = dialect
7277 }
7378
79 // GetDialect gets the dialect for the specified dialect name
80 func GetDialect(name string) (dialect Dialect, ok bool) {
81 dialect, ok = dialectsMap[name]
82 return
83 }
84
7485 // ParseFieldStructForDialect get field's sql data type
7586 var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
7687 // Get redirected field type
7788 var (
7889 reflectType = field.Struct.Type
79 dataType = field.TagSettings["TYPE"]
90 dataType, _ = field.TagSettingsGet("TYPE")
8091 )
8192
8293 for reflectType.Kind() == reflect.Ptr {
105116 }
106117
107118 // Default Size
108 if num, ok := field.TagSettings["SIZE"]; ok {
119 if num, ok := field.TagSettingsGet("SIZE"); ok {
109120 size, _ = strconv.Atoi(num)
110121 } else {
111122 size = 255
112123 }
113124
114125 // Default type from tag setting
115 additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
116 if value, ok := field.TagSettings["DEFAULT"]; ok {
126 notNull, _ := field.TagSettingsGet("NOT NULL")
127 unique, _ := field.TagSettingsGet("UNIQUE")
128 additionalType = notNull + " " + unique
129 if value, ok := field.TagSettingsGet("DEFAULT"); ok {
117130 additionalType = additionalType + " DEFAULT " + value
131 }
132
133 if value, ok := field.TagSettingsGet("COMMENT"); ok && dialect.GetName() != "sqlite3" {
134 additionalType = additionalType + " COMMENT " + value
118135 }
119136
120137 return fieldValue, dataType, size, strings.TrimSpace(additionalType)
77 "strings"
88 "time"
99 )
10
11 var keyNameRegex = regexp.MustCompile("[^a-zA-Z0-9]+")
1012
1113 // DefaultForeignKeyNamer contains the default foreign key name generator method
1214 type DefaultForeignKeyNamer struct {
3840 }
3941
4042 func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool {
41 if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
43 if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
4244 return strings.ToLower(value) != "false"
4345 }
4446 return field.IsPrimaryKey
136138 return
137139 }
138140
139 func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
141 // LimitAndOffsetSQL return generated SQL with Limit and Offset
142 func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
140143 if limit != nil {
141 if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
144 if parsedLimit, err := s.parseInt(limit); err != nil {
145 return "", err
146 } else if parsedLimit >= 0 {
142147 sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
143148 }
144149 }
145150 if offset != nil {
146 if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
151 if parsedOffset, err := s.parseInt(offset); err != nil {
152 return "", err
153 } else if parsedOffset >= 0 {
147154 sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
148155 }
149156 }
151158 }
152159
153160 func (commonDialect) SelectFromDummyTable() string {
161 return ""
162 }
163
164 func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
154165 return ""
155166 }
156167
165176 // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
166177 func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
167178 keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
168 keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_")
179 keyName = keyNameRegex.ReplaceAllString(keyName, "_")
169180 return keyName
181 }
182
183 // NormalizeIndexAndColumn returns argument's index name and column name without doing anything
184 func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
185 return indexName, columnName
186 }
187
188 func (commonDialect) parseInt(value interface{}) (int64, error) {
189 return strconv.ParseInt(fmt.Sprint(value), 0, 0)
170190 }
171191
172192 // IsByteArrayOrSlice returns true of the reflected value is an array or slice
11
22 import (
33 "crypto/sha1"
4 "database/sql"
45 "fmt"
56 "reflect"
67 "regexp"
7 "strconv"
88 "strings"
99 "time"
1010 "unicode/utf8"
1111 )
1212
13 var mysqlIndexRegex = regexp.MustCompile(`^(.+)\((\d+)\)$`)
14
1315 type mysql struct {
1416 commonDialect
1517 }
3234
3335 // MySQL allows only one auto increment column per table, and it must
3436 // be a KEY column.
35 if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
36 if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey {
37 delete(field.TagSettings, "AUTO_INCREMENT")
37 if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
38 if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey {
39 field.TagSettingsDelete("AUTO_INCREMENT")
3840 }
3941 }
4042
4446 sqlType = "boolean"
4547 case reflect.Int8:
4648 if s.fieldCanAutoIncrement(field) {
47 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
49 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
4850 sqlType = "tinyint AUTO_INCREMENT"
4951 } else {
5052 sqlType = "tinyint"
5153 }
5254 case reflect.Int, reflect.Int16, reflect.Int32:
5355 if s.fieldCanAutoIncrement(field) {
54 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
56 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
5557 sqlType = "int AUTO_INCREMENT"
5658 } else {
5759 sqlType = "int"
5860 }
5961 case reflect.Uint8:
6062 if s.fieldCanAutoIncrement(field) {
61 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
63 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
6264 sqlType = "tinyint unsigned AUTO_INCREMENT"
6365 } else {
6466 sqlType = "tinyint unsigned"
6567 }
6668 case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
6769 if s.fieldCanAutoIncrement(field) {
68 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
70 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
6971 sqlType = "int unsigned AUTO_INCREMENT"
7072 } else {
7173 sqlType = "int unsigned"
7274 }
7375 case reflect.Int64:
7476 if s.fieldCanAutoIncrement(field) {
75 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
77 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
7678 sqlType = "bigint AUTO_INCREMENT"
7779 } else {
7880 sqlType = "bigint"
7981 }
8082 case reflect.Uint64:
8183 if s.fieldCanAutoIncrement(field) {
82 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
84 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
8385 sqlType = "bigint unsigned AUTO_INCREMENT"
8486 } else {
8587 sqlType = "bigint unsigned"
9597 case reflect.Struct:
9698 if _, ok := dataValue.Interface().(time.Time); ok {
9799 precision := ""
98 if p, ok := field.TagSettings["PRECISION"]; ok {
100 if p, ok := field.TagSettingsGet("PRECISION"); ok {
99101 precision = fmt.Sprintf("(%s)", p)
100102 }
101103
102 if _, ok := field.TagSettings["NOT NULL"]; ok {
103 sqlType = fmt.Sprintf("timestamp%v", precision)
104 if _, ok := field.TagSettings["NOT NULL"]; ok || field.IsPrimaryKey {
105 sqlType = fmt.Sprintf("DATETIME%v", precision)
104106 } else {
105 sqlType = fmt.Sprintf("timestamp%v NULL", precision)
107 sqlType = fmt.Sprintf("DATETIME%v NULL", precision)
106108 }
107109 }
108110 default:
117119 }
118120
119121 if sqlType == "" {
120 panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
122 panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name))
121123 }
122124
123125 if strings.TrimSpace(additionalType) == "" {
136138 return err
137139 }
138140
139 func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
141 func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
140142 if limit != nil {
141 if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
143 parsedLimit, err := s.parseInt(limit)
144 if err != nil {
145 return "", err
146 }
147 if parsedLimit >= 0 {
142148 sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
143149
144150 if offset != nil {
145 if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
151 parsedOffset, err := s.parseInt(offset)
152 if err != nil {
153 return "", err
154 }
155 if parsedOffset >= 0 {
146156 sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
147157 }
148158 }
156166 currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
157167 s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
158168 return count > 0
169 }
170
171 func (s mysql) HasTable(tableName string) bool {
172 currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
173 var name string
174 // allow mysql database name with '-' character
175 if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil {
176 if err == sql.ErrNoRows {
177 return false
178 }
179 panic(err)
180 } else {
181 return true
182 }
183 }
184
185 func (s mysql) HasIndex(tableName string, indexName string) bool {
186 currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
187 if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil {
188 panic(err)
189 } else {
190 defer rows.Close()
191 return rows.Next()
192 }
193 }
194
195 func (s mysql) HasColumn(tableName string, columnName string) bool {
196 currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
197 if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil {
198 panic(err)
199 } else {
200 defer rows.Close()
201 return rows.Next()
202 }
159203 }
160204
161205 func (s mysql) CurrentDatabase() (name string) {
177221 bs := h.Sum(nil)
178222
179223 // sha1 is 40 characters, keep first 24 characters of destination
180 destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_"))
224 destRunes := []rune(keyNameRegex.ReplaceAllString(fields[0], "_"))
181225 if len(destRunes) > 24 {
182226 destRunes = destRunes[:24]
183227 }
185229 return fmt.Sprintf("%s%x", string(destRunes), bs)
186230 }
187231
232 // NormalizeIndexAndColumn returns index name and column name for specify an index prefix length if needed
233 func (mysql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
234 submatch := mysqlIndexRegex.FindStringSubmatch(indexName)
235 if len(submatch) != 3 {
236 return indexName, columnName
237 }
238 indexName = submatch[1]
239 columnName = fmt.Sprintf("%s(%s)", columnName, submatch[2])
240 return indexName, columnName
241 }
242
188243 func (mysql) DefaultValueStr() string {
189244 return "VALUES()"
190245 }
3333 sqlType = "boolean"
3434 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
3535 if s.fieldCanAutoIncrement(field) {
36 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
36 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
3737 sqlType = "serial"
3838 } else {
3939 sqlType = "integer"
4040 }
4141 case reflect.Int64, reflect.Uint32, reflect.Uint64:
4242 if s.fieldCanAutoIncrement(field) {
43 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
43 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
4444 sqlType = "bigserial"
4545 } else {
4646 sqlType = "bigint"
4848 case reflect.Float32, reflect.Float64:
4949 sqlType = "numeric"
5050 case reflect.String:
51 if _, ok := field.TagSettings["SIZE"]; !ok {
51 if _, ok := field.TagSettingsGet("SIZE"); !ok {
5252 size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
5353 }
5454
119119 return
120120 }
121121
122 func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string {
123 return ""
124 }
125
122126 func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
123127 return fmt.Sprintf("RETURNING %v.%v", tableName, key)
124128 }
2828 sqlType = "bool"
2929 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
3030 if s.fieldCanAutoIncrement(field) {
31 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
31 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
3232 sqlType = "integer primary key autoincrement"
3333 } else {
3434 sqlType = "integer"
3535 }
3636 case reflect.Int64, reflect.Uint64:
3737 if s.fieldCanAutoIncrement(field) {
38 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
38 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
3939 sqlType = "integer primary key autoincrement"
4040 } else {
4141 sqlType = "bigint"
00 package mssql
11
22 import (
3 "database/sql/driver"
4 "encoding/json"
5 "errors"
36 "fmt"
47 "reflect"
58 "strconv"
69 "strings"
710 "time"
811
12 // Importing mssql driver package only in dialect file, otherwide not needed
913 _ "github.com/denisenkom/go-mssqldb"
1014 "github.com/jinzhu/gorm"
1115 )
1317 func setIdentityInsert(scope *gorm.Scope) {
1418 if scope.Dialect().GetName() == "mssql" {
1519 for _, field := range scope.PrimaryFields() {
16 if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank {
20 if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank {
1721 scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
1822 scope.InstanceSet("mssql:identity_insert_on", true)
1923 }
6569 sqlType = "bit"
6670 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
6771 if s.fieldCanAutoIncrement(field) {
68 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
72 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
6973 sqlType = "int IDENTITY(1,1)"
7074 } else {
7175 sqlType = "int"
7276 }
7377 case reflect.Int64, reflect.Uint64:
7478 if s.fieldCanAutoIncrement(field) {
75 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
79 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
7680 sqlType = "bigint IDENTITY(1,1)"
7781 } else {
7882 sqlType = "bigint"
111115 }
112116
113117 func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
114 if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
118 if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
115119 return value != "FALSE"
116120 }
117121 return field.IsPrimaryKey
129133 }
130134
131135 func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
132 return false
136 var count int
137 currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
138 s.db.QueryRow(`SELECT count(*)
139 FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id
140 inner join information_schema.tables as I on I.TABLE_NAME = T.name
141 WHERE F.name = ?
142 AND T.Name = ? AND I.TABLE_CATALOG = ?;`, foreignKeyName, tableName, currentDatabase).Scan(&count)
143 return count > 0
133144 }
134145
135146 func (s mssql) HasTable(tableName string) bool {
156167 return
157168 }
158169
159 func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
170 func parseInt(value interface{}) (int64, error) {
171 return strconv.ParseInt(fmt.Sprint(value), 0, 0)
172 }
173
174 func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
160175 if offset != nil {
161 if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
176 if parsedOffset, err := parseInt(offset); err != nil {
177 return "", err
178 } else if parsedOffset >= 0 {
162179 sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset)
163180 }
164181 }
165182 if limit != nil {
166 if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
183 if parsedLimit, err := parseInt(limit); err != nil {
184 return "", err
185 } else if parsedLimit >= 0 {
167186 if sql == "" {
168187 // add default zero offset
169188 sql += " OFFSET 0 ROWS"
178197 return ""
179198 }
180199
200 func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
201 if len(columns) == 0 {
202 // No OUTPUT to query
203 return ""
204 }
205 return fmt.Sprintf("OUTPUT Inserted.%v", columnName)
206 }
207
181208 func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
182 return ""
209 // https://stackoverflow.com/questions/5228780/how-to-get-last-inserted-id
210 return "; SELECT SCOPE_IDENTITY()"
183211 }
184212
185213 func (mssql) DefaultValueStr() string {
186214 return "DEFAULT VALUES"
215 }
216
217 // NormalizeIndexAndColumn returns argument's index name and column name without doing anything
218 func (mssql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
219 return indexName, columnName
187220 }
188221
189222 func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
193226 }
194227 return dialect.CurrentDatabase(), tableName
195228 }
229
230 // JSON type to support easy handling of JSON data in character table fields
231 // using golang json.RawMessage for deferred decoding/encoding
232 type JSON struct {
233 json.RawMessage
234 }
235
236 // Value get value of JSON
237 func (j JSON) Value() (driver.Value, error) {
238 if len(j.RawMessage) == 0 {
239 return nil, nil
240 }
241 return j.MarshalJSON()
242 }
243
244 // Scan scan value into JSON
245 func (j *JSON) Scan(value interface{}) error {
246 str, ok := value.(string)
247 if !ok {
248 return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value))
249 }
250 bytes := []byte(str)
251 return json.Unmarshal(bytes, j)
252 }
33 "database/sql"
44 "database/sql/driver"
55
6 _ "github.com/lib/pq"
7 "github.com/lib/pq/hstore"
86 "encoding/json"
97 "errors"
108 "fmt"
9
10 _ "github.com/lib/pq"
11 "github.com/lib/pq/hstore"
1112 )
1213
1314 type Hstore map[string]*string
55 )
66
77 var (
8 // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct
8 // ErrRecordNotFound returns a "record not found error". Occurs only when attempting to query the database with a struct; querying with a slice won't return this error
99 ErrRecordNotFound = errors.New("record not found")
10 // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
10 // ErrInvalidSQL occurs when you attempt a query with invalid SQL
1111 ErrInvalidSQL = errors.New("invalid SQL")
12 // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
12 // ErrInvalidTransaction occurs when you are trying to `Commit` or `Rollback`
1313 ErrInvalidTransaction = errors.New("no valid transaction")
1414 // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
1515 ErrCantStartTransaction = errors.New("can't start transaction")
2020 // Errors contains all happened errors
2121 type Errors []error
2222
23 // IsRecordNotFoundError returns current error has record not found error or not
23 // IsRecordNotFoundError returns true if error contains a RecordNotFound error
2424 func IsRecordNotFoundError(err error) bool {
2525 if errs, ok := err.(Errors); ok {
2626 for _, err := range errs {
3232 return err == ErrRecordNotFound
3333 }
3434
35 // GetErrors gets all happened errors
35 // GetErrors gets all errors that have occurred and returns a slice of errors (Error type)
3636 func (errs Errors) GetErrors() []error {
3737 return errs
3838 }
3939
40 // Add adds an error
40 // Add adds an error to a given slice of errors
4141 func (errs Errors) Add(newErrors ...error) Errors {
4242 for _, err := range newErrors {
4343 if err == nil {
6161 return errs
6262 }
6363
64 // Error format happened errors
64 // Error takes a slice of all errors that have occurred and returns it as a formatted string
6565 func (errs Errors) Error() string {
6666 var errors = []string{}
6767 for _, e := range errs {
11
22 import (
33 "database/sql"
4 "database/sql/driver"
45 "errors"
56 "fmt"
67 "reflect"
4344 if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
4445 fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
4546 } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
46 err = scanner.Scan(reflectValue.Interface())
47 v := reflectValue.Interface()
48 if valuer, ok := v.(driver.Valuer); ok {
49 if v, err = valuer.Value(); err == nil {
50 err = scanner.Scan(v)
51 }
52 } else {
53 err = scanner.Scan(v)
54 }
4755 } else {
4856 err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
4957 }
00 package gorm_test
11
22 import (
3 "database/sql/driver"
4 "encoding/hex"
5 "fmt"
36 "testing"
47
58 "github.com/jinzhu/gorm"
4245
4346 if field, ok := scope.FieldByName("embedded_name"); !ok {
4447 t.Errorf("should find embedded field")
45 } else if _, ok := field.TagSettings["NOT NULL"]; !ok {
48 } else if _, ok := field.TagSettingsGet("NOT NULL"); !ok {
4649 t.Errorf("should find embedded field's tag settings")
4750 }
4851 }
52
53 type UUID [16]byte
54
55 type NullUUID struct {
56 UUID
57 Valid bool
58 }
59
60 func FromString(input string) (u UUID) {
61 src := []byte(input)
62 return FromBytes(src)
63 }
64
65 func FromBytes(src []byte) (u UUID) {
66 dst := u[:]
67 hex.Decode(dst[0:4], src[0:8])
68 hex.Decode(dst[4:6], src[9:13])
69 hex.Decode(dst[6:8], src[14:18])
70 hex.Decode(dst[8:10], src[19:23])
71 hex.Decode(dst[10:], src[24:])
72 return
73 }
74
75 func (u UUID) String() string {
76 buf := make([]byte, 36)
77 src := u[:]
78 hex.Encode(buf[0:8], src[0:4])
79 buf[8] = '-'
80 hex.Encode(buf[9:13], src[4:6])
81 buf[13] = '-'
82 hex.Encode(buf[14:18], src[6:8])
83 buf[18] = '-'
84 hex.Encode(buf[19:23], src[8:10])
85 buf[23] = '-'
86 hex.Encode(buf[24:], src[10:])
87 return string(buf)
88 }
89
90 func (u UUID) Value() (driver.Value, error) {
91 return u.String(), nil
92 }
93
94 func (u *UUID) Scan(src interface{}) error {
95 switch src := src.(type) {
96 case UUID: // support gorm convert from UUID to NullUUID
97 *u = src
98 return nil
99 case []byte:
100 *u = FromBytes(src)
101 return nil
102 case string:
103 *u = FromString(src)
104 return nil
105 }
106 return fmt.Errorf("uuid: cannot convert %T to UUID", src)
107 }
108
109 func (u *NullUUID) Scan(src interface{}) error {
110 u.Valid = true
111 return u.UUID.Scan(src)
112 }
113
114 func TestFieldSet(t *testing.T) {
115 type TestFieldSetNullUUID struct {
116 NullUUID NullUUID
117 }
118 scope := DB.NewScope(&TestFieldSetNullUUID{})
119 field := scope.Fields()[0]
120 err := field.Set(FromString("3034d44a-da03-11e8-b366-4a00070b9f00"))
121 if err != nil {
122 t.Fatal(err)
123 }
124 if id, ok := field.Field.Addr().Interface().(*NullUUID); !ok {
125 t.Fatal()
126 } else if !id.Valid || id.UUID.String() != "3034d44a-da03-11e8-b366-4a00070b9f00" {
127 t.Fatal(id)
128 }
129 }
0 module github.com/jinzhu/gorm
1
2 go 1.12
3
4 require (
5 github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd
6 github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5
7 github.com/go-sql-driver/mysql v1.5.0
8 github.com/jinzhu/inflection v1.0.0
9 github.com/jinzhu/now v1.0.1
10 github.com/lib/pq v1.1.1
11 github.com/mattn/go-sqlite3 v1.14.0
12 golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd // indirect
13 )
0 github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc=
1 github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y=
2 github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd h1:83Wprp6ROGeiHFAP8WJdI2RoxALQYgdllERc3N5N2DM=
3 github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
4 github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y=
5 github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0=
6 github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
7 github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
8 github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY=
9 github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
10 github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
11 github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
12 github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M=
13 github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
14 github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4=
15 github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
16 github.com/mattn/go-sqlite3 v1.14.0 h1:mLyGNKR8+Vv9CAU7PphKa2hkEqxxhn8i32J6FPj1/QA=
17 github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus=
18 github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw=
19 github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
20 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
21 golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI=
22 golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
23 golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd h1:GGJVjV8waZKRHrgwvtH66z9ZGVurTD1MT0n1Bb+q4aM=
24 golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
25 golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
26 golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
27 golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
28 golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
29 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
30 golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
31 golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
32 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
00 package gorm
11
2 import "database/sql"
2 import (
3 "context"
4 "database/sql"
5 )
36
47 // SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
58 type SQLCommon interface {
1114
1215 type sqlDb interface {
1316 Begin() (*sql.Tx, error)
17 BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
1418 }
1519
1620 type sqlTx interface {
3838
3939 messages = []interface{}{source, currentTime}
4040
41 if len(values) == 2 {
42 //remove the line break
43 currentTime = currentTime[1:]
44 //remove the brackets
45 source = fmt.Sprintf("\033[35m%v\033[0m", values[1])
46
47 messages = []interface{}{currentTime, source}
48 }
49
4150 if level == "sql" {
4251 // duration
4352 messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
4857 if indirectValue.IsValid() {
4958 value = indirectValue.Interface()
5059 if t, ok := value.(time.Time); ok {
51 formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05")))
60 if t.IsZero() {
61 formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00"))
62 } else {
63 formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05")))
64 }
5265 } else if b, ok := value.([]byte); ok {
5366 if str := string(b); isPrintable(str) {
5467 formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
6275 formattedValues = append(formattedValues, "NULL")
6376 }
6477 } else {
65 formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
78 switch value.(type) {
79 case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
80 formattedValues = append(formattedValues, fmt.Sprintf("%v", value))
81 default:
82 formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
83 }
6684 }
6785 } else {
6886 formattedValues = append(formattedValues, "NULL")
116134 func (logger Logger) Print(values ...interface{}) {
117135 logger.Println(LogFormatter(values...)...)
118136 }
137
138 type nopLogger struct{}
139
140 func (nopLogger) Print(values ...interface{}) {}
00 package gorm
11
22 import (
3 "context"
34 "database/sql"
45 "errors"
56 "fmt"
67 "reflect"
78 "strings"
9 "sync"
810 "time"
911 )
1012
1113 // DB contains information for current db connection
1214 type DB struct {
15 sync.RWMutex
1316 Value interface{}
1417 Error error
1518 RowsAffected int64
1720 // single db
1821 db SQLCommon
1922 blockGlobalUpdate bool
20 logMode int
23 logMode logModeValue
2124 logger logger
2225 search *search
23 values map[string]interface{}
26 values sync.Map
2427
2528 // global db
2629 parent *DB
2730 callbacks *Callback
2831 dialect Dialect
2932 singularTable bool
30 }
33
34 // function to be used to override the creating of a new timestamp
35 nowFuncOverride func() time.Time
36 }
37
38 type logModeValue int
39
40 const (
41 defaultLogMode logModeValue = iota
42 noLogMode
43 detailedLogMode
44 )
3145
3246 // Open initialize a new db connection, need to import driver first, e.g:
3347 //
4761 }
4862 var source string
4963 var dbSQL SQLCommon
64 var ownDbSQL bool
5065
5166 switch value := args[0].(type) {
5267 case string:
5873 source = args[1].(string)
5974 }
6075 dbSQL, err = sql.Open(driver, source)
76 ownDbSQL = true
6177 case SQLCommon:
6278 dbSQL = value
79 ownDbSQL = false
80 default:
81 return nil, fmt.Errorf("invalid database source: %v is not a valid type", value)
6382 }
6483
6584 db = &DB{
6685 db: dbSQL,
6786 logger: defaultLogger,
68 values: map[string]interface{}{},
69 callbacks: DefaultCallback,
87
88 // Create a clone of the default logger to avoid mutating a shared object when
89 // multiple gorm connections are created simultaneously.
90 callbacks: DefaultCallback.clone(defaultLogger),
7091 dialect: newDialect(dialect, dbSQL),
7192 }
7293 db.parent = db
7596 }
7697 // Send a ping to make sure the database connection is alive.
7798 if d, ok := dbSQL.(*sql.DB); ok {
78 if err = d.Ping(); err != nil {
99 if err = d.Ping(); err != nil && ownDbSQL {
79100 d.Close()
80101 }
81102 }
105126 // DB get `*sql.DB` from current connection
106127 // If the underlying database connection is not a *sql.DB, returns nil
107128 func (s *DB) DB() *sql.DB {
108 db, _ := s.db.(*sql.DB)
129 db, ok := s.db.(*sql.DB)
130 if !ok {
131 panic("can't support full GORM on currently status, maybe this is a TX instance.")
132 }
109133 return db
110134 }
111135
116140
117141 // Dialect get dialect
118142 func (s *DB) Dialect() Dialect {
119 return s.parent.dialect
143 return s.dialect
120144 }
121145
122146 // Callback return `Callbacks` container, you could add/change/delete callbacks with it
123147 // db.Callback().Create().Register("update_created_at", updateCreated)
124148 // Refer https://jinzhu.github.io/gorm/development.html#callbacks
125149 func (s *DB) Callback() *Callback {
126 s.parent.callbacks = s.parent.callbacks.clone()
150 s.parent.callbacks = s.parent.callbacks.clone(s.logger)
127151 return s.parent.callbacks
128152 }
129153
135159 // LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
136160 func (s *DB) LogMode(enable bool) *DB {
137161 if enable {
138 s.logMode = 2
139 } else {
140 s.logMode = 1
162 s.logMode = detailedLogMode
163 } else {
164 s.logMode = noLogMode
141165 }
142166 return s
167 }
168
169 // SetNowFuncOverride set the function to be used when creating a new timestamp
170 func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB {
171 s.nowFuncOverride = nowFuncOverride
172 return s
173 }
174
175 // Get a new timestamp, using the provided nowFuncOverride on the DB instance if set,
176 // otherwise defaults to the global NowFunc()
177 func (s *DB) nowFunc() time.Time {
178 if s.nowFuncOverride != nil {
179 return s.nowFuncOverride()
180 }
181
182 return NowFunc()
143183 }
144184
145185 // BlockGlobalUpdate if true, generates an error on update/delete without where clause.
156196
157197 // SingularTable use singular table by default
158198 func (s *DB) SingularTable(enable bool) {
159 modelStructsMap = newModelStructsMap()
199 s.parent.Lock()
200 defer s.parent.Unlock()
160201 s.parent.singularTable = enable
161202 }
162203
164205 func (s *DB) NewScope(value interface{}) *Scope {
165206 dbClone := s.clone()
166207 dbClone.Value = value
167 return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
168 }
169
170 // QueryExpr returns the query as expr object
171 func (s *DB) QueryExpr() *expr {
208 scope := &Scope{db: dbClone, Value: value}
209 if s.search != nil {
210 scope.Search = s.search.clone()
211 } else {
212 scope.Search = &search{}
213 }
214 return scope
215 }
216
217 // QueryExpr returns the query as SqlExpr object
218 func (s *DB) QueryExpr() *SqlExpr {
172219 scope := s.NewScope(s.Value)
173220 scope.InstanceSet("skip_bindvar", true)
174221 scope.prepareQuerySQL()
177224 }
178225
179226 // SubQuery returns the query as sub query
180 func (s *DB) SubQuery() *expr {
227 func (s *DB) SubQuery() *SqlExpr {
181228 scope := s.NewScope(s.Value)
182229 scope.InstanceSet("skip_bindvar", true)
183230 scope.prepareQuerySQL()
284331 func (s *DB) First(out interface{}, where ...interface{}) *DB {
285332 newScope := s.NewScope(out)
286333 newScope.Search.Limit(1)
334
287335 return newScope.Set("gorm:order_by_primary_key", "ASC").
288336 inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
289337 }
306354 // Find find records that match given conditions
307355 func (s *DB) Find(out interface{}, where ...interface{}) *DB {
308356 return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
357 }
358
359 //Preloads preloads relations, don`t touch out
360 func (s *DB) Preloads(out interface{}) *DB {
361 return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db
309362 }
310363
311364 // Scan scan value to a struct
386439 }
387440
388441 // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
442 // WARNING when update with struct, GORM will not update fields that with zero value
389443 func (s *DB) Update(attrs ...interface{}) *DB {
390444 return s.Updates(toSearchableMap(attrs...), true)
391445 }
418472 if !scope.PrimaryKeyZero() {
419473 newDB := scope.callCallbacks(s.parent.callbacks.updates).db
420474 if newDB.Error == nil && newDB.RowsAffected == 0 {
421 return s.New().FirstOrCreate(value)
475 return s.New().Table(scope.TableName()).FirstOrCreate(value)
422476 }
423477 return newDB
424478 }
432486 }
433487
434488 // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
489 // WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time
435490 func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
436491 return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
437492 }
475530 return s.clone().LogMode(true)
476531 }
477532
478 // Begin begin a transaction
533 // Transaction start a transaction as a block,
534 // return error will rollback, otherwise to commit.
535 func (s *DB) Transaction(fc func(tx *DB) error) (err error) {
536
537 if _, ok := s.db.(*sql.Tx); ok {
538 return fc(s)
539 }
540
541 panicked := true
542 tx := s.Begin()
543 defer func() {
544 // Make sure to rollback when panic, Block error or Commit error
545 if panicked || err != nil {
546 tx.Rollback()
547 }
548 }()
549
550 err = fc(tx)
551
552 if err == nil {
553 err = tx.Commit().Error
554 }
555
556 panicked = false
557 return
558 }
559
560 // Begin begins a transaction
479561 func (s *DB) Begin() *DB {
562 return s.BeginTx(context.Background(), &sql.TxOptions{})
563 }
564
565 // BeginTx begins a transaction with options
566 func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
480567 c := s.clone()
481568 if db, ok := c.db.(sqlDb); ok && db != nil {
482 tx, err := db.Begin()
569 tx, err := db.BeginTx(ctx, opts)
483570 c.db = interface{}(tx).(SQLCommon)
571
572 c.dialect.SetDB(c.db)
484573 c.AddError(err)
485574 } else {
486575 c.AddError(ErrCantStartTransaction)
490579
491580 // Commit commit a transaction
492581 func (s *DB) Commit() *DB {
493 if db, ok := s.db.(sqlTx); ok && db != nil {
582 var emptySQLTx *sql.Tx
583 if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
494584 s.AddError(db.Commit())
495585 } else {
496586 s.AddError(ErrInvalidTransaction)
500590
501591 // Rollback rollback a transaction
502592 func (s *DB) Rollback() *DB {
503 if db, ok := s.db.(sqlTx); ok && db != nil {
504 s.AddError(db.Rollback())
593 var emptySQLTx *sql.Tx
594 if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
595 if err := db.Rollback(); err != nil && err != sql.ErrTxDone {
596 s.AddError(err)
597 }
598 } else {
599 s.AddError(ErrInvalidTransaction)
600 }
601 return s
602 }
603
604 // RollbackUnlessCommitted rollback a transaction if it has not yet been
605 // committed.
606 func (s *DB) RollbackUnlessCommitted() *DB {
607 var emptySQLTx *sql.Tx
608 if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
609 err := db.Rollback()
610 // Ignore the error indicating that the transaction has already
611 // been committed.
612 if err != sql.ErrTxDone {
613 s.AddError(err)
614 }
505615 } else {
506616 s.AddError(ErrInvalidTransaction)
507617 }
669779
670780 // InstantSet instant set setting, will affect current db
671781 func (s *DB) InstantSet(name string, value interface{}) *DB {
672 s.values[name] = value
782 s.values.Store(name, value)
673783 return s
674784 }
675785
676786 // Get get setting by name
677787 func (s *DB) Get(name string) (value interface{}, ok bool) {
678 value, ok = s.values[name]
788 value, ok = s.values.Load(name)
679789 return
680790 }
681791
684794 scope := s.NewScope(source)
685795 for _, field := range scope.GetModelStruct().StructFields {
686796 if field.Name == column || field.DBName == column {
687 if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
797 if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" {
688798 source := (&Scope{Value: source}).GetModelStruct().ModelType
689799 destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
690800 handler.Setup(field.Relationship, many2many, source, destination)
701811 func (s *DB) AddError(err error) error {
702812 if err != nil {
703813 if err != ErrRecordNotFound {
704 if s.logMode == 0 {
705 go s.print(fileWithLineNum(), err)
814 if s.logMode == defaultLogMode {
815 go s.print("error", fileWithLineNum(), err)
706816 } else {
707817 s.log(err)
708818 }
739849 parent: s.parent,
740850 logger: s.logger,
741851 logMode: s.logMode,
742 values: map[string]interface{}{},
743852 Value: s.Value,
744853 Error: s.Error,
745854 blockGlobalUpdate: s.blockGlobalUpdate,
746 }
747
748 for key, value := range s.values {
749 db.values[key] = value
750 }
855 dialect: newDialect(s.dialect.GetName(), s.db),
856 nowFuncOverride: s.nowFuncOverride,
857 }
858
859 s.values.Range(func(k, v interface{}) bool {
860 db.values.Store(k, v)
861 return true
862 })
751863
752864 if s.search == nil {
753865 db.search = &search{limit: -1, offset: -1}
764876 }
765877
766878 func (s *DB) log(v ...interface{}) {
767 if s != nil && s.logMode == 2 {
879 if s != nil && s.logMode == detailedLogMode {
768880 s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
769881 }
770882 }
771883
772884 func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
773 if s.logMode == 2 {
885 if s.logMode == detailedLogMode {
774886 s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected)
775887 }
776888 }
00 package gorm_test
11
2 // Run tests
3 // $ docker-compose up
4 // $ ./test_all.sh
5
26 import (
7 "context"
38 "database/sql"
49 "database/sql/driver"
10 "errors"
511 "fmt"
612 "os"
713 "path/filepath"
814 "reflect"
15 "regexp"
16 "sort"
917 "strconv"
18 "strings"
19 "sync"
1020 "testing"
1121 "time"
1222
4656 case "postgres":
4757 fmt.Println("testing postgres...")
4858 if dbDSN == "" {
49 dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
59 dbDSN = "user=gorm password=gorm dbname=gorm port=9920 sslmode=disable"
5060 }
5161 db, err = gorm.Open("postgres", dbDSN)
5262 case "mssql":
7888 return
7989 }
8090
91 func TestOpen_ReturnsError_WithBadArgs(t *testing.T) {
92 stringRef := "foo"
93 testCases := []interface{}{42, time.Now(), &stringRef}
94 for _, tc := range testCases {
95 t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) {
96 _, err := gorm.Open("postgresql", tc)
97 if err == nil {
98 t.Error("Should got error with invalid database source")
99 }
100 if !strings.HasPrefix(err.Error(), "invalid database source:") {
101 t.Errorf("Should got error starting with \"invalid database source:\", but got %q", err.Error())
102 }
103 })
104 }
105 }
106
81107 func TestStringPrimaryKey(t *testing.T) {
82108 type UUIDStruct struct {
83109 ID string `gorm:"primary_key"`
156182 DB.Table("deleted_users").Find(&deletedUsers)
157183 if len(deletedUsers) != 1 {
158184 t.Errorf("Query from specified table")
185 }
186
187 var user User
188 DB.Table("deleted_users").First(&user, "name = ?", "DeletedUser")
189
190 user.Age = 20
191 DB.Table("deleted_users").Save(&user)
192 if DB.Table("deleted_users").First(&user, "name = ? AND age = ?", "DeletedUser", 20).RecordNotFound() {
193 t.Errorf("Failed to found updated user")
159194 }
160195
161196 DB.Save(getPreparedUser("normal_user", "reset_table"))
256291 if DB.NewScope([]Cart{}).TableName() != "shopping_cart" {
257292 t.Errorf("[]Cart's singular table name should be shopping_cart")
258293 }
294 DB.SingularTable(false)
295 }
296
297 func TestTableNameConcurrently(t *testing.T) {
298 DB := DB.Model("")
299 if DB.NewScope(Order{}).TableName() != "orders" {
300 t.Errorf("Order's table name should be orders")
301 }
302
303 var wg sync.WaitGroup
304 wg.Add(10)
305
306 for i := 1; i <= 10; i++ {
307 go func(db *gorm.DB) {
308 DB.SingularTable(true)
309 wg.Done()
310 }(DB)
311 }
312 wg.Wait()
313
314 if DB.NewScope(Order{}).TableName() != "order" {
315 t.Errorf("Order's singular table name should be order")
316 }
317
259318 DB.SingularTable(false)
260319 }
261320
376435 if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
377436 t.Errorf("Should be able to find committed record")
378437 }
438
439 tx3 := DB.Begin()
440 u3 := User{Name: "transcation-3"}
441 if err := tx3.Save(&u3).Error; err != nil {
442 t.Errorf("No error should raise")
443 }
444
445 if err := tx3.First(&User{}, "name = ?", "transcation-3").Error; err != nil {
446 t.Errorf("Should find saved record")
447 }
448
449 tx3.RollbackUnlessCommitted()
450
451 if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {
452 t.Errorf("Should not find record after rollback")
453 }
454
455 tx4 := DB.Begin()
456 u4 := User{Name: "transcation-4"}
457 if err := tx4.Save(&u4).Error; err != nil {
458 t.Errorf("No error should raise")
459 }
460
461 if err := tx4.First(&User{}, "name = ?", "transcation-4").Error; err != nil {
462 t.Errorf("Should find saved record")
463 }
464
465 tx4.Commit()
466
467 tx4.RollbackUnlessCommitted()
468
469 if err := DB.First(&User{}, "name = ?", "transcation-4").Error; err != nil {
470 t.Errorf("Should be able to find committed record")
471 }
472 }
473
474 func assertPanic(t *testing.T, f func()) {
475 defer func() {
476 if r := recover(); r == nil {
477 t.Errorf("The code did not panic")
478 }
479 }()
480 f()
481 }
482
483 func TestTransactionWithBlock(t *testing.T) {
484 // rollback
485 err := DB.Transaction(func(tx *gorm.DB) error {
486 u := User{Name: "transcation"}
487 if err := tx.Save(&u).Error; err != nil {
488 t.Errorf("No error should raise")
489 }
490
491 if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
492 t.Errorf("Should find saved record")
493 }
494
495 return errors.New("the error message")
496 })
497
498 if err.Error() != "the error message" {
499 t.Errorf("Transaction return error will equal the block returns error")
500 }
501
502 if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil {
503 t.Errorf("Should not find record after rollback")
504 }
505
506 // commit
507 DB.Transaction(func(tx *gorm.DB) error {
508 u2 := User{Name: "transcation-2"}
509 if err := tx.Save(&u2).Error; err != nil {
510 t.Errorf("No error should raise")
511 }
512
513 if err := tx.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
514 t.Errorf("Should find saved record")
515 }
516 return nil
517 })
518
519 if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
520 t.Errorf("Should be able to find committed record")
521 }
522
523 // panic will rollback
524 assertPanic(t, func() {
525 DB.Transaction(func(tx *gorm.DB) error {
526 u3 := User{Name: "transcation-3"}
527 if err := tx.Save(&u3).Error; err != nil {
528 t.Errorf("No error should raise")
529 }
530
531 if err := tx.First(&User{}, "name = ?", "transcation-3").Error; err != nil {
532 t.Errorf("Should find saved record")
533 }
534
535 panic("force panic")
536 })
537 })
538
539 if err := DB.First(&User{}, "name = ?", "transcation-3").Error; err == nil {
540 t.Errorf("Should not find record after panic rollback")
541 }
542 }
543
544 func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) {
545 tx := DB.Begin()
546 u := User{Name: "transcation"}
547 if err := tx.Save(&u).Error; err != nil {
548 t.Errorf("No error should raise")
549 }
550
551 if err := tx.Commit().Error; err != nil {
552 t.Errorf("Commit should not raise error")
553 }
554
555 if err := tx.Rollback().Error; err != nil {
556 t.Errorf("Rollback should not raise error")
557 }
558 }
559
560 func TestTransactionReadonly(t *testing.T) {
561 dialect := os.Getenv("GORM_DIALECT")
562 if dialect == "" {
563 dialect = "sqlite"
564 }
565 switch dialect {
566 case "mssql", "sqlite":
567 t.Skipf("%s does not support readonly transactions\n", dialect)
568 }
569
570 tx := DB.Begin()
571 u := User{Name: "transcation"}
572 if err := tx.Save(&u).Error; err != nil {
573 t.Errorf("No error should raise")
574 }
575 tx.Commit()
576
577 tx = DB.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
578 if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
579 t.Errorf("Should find saved record")
580 }
581
582 if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil {
583 t.Errorf("Should return the underlying sql.Tx")
584 }
585
586 u = User{Name: "transcation-2"}
587 if err := tx.Save(&u).Error; err == nil {
588 t.Errorf("Error should have been raised in a readonly transaction")
589 }
590
591 tx.Rollback()
379592 }
380593
381594 func TestRow(t *testing.T) {
563776 }
564777 }
565778
779 type JoinedIds struct {
780 UserID int64 `gorm:"column:id"`
781 BillingAddressID int64 `gorm:"column:id"`
782 EmailID int64 `gorm:"column:id"`
783 }
784
785 func TestScanIdenticalColumnNames(t *testing.T) {
786 var user = User{
787 Name: "joinsIds",
788 Email: "joinIds@example.com",
789 BillingAddress: Address{
790 Address1: "One Park Place",
791 },
792 Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
793 }
794 DB.Save(&user)
795
796 var users []JoinedIds
797 DB.Select("users.id, addresses.id, emails.id").Table("users").
798 Joins("left join addresses on users.billing_address_id = addresses.id").
799 Joins("left join emails on emails.user_id = users.id").
800 Where("name = ?", "joinsIds").Scan(&users)
801
802 if len(users) != 2 {
803 t.Fatal("should find two rows using left join")
804 }
805
806 if user.Id != users[0].UserID {
807 t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[0].UserID)
808 }
809 if user.Id != users[1].UserID {
810 t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[1].UserID)
811 }
812
813 if user.BillingAddressID.Int64 != users[0].BillingAddressID {
814 t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID)
815 }
816 if user.BillingAddressID.Int64 != users[1].BillingAddressID {
817 t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID)
818 }
819
820 if users[0].EmailID == users[1].EmailID {
821 t.Errorf("Email ids should be unique. Got %d and %d", users[0].EmailID, users[1].EmailID)
822 }
823
824 if int64(user.Emails[0].Id) != users[0].EmailID && int64(user.Emails[1].Id) != users[0].EmailID {
825 t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[0].EmailID)
826 }
827
828 if int64(user.Emails[0].Id) != users[1].EmailID && int64(user.Emails[1].Id) != users[1].EmailID {
829 t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[1].EmailID)
830 }
831 }
832
566833 func TestJoinsWithSelect(t *testing.T) {
567834 type result struct {
568835 Name string
577844
578845 var results []result
579846 DB.Table("users").Select("name, emails.email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results)
847
848 sort.Slice(results, func(i, j int) bool {
849 return strings.Compare(results[i].Email, results[j].Email) < 0
850 })
851
580852 if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" {
581853 t.Errorf("Should find all two emails with Join select")
582854 }
8611133 }
8621134 }
8631135
1136 func TestSaveAssociations(t *testing.T) {
1137 db := DB.New()
1138 deltaAddressCount := 0
1139 if err := db.Model(&Address{}).Count(&deltaAddressCount).Error; err != nil {
1140 t.Errorf("failed to fetch address count")
1141 t.FailNow()
1142 }
1143
1144 placeAddress := &Address{
1145 Address1: "somewhere on earth",
1146 }
1147 ownerAddress1 := &Address{
1148 Address1: "near place address",
1149 }
1150 ownerAddress2 := &Address{
1151 Address1: "address2",
1152 }
1153 db.Create(placeAddress)
1154
1155 addressCountShouldBe := func(t *testing.T, expectedCount int) {
1156 countFromDB := 0
1157 t.Helper()
1158 err := db.Model(&Address{}).Count(&countFromDB).Error
1159 if err != nil {
1160 t.Error("failed to fetch address count")
1161 }
1162 if countFromDB != expectedCount {
1163 t.Errorf("address count mismatch: %d", countFromDB)
1164 }
1165 }
1166 addressCountShouldBe(t, deltaAddressCount+1)
1167
1168 // owner address should be created, place address should be reused
1169 place1 := &Place{
1170 PlaceAddressID: placeAddress.ID,
1171 PlaceAddress: placeAddress,
1172 OwnerAddress: ownerAddress1,
1173 }
1174 err := db.Create(place1).Error
1175 if err != nil {
1176 t.Errorf("failed to store place: %s", err.Error())
1177 }
1178 addressCountShouldBe(t, deltaAddressCount+2)
1179
1180 // owner address should be created again, place address should be reused
1181 place2 := &Place{
1182 PlaceAddressID: placeAddress.ID,
1183 PlaceAddress: &Address{
1184 ID: 777,
1185 Address1: "address1",
1186 },
1187 OwnerAddress: ownerAddress2,
1188 OwnerAddressID: 778,
1189 }
1190 err = db.Create(place2).Error
1191 if err != nil {
1192 t.Errorf("failed to store place: %s", err.Error())
1193 }
1194 addressCountShouldBe(t, deltaAddressCount+3)
1195
1196 count := 0
1197 db.Model(&Place{}).Where(&Place{
1198 PlaceAddressID: placeAddress.ID,
1199 OwnerAddressID: ownerAddress1.ID,
1200 }).Count(&count)
1201 if count != 1 {
1202 t.Errorf("only one instance of (%d, %d) should be available, found: %d",
1203 placeAddress.ID, ownerAddress1.ID, count)
1204 }
1205
1206 db.Model(&Place{}).Where(&Place{
1207 PlaceAddressID: placeAddress.ID,
1208 OwnerAddressID: ownerAddress2.ID,
1209 }).Count(&count)
1210 if count != 1 {
1211 t.Errorf("only one instance of (%d, %d) should be available, found: %d",
1212 placeAddress.ID, ownerAddress2.ID, count)
1213 }
1214
1215 db.Model(&Place{}).Where(&Place{
1216 PlaceAddressID: placeAddress.ID,
1217 }).Count(&count)
1218 if count != 2 {
1219 t.Errorf("two instances of (%d) should be available, found: %d",
1220 placeAddress.ID, count)
1221 }
1222 }
1223
8641224 func TestBlockGlobalUpdate(t *testing.T) {
8651225 db := DB.New()
8661226 db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
8971257 if err != nil {
8981258 t.Error("Unexpected error on conditional delete")
8991259 }
1260 }
1261
1262 func TestCountWithHaving(t *testing.T) {
1263 db := DB.New()
1264 db.Delete(User{})
1265 defer db.Delete(User{})
1266
1267 DB.Create(getPreparedUser("user1", "pluck_user"))
1268 DB.Create(getPreparedUser("user2", "pluck_user"))
1269 user3 := getPreparedUser("user3", "pluck_user")
1270 user3.Languages = []Language{}
1271 DB.Create(user3)
1272
1273 var count int
1274 err := db.Model(User{}).Select("users.id").
1275 Joins("LEFT JOIN user_languages ON user_languages.user_id = users.id").
1276 Joins("LEFT JOIN languages ON user_languages.language_id = languages.id").
1277 Group("users.id").Having("COUNT(languages.id) > 1").Count(&count).Error
1278
1279 if err != nil {
1280 t.Error("Unexpected error on query count with having")
1281 }
1282
1283 if count != 2 {
1284 t.Error("Unexpected result on query count with having")
1285 }
1286 }
1287
1288 func TestPluck(t *testing.T) {
1289 db := DB.New()
1290 db.Delete(User{})
1291 defer db.Delete(User{})
1292
1293 DB.Create(&User{Id: 1, Name: "user1"})
1294 DB.Create(&User{Id: 2, Name: "user2"})
1295 DB.Create(&User{Id: 3, Name: "user3"})
1296
1297 var ids []int64
1298 err := db.Model(User{}).Order("id").Pluck("id", &ids).Error
1299
1300 if err != nil {
1301 t.Error("Unexpected error on pluck")
1302 }
1303
1304 if len(ids) != 3 || ids[0] != 1 || ids[1] != 2 || ids[2] != 3 {
1305 t.Error("Unexpected result on pluck")
1306 }
1307
1308 err = db.Model(User{}).Order("id").Pluck("id", &ids).Error
1309
1310 if err != nil {
1311 t.Error("Unexpected error on pluck again")
1312 }
1313
1314 if len(ids) != 3 || ids[0] != 1 || ids[1] != 2 || ids[2] != 3 {
1315 t.Error("Unexpected result on pluck again")
1316 }
1317 }
1318
1319 func TestCountWithQueryOption(t *testing.T) {
1320 db := DB.New()
1321 db.Delete(User{})
1322 defer db.Delete(User{})
1323
1324 DB.Create(&User{Name: "user1"})
1325 DB.Create(&User{Name: "user2"})
1326 DB.Create(&User{Name: "user3"})
1327
1328 var count int
1329 err := db.Model(User{}).Select("users.id").
1330 Set("gorm:query_option", "WHERE users.name='user2'").
1331 Count(&count).Error
1332
1333 if err != nil {
1334 t.Error("Unexpected error on query count with query_option")
1335 }
1336
1337 if count != 1 {
1338 t.Error("Unexpected result on query count with query_option")
1339 }
1340 }
1341
1342 func TestSubQueryWithQueryOption(t *testing.T) {
1343 db := DB.New()
1344
1345 subQuery := db.Model(User{}).Select("users.id").
1346 Set("gorm:query_option", "WHERE users.name='user2'").
1347 SubQuery()
1348
1349 matched, _ := regexp.MatchString(
1350 `^&{.+\s+WHERE users\.name='user2'.*\s\[]}$`, fmt.Sprint(subQuery))
1351 if !matched {
1352 t.Error("Unexpected result of SubQuery with query_option")
1353 }
1354 }
1355
1356 func TestQueryExprWithQueryOption(t *testing.T) {
1357 db := DB.New()
1358
1359 queryExpr := db.Model(User{}).Select("users.id").
1360 Set("gorm:query_option", "WHERE users.name='user2'").
1361 QueryExpr()
1362
1363 matched, _ := regexp.MatchString(
1364 `^&{.+\s+WHERE users\.name='user2'.*\s\[]}$`, fmt.Sprint(queryExpr))
1365 if !matched {
1366 t.Error("Unexpected result of QueryExpr with query_option")
1367 }
1368 }
1369
1370 func TestQueryHint1(t *testing.T) {
1371 db := DB.New()
1372
1373 _, err := db.Model(User{}).Raw("select 1").Rows()
1374
1375 if err != nil {
1376 t.Error("Unexpected error on query count with query_option")
1377 }
1378 }
1379
1380 func TestQueryHint2(t *testing.T) {
1381 type TestStruct struct {
1382 ID string `gorm:"primary_key"`
1383 Name string
1384 }
1385 DB.DropTable(&TestStruct{})
1386 DB.AutoMigrate(&TestStruct{})
1387
1388 data := TestStruct{ID: "uuid", Name: "hello"}
1389 if err := DB.Set("gorm:query_hint", "/*master*/").Save(&data).Error; err != nil {
1390 t.Error("Unexpected error on query count with query_option")
1391 }
1392 }
1393
1394 func TestFloatColumnPrecision(t *testing.T) {
1395 if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" && dialect != "sqlite" {
1396 t.Skip()
1397 }
1398
1399 type FloatTest struct {
1400 ID string `gorm:"primary_key"`
1401 FloatValue float64 `gorm:"column:float_value" sql:"type:float(255,5);"`
1402 }
1403 DB.DropTable(&FloatTest{})
1404 DB.AutoMigrate(&FloatTest{})
1405
1406 data := FloatTest{ID: "uuid", FloatValue: 112.57315}
1407 if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.FloatValue != 112.57315 {
1408 t.Errorf("Float value should not lose precision")
1409 }
1410 }
1411
1412 func TestWhereUpdates(t *testing.T) {
1413 type OwnerEntity struct {
1414 gorm.Model
1415 OwnerID uint
1416 OwnerType string
1417 }
1418
1419 type SomeEntity struct {
1420 gorm.Model
1421 Name string
1422 OwnerEntity OwnerEntity `gorm:"polymorphic:Owner"`
1423 }
1424
1425 DB.DropTable(&SomeEntity{})
1426 DB.AutoMigrate(&SomeEntity{})
1427
1428 a := SomeEntity{Name: "test"}
1429 DB.Model(&a).Where(a).Updates(SomeEntity{Name: "test2"})
9001430 }
9011431
9021432 func BenchmarkGorm(b *testing.B) {
117117 Owner *User `sql:"-"`
118118 }
119119
120 type Place struct {
121 Id int64
122 PlaceAddressID int
123 PlaceAddress *Address `gorm:"save_associations:false"`
124 OwnerAddressID int
125 OwnerAddress *Address `gorm:"save_associations:true"`
126 }
127
120128 type EncryptedData []byte
121129
122130 func (data *EncryptedData) Scan(value interface{}) error {
274282 }
275283 }
276284
285 type Panda struct {
286 Number int64 `gorm:"unique_index:number"`
287 Name string `gorm:"column:name;type:varchar(255);default:null"`
288 }
289
277290 func runMigration() {
278291 if err := DB.DropTableIfExists(&User{}).Error; err != nil {
279292 fmt.Printf("Got error when try to delete table users, %+v\n", err)
283296 DB.Exec(fmt.Sprintf("drop table %v;", table))
284297 }
285298
286 values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}}
299 values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}, &Panda{}}
287300 for _, value := range values {
288301 DB.DropTable(value)
289302 }
397410 }
398411 }
399412
413 func TestCreateAndAutomigrateTransaction(t *testing.T) {
414 tx := DB.Begin()
415
416 func() {
417 type Bar struct {
418 ID uint
419 }
420 DB.DropTableIfExists(&Bar{})
421
422 if ok := DB.HasTable("bars"); ok {
423 t.Errorf("Table should not exist, but does")
424 }
425
426 if ok := tx.HasTable("bars"); ok {
427 t.Errorf("Table should not exist, but does")
428 }
429 }()
430
431 func() {
432 type Bar struct {
433 Name string
434 }
435 err := tx.CreateTable(&Bar{}).Error
436
437 if err != nil {
438 t.Errorf("Should have been able to create the table, but couldn't: %s", err)
439 }
440
441 if ok := tx.HasTable(&Bar{}); !ok {
442 t.Errorf("The transaction should be able to see the table")
443 }
444 }()
445
446 func() {
447 type Bar struct {
448 Stuff string
449 }
450
451 err := tx.AutoMigrate(&Bar{}).Error
452 if err != nil {
453 t.Errorf("Should have been able to alter the table, but couldn't")
454 }
455 }()
456
457 tx.Rollback()
458 }
459
400460 type MultipleIndexes struct {
401461 ID int64
402462 UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`
482542 t.Errorf("No error should happen when ModifyColumn, but got %v", err)
483543 }
484544 }
545
546 func TestIndexWithPrefixLength(t *testing.T) {
547 if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" {
548 t.Skip("Skipping this because only mysql support setting an index prefix length")
549 }
550
551 type IndexWithPrefix struct {
552 gorm.Model
553 Name string
554 Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
555 }
556 type IndexesWithPrefix struct {
557 gorm.Model
558 Name string
559 Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
560 Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
561 }
562 type IndexesWithPrefixAndWithoutPrefix struct {
563 gorm.Model
564 Name string `gorm:"index:idx_index_with_prefixes_length"`
565 Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
566 }
567 tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}}
568 for _, table := range tables {
569 scope := DB.NewScope(table)
570 tableName := scope.TableName()
571 t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) {
572 if err := DB.DropTableIfExists(table).Error; err != nil {
573 t.Errorf("Failed to drop %s table: %v", tableName, err)
574 }
575 if err := DB.CreateTable(table).Error; err != nil {
576 t.Errorf("Failed to create %s table: %v", tableName, err)
577 }
578 if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") {
579 t.Errorf("Failed to create %s table index:", tableName)
580 }
581 })
582 }
583 }
1616 return defaultTableName
1717 }
1818
19 type safeModelStructsMap struct {
20 m map[reflect.Type]*ModelStruct
21 l *sync.RWMutex
22 }
23
24 func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) {
19 // lock for mutating global cached model metadata
20 var structsLock sync.Mutex
21
22 // global cache of model metadata
23 var modelStructsMap sync.Map
24
25 // ModelStruct model definition
26 type ModelStruct struct {
27 PrimaryFields []*StructField
28 StructFields []*StructField
29 ModelType reflect.Type
30
31 defaultTableName string
32 l sync.Mutex
33 }
34
35 // TableName returns model's table name
36 func (s *ModelStruct) TableName(db *DB) string {
2537 s.l.Lock()
2638 defer s.l.Unlock()
27 s.m[key] = value
28 }
29
30 func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct {
31 s.l.RLock()
32 defer s.l.RUnlock()
33 return s.m[key]
34 }
35
36 func newModelStructsMap() *safeModelStructsMap {
37 return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)}
38 }
39
40 var modelStructsMap = newModelStructsMap()
41
42 // ModelStruct model definition
43 type ModelStruct struct {
44 PrimaryFields []*StructField
45 StructFields []*StructField
46 ModelType reflect.Type
47 defaultTableName string
48 }
49
50 // TableName get model's table name
51 func (s *ModelStruct) TableName(db *DB) string {
39
5240 if s.defaultTableName == "" && db != nil && s.ModelType != nil {
5341 // Set default table name
5442 if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok {
5543 s.defaultTableName = tabler.TableName()
5644 } else {
57 tableName := ToDBName(s.ModelType.Name())
58 if db == nil || !db.parent.singularTable {
45 tableName := ToTableName(s.ModelType.Name())
46 db.parent.RLock()
47 if db == nil || (db.parent != nil && !db.parent.singularTable) {
5948 tableName = inflection.Plural(tableName)
6049 }
50 db.parent.RUnlock()
6151 s.defaultTableName = tableName
6252 }
6353 }
8070 Struct reflect.StructField
8171 IsForeignKey bool
8272 Relationship *Relationship
83 }
84
85 func (structField *StructField) clone() *StructField {
73
74 tagSettingsLock sync.RWMutex
75 }
76
77 // TagSettingsSet Sets a tag in the tag settings map
78 func (sf *StructField) TagSettingsSet(key, val string) {
79 sf.tagSettingsLock.Lock()
80 defer sf.tagSettingsLock.Unlock()
81 sf.TagSettings[key] = val
82 }
83
84 // TagSettingsGet returns a tag from the tag settings
85 func (sf *StructField) TagSettingsGet(key string) (string, bool) {
86 sf.tagSettingsLock.RLock()
87 defer sf.tagSettingsLock.RUnlock()
88 val, ok := sf.TagSettings[key]
89 return val, ok
90 }
91
92 // TagSettingsDelete deletes a tag
93 func (sf *StructField) TagSettingsDelete(key string) {
94 sf.tagSettingsLock.Lock()
95 defer sf.tagSettingsLock.Unlock()
96 delete(sf.TagSettings, key)
97 }
98
99 func (sf *StructField) clone() *StructField {
86100 clone := &StructField{
87 DBName: structField.DBName,
88 Name: structField.Name,
89 Names: structField.Names,
90 IsPrimaryKey: structField.IsPrimaryKey,
91 IsNormal: structField.IsNormal,
92 IsIgnored: structField.IsIgnored,
93 IsScanner: structField.IsScanner,
94 HasDefaultValue: structField.HasDefaultValue,
95 Tag: structField.Tag,
101 DBName: sf.DBName,
102 Name: sf.Name,
103 Names: sf.Names,
104 IsPrimaryKey: sf.IsPrimaryKey,
105 IsNormal: sf.IsNormal,
106 IsIgnored: sf.IsIgnored,
107 IsScanner: sf.IsScanner,
108 HasDefaultValue: sf.HasDefaultValue,
109 Tag: sf.Tag,
96110 TagSettings: map[string]string{},
97 Struct: structField.Struct,
98 IsForeignKey: structField.IsForeignKey,
99 }
100
101 if structField.Relationship != nil {
102 relationship := *structField.Relationship
111 Struct: sf.Struct,
112 IsForeignKey: sf.IsForeignKey,
113 }
114
115 if sf.Relationship != nil {
116 relationship := *sf.Relationship
103117 clone.Relationship = &relationship
104118 }
105119
106 for key, value := range structField.TagSettings {
120 // copy the struct field tagSettings, they should be read-locked while they are copied
121 sf.tagSettingsLock.Lock()
122 defer sf.tagSettingsLock.Unlock()
123 for key, value := range sf.TagSettings {
107124 clone.TagSettings[key] = value
108125 }
109126
125142
126143 func getForeignField(column string, fields []*StructField) *StructField {
127144 for _, field := range fields {
128 if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) {
145 if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) {
129146 return field
130147 }
131148 }
134151
135152 // GetModelStruct get value's model struct, relationships based on struct and tag definition
136153 func (scope *Scope) GetModelStruct() *ModelStruct {
154 return scope.getModelStruct(scope, make([]*StructField, 0))
155 }
156
157 func (scope *Scope) getModelStruct(rootScope *Scope, allFields []*StructField) *ModelStruct {
137158 var modelStruct ModelStruct
138159 // Scope value can't be nil
139160 if scope.Value == nil {
151172 }
152173
153174 // Get Cached model struct
154 if value := modelStructsMap.Get(reflectType); value != nil {
155 return value
175 isSingularTable := false
176 if scope.db != nil && scope.db.parent != nil {
177 scope.db.parent.RLock()
178 isSingularTable = scope.db.parent.singularTable
179 scope.db.parent.RUnlock()
180 }
181
182 hashKey := struct {
183 singularTable bool
184 reflectType reflect.Type
185 }{isSingularTable, reflectType}
186 if value, ok := modelStructsMap.Load(hashKey); ok && value != nil {
187 return value.(*ModelStruct)
156188 }
157189
158190 modelStruct.ModelType = reflectType
169201 }
170202
171203 // is ignored field
172 if _, ok := field.TagSettings["-"]; ok {
204 if _, ok := field.TagSettingsGet("-"); ok {
173205 field.IsIgnored = true
174206 } else {
175 if _, ok := field.TagSettings["PRIMARY_KEY"]; ok {
207 if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok {
176208 field.IsPrimaryKey = true
177209 modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
178210 }
179211
180 if _, ok := field.TagSettings["DEFAULT"]; ok {
212 if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey {
181213 field.HasDefaultValue = true
182214 }
183215
184 if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsPrimaryKey {
216 if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey {
185217 field.HasDefaultValue = true
186218 }
187219
197229 if indirectType.Kind() == reflect.Struct {
198230 for i := 0; i < indirectType.NumField(); i++ {
199231 for key, value := range parseTagSetting(indirectType.Field(i).Tag) {
200 if _, ok := field.TagSettings[key]; !ok {
201 field.TagSettings[key] = value
232 if _, ok := field.TagSettingsGet(key); !ok {
233 field.TagSettingsSet(key, value)
202234 }
203235 }
204236 }
206238 } else if _, isTime := fieldValue.(*time.Time); isTime {
207239 // is time
208240 field.IsNormal = true
209 } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
241 } else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous {
210242 // is embedded struct
211 for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields {
243 for _, subField := range scope.New(fieldValue).getModelStruct(rootScope, allFields).StructFields {
212244 subField = subField.clone()
213245 subField.Names = append([]string{fieldStruct.Name}, subField.Names...)
214 if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok {
246 if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok {
215247 subField.DBName = prefix + subField.DBName
216248 }
217249
218250 if subField.IsPrimaryKey {
219 if _, ok := subField.TagSettings["PRIMARY_KEY"]; ok {
251 if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok {
220252 modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField)
221253 } else {
222254 subField.IsPrimaryKey = false
232264 }
233265
234266 modelStruct.StructFields = append(modelStruct.StructFields, subField)
267 allFields = append(allFields, subField)
235268 }
236269 continue
237270 } else {
247280 elemType = field.Struct.Type
248281 )
249282
250 if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
283 if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" {
251284 foreignKeys = strings.Split(foreignKey, ",")
252285 }
253286
254 if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
287 if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" {
255288 associationForeignKeys = strings.Split(foreignKey, ",")
256 } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
289 } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" {
257290 associationForeignKeys = strings.Split(foreignKey, ",")
258291 }
259292
262295 }
263296
264297 if elemType.Kind() == reflect.Struct {
265 if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
298 if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" {
266299 relationship.Kind = "many_to_many"
267300
268301 { // Foreign Keys for Source
269302 joinTableDBNames := []string{}
270303
271 if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
304 if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" {
272305 joinTableDBNames = strings.Split(foreignKey, ",")
273306 }
274307
289322 // if defined join table's foreign key
290323 relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
291324 } else {
292 defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
325 defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName
293326 relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
294327 }
295328 }
299332 { // Foreign Keys for Association (Destination)
300333 associationJoinTableDBNames := []string{}
301334
302 if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
335 if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" {
303336 associationJoinTableDBNames = strings.Split(foreignKey, ",")
304337 }
305338
320353 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
321354 } else {
322355 // join table foreign keys for association
323 joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
356 joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName
324357 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
325358 }
326359 }
337370 var toFields = toScope.GetStructFields()
338371 relationship.Kind = "has_many"
339372
340 if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
373 if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" {
341374 // Dog has many toys, tag polymorphic is Owner, then associationType is Owner
342375 // Toy use OwnerID, OwnerType ('dogs') as foreign key
343376 if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
345378 relationship.PolymorphicType = polymorphicType.Name
346379 relationship.PolymorphicDBName = polymorphicType.DBName
347380 // if Dog has multiple set of toys set name of the set (instead of default 'dogs')
348 if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
381 if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok {
349382 relationship.PolymorphicValue = value
350383 } else {
351384 relationship.PolymorphicValue = scope.TableName()
365398 } else {
366399 // generate foreign keys from defined association foreign keys
367400 for _, scopeFieldName := range associationForeignKeys {
368 if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil {
401 if foreignField := getForeignField(scopeFieldName, allFields); foreignField != nil {
369402 foreignKeys = append(foreignKeys, associationType+foreignField.Name)
370403 associationForeignKeys = append(associationForeignKeys, foreignField.Name)
371404 }
377410 for _, foreignKey := range foreignKeys {
378411 if strings.HasPrefix(foreignKey, associationType) {
379412 associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
380 if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
413 if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil {
381414 associationForeignKeys = append(associationForeignKeys, associationForeignKey)
382415 }
383416 }
384417 }
385418 if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
386 associationForeignKeys = []string{scope.PrimaryKey()}
419 associationForeignKeys = []string{rootScope.PrimaryKey()}
387420 }
388421 } else if len(foreignKeys) != len(associationForeignKeys) {
389422 scope.Err(errors.New("invalid foreign keys, should have same length"))
393426
394427 for idx, foreignKey := range foreignKeys {
395428 if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
396 if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil {
397 // source foreign keys
429 if associationField := getForeignField(associationForeignKeys[idx], allFields); associationField != nil {
430 // mark field as foreignkey, use global lock to avoid race
431 structsLock.Lock()
398432 foreignField.IsForeignKey = true
433 structsLock.Unlock()
434
435 // association foreign keys
399436 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
400437 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)
401438
427464 tagAssociationForeignKeys []string
428465 )
429466
430 if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
467 if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" {
431468 tagForeignKeys = strings.Split(foreignKey, ",")
432469 }
433470
434 if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
471 if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" {
435472 tagAssociationForeignKeys = strings.Split(foreignKey, ",")
436 } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
473 } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" {
437474 tagAssociationForeignKeys = strings.Split(foreignKey, ",")
438475 }
439476
440 if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
477 if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" {
441478 // Cat has one toy, tag polymorphic is Owner, then associationType is Owner
442479 // Toy use OwnerID, OwnerType ('cats') as foreign key
443480 if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
445482 relationship.PolymorphicType = polymorphicType.Name
446483 relationship.PolymorphicDBName = polymorphicType.DBName
447484 // if Cat has several different types of toys set name for each (instead of default 'cats')
448 if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
485 if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok {
449486 relationship.PolymorphicValue = value
450487 } else {
451488 relationship.PolymorphicValue = scope.TableName()
469506 } else {
470507 // generate foreign keys form association foreign keys
471508 for _, associationForeignKey := range tagAssociationForeignKeys {
472 if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
509 if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil {
473510 foreignKeys = append(foreignKeys, associationType+foreignField.Name)
474511 associationForeignKeys = append(associationForeignKeys, foreignField.Name)
475512 }
481518 for _, foreignKey := range foreignKeys {
482519 if strings.HasPrefix(foreignKey, associationType) {
483520 associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
484 if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
521 if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil {
485522 associationForeignKeys = append(associationForeignKeys, associationForeignKey)
486523 }
487524 }
488525 }
489526 if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
490 associationForeignKeys = []string{scope.PrimaryKey()}
527 associationForeignKeys = []string{rootScope.PrimaryKey()}
491528 }
492529 } else if len(foreignKeys) != len(associationForeignKeys) {
493530 scope.Err(errors.New("invalid foreign keys, should have same length"))
497534
498535 for idx, foreignKey := range foreignKeys {
499536 if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
500 if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil {
537 if scopeField := getForeignField(associationForeignKeys[idx], allFields); scopeField != nil {
538 // mark field as foreignkey, use global lock to avoid race
539 structsLock.Lock()
501540 foreignField.IsForeignKey = true
502 // source foreign keys
541 structsLock.Unlock()
542
543 // association foreign keys
503544 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name)
504545 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName)
505546
557598 for idx, foreignKey := range foreignKeys {
558599 if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
559600 if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil {
601 // mark field as foreignkey, use global lock to avoid race
602 structsLock.Lock()
560603 foreignField.IsForeignKey = true
604 structsLock.Unlock()
561605
562606 // association foreign keys
563607 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
583627 }
584628
585629 // Even it is ignored, also possible to decode db value into the field
586 if value, ok := field.TagSettings["COLUMN"]; ok {
630 if value, ok := field.TagSettingsGet("COLUMN"); ok {
587631 field.DBName = value
588632 } else {
589 field.DBName = ToDBName(fieldStruct.Name)
633 field.DBName = ToColumnName(fieldStruct.Name)
590634 }
591635
592636 modelStruct.StructFields = append(modelStruct.StructFields, field)
637 allFields = append(allFields, field)
593638 }
594639 }
595640
600645 }
601646 }
602647
603 modelStructsMap.Set(reflectType, &modelStruct)
648 modelStructsMap.Store(hashKey, &modelStruct)
604649
605650 return &modelStruct
606651 }
613658 func parseTagSetting(tags reflect.StructTag) map[string]string {
614659 setting := map[string]string{}
615660 for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
661 if str == "" {
662 continue
663 }
616664 tags := strings.Split(str, ";")
617665 for _, value := range tags {
618666 v := strings.Split(value, ":")
0 package gorm_test
1
2 import (
3 "sync"
4 "testing"
5
6 "github.com/jinzhu/gorm"
7 )
8
9 type ModelA struct {
10 gorm.Model
11 Name string
12
13 ModelCs []ModelC `gorm:"foreignkey:OtherAID"`
14 }
15
16 type ModelB struct {
17 gorm.Model
18 Name string
19
20 ModelCs []ModelC `gorm:"foreignkey:OtherBID"`
21 }
22
23 type ModelC struct {
24 gorm.Model
25 Name string
26
27 OtherAID uint64
28 OtherA *ModelA `gorm:"foreignkey:OtherAID"`
29 OtherBID uint64
30 OtherB *ModelB `gorm:"foreignkey:OtherBID"`
31 }
32
33 type RequestModel struct {
34 Name string
35 Children []ChildModel `gorm:"foreignkey:ParentID"`
36 }
37
38 type ChildModel struct {
39 ID string
40 ParentID string
41 Name string
42 }
43
44 type ResponseModel struct {
45 gorm.Model
46 RequestModel
47 }
48
49 // This test will try to cause a race condition on the model's foreignkey metadata
50 func TestModelStructRaceSameModel(t *testing.T) {
51 // use a WaitGroup to execute as much in-sync as possible
52 // it's more likely to hit a race condition than without
53 n := 32
54 start := sync.WaitGroup{}
55 start.Add(n)
56
57 // use another WaitGroup to know when the test is done
58 done := sync.WaitGroup{}
59 done.Add(n)
60
61 for i := 0; i < n; i++ {
62 go func() {
63 start.Wait()
64
65 // call GetStructFields, this had a race condition before we fixed it
66 DB.NewScope(&ModelA{}).GetStructFields()
67
68 done.Done()
69 }()
70
71 start.Done()
72 }
73
74 done.Wait()
75 }
76
77 // This test will try to cause a race condition on the model's foreignkey metadata
78 func TestModelStructRaceDifferentModel(t *testing.T) {
79 // use a WaitGroup to execute as much in-sync as possible
80 // it's more likely to hit a race condition than without
81 n := 32
82 start := sync.WaitGroup{}
83 start.Add(n)
84
85 // use another WaitGroup to know when the test is done
86 done := sync.WaitGroup{}
87 done.Add(n)
88
89 for i := 0; i < n; i++ {
90 i := i
91 go func() {
92 start.Wait()
93
94 // call GetStructFields, this had a race condition before we fixed it
95 if i%2 == 0 {
96 DB.NewScope(&ModelA{}).GetStructFields()
97 } else {
98 DB.NewScope(&ModelB{}).GetStructFields()
99 }
100
101 done.Done()
102 }()
103
104 start.Done()
105 }
106
107 done.Wait()
108 }
109
110 func TestModelStructEmbeddedHasMany(t *testing.T) {
111 fields := DB.NewScope(&ResponseModel{}).GetStructFields()
112
113 var childrenField *gorm.StructField
114
115 for i := 0; i < len(fields); i++ {
116 field := fields[i]
117
118 if field != nil && field.Name == "Children" {
119 childrenField = field
120 }
121 }
122
123 if childrenField == nil {
124 t.Error("childrenField should not be nil")
125 return
126 }
127
128 if childrenField.Relationship == nil {
129 t.Error("childrenField.Relation should not be nil")
130 return
131 }
132
133 expected := "has_many"
134 actual := childrenField.Relationship.Kind
135
136 if actual != expected {
137 t.Errorf("childrenField.Relationship.Kind should be %v, but was %v", expected, actual)
138 }
139 }
0 package gorm
1
2 import (
3 "bytes"
4 "strings"
5 )
6
7 // Namer is a function type which is given a string and return a string
8 type Namer func(string) string
9
10 // NamingStrategy represents naming strategies
11 type NamingStrategy struct {
12 DB Namer
13 Table Namer
14 Column Namer
15 }
16
17 // TheNamingStrategy is being initialized with defaultNamingStrategy
18 var TheNamingStrategy = &NamingStrategy{
19 DB: defaultNamer,
20 Table: defaultNamer,
21 Column: defaultNamer,
22 }
23
24 // AddNamingStrategy sets the naming strategy
25 func AddNamingStrategy(ns *NamingStrategy) {
26 if ns.DB == nil {
27 ns.DB = defaultNamer
28 }
29 if ns.Table == nil {
30 ns.Table = defaultNamer
31 }
32 if ns.Column == nil {
33 ns.Column = defaultNamer
34 }
35 TheNamingStrategy = ns
36 }
37
38 // DBName alters the given name by DB
39 func (ns *NamingStrategy) DBName(name string) string {
40 return ns.DB(name)
41 }
42
43 // TableName alters the given name by Table
44 func (ns *NamingStrategy) TableName(name string) string {
45 return ns.Table(name)
46 }
47
48 // ColumnName alters the given name by Column
49 func (ns *NamingStrategy) ColumnName(name string) string {
50 return ns.Column(name)
51 }
52
53 // ToDBName convert string to db name
54 func ToDBName(name string) string {
55 return TheNamingStrategy.DBName(name)
56 }
57
58 // ToTableName convert string to table name
59 func ToTableName(name string) string {
60 return TheNamingStrategy.TableName(name)
61 }
62
63 // ToColumnName convert string to db name
64 func ToColumnName(name string) string {
65 return TheNamingStrategy.ColumnName(name)
66 }
67
68 var smap = newSafeMap()
69
70 func defaultNamer(name string) string {
71 const (
72 lower = false
73 upper = true
74 )
75
76 if v := smap.Get(name); v != "" {
77 return v
78 }
79
80 if name == "" {
81 return ""
82 }
83
84 var (
85 value = commonInitialismsReplacer.Replace(name)
86 buf = bytes.NewBufferString("")
87 lastCase, currCase, nextCase, nextNumber bool
88 )
89
90 for i, v := range value[:len(value)-1] {
91 nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z')
92 nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9')
93
94 if i > 0 {
95 if currCase == upper {
96 if lastCase == upper && (nextCase == upper || nextNumber == upper) {
97 buf.WriteRune(v)
98 } else {
99 if value[i-1] != '_' && value[i+1] != '_' {
100 buf.WriteRune('_')
101 }
102 buf.WriteRune(v)
103 }
104 } else {
105 buf.WriteRune(v)
106 if i == len(value)-2 && (nextCase == upper && nextNumber == lower) {
107 buf.WriteRune('_')
108 }
109 }
110 } else {
111 currCase = upper
112 buf.WriteRune(v)
113 }
114 lastCase = currCase
115 currCase = nextCase
116 }
117
118 buf.WriteByte(value[len(value)-1])
119
120 s := strings.ToLower(buf.String())
121 smap.Set(name, s)
122 return s
123 }
0 package gorm_test
1
2 import (
3 "testing"
4
5 "github.com/jinzhu/gorm"
6 )
7
8 func TestTheNamingStrategy(t *testing.T) {
9
10 cases := []struct {
11 name string
12 namer gorm.Namer
13 expected string
14 }{
15 {name: "auth", expected: "auth", namer: gorm.TheNamingStrategy.DB},
16 {name: "userRestrictions", expected: "user_restrictions", namer: gorm.TheNamingStrategy.Table},
17 {name: "clientID", expected: "client_id", namer: gorm.TheNamingStrategy.Column},
18 }
19
20 for _, c := range cases {
21 t.Run(c.name, func(t *testing.T) {
22 result := c.namer(c.name)
23 if result != c.expected {
24 t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result)
25 }
26 })
27 }
28
29 }
30
31 func TestNamingStrategy(t *testing.T) {
32
33 dbNameNS := func(name string) string {
34 return "db_" + name
35 }
36 tableNameNS := func(name string) string {
37 return "tbl_" + name
38 }
39 columnNameNS := func(name string) string {
40 return "col_" + name
41 }
42
43 ns := &gorm.NamingStrategy{
44 DB: dbNameNS,
45 Table: tableNameNS,
46 Column: columnNameNS,
47 }
48
49 cases := []struct {
50 name string
51 namer gorm.Namer
52 expected string
53 }{
54 {name: "auth", expected: "db_auth", namer: ns.DB},
55 {name: "user", expected: "tbl_user", namer: ns.Table},
56 {name: "password", expected: "col_password", namer: ns.Column},
57 }
58
59 for _, c := range cases {
60 t.Run(c.name, func(t *testing.T) {
61 result := c.namer(c.name)
62 if result != c.expected {
63 t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result)
64 }
65 })
66 }
67
68 }
122122 }
123123 }
124124
125 func TestAutoPreloadFalseDoesntPreload(t *testing.T) {
126 user1 := getPreloadUser("auto_user1")
127 DB.Save(user1)
128
129 preloadDB := DB.Set("gorm:auto_preload", false).Where("role = ?", "Preload")
130 var user User
131 preloadDB.Find(&user)
132
133 if user.BillingAddress.Address1 != "" {
134 t.Error("AutoPreload was set to fasle, but still fetched data")
135 }
136
137 user2 := getPreloadUser("auto_user2")
138 DB.Save(user2)
139
140 var users []User
141 preloadDB.Find(&users)
142
143 for _, user := range users {
144 if user.BillingAddress.Address1 != "" {
145 t.Error("AutoPreload was set to fasle, but still fetched data")
146 }
147 }
148 }
149
125150 func TestNestedPreload1(t *testing.T) {
126151 type (
127152 Level1 struct {
745770 levelB3 := &LevelB3{
746771 Value: "bar",
747772 LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)},
773 LevelB2s: []*LevelB2{},
748774 }
749775 if err := DB.Create(levelB3).Error; err != nil {
750776 t.Error(err)
16501676 lvl := Level1{
16511677 Name: "l1",
16521678 Level2s: []Level2{
1653 Level2{Name: "l2-1"}, Level2{Name: "l2-2"},
1679 {Name: "l2-1"}, {Name: "l2-2"},
16541680 },
16551681 }
16561682 DB.Save(&lvl)
3737
3838 if user.Email != "" {
3939 t.Errorf("User's Email should be blank as no one set it")
40 }
41 }
42
43 func TestQueryWithAssociation(t *testing.T) {
44 user := &User{Name: "user1", Emails: []Email{{Email: "user1@example.com"}}, Company: Company{Name: "company"}}
45
46 if err := DB.Create(&user).Error; err != nil {
47 t.Fatalf("errors happened when create user: %v", err)
48 }
49
50 user.CreatedAt = time.Time{}
51 user.UpdatedAt = time.Time{}
52 if err := DB.Where(&user).First(&User{}).Error; err != nil {
53 t.Errorf("search with struct with association should returns no error, but got %v", err)
54 }
55
56 if err := DB.Where(user).First(&User{}).Error; err != nil {
57 t.Errorf("search with struct with association should returns no error, but got %v", err)
4058 }
4159 }
4260
180198
181199 scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users)
182200 if len(users) != 2 {
183 t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users))
201 t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users))
184202 }
185203
186204 scopedb.Where("birthday > ?", "2002-10-10").Find(&users)
187205 if len(users) != 2 {
188 t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users))
206 t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users))
189207 }
190208
191209 scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users)
192210 if len(users) != 1 {
193 t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
211 t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
194212 }
195213
196214 DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users)
456474 }
457475 }
458476
477 func TestLimitAndOffsetSQL(t *testing.T) {
478 user1 := User{Name: "TestLimitAndOffsetSQL1", Age: 10}
479 user2 := User{Name: "TestLimitAndOffsetSQL2", Age: 20}
480 user3 := User{Name: "TestLimitAndOffsetSQL3", Age: 30}
481 user4 := User{Name: "TestLimitAndOffsetSQL4", Age: 40}
482 user5 := User{Name: "TestLimitAndOffsetSQL5", Age: 50}
483 if err := DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5).Error; err != nil {
484 t.Fatal(err)
485 }
486
487 tests := []struct {
488 name string
489 limit, offset interface{}
490 users []*User
491 ok bool
492 }{
493 {
494 name: "OK",
495 limit: float64(2),
496 offset: float64(2),
497 users: []*User{
498 &User{Name: "TestLimitAndOffsetSQL3", Age: 30},
499 &User{Name: "TestLimitAndOffsetSQL2", Age: 20},
500 },
501 ok: true,
502 },
503 {
504 name: "Limit parse error",
505 limit: float64(1000000), // 1e+06
506 offset: float64(2),
507 ok: false,
508 },
509 {
510 name: "Offset parse error",
511 limit: float64(2),
512 offset: float64(1000000), // 1e+06
513 ok: false,
514 },
515 }
516
517 for _, tt := range tests {
518 t.Run(tt.name, func(t *testing.T) {
519 var users []*User
520 err := DB.Where("name LIKE ?", "TestLimitAndOffsetSQL%").Order("age desc").Limit(tt.limit).Offset(tt.offset).Find(&users).Error
521 if tt.ok {
522 if err != nil {
523 t.Errorf("error expected nil, but got %v", err)
524 }
525 if len(users) != len(tt.users) {
526 t.Errorf("users length expected %d, but got %d", len(tt.users), len(users))
527 }
528 for i := range tt.users {
529 if users[i].Name != tt.users[i].Name {
530 t.Errorf("users[%d] name expected %s, but got %s", i, tt.users[i].Name, users[i].Name)
531 }
532 if users[i].Age != tt.users[i].Age {
533 t.Errorf("users[%d] age expected %d, but got %d", i, tt.users[i].Age, users[i].Age)
534 }
535 }
536 } else {
537 if err == nil {
538 t.Error("error expected not nil, but got nil")
539 }
540 }
541 })
542 }
543 }
544
459545 func TestOr(t *testing.T) {
460546 user1 := User{Name: "OrUser1", Age: 1}
461547 user2 := User{Name: "OrUser2", Age: 10}
531617 DB.Table("users").Where("name = ?", "user3").Count(&name3Count)
532618 DB.Not("name", "user3").Find(&users4)
533619 if len(users1)-len(users4) != int(name3Count) {
534 t.Errorf("Should find all users's name not equal 3")
620 t.Errorf("Should find all users' name not equal 3")
535621 }
536622
537623 DB.Not("name = ?", "user3").Find(&users4)
538624 if len(users1)-len(users4) != int(name3Count) {
539 t.Errorf("Should find all users's name not equal 3")
625 t.Errorf("Should find all users' name not equal 3")
540626 }
541627
542628 DB.Not("name <> ?", "user3").Find(&users4)
543629 if len(users4) != int(name3Count) {
544 t.Errorf("Should find all users's name not equal 3")
630 t.Errorf("Should find all users' name not equal 3")
545631 }
546632
547633 DB.Not(User{Name: "user3"}).Find(&users5)
548634
549635 if len(users1)-len(users5) != int(name3Count) {
550 t.Errorf("Should find all users's name not equal 3")
636 t.Errorf("Should find all users' name not equal 3")
551637 }
552638
553639 DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6)
554640 if len(users1)-len(users6) != int(name3Count) {
555 t.Errorf("Should find all users's name not equal 3")
641 t.Errorf("Should find all users' name not equal 3")
556642 }
557643
558644 DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7)
562648
563649 DB.Not("name", []string{"user3"}).Find(&users8)
564650 if len(users1)-len(users8) != int(name3Count) {
565 t.Errorf("Should find all users's name not equal 3")
651 t.Errorf("Should find all users' name not equal 3")
566652 }
567653
568654 var name2Count int64
569655 DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
570656 DB.Not("name", []string{"user3", "user2"}).Find(&users9)
571657 if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) {
572 t.Errorf("Should find all users's name not equal 3")
658 t.Errorf("Should find all users' name not equal 3")
573659 }
574660 }
575661
6262
6363 // Dialect get dialect
6464 func (scope *Scope) Dialect() Dialect {
65 return scope.db.parent.dialect
65 return scope.db.dialect
6666 }
6767
6868 // Quote used to quote string to escape them for database
6969 func (scope *Scope) Quote(str string) string {
70 if strings.Index(str, ".") != -1 {
70 if strings.Contains(str, ".") {
7171 newStrs := []string{}
7272 for _, str := range strings.Split(str, ".") {
7373 newStrs = append(newStrs, scope.Dialect().Quote(str))
133133 // FieldByName find `gorm.Field` with field name or db name
134134 func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
135135 var (
136 dbName = ToDBName(name)
136 dbName = ToColumnName(name)
137137 mostMatchedField *Field
138138 )
139139
224224 updateAttrs[field.DBName] = value
225225 return field.Set(value)
226226 }
227 if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) {
227 if !field.IsIgnored && ((field.DBName == dbName) || (field.Name == name && mostMatchedField == nil)) {
228228 mostMatchedField = field
229229 }
230230 }
256256 func (scope *Scope) AddToVars(value interface{}) string {
257257 _, skipBindVar := scope.InstanceGet("skip_bindvar")
258258
259 if expr, ok := value.(*expr); ok {
259 if expr, ok := value.(*SqlExpr); ok {
260260 exp := expr.expr
261261 for _, arg := range expr.args {
262262 if skipBindVar {
329329 // QuotedTableName return quoted table name
330330 func (scope *Scope) QuotedTableName() (name string) {
331331 if scope.Search != nil && len(scope.Search.tableName) > 0 {
332 if strings.Index(scope.Search.tableName, " ") != -1 {
332 if strings.Contains(scope.Search.tableName, " ") {
333333 return scope.Search.tableName
334334 }
335335 return scope.Quote(scope.Search.tableName)
401401 // Begin start a transaction
402402 func (scope *Scope) Begin() *Scope {
403403 if db, ok := scope.SQLDB().(sqlDb); ok {
404 if tx, err := db.Begin(); err == nil {
404 if tx, err := db.Begin(); scope.Err(err) == nil {
405405 scope.db.db = interface{}(tx).(SQLCommon)
406406 scope.InstanceSet("gorm:started_transaction", true)
407407 }
485485 values[index] = &ignored
486486
487487 selectFields = fields
488 offset := 0
488489 if idx, ok := selectedColumnsMap[column]; ok {
489 selectFields = selectFields[idx+1:]
490 offset = idx + 1
491 selectFields = selectFields[offset:]
490492 }
491493
492494 for fieldIndex, field := range selectFields {
500502 resetFields[index] = field
501503 }
502504
503 selectedColumnsMap[column] = fieldIndex
505 selectedColumnsMap[column] = offset + fieldIndex
504506
505507 if field.IsNormal {
506508 break
585587 scope.Err(fmt.Errorf("invalid query condition: %v", value))
586588 return
587589 }
588
590 scopeQuotedTableName := newScope.QuotedTableName()
589591 for _, field := range newScope.Fields() {
590 if !field.IsIgnored && !field.IsBlank {
591 sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
592 if !field.IsIgnored && !field.IsBlank && field.Relationship == nil {
593 sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
592594 }
593595 }
594596 return strings.Join(sqls, " AND ")
691693
692694 buff := bytes.NewBuffer([]byte{})
693695 i := 0
694 for pos := range str {
696 for pos, char := range str {
695697 if str[pos] == '?' {
696698 buff.WriteString(replacements[i])
697699 i++
698700 } else {
699 buff.WriteByte(str[pos])
701 buff.WriteRune(char)
700702 }
701703 }
702704
782784 for _, order := range scope.Search.orders {
783785 if str, ok := order.(string); ok {
784786 orders = append(orders, scope.quoteIfPossible(str))
785 } else if expr, ok := order.(*expr); ok {
787 } else if expr, ok := order.(*SqlExpr); ok {
786788 exp := expr.expr
787789 for _, arg := range expr.args {
788790 exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
794796 }
795797
796798 func (scope *Scope) limitAndOffsetSQL() string {
797 return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
799 sql, err := scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
800 scope.Err(err)
801 return sql
798802 }
799803
800804 func (scope *Scope) groupSQL() string {
836840 }
837841
838842 func (scope *Scope) prepareQuerySQL() {
843 var sql string
839844 if scope.Search.raw {
840 scope.Raw(scope.CombinedConditionSql())
845 sql = scope.CombinedConditionSql()
841846 } else {
842 scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
843 }
844 return
847 sql = fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())
848 }
849 if str, ok := scope.Get("gorm:query_option"); ok {
850 sql += addExtraSpaceIfExist(fmt.Sprint(str))
851 }
852 scope.Raw(sql)
845853 }
846854
847855 func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
852860 }
853861
854862 func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
863 defer func() {
864 if err := recover(); err != nil {
865 if db, ok := scope.db.db.(sqlTx); ok {
866 db.Rollback()
867 }
868 panic(err)
869 }
870 }()
855871 for _, f := range funcs {
856872 (*f)(scope)
857873 if scope.skipLeft {
861877 return scope
862878 }
863879
864 func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string]interface{} {
880 func convertInterfaceToMap(values interface{}, withIgnoredField bool, db *DB) map[string]interface{} {
865881 var attrs = map[string]interface{}{}
866882
867883 switch value := values.(type) {
869885 return value
870886 case []interface{}:
871887 for _, v := range value {
872 for key, value := range convertInterfaceToMap(v, withIgnoredField) {
888 for key, value := range convertInterfaceToMap(v, withIgnoredField, db) {
873889 attrs[key] = value
874890 }
875891 }
879895 switch reflectValue.Kind() {
880896 case reflect.Map:
881897 for _, key := range reflectValue.MapKeys() {
882 attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
898 attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
883899 }
884900 default:
885 for _, field := range (&Scope{Value: values}).Fields() {
901 for _, field := range (&Scope{Value: values, db: db}).Fields() {
886902 if !field.IsBlank && (withIgnoredField || !field.IsIgnored) {
887903 attrs[field.DBName] = field.Field.Interface()
888904 }
894910
895911 func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) {
896912 if scope.IndirectValue().Kind() != reflect.Struct {
897 return convertInterfaceToMap(value, false), true
913 return convertInterfaceToMap(value, false, scope.db), true
898914 }
899915
900916 results = map[string]interface{}{}
901917
902 for key, value := range convertInterfaceToMap(value, true) {
903 if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
904 if _, ok := value.(*expr); ok {
905 hasUpdate = true
906 results[field.DBName] = value
907 } else {
908 err := field.Set(value)
909 if field.IsNormal {
918 for key, value := range convertInterfaceToMap(value, true, scope.db) {
919 if field, ok := scope.FieldByName(key); ok {
920 if scope.changeableField(field) {
921 if _, ok := value.(*SqlExpr); ok {
910922 hasUpdate = true
911 if err == ErrUnaddressable {
912 results[field.DBName] = value
913 } else {
914 results[field.DBName] = field.Field.Interface()
923 results[field.DBName] = value
924 } else {
925 err := field.Set(value)
926 if field.IsNormal && !field.IsIgnored {
927 hasUpdate = true
928 if err == ErrUnaddressable {
929 results[field.DBName] = value
930 } else {
931 results[field.DBName] = field.Field.Interface()
932 }
915933 }
916934 }
917935 }
936 } else {
937 results[key] = value
918938 }
919939 }
920940 return
971991 if dest.Kind() != reflect.Slice {
972992 scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
973993 return scope
994 }
995
996 if dest.Len() > 0 {
997 dest.Set(reflect.Zero(dest.Type()))
974998 }
975999
9761000 if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) {
9961020 func (scope *Scope) count(value interface{}) *Scope {
9971021 if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) {
9981022 if len(scope.Search.group) != 0 {
999 scope.Search.Select("count(*) FROM ( SELECT count(*) as name ")
1000 scope.Search.group += " ) AS count_table"
1023 if len(scope.Search.havingConditions) != 0 {
1024 scope.prepareQuerySQL()
1025 scope.Search = &search{}
1026 scope.Search.Select("count(*)")
1027 scope.Search.Table(fmt.Sprintf("( %s ) AS count_table", scope.SQL))
1028 } else {
1029 scope.Search.Select("count(*) FROM ( SELECT count(*) as name ")
1030 scope.Search.group += " ) AS count_table"
1031 }
10011032 } else {
10021033 scope.Search.Select("count(*)")
10031034 }
11121143 if field, ok := scope.FieldByName(fieldName); ok {
11131144 foreignKeyStruct := field.clone()
11141145 foreignKeyStruct.IsPrimaryKey = false
1115 foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
1116 delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT")
1146 foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true")
1147 foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT")
11171148 sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
11181149 primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
11191150 }
11231154 if field, ok := toScope.FieldByName(fieldName); ok {
11241155 foreignKeyStruct := field.clone()
11251156 foreignKeyStruct.IsPrimaryKey = false
1126 foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
1127 delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT")
1157 foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true")
1158 foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT")
11281159 sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
11291160 primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
11301161 }
11721203 }
11731204
11741205 func (scope *Scope) dropTable() *Scope {
1175 scope.Raw(fmt.Sprintf("DROP TABLE %v%s", scope.QuotedTableName(), scope.getTableOptions())).Exec()
1206 scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
11761207 return scope
11771208 }
11781209
12141245 }
12151246
12161247 func (scope *Scope) removeForeignKey(field string, dest string) {
1217 keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest)
1218
1248 keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
12191249 if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
12201250 return
12211251 }
1222 var query = `ALTER TABLE %s DROP CONSTRAINT %s;`
1252 var mysql mysql
1253 var query string
1254 if scope.Dialect().GetName() == mysql.GetName() {
1255 query = `ALTER TABLE %s DROP FOREIGN KEY %s;`
1256 } else {
1257 query = `ALTER TABLE %s DROP CONSTRAINT %s;`
1258 }
1259
12231260 scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
12241261 }
12251262
12531290 var uniqueIndexes = map[string][]string{}
12541291
12551292 for _, field := range scope.GetStructFields() {
1256 if name, ok := field.TagSettings["INDEX"]; ok {
1293 if name, ok := field.TagSettingsGet("INDEX"); ok {
12571294 names := strings.Split(name, ",")
12581295
12591296 for _, name := range names {
12601297 if name == "INDEX" || name == "" {
12611298 name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName)
12621299 }
1263 indexes[name] = append(indexes[name], field.DBName)
1264 }
1265 }
1266
1267 if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok {
1300 name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName)
1301 indexes[name] = append(indexes[name], column)
1302 }
1303 }
1304
1305 if name, ok := field.TagSettingsGet("UNIQUE_INDEX"); ok {
12681306 names := strings.Split(name, ",")
12691307
12701308 for _, name := range names {
12711309 if name == "UNIQUE_INDEX" || name == "" {
12721310 name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName)
12731311 }
1274 uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
1312 name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName)
1313 uniqueIndexes[name] = append(uniqueIndexes[name], column)
12751314 }
12761315 }
12771316 }
12921331 }
12931332
12941333 func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
1334 resultMap := make(map[string][]interface{})
12951335 for _, value := range values {
12961336 indirectValue := indirect(reflect.ValueOf(value))
12971337
13101350 }
13111351
13121352 if hasValue {
1313 results = append(results, result)
1353 h := fmt.Sprint(result...)
1354 if _, exist := resultMap[h]; !exist {
1355 resultMap[h] = result
1356 }
13141357 }
13151358 }
13161359 case reflect.Struct:
13251368 }
13261369
13271370 if hasValue {
1328 results = append(results, result)
1329 }
1330 }
1331 }
1332
1371 h := fmt.Sprint(result...)
1372 if _, exist := resultMap[h]; !exist {
1373 resultMap[h] = result
1374 }
1375 }
1376 }
1377 }
1378 for _, v := range resultMap {
1379 results = append(results, v)
1380 }
13331381 return
13341382 }
13351383
7777 t.Errorf("The error should be returned from Valuer, but get %v", err)
7878 }
7979 }
80
81 func TestDropTableWithTableOptions(t *testing.T) {
82 type UserWithOptions struct {
83 gorm.Model
84 }
85 DB.AutoMigrate(&UserWithOptions{})
86
87 DB = DB.Set("gorm:table_options", "CHARSET=utf8")
88 err := DB.DropTable(&UserWithOptions{}).Error
89 if err != nil {
90 t.Errorf("Table must be dropped, got error %s", err)
91 }
92 }
3131 }
3232
3333 func (s *search) clone() *search {
34 clone := *s
34 clone := search{
35 db: s.db,
36 whereConditions: make([]map[string]interface{}, len(s.whereConditions)),
37 orConditions: make([]map[string]interface{}, len(s.orConditions)),
38 notConditions: make([]map[string]interface{}, len(s.notConditions)),
39 havingConditions: make([]map[string]interface{}, len(s.havingConditions)),
40 joinConditions: make([]map[string]interface{}, len(s.joinConditions)),
41 initAttrs: make([]interface{}, len(s.initAttrs)),
42 assignAttrs: make([]interface{}, len(s.assignAttrs)),
43 selects: s.selects,
44 omits: make([]string, len(s.omits)),
45 orders: make([]interface{}, len(s.orders)),
46 preload: make([]searchPreload, len(s.preload)),
47 offset: s.offset,
48 limit: s.limit,
49 group: s.group,
50 tableName: s.tableName,
51 raw: s.raw,
52 Unscoped: s.Unscoped,
53 ignoreOrderQuery: s.ignoreOrderQuery,
54 }
55 for i, value := range s.whereConditions {
56 clone.whereConditions[i] = value
57 }
58 for i, value := range s.orConditions {
59 clone.orConditions[i] = value
60 }
61 for i, value := range s.notConditions {
62 clone.notConditions[i] = value
63 }
64 for i, value := range s.havingConditions {
65 clone.havingConditions[i] = value
66 }
67 for i, value := range s.joinConditions {
68 clone.joinConditions[i] = value
69 }
70 for i, value := range s.initAttrs {
71 clone.initAttrs[i] = value
72 }
73 for i, value := range s.assignAttrs {
74 clone.assignAttrs[i] = value
75 }
76 for i, value := range s.omits {
77 clone.omits[i] = value
78 }
79 for i, value := range s.orders {
80 clone.orders[i] = value
81 }
82 for i, value := range s.preload {
83 clone.preload[i] = value
84 }
3585 return &clone
3686 }
3787
97147 }
98148
99149 func (s *search) Having(query interface{}, values ...interface{}) *search {
100 if val, ok := query.(*expr); ok {
150 if val, ok := query.(*SqlExpr); ok {
101151 s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args})
102152 } else {
103153 s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values})
00 package gorm
11
22 import (
3 "fmt"
34 "reflect"
45 "testing"
56 )
2728 t.Errorf("selectStr should be copied")
2829 }
2930 }
31
32 func TestWhereCloneCorruption(t *testing.T) {
33 for whereCount := 1; whereCount <= 8; whereCount++ {
34 t.Run(fmt.Sprintf("w=%d", whereCount), func(t *testing.T) {
35 s := new(search)
36 for w := 0; w < whereCount; w++ {
37 s = s.clone().Where(fmt.Sprintf("w%d = ?", w), fmt.Sprintf("value%d", w))
38 }
39 if len(s.whereConditions) != whereCount {
40 t.Errorf("s: where count should be %d", whereCount)
41 }
42
43 q1 := s.clone().Where("finalThing = ?", "THING1")
44 q2 := s.clone().Where("finalThing = ?", "THING2")
45
46 if reflect.DeepEqual(q1.whereConditions, q2.whereConditions) {
47 t.Errorf("Where conditions should be different")
48 }
49 })
50 }
51 }
00 package gorm
11
22 import (
3 "bytes"
43 "database/sql/driver"
54 "fmt"
65 "reflect"
2524 var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
2625 var commonInitialismsReplacer *strings.Replacer
2726
28 var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)
29 var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`)
27 var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`)
28 var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`)
3029
3130 func init() {
3231 var commonInitialismsForReplacer []string
5756 return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
5857 }
5958
60 var smap = newSafeMap()
61
62 type strCase bool
63
64 const (
65 lower strCase = false
66 upper strCase = true
67 )
68
69 // ToDBName convert string to db name
70 func ToDBName(name string) string {
71 if v := smap.Get(name); v != "" {
72 return v
73 }
74
75 if name == "" {
76 return ""
77 }
78
79 var (
80 value = commonInitialismsReplacer.Replace(name)
81 buf = bytes.NewBufferString("")
82 lastCase, currCase, nextCase strCase
83 )
84
85 for i, v := range value[:len(value)-1] {
86 nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z')
87 if i > 0 {
88 if currCase == upper {
89 if lastCase == upper && nextCase == upper {
90 buf.WriteRune(v)
91 } else {
92 if value[i-1] != '_' && value[i+1] != '_' {
93 buf.WriteRune('_')
94 }
95 buf.WriteRune(v)
96 }
97 } else {
98 buf.WriteRune(v)
99 if i == len(value)-2 && nextCase == upper {
100 buf.WriteRune('_')
101 }
102 }
103 } else {
104 currCase = upper
105 buf.WriteRune(v)
106 }
107 lastCase = currCase
108 currCase = nextCase
109 }
110
111 buf.WriteByte(value[len(value)-1])
112
113 s := strings.ToLower(buf.String())
114 smap.Set(name, s)
115 return s
116 }
117
11859 // SQL expression
119 type expr struct {
60 type SqlExpr struct {
12061 expr string
12162 args []interface{}
12263 }
12364
12465 // Expr generate raw SQL expression, for example:
12566 // DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100))
126 func Expr(expression string, args ...interface{}) *expr {
127 return &expr{expr: expression, args: args}
67 func Expr(expression string, args ...interface{}) *SqlExpr {
68 return &SqlExpr{expr: expression, args: args}
12869 }
12970
13071 func indirect(reflectValue reflect.Value) reflect.Value {
264205 // as FieldByName could panic
265206 if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
266207 for _, fieldName := range fieldNames {
267 if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() {
208 if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() {
268209 result := fieldValue.Interface()
269210 if r, ok := result.(driver.Valuer); ok {
270211 result, _ = r.Value()
+0
-32
utils_test.go less more
0 package gorm_test
1
2 import (
3 "testing"
4
5 "github.com/jinzhu/gorm"
6 )
7
8 func TestToDBNameGenerateFriendlyName(t *testing.T) {
9 var maps = map[string]string{
10 "": "",
11 "X": "x",
12 "ThisIsATest": "this_is_a_test",
13 "PFAndESI": "pf_and_esi",
14 "AbcAndJkl": "abc_and_jkl",
15 "EmployeeID": "employee_id",
16 "SKU_ID": "sku_id",
17 "FieldX": "field_x",
18 "HTTPAndSMTP": "http_and_smtp",
19 "HTTPServerHandlerForURLID": "http_server_handler_for_url_id",
20 "UUID": "uuid",
21 "HTTPURL": "http_url",
22 "HTTP_URL": "http_url",
23 "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
24 }
25
26 for key, value := range maps {
27 if gorm.ToDBName(key) != value {
28 t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key))
29 }
30 }
31 }
33 services:
44 - name: mariadb
55 id: mariadb:latest
6 env:
7 MYSQL_DATABASE: gorm
8 MYSQL_USER: gorm
9 MYSQL_PASSWORD: gorm
10 MYSQL_RANDOM_ROOT_PASSWORD: "yes"
11 - name: mysql
12 id: mysql:latest
613 env:
714 MYSQL_DATABASE: gorm
815 MYSQL_USER: gorm
1724 MYSQL_RANDOM_ROOT_PASSWORD: "yes"
1825 - name: mysql56
1926 id: mysql:5.6
20 env:
21 MYSQL_DATABASE: gorm
22 MYSQL_USER: gorm
23 MYSQL_PASSWORD: gorm
24 MYSQL_RANDOM_ROOT_PASSWORD: "yes"
25 - name: mysql55
26 id: mysql:5.5
2727 env:
2828 MYSQL_DATABASE: gorm
2929 MYSQL_USER: gorm
8282 code: |
8383 cd $WERCKER_SOURCE_DIR
8484 go version
85 go get -t ./...
85 go get -t -v ./...
8686
8787 # Build the project
8888 - script:
9494 - script:
9595 name: test sqlite
9696 code: |
97 go test ./...
97 go test -race -v ./...
9898
9999 - script:
100100 name: test mariadb
101101 code: |
102 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./...
102 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
103
104 - script:
105 name: test mysql
106 code: |
107 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
103108
104109 - script:
105110 name: test mysql5.7
106111 code: |
107 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test ./...
112 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
108113
109114 - script:
110115 name: test mysql5.6
111116 code: |
112 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test ./...
113
114 - script:
115 name: test mysql5.5
116 code: |
117 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test ./...
117 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
118118
119119 - script:
120120 name: test postgres
121121 code: |
122 GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
122 GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
123123
124124 - script:
125125 name: test postgres96
126126 code: |
127 GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
127 GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
128128
129129 - script:
130130 name: test postgres95
131131 code: |
132 GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
132 GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
133133
134134 - script:
135135 name: test postgres94
136136 code: |
137 GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
137 GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
138138
139139 - script:
140140 name: test postgres93
141141 code: |
142 GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
142 GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
143143
144144 - script:
145 name: test mssql
145 name: codecov
146146 code: |
147 GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test ./...
147 go test -race -coverprofile=coverage.txt -covermode=atomic ./...
148 bash <(curl -s https://codecov.io/bash)