Codebase list golang-github-jinzhu-gorm / 87c6039
New upstream version 1.0+git20180218.58e3472 Michael Stapelberg 6 years ago
73 changed file(s) with 9776 addition(s) and 5745 deletion(s). Raw diff Collapse all Expand all
0 Your issue may already be reported! Please search on the [issue track](https://github.com/jinzhu/gorm/issues) before creating one.
1
2 ### What version of Go are you using (`go version`)?
3
4
5 ### Which database and its version are you using?
6
7
8 ### Please provide a complete runnable program to reproduce your issue. **IMPORTANT**
9
10 Need to runnable with [GORM's docker compose config](https://github.com/jinzhu/gorm/blob/master/docker-compose.yml) or please provides your config.
11
12 ```go
13 package main
14
15 import (
16 "github.com/jinzhu/gorm"
17 _ "github.com/jinzhu/gorm/dialects/mssql"
18 _ "github.com/jinzhu/gorm/dialects/mysql"
19 _ "github.com/jinzhu/gorm/dialects/postgres"
20 _ "github.com/jinzhu/gorm/dialects/sqlite"
21 )
22
23 var db *gorm.DB
24
25 func init() {
26 var err error
27 db, err = gorm.Open("sqlite3", "test.db")
28 // db, err = gorm.Open("postgres", "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable")
29 // db, err = gorm.Open("mysql", "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True")
30 // db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm")
31 if err != nil {
32 panic(err)
33 }
34 db.LogMode(true)
35 }
36
37 func main() {
38 if /* failure condition */ {
39 fmt.Println("failed")
40 } else {
41 fmt.Println("success")
42 }
43 }
44 ```
0 Make sure these boxes checked before submitting your pull request.
1
2 - [] Do only one thing
3 - [] No API-breaking changes
4 - [] New code/logic commented & tested
5
6 For significant changes like big bug fixes, new features, please open an issue to make an agreement on an implementation design/plan first before starting it.
7
8 ### What did this pull request do?
0 documents
1 _book
00 # GORM
1
2 [![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
31
42 The fantastic ORM library for Golang, aims to be developer friendly.
53
6 [![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921)
4 [![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm)
5 [![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b)
6 [![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
7 [![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm)
8 [![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm)
9 [![MIT license](http://img.shields.io/badge/license-MIT-brightgreen.svg)](http://opensource.org/licenses/MIT)
10 [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm)
711
812 ## Overview
913
1014 * Full-Featured ORM (almost)
11 * Chainable API
12 * Auto Migrations
13 * Relations (Has One, Has Many, Belongs To, Many To Many, [Polymorphism](#polymorphism))
14 * Callbacks (Before/After Create/Save/Update/Delete/Find)
15 * Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism)
16 * Hooks (Before/After Create/Save/Update/Delete/Find)
1517 * Preloading (eager loading)
1618 * Transactions
17 * Embed Anonymous Struct
18 * Soft Deletes
19 * Customizable Logger
20 * Iteration Support via [Rows](#row--rows)
19 * Composite Primary Key
20 * SQL Builder
21 * Auto Migrations
22 * Logger
23 * Extendable, write Plugins based on GORM callbacks
2124 * Every feature comes with tests
2225 * Developer Friendly
2326
24 # Getting Started
27 ## Getting Started
2528
26 ## Install
29 * GORM Guides [http://gorm.io](http://gorm.io)
2730
28 ```
29 go get -u github.com/jinzhu/gorm
30 ```
31 ## Contributing
3132
32 ## Documentation
33
34 [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm)
35
36 `go doc` format documentation for this project can be viewed online without
37 installing the package by using the GoDoc page at:
38 http://godoc.org/github.com/jinzhu/gorm
39
40 ## Table of Contents
41
42 - [Define Models (Structs)](#define-models-structs)
43 - [Conventions](#conventions)
44 - [Initialize Database](#initialize-database)
45 - [Migration](#migration)
46 - [Basic CRUD](#basic-crud)
47 - [Create](#create-record)
48 - [Query](#query)
49 - [Query With Where (Plain SQL)](#query-with-where-plain-sql)
50 - [Query With Where (Struct & Map)](#query-with-where-struct--map)
51 - [Query With Not](#query-with-not)
52 - [Query With Inline Condition](#query-with-inline-condition)
53 - [Query With Or](#query-with-or)
54 - [Query Chains](#query-chains)
55 - [Preloading (Eager loading)](#preloading-eager-loading)
56 - [Update](#update)
57 - [Update Without Callbacks](#update-without-callbacks)
58 - [Batch Updates](#batch-updates)
59 - [Update with SQL Expression](#update-with-sql-expression)
60 - [Delete](#delete)
61 - [Batch Delete](#batch-delete)
62 - [Soft Delete](#soft-delete)
63 - [Associations](#associations)
64 - [Has One](#has-one)
65 - [Belongs To](#belongs-to)
66 - [Has Many](#has-many)
67 - [Many To Many](#many-to-many)
68 - [Polymorphism](#polymorphism)
69 - [Advanced Usage](#advanced-usage)
70 - [FirstOrInit](#firstorinit)
71 - [FirstOrCreate](#firstorcreate)
72 - [Select](#select)
73 - [Order](#order)
74 - [Limit](#limit)
75 - [Offset](#offset)
76 - [Count](#count)
77 - [Pluck](#pluck)
78 - [Raw SQL](#raw-sql)
79 - [Row & Rows](#row--rows)
80 - [Scan](#scan)
81 - [Group & Having](#group--having)
82 - [Joins](#joins)
83 - [Transactions](#transactions)
84 - [Scopes](#scopes)
85 - [Callbacks](#callbacks)
86 - [Specifying The Table Name](#specifying-the-table-name)
87 - [Error Handling](#error-handling)
88 - [Logger](#logger)
89 - [Existing Schema](#existing-schema)
90 - [Composite Primary Key](#composite-primary-key)
91 - [Database Indexes & Foreign Key](#database-indexes--foreign-key)
92 - [Default values](#default-values)
93 - [More examples with query chain](#more-examples-with-query-chain)
94
95 ## Define Models (Structs)
96
97 ```go
98 type User struct {
99 ID int
100 Birthday time.Time
101 Age int
102 Name string `sql:"size:255"` // Default size for string is 255, you could reset it with this tag
103 Num int `sql:"AUTO_INCREMENT"`
104 CreatedAt time.Time
105 UpdatedAt time.Time
106 DeletedAt *time.Time
107
108 Emails []Email // One-To-Many relationship (has many)
109 BillingAddress Address // One-To-One relationship (has one)
110 BillingAddressID sql.NullInt64 // Foreign key of BillingAddress
111 ShippingAddress Address // One-To-One relationship (has one)
112 ShippingAddressID int // Foreign key of ShippingAddress
113 IgnoreMe int `sql:"-"` // Ignore this field
114 Languages []Language `gorm:"many2many:user_languages;"` // Many-To-Many relationship, 'user_languages' is join table
115 }
116
117 type Email struct {
118 ID int
119 UserID int `sql:"index"` // Foreign key (belongs to), tag `index` will create index for this field when using AutoMigrate
120 Email string `sql:"type:varchar(100);unique_index"` // Set field's sql type, tag `unique_index` will create unique index
121 Subscribed bool
122 }
123
124 type Address struct {
125 ID int
126 Address1 string `sql:"not null;unique"` // Set field as not nullable and unique
127 Address2 string `sql:"type:varchar(100);unique"`
128 Post sql.NullString `sql:"not null"`
129 }
130
131 type Language struct {
132 ID int
133 Name string `sql:"index:idx_name_code"` // Create index with name, and will create combined index if find other fields defined same name
134 Code string `sql:"index:idx_name_code"` // `unique_index` also works
135 }
136 ```
137
138 ## Conventions
139
140 * Table name is the plural of struct name's snake case, you can disable pluralization with `db.SingularTable(true)`, or [Specifying The Table Name For A Struct Permanently With TableName](#specifying-the-table-name-for-a-struct-permanently-with-tablename)
141
142 ```go
143 type User struct{} // struct User's database table name is "users" by default, will be "user" if you disabled pluralisation
144 ```
145
146 * Column name is the snake case of field's name
147 * Use `ID` field as primary key
148 * Use `CreatedAt` to store record's created time if field exists
149 * Use `UpdatedAt` to store record's updated time if field exists
150 * Use `DeletedAt` to store record's deleted time if field exists [Soft Delete](#soft-delete)
151 * Gorm provide a default model struct, you could embed it in your struct
152
153 ```go
154 type Model struct {
155 ID uint `gorm:"primary_key"`
156 CreatedAt time.Time
157 UpdatedAt time.Time
158 DeletedAt *time.Time
159 }
160
161 type User struct {
162 gorm.Model
163 Name string
164 }
165 ```
166
167 ## Initialize Database
168
169 ```go
170 import (
171 "github.com/jinzhu/gorm"
172 _ "github.com/lib/pq"
173 _ "github.com/go-sql-driver/mysql"
174 _ "github.com/mattn/go-sqlite3"
175 )
176
177 db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
178 // db, err := gorm.Open("foundation", "dbname=gorm") // FoundationDB.
179 // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
180 // db, err := gorm.Open("sqlite3", "/tmp/gorm.db")
181
182 // You can also use an existing database connection handle
183 // dbSql, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
184 // db, _ := gorm.Open("postgres", dbSql)
185
186 // Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB)
187 db.DB()
188
189 // Then you could invoke `*sql.DB`'s functions with it
190 db.DB().Ping()
191 db.DB().SetMaxIdleConns(10)
192 db.DB().SetMaxOpenConns(100)
193
194 // Disable table name's pluralization
195 db.SingularTable(true)
196 ```
197
198 ## Migration
199
200 ```go
201 // Create table
202 db.CreateTable(&User{})
203 db.Set("gorm:table_options", "ENGINE=InnoDB").CreateTable(&User{})
204
205 // Drop table
206 db.DropTable(&User{})
207
208 // ModifyColumn
209 db.Model(&User{}).ModifyColumn("description", "text")
210
211 // DropColumn
212 db.Model(&User{}).DropColumn("description")
213
214 // Automating Migration
215 db.AutoMigrate(&User{})
216 db.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(&User{})
217 db.AutoMigrate(&User{}, &Product{}, &Order{})
218 // Feel free to change your struct, AutoMigrate will keep your database up-to-date.
219 // AutoMigrate will ONLY add *new columns* and *new indexes*,
220 // WON'T update current column's type or delete unused columns, to protect your data.
221 // If the table is not existing, AutoMigrate will create the table automatically.
222 ```
223
224 # Basic CRUD
225
226 ## Create Record
227
228 ```go
229 user := User{Name: "Jinzhu", Age: 18, Birthday: time.Now()}
230
231 db.NewRecord(user) // => returns `true` if primary key is blank
232
233 db.Create(&user)
234
235 db.NewRecord(user) // => return `false` after `user` created
236
237 // Associations will be inserted automatically when save the record
238 user := User{
239 Name: "jinzhu",
240 BillingAddress: Address{Address1: "Billing Address - Address 1"},
241 ShippingAddress: Address{Address1: "Shipping Address - Address 1"},
242 Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}},
243 Languages: []Language{{Name: "ZH"}, {Name: "EN"}},
244 }
245
246 db.Create(&user)
247 //// BEGIN TRANSACTION;
248 //// INSERT INTO "addresses" (address1) VALUES ("Billing Address - Address 1");
249 //// INSERT INTO "addresses" (address1) VALUES ("Shipping Address - Address 1");
250 //// INSERT INTO "users" (name,billing_address_id,shipping_address_id) VALUES ("jinzhu", 1, 2);
251 //// INSERT INTO "emails" (user_id,email) VALUES (111, "jinzhu@example.com");
252 //// INSERT INTO "emails" (user_id,email) VALUES (111, "jinzhu-2@example.com");
253 //// INSERT INTO "languages" ("name") VALUES ('ZH');
254 //// INSERT INTO user_languages ("user_id","language_id") VALUES (111, 1);
255 //// INSERT INTO "languages" ("name") VALUES ('EN');
256 //// INSERT INTO user_languages ("user_id","language_id") VALUES (111, 2);
257 //// COMMIT;
258 ```
259
260 Refer [Associations](#associations) for more details
261
262 ## Query
263
264 ```go
265 // Get the first record
266 db.First(&user)
267 //// SELECT * FROM users ORDER BY id LIMIT 1;
268
269 // Get the last record
270 db.Last(&user)
271 //// SELECT * FROM users ORDER BY id DESC LIMIT 1;
272
273 // Get all records
274 db.Find(&users)
275 //// SELECT * FROM users;
276
277 // Get record with primary key
278 db.First(&user, 10)
279 //// SELECT * FROM users WHERE id = 10;
280 ```
281
282 ### Query With Where (Plain SQL)
283
284 ```go
285 // Get the first matched record
286 db.Where("name = ?", "jinzhu").First(&user)
287 //// SELECT * FROM users WHERE name = 'jinzhu' limit 1;
288
289 // Get all matched records
290 db.Where("name = ?", "jinzhu").Find(&users)
291 //// SELECT * FROM users WHERE name = 'jinzhu';
292
293 db.Where("name <> ?", "jinzhu").Find(&users)
294
295 // IN
296 db.Where("name in (?)", []string{"jinzhu", "jinzhu 2"}).Find(&users)
297
298 // LIKE
299 db.Where("name LIKE ?", "%jin%").Find(&users)
300
301 // AND
302 db.Where("name = ? and age >= ?", "jinzhu", "22").Find(&users)
303
304 // Time
305 db.Where("updated_at > ?", lastWeek).Find(&users)
306
307 db.Where("created_at BETWEEN ? AND ?", lastWeek, today).Find(&users)
308 ```
309
310 ### Query With Where (Struct & Map)
311
312 ```go
313 // Struct
314 db.Where(&User{Name: "jinzhu", Age: 20}).First(&user)
315 //// SELECT * FROM users WHERE name = "jinzhu" AND age = 20 LIMIT 1;
316
317 // Map
318 db.Where(map[string]interface{}{"name": "jinzhu", "age": 20}).Find(&users)
319 //// SELECT * FROM users WHERE name = "jinzhu" AND age = 20;
320
321 // Slice of primary keys
322 db.Where([]int64{20, 21, 22}).Find(&users)
323 //// SELECT * FROM users WHERE id IN (20, 21, 22);
324 ```
325
326 ### Query With Not
327
328 ```go
329 db.Not("name", "jinzhu").First(&user)
330 //// SELECT * FROM users WHERE name <> "jinzhu" LIMIT 1;
331
332 // Not In
333 db.Not("name", []string{"jinzhu", "jinzhu 2"}).Find(&users)
334 //// SELECT * FROM users WHERE name NOT IN ("jinzhu", "jinzhu 2");
335
336 // Not In slice of primary keys
337 db.Not([]int64{1,2,3}).First(&user)
338 //// SELECT * FROM users WHERE id NOT IN (1,2,3);
339
340 db.Not([]int64{}).First(&user)
341 //// SELECT * FROM users;
342
343 // Plain SQL
344 db.Not("name = ?", "jinzhu").First(&user)
345 //// SELECT * FROM users WHERE NOT(name = "jinzhu");
346
347 // Struct
348 db.Not(User{Name: "jinzhu"}).First(&user)
349 //// SELECT * FROM users WHERE name <> "jinzhu";
350 ```
351
352 ### Query With Inline Condition
353
354 ```go
355 // Get by primary key
356 db.First(&user, 23)
357 //// SELECT * FROM users WHERE id = 23 LIMIT 1;
358
359 // Plain SQL
360 db.Find(&user, "name = ?", "jinzhu")
361 //// SELECT * FROM users WHERE name = "jinzhu";
362
363 db.Find(&users, "name <> ? AND age > ?", "jinzhu", 20)
364 //// SELECT * FROM users WHERE name <> "jinzhu" AND age > 20;
365
366 // Struct
367 db.Find(&users, User{Age: 20})
368 //// SELECT * FROM users WHERE age = 20;
369
370 // Map
371 db.Find(&users, map[string]interface{}{"age": 20})
372 //// SELECT * FROM users WHERE age = 20;
373 ```
374
375 ### Query With Or
376
377 ```go
378 db.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&users)
379 //// SELECT * FROM users WHERE role = 'admin' OR role = 'super_admin';
380
381 // Struct
382 db.Where("name = 'jinzhu'").Or(User{Name: "jinzhu 2"}).Find(&users)
383 //// SELECT * FROM users WHERE name = 'jinzhu' OR name = 'jinzhu 2';
384
385 // Map
386 db.Where("name = 'jinzhu'").Or(map[string]interface{}{"name": "jinzhu 2"}).Find(&users)
387 ```
388
389 ### Query Chains
390
391 Gorm has a chainable API, you could use it like this
392
393 ```go
394 db.Where("name <> ?","jinzhu").Where("age >= ? and role <> ?",20,"admin").Find(&users)
395 //// SELECT * FROM users WHERE name <> 'jinzhu' AND age >= 20 AND role <> 'admin';
396
397 db.Where("role = ?", "admin").Or("role = ?", "super_admin").Not("name = ?", "jinzhu").Find(&users)
398 ```
399
400 ### Preloading (Eager loading)
401
402 ```go
403 db.Preload("Orders").Find(&users)
404 //// SELECT * FROM users;
405 //// SELECT * FROM orders WHERE user_id IN (1,2,3,4);
406
407 db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
408 //// SELECT * FROM users;
409 //// SELECT * FROM orders WHERE user_id IN (1,2,3,4) AND state NOT IN ('cancelled');
410
411 db.Where("state = ?", "active").Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
412 //// SELECT * FROM users WHERE state = 'active';
413 //// SELECT * FROM orders WHERE user_id IN (1,2) AND state NOT IN ('cancelled');
414
415 db.Preload("Orders").Preload("Profile").Preload("Role").Find(&users)
416 //// SELECT * FROM users;
417 //// SELECT * FROM orders WHERE user_id IN (1,2,3,4); // has many
418 //// SELECT * FROM profiles WHERE user_id IN (1,2,3,4); // has one
419 //// SELECT * FROM roles WHERE id IN (4,5,6); // belongs to
420 ```
421
422 #### Nested Preloading
423
424 ```go
425 db.Preload("Orders.OrderItems").Find(&users)
426 db.Preload("Orders", "state = ?", "paid").Preload("Orders.OrderItems").Find(&users)
427 ```
428
429 ## Update
430
431 ```go
432 // Update an existing struct
433 db.First(&user)
434 user.Name = "jinzhu 2"
435 user.Age = 100
436 db.Save(&user)
437 //// UPDATE users SET name='jinzhu 2', age=100, updated_at = '2013-11-17 21:34:10' WHERE id=111;
438
439 db.Where("active = ?", true).Save(&user)
440 //// UPDATE users SET name='jinzhu 2', age=100, updated_at = '2013-11-17 21:34:10' WHERE id=111 AND active = true;
441
442 // Update an attribute if it is changed
443 db.Model(&user).Update("name", "hello")
444 //// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111;
445
446 db.Model(&user).Where("active = ?", true).Update("name", "hello")
447 //// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111 AND active = true;
448
449 db.First(&user, 111).Update("name", "hello")
450 //// SELECT * FROM users LIMIT 1;
451 //// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111;
452
453 // Update multiple attributes if they are changed
454 db.Model(&user).Updates(map[string]interface{}{"name": "hello", "age": 18, "actived": false})
455
456 // Update multiple attributes if they are changed (update with struct only works with none zero values)
457 db.Model(&user).Updates(User{Name: "hello", Age: 18})
458 //// UPDATE users SET name='hello', age=18, updated_at = '2013-11-17 21:34:10' WHERE id = 111;
459 ```
460
461 ### Update Without Callbacks
462
463 By default, update will call BeforeUpdate, AfterUpdate callbacks, if you want to update w/o callbacks and w/o saving associations:
464
465 ```go
466 db.Model(&user).UpdateColumn("name", "hello")
467 //// UPDATE users SET name='hello' WHERE id = 111;
468
469 // Update with struct only works with none zero values, or use map[string]interface{}
470 db.Model(&user).UpdateColumns(User{Name: "hello", Age: 18})
471 //// UPDATE users SET name='hello', age=18 WHERE id = 111;
472 ```
473
474 ### Batch Updates
475
476 ```go
477 db.Table("users").Where("id = ?", 10).Updates(map[string]interface{}{"name": "hello", "age": 18})
478 //// UPDATE users SET name='hello', age=18 WHERE id = 10;
479
480 // Update with struct only works with none zero values, or use map[string]interface{}
481 db.Model(User{}).Updates(User{Name: "hello", Age: 18})
482 //// UPDATE users SET name='hello', age=18;
483
484 // Callbacks won't run when do batch updates
485
486 // Use `RowsAffected` to get the count of affected records
487 db.Model(User{}).Updates(User{Name: "hello", Age: 18}).RowsAffected
488 ```
489
490 ### Update with SQL Expression
491
492 ```go
493 DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100))
494 //// UPDATE "products" SET "code" = 'L1212', "price" = price * '2' + '100', "updated_at" = '2013-11-17 21:34:10' WHERE "id" = '2';
495
496 DB.Model(&product).Updates(map[string]interface{}{"price": gorm.Expr("price * ? + ?", 2, 100)})
497 //// UPDATE "products" SET "code" = 'L1212', "price" = price * '2' + '100', "updated_at" = '2013-11-17 21:34:10' WHERE "id" = '2';
498
499 DB.Model(&product).UpdateColumn("quantity", gorm.Expr("quantity - ?", 1))
500 //// UPDATE "products" SET "quantity" = quantity - 1 WHERE "id" = '2';
501
502 DB.Model(&product).Where("quantity > 1").UpdateColumn("quantity", gorm.Expr("quantity - ?", 1))
503 //// UPDATE "products" SET "quantity" = quantity - 1 WHERE "id" = '2' AND quantity > 1;
504 ```
505
506 ## Delete
507
508 ```go
509 // Delete an existing record
510 db.Delete(&email)
511 //// DELETE from emails where id=10;
512 ```
513
514 ### Batch Delete
515
516 ```go
517 db.Where("email LIKE ?", "%jinzhu%").Delete(Email{})
518 //// DELETE from emails where email LIKE "%jinhu%";
519 ```
520
521 ### Soft Delete
522
523 If struct has `DeletedAt` field, it will get soft delete ability automatically!
524 Then it won't be deleted from database permanently when call `Delete`.
525
526 ```go
527 db.Delete(&user)
528 //// UPDATE users SET deleted_at="2013-10-29 10:23" WHERE id = 111;
529
530 // Batch Delete
531 db.Where("age = ?", 20).Delete(&User{})
532 //// UPDATE users SET deleted_at="2013-10-29 10:23" WHERE age = 20;
533
534 // Soft deleted records will be ignored when query them
535 db.Where("age = 20").Find(&user)
536 //// SELECT * FROM users WHERE age = 20 AND (deleted_at IS NULL OR deleted_at <= '0001-01-02');
537
538 // Find soft deleted records with Unscoped
539 db.Unscoped().Where("age = 20").Find(&users)
540 //// SELECT * FROM users WHERE age = 20;
541
542 // Delete record permanently with Unscoped
543 db.Unscoped().Delete(&order)
544 //// DELETE FROM orders WHERE id=10;
545 ```
546
547 ## Associations
548
549 ### Has One
550
551 ```go
552 // User has one address
553 db.Model(&user).Related(&address)
554 //// SELECT * FROM addresses WHERE id = 123; // 123 is user's foreign key AddressId
555
556 // Specify the foreign key
557 db.Model(&user).Related(&address1, "BillingAddressId")
558 //// SELECT * FROM addresses WHERE id = 123; // 123 is user's foreign key BillingAddressId
559 ```
560
561 ### Belongs To
562
563 ```go
564 // Email belongs to user
565 db.Model(&email).Related(&user)
566 //// SELECT * FROM users WHERE id = 111; // 111 is email's foreign key UserId
567
568 // Specify the foreign key
569 db.Model(&email).Related(&user, "ProfileId")
570 //// SELECT * FROM users WHERE id = 111; // 111 is email's foreign key ProfileId
571 ```
572
573 ### Has Many
574
575 ```go
576 // User has many emails
577 db.Model(&user).Related(&emails)
578 //// SELECT * FROM emails WHERE user_id = 111;
579 // user_id is the foreign key, 111 is user's primary key's value
580
581 // Specify the foreign key
582 db.Model(&user).Related(&emails, "ProfileId")
583 //// SELECT * FROM emails WHERE profile_id = 111;
584 // profile_id is the foreign key, 111 is user's primary key's value
585 ```
586
587 ### Many To Many
588
589 ```go
590 // User has many languages and belongs to many languages
591 db.Model(&user).Related(&languages, "Languages")
592 //// SELECT * FROM "languages" INNER JOIN "user_languages" ON "user_languages"."language_id" = "languages"."id" WHERE "user_languages"."user_id" = 111
593 // `Languages` is user's column name, this column's tag defined join table like this `gorm:"many2many:user_languages;"`
594 ```
595
596 There is also a mode used to handle many to many relations easily
597
598 ```go
599 // Query
600 db.Model(&user).Association("Languages").Find(&languages)
601 // same as `db.Model(&user).Related(&languages, "Languages")`
602
603 db.Where("name = ?", "ZH").First(&languageZH)
604 db.Where("name = ?", "EN").First(&languageEN)
605
606 // Append
607 db.Model(&user).Association("Languages").Append([]Language{languageZH, languageEN})
608 db.Model(&user).Association("Languages").Append([]Language{{Name: "DE"}})
609 db.Model(&user).Association("Languages").Append(Language{Name: "DE"})
610
611 // Delete
612 db.Model(&user).Association("Languages").Delete([]Language{languageZH, languageEN})
613 db.Model(&user).Association("Languages").Delete(languageZH, languageEN)
614
615 // Replace
616 db.Model(&user).Association("Languages").Replace([]Language{languageZH, languageEN})
617 db.Model(&user).Association("Languages").Replace(Language{Name: "DE"}, languageEN)
618
619 // Count
620 db.Model(&user).Association("Languages").Count()
621 // Return the count of languages the user has
622
623 // Clear
624 db.Model(&user).Association("Languages").Clear()
625 // Remove all relations between the user and languages
626 ```
627
628 ### Polymorphism
629
630 Supports polymorphic has-many and has-one associations.
631
632 ```go
633 type Cat struct {
634 Id int
635 Name string
636 Toy Toy `gorm:"polymorphic:Owner;"`
637 }
638
639 type Dog struct {
640 Id int
641 Name string
642 Toy Toy `gorm:"polymorphic:Owner;"`
643 }
644
645 type Toy struct {
646 Id int
647 Name string
648 OwnerId int
649 OwnerType string
650 }
651 ```
652 Note: polymorphic belongs-to and many-to-many are explicitly NOT supported, and will throw errors.
653
654 ## Advanced Usage
655
656 ## FirstOrInit
657
658 Get the first matched record, or initialize a record with search conditions.
659
660 ```go
661 // Unfound
662 db.FirstOrInit(&user, User{Name: "non_existing"})
663 //// user -> User{Name: "non_existing"}
664
665 // Found
666 db.Where(User{Name: "Jinzhu"}).FirstOrInit(&user)
667 //// user -> User{Id: 111, Name: "Jinzhu", Age: 20}
668 db.FirstOrInit(&user, map[string]interface{}{"name": "jinzhu"})
669 //// user -> User{Id: 111, Name: "Jinzhu", Age: 20}
670 ```
671
672 ### Attrs
673
674 Ignore some values when searching, but use them to initialize the struct if record is not found.
675
676 ```go
677 // Unfound
678 db.Where(User{Name: "non_existing"}).Attrs(User{Age: 20}).FirstOrInit(&user)
679 //// SELECT * FROM USERS WHERE name = 'non_existing';
680 //// user -> User{Name: "non_existing", Age: 20}
681
682 db.Where(User{Name: "noexisting_user"}).Attrs("age", 20).FirstOrInit(&user)
683 //// SELECT * FROM USERS WHERE name = 'non_existing';
684 //// user -> User{Name: "non_existing", Age: 20}
685
686 // Found
687 db.Where(User{Name: "Jinzhu"}).Attrs(User{Age: 30}).FirstOrInit(&user)
688 //// SELECT * FROM USERS WHERE name = jinzhu';
689 //// user -> User{Id: 111, Name: "Jinzhu", Age: 20}
690 ```
691
692 ### Assign
693
694 Ignore some values when searching, but assign it to the result regardless it is found or not.
695
696 ```go
697 // Unfound
698 db.Where(User{Name: "non_existing"}).Assign(User{Age: 20}).FirstOrInit(&user)
699 //// user -> User{Name: "non_existing", Age: 20}
700
701 // Found
702 db.Where(User{Name: "Jinzhu"}).Assign(User{Age: 30}).FirstOrInit(&user)
703 //// SELECT * FROM USERS WHERE name = jinzhu';
704 //// user -> User{Id: 111, Name: "Jinzhu", Age: 30}
705 ```
706
707 ## FirstOrCreate
708
709 Get the first matched record, or create with search conditions.
710
711 ```go
712 // Unfound
713 db.FirstOrCreate(&user, User{Name: "non_existing"})
714 //// INSERT INTO "users" (name) VALUES ("non_existing");
715 //// user -> User{Id: 112, Name: "non_existing"}
716
717 // Found
718 db.Where(User{Name: "Jinzhu"}).FirstOrCreate(&user)
719 //// user -> User{Id: 111, Name: "Jinzhu"}
720 ```
721
722 ### Attrs
723
724 Ignore some values when searching, but use them to create the struct if record is not found. like `FirstOrInit`
725
726 ```go
727 // Unfound
728 db.Where(User{Name: "non_existing"}).Attrs(User{Age: 20}).FirstOrCreate(&user)
729 //// SELECT * FROM users WHERE name = 'non_existing';
730 //// INSERT INTO "users" (name, age) VALUES ("non_existing", 20);
731 //// user -> User{Id: 112, Name: "non_existing", Age: 20}
732
733 // Found
734 db.Where(User{Name: "jinzhu"}).Attrs(User{Age: 30}).FirstOrCreate(&user)
735 //// SELECT * FROM users WHERE name = 'jinzhu';
736 //// user -> User{Id: 111, Name: "jinzhu", Age: 20}
737 ```
738
739 ### Assign
740
741 Ignore some values when searching, but assign it to the record regardless it is found or not, then save back to database. like `FirstOrInit`
742
743 ```go
744 // Unfound
745 db.Where(User{Name: "non_existing"}).Assign(User{Age: 20}).FirstOrCreate(&user)
746 //// SELECT * FROM users WHERE name = 'non_existing';
747 //// INSERT INTO "users" (name, age) VALUES ("non_existing", 20);
748 //// user -> User{Id: 112, Name: "non_existing", Age: 20}
749
750 // Found
751 db.Where(User{Name: "jinzhu"}).Assign(User{Age: 30}).FirstOrCreate(&user)
752 //// SELECT * FROM users WHERE name = 'jinzhu';
753 //// UPDATE users SET age=30 WHERE id = 111;
754 //// user -> User{Id: 111, Name: "jinzhu", Age: 30}
755 ```
756
757 ## Select
758
759 ```go
760 db.Select("name, age").Find(&users)
761 //// SELECT name, age FROM users;
762
763 db.Select([]string{"name", "age"}).Find(&users)
764 //// SELECT name, age FROM users;
765
766 db.Table("users").Select("COALESCE(age,?)", 42).Rows()
767 //// SELECT COALESCE(age,'42') FROM users;
768 ```
769
770 ## Order
771
772 ```go
773 db.Order("age desc, name").Find(&users)
774 //// SELECT * FROM users ORDER BY age desc, name;
775
776 // Multiple orders
777 db.Order("age desc").Order("name").Find(&users)
778 //// SELECT * FROM users ORDER BY age desc, name;
779
780 // ReOrder
781 db.Order("age desc").Find(&users1).Order("age", true).Find(&users2)
782 //// SELECT * FROM users ORDER BY age desc; (users1)
783 //// SELECT * FROM users ORDER BY age; (users2)
784 ```
785
786 ## Limit
787
788 ```go
789 db.Limit(3).Find(&users)
790 //// SELECT * FROM users LIMIT 3;
791
792 // Cancel limit condition with -1
793 db.Limit(10).Find(&users1).Limit(-1).Find(&users2)
794 //// SELECT * FROM users LIMIT 10; (users1)
795 //// SELECT * FROM users; (users2)
796 ```
797
798 ## Offset
799
800 ```go
801 db.Offset(3).Find(&users)
802 //// SELECT * FROM users OFFSET 3;
803
804 // Cancel offset condition with -1
805 db.Offset(10).Find(&users1).Offset(-1).Find(&users2)
806 //// SELECT * FROM users OFFSET 10; (users1)
807 //// SELECT * FROM users; (users2)
808 ```
809
810 ## Count
811
812 ```go
813 db.Where("name = ?", "jinzhu").Or("name = ?", "jinzhu 2").Find(&users).Count(&count)
814 //// SELECT * from USERS WHERE name = 'jinzhu' OR name = 'jinzhu 2'; (users)
815 //// SELECT count(*) FROM users WHERE name = 'jinzhu' OR name = 'jinzhu 2'; (count)
816
817 db.Model(User{}).Where("name = ?", "jinzhu").Count(&count)
818 //// SELECT count(*) FROM users WHERE name = 'jinzhu'; (count)
819
820 db.Table("deleted_users").Count(&count)
821 //// SELECT count(*) FROM deleted_users;
822 ```
823
824 ## Pluck
825
826 Get selected attributes as map
827
828 ```go
829 var ages []int64
830 db.Find(&users).Pluck("age", &ages)
831
832 var names []string
833 db.Model(&User{}).Pluck("name", &names)
834
835 db.Table("deleted_users").Pluck("name", &names)
836
837 // Requesting more than one column? Do it like this:
838 db.Select("name, age").Find(&users)
839 ```
840
841 ## Raw SQL
842
843 ```go
844 db.Exec("DROP TABLE users;")
845 db.Exec("UPDATE orders SET shipped_at=? WHERE id IN (?)", time.Now, []int64{11,22,33})
846 ```
847
848 ## Row & Rows
849
850 It is even possible to get query result as `*sql.Row` or `*sql.Rows`
851
852 ```go
853 row := db.Table("users").Where("name = ?", "jinzhu").Select("name, age").Row() // (*sql.Row)
854 row.Scan(&name, &age)
855
856 rows, err := db.Model(User{}).Where("name = ?", "jinzhu").Select("name, age, email").Rows() // (*sql.Rows, error)
857 defer rows.Close()
858 for rows.Next() {
859 ...
860 rows.Scan(&name, &age, &email)
861 ...
862 }
863
864 // Raw SQL
865 rows, err := db.Raw("select name, age, email from users where name = ?", "jinzhu").Rows() // (*sql.Rows, error)
866 defer rows.Close()
867 for rows.Next() {
868 ...
869 rows.Scan(&name, &age, &email)
870 ...
871 }
872 ```
873
874 ## Scan
875
876 Scan results into another struct.
877
878 ```go
879 type Result struct {
880 Name string
881 Age int
882 }
883
884 var result Result
885 db.Table("users").Select("name, age").Where("name = ?", 3).Scan(&result)
886
887 // Raw SQL
888 db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result)
889 ```
890
891 ## Group & Having
892
893 ```go
894 rows, err := db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Rows()
895 for rows.Next() {
896 ...
897 }
898
899 rows, err := db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Having("sum(amount) > ?", 100).Rows()
900 for rows.Next() {
901 ...
902 }
903
904 type Result struct {
905 Date time.Time
906 Total int64
907 }
908 db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Having("sum(amount) > ?", 100).Scan(&results)
909 ```
910
911 ## Joins
912
913 ```go
914 rows, err := db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Rows()
915 for rows.Next() {
916 ...
917 }
918
919 db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Scan(&results)
920
921 // find a user by email address
922 db.Joins("inner join emails on emails.user_id = users.id").Where("emails.email = ?", "x@example.org").Find(&user)
923
924 // find all email addresses for a user
925 db.Joins("left join users on users.id = emails.user_id").Where("users.name = ?", "jinzhu").Find(&emails)
926 ```
927
928 ## Transactions
929
930 To perform a set of operations within a transaction, the general flow is as below.
931 The database handle returned from ``` db.Begin() ``` should be used for all operations within the transaction.
932 (Note that all individual save and delete operations are run in a transaction by default.)
933
934 ```go
935 // begin
936 tx := db.Begin()
937
938 // do some database operations (use 'tx' from this point, not 'db')
939 tx.Create(...)
940 ...
941
942 // rollback in case of error
943 tx.Rollback()
944
945 // Or commit if all is ok
946 tx.Commit()
947 ```
948
949 ### A Specific Example
950 ```
951 func CreateAnimals(db *gorm.DB) err {
952 tx := db.Begin()
953 // Note the use of tx as the database handle once you are within a transaction
954
955 if err := tx.Create(&Animal{Name: "Giraffe"}).Error; err != nil {
956 tx.Rollback()
957 return err
958 }
959
960 if err := tx.Create(&Animal{Name: "Lion"}).Error; err != nil {
961 tx.Rollback()
962 return err
963 }
964
965 tx.Commit()
966 return nil
967 }
968 ```
969
970 ## Scopes
971
972 ```go
973 func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
974 return db.Where("amount > ?", 1000)
975 }
976
977 func PaidWithCreditCard(db *gorm.DB) *gorm.DB {
978 return db.Where("pay_mode_sign = ?", "C")
979 }
980
981 func PaidWithCod(db *gorm.DB) *gorm.DB {
982 return db.Where("pay_mode_sign = ?", "C")
983 }
984
985 func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
986 return func (db *gorm.DB) *gorm.DB {
987 return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
988 }
989 }
990
991 db.Scopes(AmountGreaterThan1000, PaidWithCreditCard).Find(&orders)
992 // Find all credit card orders and amount greater than 1000
993
994 db.Scopes(AmountGreaterThan1000, PaidWithCod).Find(&orders)
995 // Find all COD orders and amount greater than 1000
996
997 db.Scopes(OrderStatus([]string{"paid", "shipped"})).Find(&orders)
998 // Find all paid, shipped orders
999 ```
1000
1001 ## Callbacks
1002
1003 Callbacks are methods defined on the pointer of struct.
1004 If any callback returns an error, gorm will stop future operations and rollback all changes.
1005
1006 Here is the list of all available callbacks:
1007 (listed in the same order in which they will get called during the respective operations)
1008
1009 ### Creating An Object
1010
1011 ```go
1012 BeforeSave
1013 BeforeCreate
1014 // save before associations
1015 // save self
1016 // save after associations
1017 AfterCreate
1018 AfterSave
1019 ```
1020 ### Updating An Object
1021
1022 ```go
1023 BeforeSave
1024 BeforeUpdate
1025 // save before associations
1026 // save self
1027 // save after associations
1028 AfterUpdate
1029 AfterSave
1030 ```
1031
1032 ### Destroying An Object
1033
1034 ```go
1035 BeforeDelete
1036 // delete self
1037 AfterDelete
1038 ```
1039
1040 ### After Find
1041
1042 ```go
1043 // load data from database
1044 AfterFind
1045 ```
1046
1047 ### Example
1048
1049 ```go
1050 func (u *User) BeforeUpdate() (err error) {
1051 if u.readonly() {
1052 err = errors.New("read only user")
1053 }
1054 return
1055 }
1056
1057 // Rollback the insertion if user's id greater than 1000
1058 func (u *User) AfterCreate() (err error) {
1059 if (u.Id > 1000) {
1060 err = errors.New("user id is already greater than 1000")
1061 }
1062 return
1063 }
1064 ```
1065
1066 As you know, save/delete operations in gorm are running in a transaction,
1067 This is means if changes made in the transaction is not visiable unless it is commited,
1068 So if you want to use those changes in your callbacks, you need to run SQL in same transaction.
1069 Fortunately, gorm support pass transaction to callbacks as you needed, you could do it like this:
1070
1071 ```go
1072 func (u *User) AfterCreate(tx *gorm.DB) (err error) {
1073 tx.Model(u).Update("role", "admin")
1074 return
1075 }
1076 ```
1077
1078 ## Specifying The Table Name
1079
1080 ```go
1081 // Create `deleted_users` table with struct User's definition
1082 db.Table("deleted_users").CreateTable(&User{})
1083
1084 var deleted_users []User
1085 db.Table("deleted_users").Find(&deleted_users)
1086 //// SELECT * FROM deleted_users;
1087
1088 db.Table("deleted_users").Where("name = ?", "jinzhu").Delete()
1089 //// DELETE FROM deleted_users WHERE name = 'jinzhu';
1090 ```
1091
1092 ### Specifying The Table Name For A Struct Permanently with TableName
1093
1094 ```go
1095 type Cart struct {
1096 }
1097
1098 func (c Cart) TableName() string {
1099 return "shopping_cart"
1100 }
1101
1102 func (u User) TableName() string {
1103 if u.Role == "admin" {
1104 return "admin_users"
1105 } else {
1106 return "users"
1107 }
1108 }
1109 ```
1110
1111 ## Error Handling
1112
1113 ```go
1114 query := db.Where("name = ?", "jinzhu").First(&user)
1115 query := db.First(&user).Limit(10).Find(&users)
1116 // query.Error will return the last happened error
1117
1118 // So you could do error handing in your application like this:
1119 if err := db.Where("name = ?", "jinzhu").First(&user).Error; err != nil {
1120 // error handling...
1121 }
1122
1123 // RecordNotFound
1124 // If no record found when you query data, gorm will return RecordNotFound error, you could check it like this:
1125 db.Where("name = ?", "hello world").First(&User{}).Error == gorm.RecordNotFound
1126 // Or use the shortcut method
1127 db.Where("name = ?", "hello world").First(&user).RecordNotFound()
1128
1129 if db.Model(&user).Related(&credit_card).RecordNotFound() {
1130 // no credit card found error handling
1131 }
1132 ```
1133
1134 ## Logger
1135
1136 Gorm has built-in logger support
1137
1138 ```go
1139 // Enable Logger
1140 db.LogMode(true)
1141
1142 // Diable Logger
1143 db.LogMode(false)
1144
1145 // Debug a single operation
1146 db.Debug().Where("name = ?", "jinzhu").First(&User{})
1147 ```
1148
1149 ![logger](https://raw.github.com/jinzhu/gorm/master/images/logger.png)
1150
1151 ### Customize Logger
1152
1153 ```go
1154 // Refer gorm's default logger for how to: https://github.com/jinzhu/gorm/blob/master/logger.go#files
1155 db.SetLogger(gorm.Logger{revel.TRACE})
1156 db.SetLogger(log.New(os.Stdout, "\r\n", 0))
1157 ```
1158
1159 ## Existing Schema
1160
1161 If you have an existing database schema, and the primary key field is different from `id`, you can add a tag to the field structure to specify that this field is a primary key.
1162
1163 ```go
1164 type Animal struct {
1165 AnimalId int64 `gorm:"primary_key"`
1166 Birthday time.Time `sql:"DEFAULT:current_timestamp"`
1167 Name string `sql:"default:'galeone'"`
1168 Age int64
1169 }
1170 ```
1171
1172 If your column names differ from the struct fields, you can specify them like this:
1173
1174 ```go
1175 type Animal struct {
1176 AnimalId int64 `gorm:"column:beast_id;primary_key"`
1177 Birthday time.Time `gorm:"column:day_of_the_beast"`
1178 Age int64 `gorm:"column:age_of_the_beast"`
1179 }
1180 ```
1181
1182 ## Composite Primary Key
1183
1184 ```go
1185 type Product struct {
1186 ID string `gorm:"primary_key"`
1187 LanguageCode string `gorm:"primary_key"`
1188 }
1189 ```
1190
1191 ## Database Indexes & Foreign Key
1192
1193 ```go
1194 // Add foreign key
1195 // 1st param : foreignkey field
1196 // 2nd param : destination table(id)
1197 // 3rd param : ONDELETE
1198 // 4th param : ONUPDATE
1199 db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
1200
1201 // Add index
1202 db.Model(&User{}).AddIndex("idx_user_name", "name")
1203
1204 // Multiple column index
1205 db.Model(&User{}).AddIndex("idx_user_name_age", "name", "age")
1206
1207 // Add unique index
1208 db.Model(&User{}).AddUniqueIndex("idx_user_name", "name")
1209
1210 // Multiple column unique index
1211 db.Model(&User{}).AddUniqueIndex("idx_user_name_age", "name", "age")
1212
1213 // Remove index
1214 db.Model(&User{}).RemoveIndex("idx_user_name")
1215 ```
1216
1217 ## Default values
1218
1219 ```go
1220 type Animal struct {
1221 ID int64
1222 Name string `sql:"default:'galeone'"`
1223 Age int64
1224 }
1225 ```
1226
1227 If you have defined a default value in the `sql` tag, the generated create SQl will ignore these fields if it is blank.
1228
1229 Eg.
1230
1231 ```go
1232 db.Create(&Animal{Age: 99, Name: ""})
1233 ```
1234
1235 The generated SQL will be:
1236
1237 ```sql
1238 INSERT INTO animals("age") values('99');
1239 ```
1240
1241 The same thing occurs in update statements.
1242
1243 ## More examples with query chain
1244
1245 ```go
1246 db.First(&first_article).Count(&total_count).Limit(10).Find(&first_page_articles).Offset(10).Find(&second_page_articles)
1247 //// SELECT * FROM articles LIMIT 1; (first_article)
1248 //// SELECT count(*) FROM articles; (total_count)
1249 //// SELECT * FROM articles LIMIT 10; (first_page_articles)
1250 //// SELECT * FROM articles LIMIT 10 OFFSET 10; (second_page_articles)
1251
1252
1253 db.Where("created_at > ?", "2013-10-10").Find(&cancelled_orders, "state = ?", "cancelled").Find(&shipped_orders, "state = ?", "shipped")
1254 //// SELECT * FROM orders WHERE created_at > '2013/10/10' AND state = 'cancelled'; (cancelled_orders)
1255 //// SELECT * FROM orders WHERE created_at > '2013/10/10' AND state = 'shipped'; (shipped_orders)
1256
1257
1258 // Use variables to keep query chain
1259 todays_orders := db.Where("created_at > ?", "2013-10-29")
1260 cancelled_orders := todays_orders.Where("state = ?", "cancelled")
1261 shipped_orders := todays_orders.Where("state = ?", "shipped")
1262
1263
1264 // Search with shared conditions for different tables
1265 db.Where("product_name = ?", "fancy_product").Find(&orders).Find(&shopping_carts)
1266 //// SELECT * FROM orders WHERE product_name = 'fancy_product'; (orders)
1267 //// SELECT * FROM carts WHERE product_name = 'fancy_product'; (shopping_carts)
1268
1269
1270 // Search with shared conditions from different tables with specified table
1271 db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").Find(&users2)
1272 //// SELECT * FROM users WHERE mail_type = 'TEXT'; (users1)
1273 //// SELECT * FROM deleted_users WHERE mail_type = 'TEXT'; (users2)
1274
1275
1276 // FirstOrCreate example
1277 db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111"}).FirstOrCreate(&user)
1278 //// SELECT * FROM users WHERE email = 'x@example.org';
1279 //// INSERT INTO "users" (email,registered_ip) VALUES ("x@example.org", "111.111.111.111") // if record not found
1280 ```
1281
1282 ## TODO
1283 * Github Pages
1284
1285 # Author
1286
1287 **jinzhu**
1288
1289 * <http://github.com/jinzhu>
1290 * <wosmvp@gmail.com>
1291 * <http://twitter.com/zhangjinzhu>
33 [You can help to deliver a better GORM, check out things you can do](http://gorm.io/contribute.html)
129234
129335 ## License
129436
1295 Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License).
37 © Jinzhu, 2013~time.Now
38
39 Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License)
33 "errors"
44 "fmt"
55 "reflect"
6 "strings"
76 )
87
8 // Association Mode contains some helper methods to handle relationship things easily.
99 type Association struct {
10 Scope *Scope
11 Column string
1210 Error error
13 Field *Field
11 scope *Scope
12 column string
13 field *Field
14 }
15
16 // Find find out all related associations
17 func (association *Association) Find(value interface{}) *Association {
18 association.scope.related(value, association.column)
19 return association.setErr(association.scope.db.Error)
20 }
21
22 // Append append new associations for many2many, has_many, replace current association for has_one, belongs_to
23 func (association *Association) Append(values ...interface{}) *Association {
24 if association.Error != nil {
25 return association
26 }
27
28 if relationship := association.field.Relationship; relationship.Kind == "has_one" {
29 return association.Replace(values...)
30 }
31 return association.saveAssociations(values...)
32 }
33
34 // Replace replace current associations with new one
35 func (association *Association) Replace(values ...interface{}) *Association {
36 if association.Error != nil {
37 return association
38 }
39
40 var (
41 relationship = association.field.Relationship
42 scope = association.scope
43 field = association.field.Field
44 newDB = scope.NewDB()
45 )
46
47 // Append new values
48 association.field.Set(reflect.Zero(association.field.Field.Type()))
49 association.saveAssociations(values...)
50
51 // Belongs To
52 if relationship.Kind == "belongs_to" {
53 // Set foreign key to be null when clearing value (length equals 0)
54 if len(values) == 0 {
55 // Set foreign key to be nil
56 var foreignKeyMap = map[string]interface{}{}
57 for _, foreignKey := range relationship.ForeignDBNames {
58 foreignKeyMap[foreignKey] = nil
59 }
60 association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
61 }
62 } else {
63 // Polymorphic Relations
64 if relationship.PolymorphicDBName != "" {
65 newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue)
66 }
67
68 // Delete Relations except new created
69 if len(values) > 0 {
70 var associationForeignFieldNames, associationForeignDBNames []string
71 if relationship.Kind == "many_to_many" {
72 // if many to many relations, get association fields name from association foreign keys
73 associationScope := scope.New(reflect.New(field.Type()).Interface())
74 for idx, dbName := range relationship.AssociationForeignFieldNames {
75 if field, ok := associationScope.FieldByName(dbName); ok {
76 associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
77 associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx])
78 }
79 }
80 } else {
81 // If has one/many relations, use primary keys
82 for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
83 associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
84 associationForeignDBNames = append(associationForeignDBNames, field.DBName)
85 }
86 }
87
88 newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface())
89
90 if len(newPrimaryKeys) > 0 {
91 sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys))
92 newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...)
93 }
94 }
95
96 if relationship.Kind == "many_to_many" {
97 // if many to many relations, delete related relations from join table
98 var sourceForeignFieldNames []string
99
100 for _, dbName := range relationship.ForeignFieldNames {
101 if field, ok := scope.FieldByName(dbName); ok {
102 sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name)
103 }
104 }
105
106 if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
107 newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
108
109 association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
110 }
111 } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
112 // has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
113 var foreignKeyMap = map[string]interface{}{}
114 for idx, foreignKey := range relationship.ForeignDBNames {
115 foreignKeyMap[foreignKey] = nil
116 if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok {
117 newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
118 }
119 }
120
121 fieldValue := reflect.New(association.field.Field.Type()).Interface()
122 association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
123 }
124 }
125 return association
126 }
127
128 // Delete remove relationship between source & passed arguments, but won't delete those arguments
129 func (association *Association) Delete(values ...interface{}) *Association {
130 if association.Error != nil {
131 return association
132 }
133
134 var (
135 relationship = association.field.Relationship
136 scope = association.scope
137 field = association.field.Field
138 newDB = scope.NewDB()
139 )
140
141 if len(values) == 0 {
142 return association
143 }
144
145 var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string
146 for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
147 deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
148 deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
149 }
150
151 deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...)
152
153 if relationship.Kind == "many_to_many" {
154 // source value's foreign keys
155 for idx, foreignKey := range relationship.ForeignDBNames {
156 if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
157 newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
158 }
159 }
160
161 // get association's foreign fields name
162 var associationScope = scope.New(reflect.New(field.Type()).Interface())
163 var associationForeignFieldNames []string
164 for _, associationDBName := range relationship.AssociationForeignFieldNames {
165 if field, ok := associationScope.FieldByName(associationDBName); ok {
166 associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
167 }
168 }
169
170 // association value's foreign keys
171 deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...)
172 sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
173 newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
174
175 association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
176 } else {
177 var foreignKeyMap = map[string]interface{}{}
178 for _, foreignKey := range relationship.ForeignDBNames {
179 foreignKeyMap[foreignKey] = nil
180 }
181
182 if relationship.Kind == "belongs_to" {
183 // find with deleting relation's foreign keys
184 primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...)
185 newDB = newDB.Where(
186 fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
187 toQueryValues(primaryKeys)...,
188 )
189
190 // set foreign key to be null if there are some records affected
191 modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
192 if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
193 if results.RowsAffected > 0 {
194 scope.updatedAttrsWithValues(foreignKeyMap)
195 }
196 } else {
197 association.setErr(results.Error)
198 }
199 } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
200 // find all relations
201 primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
202 newDB = newDB.Where(
203 fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
204 toQueryValues(primaryKeys)...,
205 )
206
207 // only include those deleting relations
208 newDB = newDB.Where(
209 fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)),
210 toQueryValues(deletingPrimaryKeys)...,
211 )
212
213 // set matched relation's foreign key to be null
214 fieldValue := reflect.New(association.field.Field.Type()).Interface()
215 association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
216 }
217 }
218
219 // Remove deleted records from source's field
220 if association.Error == nil {
221 if field.Kind() == reflect.Slice {
222 leftValues := reflect.Zero(field.Type())
223
224 for i := 0; i < field.Len(); i++ {
225 reflectValue := field.Index(i)
226 primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
227 var isDeleted = false
228 for _, pk := range deletingPrimaryKeys {
229 if equalAsString(primaryKey, pk) {
230 isDeleted = true
231 break
232 }
233 }
234 if !isDeleted {
235 leftValues = reflect.Append(leftValues, reflectValue)
236 }
237 }
238
239 association.field.Set(leftValues)
240 } else if field.Kind() == reflect.Struct {
241 primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0]
242 for _, pk := range deletingPrimaryKeys {
243 if equalAsString(primaryKey, pk) {
244 association.field.Set(reflect.Zero(field.Type()))
245 break
246 }
247 }
248 }
249 }
250
251 return association
252 }
253
254 // Clear remove relationship between source & current associations, won't delete those associations
255 func (association *Association) Clear() *Association {
256 return association.Replace()
257 }
258
259 // Count return the count of current associations
260 func (association *Association) Count() int {
261 var (
262 count = 0
263 relationship = association.field.Relationship
264 scope = association.scope
265 fieldValue = association.field.Field.Interface()
266 query = scope.DB()
267 )
268
269 if relationship.Kind == "many_to_many" {
270 query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
271 } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
272 primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
273 query = query.Where(
274 fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
275 toQueryValues(primaryKeys)...,
276 )
277 } else if relationship.Kind == "belongs_to" {
278 primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
279 query = query.Where(
280 fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
281 toQueryValues(primaryKeys)...,
282 )
283 }
284
285 if relationship.PolymorphicType != "" {
286 query = query.Where(
287 fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)),
288 relationship.PolymorphicValue,
289 )
290 }
291
292 if err := query.Model(fieldValue).Count(&count).Error; err != nil {
293 association.Error = err
294 }
295 return count
296 }
297
298 // saveAssociations save passed values as associations
299 func (association *Association) saveAssociations(values ...interface{}) *Association {
300 var (
301 scope = association.scope
302 field = association.field
303 relationship = field.Relationship
304 )
305
306 saveAssociation := func(reflectValue reflect.Value) {
307 // value has to been pointer
308 if reflectValue.Kind() != reflect.Ptr {
309 reflectPtr := reflect.New(reflectValue.Type())
310 reflectPtr.Elem().Set(reflectValue)
311 reflectValue = reflectPtr
312 }
313
314 // value has to been saved for many2many
315 if relationship.Kind == "many_to_many" {
316 if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
317 association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
318 }
319 }
320
321 // Assign Fields
322 var fieldType = field.Field.Type()
323 var setFieldBackToValue, setSliceFieldBackToValue bool
324 if reflectValue.Type().AssignableTo(fieldType) {
325 field.Set(reflectValue)
326 } else if reflectValue.Type().Elem().AssignableTo(fieldType) {
327 // if field's type is struct, then need to set value back to argument after save
328 setFieldBackToValue = true
329 field.Set(reflectValue.Elem())
330 } else if fieldType.Kind() == reflect.Slice {
331 if reflectValue.Type().AssignableTo(fieldType.Elem()) {
332 field.Set(reflect.Append(field.Field, reflectValue))
333 } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
334 // if field's type is slice of struct, then need to set value back to argument after save
335 setSliceFieldBackToValue = true
336 field.Set(reflect.Append(field.Field, reflectValue.Elem()))
337 }
338 }
339
340 if relationship.Kind == "many_to_many" {
341 association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
342 } else {
343 association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
344
345 if setFieldBackToValue {
346 reflectValue.Elem().Set(field.Field)
347 } else if setSliceFieldBackToValue {
348 reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
349 }
350 }
351 }
352
353 for _, value := range values {
354 reflectValue := reflect.ValueOf(value)
355 indirectReflectValue := reflect.Indirect(reflectValue)
356 if indirectReflectValue.Kind() == reflect.Struct {
357 saveAssociation(reflectValue)
358 } else if indirectReflectValue.Kind() == reflect.Slice {
359 for i := 0; i < indirectReflectValue.Len(); i++ {
360 saveAssociation(indirectReflectValue.Index(i))
361 }
362 } else {
363 association.setErr(errors.New("invalid value type"))
364 }
365 }
366 return association
14367 }
15368
16369 func (association *Association) setErr(err error) *Association {
19372 }
20373 return association
21374 }
22
23 func (association *Association) Find(value interface{}) *Association {
24 association.Scope.related(value, association.Column)
25 return association.setErr(association.Scope.db.Error)
26 }
27
28 func (association *Association) Append(values ...interface{}) *Association {
29 scope := association.Scope
30 field := association.Field
31
32 createJoinTable := func(reflectValue reflect.Value) {
33 var value = reflectValue.Interface()
34 if reflectValue.Kind() != reflect.Ptr {
35 reflectPtr := reflect.New(reflectValue.Type())
36 reflectPtr.Elem().Set(reflectValue)
37 value = reflectPtr.Interface()
38 }
39
40 if scope.New(value).PrimaryKeyZero() {
41 scope.NewDB().Save(value)
42 }
43
44 relationship := association.Field.Relationship
45 association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, value))
46
47 result := reflect.ValueOf(value)
48 fieldElemType := field.Field.Type().Elem()
49 if result.Type().AssignableTo(fieldElemType) {
50 field.Set(reflect.Append(field.Field, result))
51 } else if result.Type().Elem().AssignableTo(fieldElemType) {
52 field.Set(reflect.Append(field.Field, result.Elem()))
53 }
54 }
55
56 for _, value := range values {
57 reflectValue := reflect.Indirect(reflect.ValueOf(value))
58
59 if reflectValue.Kind() == reflect.Struct {
60 createJoinTable(reflectValue)
61 } else if reflectValue.Kind() == reflect.Slice {
62 for i := 0; i < reflectValue.Len(); i++ {
63 createJoinTable(reflectValue.Index(i))
64 }
65 } else {
66 association.setErr(errors.New("invalid association type"))
67 }
68 }
69 return association
70 }
71
72 func (association *Association) Delete(values ...interface{}) *Association {
73 scope := association.Scope
74 relationship := association.Field.Relationship
75
76 // many to many
77 if relationship.Kind == "many_to_many" {
78 query := scope.NewDB()
79 for idx, foreignKey := range relationship.ForeignDBNames {
80 if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
81 query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
82 }
83 }
84
85 primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...)
86 sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys))
87 query = query.Where(sql, toQueryValues(primaryKeys)...)
88
89 if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil {
90 leftValues := reflect.Zero(association.Field.Field.Type())
91 for i := 0; i < association.Field.Field.Len(); i++ {
92 reflectValue := association.Field.Field.Index(i)
93 primaryKey := association.getPrimaryKeys(relationship.ForeignFieldNames, reflectValue.Interface())[0]
94 var included = false
95 for _, pk := range primaryKeys {
96 if equalAsString(primaryKey, pk) {
97 included = true
98 }
99 }
100 if !included {
101 leftValues = reflect.Append(leftValues, reflectValue)
102 }
103 }
104 association.Field.Set(leftValues)
105 }
106 } else {
107 association.setErr(errors.New("delete only support many to many"))
108 }
109 return association
110 }
111
112 func (association *Association) Replace(values ...interface{}) *Association {
113 relationship := association.Field.Relationship
114 scope := association.Scope
115 if relationship.Kind == "many_to_many" {
116 field := association.Field.Field
117
118 oldPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface())
119 association.Field.Set(reflect.Zero(association.Field.Field.Type()))
120 association.Append(values...)
121 newPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface())
122
123 var addedPrimaryKeys = [][]interface{}{}
124 for _, newKey := range newPrimaryKeys {
125 hasEqual := false
126 for _, oldKey := range oldPrimaryKeys {
127 if equalAsString(newKey, oldKey) {
128 hasEqual = true
129 break
130 }
131 }
132 if !hasEqual {
133 addedPrimaryKeys = append(addedPrimaryKeys, newKey)
134 }
135 }
136
137 for _, primaryKey := range association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) {
138 addedPrimaryKeys = append(addedPrimaryKeys, primaryKey)
139 }
140
141 query := scope.NewDB()
142 for idx, foreignKey := range relationship.ForeignDBNames {
143 if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
144 query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
145 }
146 }
147
148 if len(addedPrimaryKeys) > 0 {
149 sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys))
150 query = query.Where(sql, toQueryValues(addedPrimaryKeys)...)
151 }
152 association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship))
153 } else {
154 association.setErr(errors.New("replace only support many to many"))
155 }
156 return association
157 }
158
159 func (association *Association) Clear() *Association {
160 relationship := association.Field.Relationship
161 scope := association.Scope
162 if relationship.Kind == "many_to_many" {
163 query := scope.NewDB()
164 for idx, foreignKey := range relationship.ForeignDBNames {
165 if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
166 query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
167 }
168 }
169
170 if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil {
171 association.Field.Set(reflect.Zero(association.Field.Field.Type()))
172 } else {
173 association.setErr(err)
174 }
175 } else {
176 association.setErr(errors.New("clear only support many to many"))
177 }
178 return association
179 }
180
181 func (association *Association) Count() int {
182 count := -1
183 relationship := association.Field.Relationship
184 scope := association.Scope
185 newScope := scope.New(association.Field.Field.Interface())
186
187 if relationship.Kind == "many_to_many" {
188 relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count)
189 } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
190 query := scope.DB()
191 for idx, foreignKey := range relationship.ForeignDBNames {
192 if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
193 query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)),
194 field.Field.Interface())
195 }
196 }
197
198 if relationship.PolymorphicType != "" {
199 query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName())
200 }
201 query.Table(newScope.TableName()).Count(&count)
202 } else if relationship.Kind == "belongs_to" {
203 query := scope.DB()
204 for idx, foreignKey := range relationship.ForeignDBNames {
205 if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
206 query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)),
207 field.Field.Interface())
208 }
209 }
210 query.Table(newScope.TableName()).Count(&count)
211 }
212
213 return count
214 }
215
216 func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) [][]interface{} {
217 results := [][]interface{}{}
218 scope := association.Scope
219
220 for _, value := range values {
221 reflectValue := reflect.Indirect(reflect.ValueOf(value))
222 if reflectValue.Kind() == reflect.Slice {
223 for i := 0; i < reflectValue.Len(); i++ {
224 primaryKeys := []interface{}{}
225 newScope := scope.New(reflectValue.Index(i).Interface())
226 for _, column := range columns {
227 if field, ok := newScope.FieldByName(column); ok {
228 primaryKeys = append(primaryKeys, field.Field.Interface())
229 } else {
230 primaryKeys = append(primaryKeys, "")
231 }
232 }
233 results = append(results, primaryKeys)
234 }
235 } else if reflectValue.Kind() == reflect.Struct {
236 newScope := scope.New(value)
237 var primaryKeys []interface{}
238 for _, column := range columns {
239 if field, ok := newScope.FieldByName(column); ok {
240 primaryKeys = append(primaryKeys, field.Field.Interface())
241 } else {
242 primaryKeys = append(primaryKeys, "")
243 }
244 }
245
246 results = append(results, primaryKeys)
247 }
248 }
249 return results
250 }
251
252 func toQueryMarks(primaryValues [][]interface{}) string {
253 var results []string
254
255 for _, primaryValue := range primaryValues {
256 var marks []string
257 for _, _ = range primaryValue {
258 marks = append(marks, "?")
259 }
260
261 if len(marks) > 1 {
262 results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
263 } else {
264 results = append(results, strings.Join(marks, ""))
265 }
266 }
267 return strings.Join(results, ",")
268 }
269
270 func toQueryCondition(scope *Scope, columns []string) string {
271 var newColumns []string
272 for _, column := range columns {
273 newColumns = append(newColumns, scope.Quote(column))
274 }
275
276 if len(columns) > 1 {
277 return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
278 } else {
279 return strings.Join(columns, ",")
280 }
281 }
282
283 func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
284 for _, primaryValue := range primaryValues {
285 for _, value := range primaryValue {
286 values = append(values, value)
287 }
288 }
289 return values
290 }
11
22 import (
33 "fmt"
4 "os"
5 "reflect"
6 "sort"
47 "testing"
8
9 "github.com/jinzhu/gorm"
510 )
611
7 func TestHasOneAndHasManyAssociation(t *testing.T) {
8 DB.DropTable(Category{}, Post{}, Comment{})
9 DB.CreateTable(Category{}, Post{}, Comment{})
10
12 func TestBelongsTo(t *testing.T) {
1113 post := Post{
12 Title: "post 1",
13 Body: "body 1",
14 Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}},
14 Title: "post belongs to",
15 Body: "body belongs to",
1516 Category: Category{Name: "Category 1"},
1617 MainCategory: Category{Name: "Main Category 1"},
1718 }
1819
1920 if err := DB.Save(&post).Error; err != nil {
20 t.Errorf("Got errors when save post", err.Error())
21 }
22
23 if err := DB.First(&Category{}, "name = ?", "Category 1").Error; err != nil {
24 t.Errorf("Category should be saved", err.Error())
25 }
26
27 var p Post
28 DB.First(&p, post.Id)
29
30 if post.CategoryId.Int64 == 0 || p.CategoryId.Int64 == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 {
31 t.Errorf("Category Id should exist")
32 }
33
21 t.Error("Got errors when save post", err)
22 }
23
24 if post.Category.ID == 0 || post.MainCategory.ID == 0 {
25 t.Errorf("Category's primary key should be updated")
26 }
27
28 if post.CategoryId.Int64 == 0 || post.MainCategoryId == 0 {
29 t.Errorf("post's foreign key should be updated")
30 }
31
32 // Query
33 var category1 Category
34 DB.Model(&post).Association("Category").Find(&category1)
35 if category1.Name != "Category 1" {
36 t.Errorf("Query belongs to relations with Association")
37 }
38
39 var mainCategory1 Category
40 DB.Model(&post).Association("MainCategory").Find(&mainCategory1)
41 if mainCategory1.Name != "Main Category 1" {
42 t.Errorf("Query belongs to relations with Association")
43 }
44
45 var category11 Category
46 DB.Model(&post).Related(&category11)
47 if category11.Name != "Category 1" {
48 t.Errorf("Query belongs to relations with Related")
49 }
50
51 if DB.Model(&post).Association("Category").Count() != 1 {
52 t.Errorf("Post's category count should be 1")
53 }
54
55 if DB.Model(&post).Association("MainCategory").Count() != 1 {
56 t.Errorf("Post's main category count should be 1")
57 }
58
59 // Append
60 var category2 = Category{
61 Name: "Category 2",
62 }
63 DB.Model(&post).Association("Category").Append(&category2)
64
65 if category2.ID == 0 {
66 t.Errorf("Category should has ID when created with Append")
67 }
68
69 var category21 Category
70 DB.Model(&post).Related(&category21)
71
72 if category21.Name != "Category 2" {
73 t.Errorf("Category should be updated with Append")
74 }
75
76 if DB.Model(&post).Association("Category").Count() != 1 {
77 t.Errorf("Post's category count should be 1")
78 }
79
80 // Replace
81 var category3 = Category{
82 Name: "Category 3",
83 }
84 DB.Model(&post).Association("Category").Replace(&category3)
85
86 if category3.ID == 0 {
87 t.Errorf("Category should has ID when created with Replace")
88 }
89
90 var category31 Category
91 DB.Model(&post).Related(&category31)
92 if category31.Name != "Category 3" {
93 t.Errorf("Category should be updated with Replace")
94 }
95
96 if DB.Model(&post).Association("Category").Count() != 1 {
97 t.Errorf("Post's category count should be 1")
98 }
99
100 // Delete
101 DB.Model(&post).Association("Category").Delete(&category2)
102 if DB.Model(&post).Related(&Category{}).RecordNotFound() {
103 t.Errorf("Should not delete any category when Delete a unrelated Category")
104 }
105
106 if post.Category.Name == "" {
107 t.Errorf("Post's category should not be reseted when Delete a unrelated Category")
108 }
109
110 DB.Model(&post).Association("Category").Delete(&category3)
111
112 if post.Category.Name != "" {
113 t.Errorf("Post's category should be reseted after Delete")
114 }
115
116 var category41 Category
117 DB.Model(&post).Related(&category41)
118 if category41.Name != "" {
119 t.Errorf("Category should be deleted with Delete")
120 }
121
122 if count := DB.Model(&post).Association("Category").Count(); count != 0 {
123 t.Errorf("Post's category count should be 0 after Delete, but got %v", count)
124 }
125
126 // Clear
127 DB.Model(&post).Association("Category").Append(&Category{
128 Name: "Category 2",
129 })
130
131 if DB.Model(&post).Related(&Category{}).RecordNotFound() {
132 t.Errorf("Should find category after append")
133 }
134
135 if post.Category.Name == "" {
136 t.Errorf("Post's category should has value after Append")
137 }
138
139 DB.Model(&post).Association("Category").Clear()
140
141 if post.Category.Name != "" {
142 t.Errorf("Post's category should be cleared after Clear")
143 }
144
145 if !DB.Model(&post).Related(&Category{}).RecordNotFound() {
146 t.Errorf("Should not find any category after Clear")
147 }
148
149 if count := DB.Model(&post).Association("Category").Count(); count != 0 {
150 t.Errorf("Post's category count should be 0 after Clear, but got %v", count)
151 }
152
153 // Check Association mode with soft delete
154 category6 := Category{
155 Name: "Category 6",
156 }
157 DB.Model(&post).Association("Category").Append(&category6)
158
159 if count := DB.Model(&post).Association("Category").Count(); count != 1 {
160 t.Errorf("Post's category count should be 1 after Append, but got %v", count)
161 }
162
163 DB.Delete(&category6)
164
165 if count := DB.Model(&post).Association("Category").Count(); count != 0 {
166 t.Errorf("Post's category count should be 0 after the category has been deleted, but got %v", count)
167 }
168
169 if err := DB.Model(&post).Association("Category").Find(&Category{}).Error; err == nil {
170 t.Errorf("Post's category is not findable after Delete")
171 }
172
173 if count := DB.Unscoped().Model(&post).Association("Category").Count(); count != 1 {
174 t.Errorf("Post's category count should be 1 when query with Unscoped, but got %v", count)
175 }
176
177 if err := DB.Unscoped().Model(&post).Association("Category").Find(&Category{}).Error; err != nil {
178 t.Errorf("Post's category should be findable when query with Unscoped, got %v", err)
179 }
180 }
181
182 func TestBelongsToOverrideForeignKey1(t *testing.T) {
183 type Profile struct {
184 gorm.Model
185 Name string
186 }
187
188 type User struct {
189 gorm.Model
190 Profile Profile `gorm:"ForeignKey:ProfileRefer"`
191 ProfileRefer int
192 }
193
194 if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
195 if relation.Relationship.Kind != "belongs_to" ||
196 !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) ||
197 !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
198 t.Errorf("Override belongs to foreign key with tag")
199 }
200 }
201 }
202
203 func TestBelongsToOverrideForeignKey2(t *testing.T) {
204 type Profile struct {
205 gorm.Model
206 Refer string
207 Name string
208 }
209
210 type User struct {
211 gorm.Model
212 Profile Profile `gorm:"ForeignKey:ProfileID;AssociationForeignKey:Refer"`
213 ProfileID int
214 }
215
216 if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
217 if relation.Relationship.Kind != "belongs_to" ||
218 !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) ||
219 !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
220 t.Errorf("Override belongs to foreign key with tag")
221 }
222 }
223 }
224
225 func TestHasOne(t *testing.T) {
226 user := User{
227 Name: "has one",
228 CreditCard: CreditCard{Number: "411111111111"},
229 }
230
231 if err := DB.Save(&user).Error; err != nil {
232 t.Error("Got errors when save user", err.Error())
233 }
234
235 if user.CreditCard.UserId.Int64 == 0 {
236 t.Errorf("CreditCard's foreign key should be updated")
237 }
238
239 // Query
240 var creditCard1 CreditCard
241 DB.Model(&user).Related(&creditCard1)
242
243 if creditCard1.Number != "411111111111" {
244 t.Errorf("Query has one relations with Related")
245 }
246
247 var creditCard11 CreditCard
248 DB.Model(&user).Association("CreditCard").Find(&creditCard11)
249
250 if creditCard11.Number != "411111111111" {
251 t.Errorf("Query has one relations with Related")
252 }
253
254 if DB.Model(&user).Association("CreditCard").Count() != 1 {
255 t.Errorf("User's credit card count should be 1")
256 }
257
258 // Append
259 var creditcard2 = CreditCard{
260 Number: "411111111112",
261 }
262 DB.Model(&user).Association("CreditCard").Append(&creditcard2)
263
264 if creditcard2.ID == 0 {
265 t.Errorf("Creditcard should has ID when created with Append")
266 }
267
268 var creditcard21 CreditCard
269 DB.Model(&user).Related(&creditcard21)
270 if creditcard21.Number != "411111111112" {
271 t.Errorf("CreditCard should be updated with Append")
272 }
273
274 if DB.Model(&user).Association("CreditCard").Count() != 1 {
275 t.Errorf("User's credit card count should be 1")
276 }
277
278 // Replace
279 var creditcard3 = CreditCard{
280 Number: "411111111113",
281 }
282 DB.Model(&user).Association("CreditCard").Replace(&creditcard3)
283
284 if creditcard3.ID == 0 {
285 t.Errorf("Creditcard should has ID when created with Replace")
286 }
287
288 var creditcard31 CreditCard
289 DB.Model(&user).Related(&creditcard31)
290 if creditcard31.Number != "411111111113" {
291 t.Errorf("CreditCard should be updated with Replace")
292 }
293
294 if DB.Model(&user).Association("CreditCard").Count() != 1 {
295 t.Errorf("User's credit card count should be 1")
296 }
297
298 // Delete
299 DB.Model(&user).Association("CreditCard").Delete(&creditcard2)
300 var creditcard4 CreditCard
301 DB.Model(&user).Related(&creditcard4)
302 if creditcard4.Number != "411111111113" {
303 t.Errorf("Should not delete credit card when Delete a unrelated CreditCard")
304 }
305
306 if DB.Model(&user).Association("CreditCard").Count() != 1 {
307 t.Errorf("User's credit card count should be 1")
308 }
309
310 DB.Model(&user).Association("CreditCard").Delete(&creditcard3)
311 if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() {
312 t.Errorf("Should delete credit card with Delete")
313 }
314
315 if DB.Model(&user).Association("CreditCard").Count() != 0 {
316 t.Errorf("User's credit card count should be 0 after Delete")
317 }
318
319 // Clear
320 var creditcard5 = CreditCard{
321 Number: "411111111115",
322 }
323 DB.Model(&user).Association("CreditCard").Append(&creditcard5)
324
325 if DB.Model(&user).Related(&CreditCard{}).RecordNotFound() {
326 t.Errorf("Should added credit card with Append")
327 }
328
329 if DB.Model(&user).Association("CreditCard").Count() != 1 {
330 t.Errorf("User's credit card count should be 1")
331 }
332
333 DB.Model(&user).Association("CreditCard").Clear()
334 if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() {
335 t.Errorf("Credit card should be deleted with Clear")
336 }
337
338 if DB.Model(&user).Association("CreditCard").Count() != 0 {
339 t.Errorf("User's credit card count should be 0 after Clear")
340 }
341
342 // Check Association mode with soft delete
343 var creditcard6 = CreditCard{
344 Number: "411111111116",
345 }
346 DB.Model(&user).Association("CreditCard").Append(&creditcard6)
347
348 if count := DB.Model(&user).Association("CreditCard").Count(); count != 1 {
349 t.Errorf("User's credit card count should be 1 after Append, but got %v", count)
350 }
351
352 DB.Delete(&creditcard6)
353
354 if count := DB.Model(&user).Association("CreditCard").Count(); count != 0 {
355 t.Errorf("User's credit card count should be 0 after credit card deleted, but got %v", count)
356 }
357
358 if err := DB.Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err == nil {
359 t.Errorf("User's creditcard is not findable after Delete")
360 }
361
362 if count := DB.Unscoped().Model(&user).Association("CreditCard").Count(); count != 1 {
363 t.Errorf("User's credit card count should be 1 when query with Unscoped, but got %v", count)
364 }
365
366 if err := DB.Unscoped().Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err != nil {
367 t.Errorf("User's creditcard should be findable when query with Unscoped, got %v", err)
368 }
369 }
370
371 func TestHasOneOverrideForeignKey1(t *testing.T) {
372 type Profile struct {
373 gorm.Model
374 Name string
375 UserRefer uint
376 }
377
378 type User struct {
379 gorm.Model
380 Profile Profile `gorm:"ForeignKey:UserRefer"`
381 }
382
383 if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
384 if relation.Relationship.Kind != "has_one" ||
385 !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
386 !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
387 t.Errorf("Override belongs to foreign key with tag")
388 }
389 }
390 }
391
392 func TestHasOneOverrideForeignKey2(t *testing.T) {
393 type Profile struct {
394 gorm.Model
395 Name string
396 UserID uint
397 }
398
399 type User struct {
400 gorm.Model
401 Refer string
402 Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
403 }
404
405 if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
406 if relation.Relationship.Kind != "has_one" ||
407 !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
408 !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
409 t.Errorf("Override belongs to foreign key with tag")
410 }
411 }
412 }
413
414 func TestHasMany(t *testing.T) {
415 post := Post{
416 Title: "post has many",
417 Body: "body has many",
418 Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}},
419 }
420
421 if err := DB.Save(&post).Error; err != nil {
422 t.Error("Got errors when save post", err)
423 }
424
425 for _, comment := range post.Comments {
426 if comment.PostId == 0 {
427 t.Errorf("comment's PostID should be updated")
428 }
429 }
430
431 var compareComments = func(comments []Comment, contents []string) bool {
432 var commentContents []string
433 for _, comment := range comments {
434 commentContents = append(commentContents, comment.Content)
435 }
436 sort.Strings(commentContents)
437 sort.Strings(contents)
438 return reflect.DeepEqual(commentContents, contents)
439 }
440
441 // Query
34442 if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil {
35443 t.Errorf("Comment 1 should be saved")
36444 }
37 if post.Comments[0].PostId == 0 {
38 t.Errorf("Comment Should have post id")
39 }
40
41 var comment Comment
42 if DB.First(&comment, "content = ?", "Comment 2").Error != nil {
43 t.Errorf("Comment 2 should be saved")
44 }
45
46 if comment.PostId == 0 {
47 t.Errorf("Comment 2 Should have post id")
48 }
49
50 comment3 := Comment{Content: "Comment 3", Post: Post{Title: "Title 3", Body: "Body 3"}}
51 DB.Save(&comment3)
445
446 var comments1 []Comment
447 DB.Model(&post).Association("Comments").Find(&comments1)
448 if !compareComments(comments1, []string{"Comment 1", "Comment 2"}) {
449 t.Errorf("Query has many relations with Association")
450 }
451
452 var comments11 []Comment
453 DB.Model(&post).Related(&comments11)
454 if !compareComments(comments11, []string{"Comment 1", "Comment 2"}) {
455 t.Errorf("Query has many relations with Related")
456 }
457
458 if DB.Model(&post).Association("Comments").Count() != 2 {
459 t.Errorf("Post's comments count should be 2")
460 }
461
462 // Append
463 DB.Model(&post).Association("Comments").Append(&Comment{Content: "Comment 3"})
464
465 var comments2 []Comment
466 DB.Model(&post).Related(&comments2)
467 if !compareComments(comments2, []string{"Comment 1", "Comment 2", "Comment 3"}) {
468 t.Errorf("Append new record to has many relations")
469 }
470
471 if DB.Model(&post).Association("Comments").Count() != 3 {
472 t.Errorf("Post's comments count should be 3 after Append")
473 }
474
475 // Delete
476 DB.Model(&post).Association("Comments").Delete(comments11)
477
478 var comments3 []Comment
479 DB.Model(&post).Related(&comments3)
480 if !compareComments(comments3, []string{"Comment 3"}) {
481 t.Errorf("Delete an existing resource for has many relations")
482 }
483
484 if DB.Model(&post).Association("Comments").Count() != 1 {
485 t.Errorf("Post's comments count should be 1 after Delete 2")
486 }
487
488 // Replace
489 DB.Model(&Post{Id: 999}).Association("Comments").Replace()
490
491 var comments4 []Comment
492 DB.Model(&post).Related(&comments4)
493 if len(comments4) == 0 {
494 t.Errorf("Replace for other resource should not clear all comments")
495 }
496
497 DB.Model(&post).Association("Comments").Replace(&Comment{Content: "Comment 4"}, &Comment{Content: "Comment 5"})
498
499 var comments41 []Comment
500 DB.Model(&post).Related(&comments41)
501 if !compareComments(comments41, []string{"Comment 4", "Comment 5"}) {
502 t.Errorf("Replace has many relations")
503 }
504
505 // Clear
506 DB.Model(&Post{Id: 999}).Association("Comments").Clear()
507
508 var comments5 []Comment
509 DB.Model(&post).Related(&comments5)
510 if len(comments5) == 0 {
511 t.Errorf("Clear should not clear all comments")
512 }
513
514 DB.Model(&post).Association("Comments").Clear()
515
516 var comments51 []Comment
517 DB.Model(&post).Related(&comments51)
518 if len(comments51) != 0 {
519 t.Errorf("Clear has many relations")
520 }
521
522 // Check Association mode with soft delete
523 var comment6 = Comment{
524 Content: "comment 6",
525 }
526 DB.Model(&post).Association("Comments").Append(&comment6)
527
528 if count := DB.Model(&post).Association("Comments").Count(); count != 1 {
529 t.Errorf("post's comments count should be 1 after Append, but got %v", count)
530 }
531
532 DB.Delete(&comment6)
533
534 if count := DB.Model(&post).Association("Comments").Count(); count != 0 {
535 t.Errorf("post's comments count should be 0 after comment been deleted, but got %v", count)
536 }
537
538 var comments6 []Comment
539 if DB.Model(&post).Association("Comments").Find(&comments6); len(comments6) != 0 {
540 t.Errorf("post's comments count should be 0 when find with Find, but got %v", len(comments6))
541 }
542
543 if count := DB.Unscoped().Model(&post).Association("Comments").Count(); count != 1 {
544 t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", count)
545 }
546
547 var comments61 []Comment
548 if DB.Unscoped().Model(&post).Association("Comments").Find(&comments61); len(comments61) != 1 {
549 t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", len(comments61))
550 }
551 }
552
553 func TestHasManyOverrideForeignKey1(t *testing.T) {
554 type Profile struct {
555 gorm.Model
556 Name string
557 UserRefer uint
558 }
559
560 type User struct {
561 gorm.Model
562 Profile []Profile `gorm:"ForeignKey:UserRefer"`
563 }
564
565 if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
566 if relation.Relationship.Kind != "has_many" ||
567 !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
568 !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
569 t.Errorf("Override belongs to foreign key with tag")
570 }
571 }
572 }
573
574 func TestHasManyOverrideForeignKey2(t *testing.T) {
575 type Profile struct {
576 gorm.Model
577 Name string
578 UserID uint
579 }
580
581 type User struct {
582 gorm.Model
583 Refer string
584 Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
585 }
586
587 if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
588 if relation.Relationship.Kind != "has_many" ||
589 !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
590 !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
591 t.Errorf("Override belongs to foreign key with tag")
592 }
593 }
594 }
595
596 func TestManyToMany(t *testing.T) {
597 DB.Raw("delete from languages")
598 var languages = []Language{{Name: "ZH"}, {Name: "EN"}}
599 user := User{Name: "Many2Many", Languages: languages}
600 DB.Save(&user)
601
602 // Query
603 var newLanguages []Language
604 DB.Model(&user).Related(&newLanguages, "Languages")
605 if len(newLanguages) != len([]string{"ZH", "EN"}) {
606 t.Errorf("Query many to many relations")
607 }
608
609 DB.Model(&user).Association("Languages").Find(&newLanguages)
610 if len(newLanguages) != len([]string{"ZH", "EN"}) {
611 t.Errorf("Should be able to find many to many relations")
612 }
613
614 if DB.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) {
615 t.Errorf("Count should return correct result")
616 }
617
618 // Append
619 DB.Model(&user).Association("Languages").Append(&Language{Name: "DE"})
620 if DB.Where("name = ?", "DE").First(&Language{}).RecordNotFound() {
621 t.Errorf("New record should be saved when append")
622 }
623
624 languageA := Language{Name: "AA"}
625 DB.Save(&languageA)
626 DB.Model(&User{Id: user.Id}).Association("Languages").Append(&languageA)
627
628 languageC := Language{Name: "CC"}
629 DB.Save(&languageC)
630 DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC})
631
632 DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}})
633
634 totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"}
635
636 if DB.Model(&user).Association("Languages").Count() != len(totalLanguages) {
637 t.Errorf("All appended languages should be saved")
638 }
639
640 // Delete
641 user.Languages = []Language{}
642 DB.Model(&user).Association("Languages").Find(&user.Languages)
643
644 var language Language
645 DB.Where("name = ?", "EE").First(&language)
646 DB.Model(&user).Association("Languages").Delete(language, &language)
647
648 if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 {
649 t.Errorf("Relations should be deleted with Delete")
650 }
651 if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() {
652 t.Errorf("Language EE should not be deleted")
653 }
654
655 DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages)
656
657 user2 := User{Name: "Many2Many_User2", Languages: languages}
658 DB.Save(&user2)
659
660 DB.Model(&user).Association("Languages").Delete(languages, &languages)
661 if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 || len(user.Languages) != len(totalLanguages)-3 {
662 t.Errorf("Relations should be deleted with Delete")
663 }
664
665 if DB.Model(&user2).Association("Languages").Count() == 0 {
666 t.Errorf("Other user's relations should not be deleted")
667 }
668
669 // Replace
670 var languageB Language
671 DB.Where("name = ?", "BB").First(&languageB)
672 DB.Model(&user).Association("Languages").Replace(languageB)
673 if len(user.Languages) != 1 || DB.Model(&user).Association("Languages").Count() != 1 {
674 t.Errorf("Relations should be replaced")
675 }
676
677 DB.Model(&user).Association("Languages").Replace()
678 if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 {
679 t.Errorf("Relations should be replaced with empty")
680 }
681
682 DB.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}})
683 if len(user.Languages) != 2 || DB.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) {
684 t.Errorf("Relations should be replaced")
685 }
686
687 // Clear
688 DB.Model(&user).Association("Languages").Clear()
689 if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 {
690 t.Errorf("Relations should be cleared")
691 }
692
693 // Check Association mode with soft delete
694 var language6 = Language{
695 Name: "language 6",
696 }
697 DB.Model(&user).Association("Languages").Append(&language6)
698
699 if count := DB.Model(&user).Association("Languages").Count(); count != 1 {
700 t.Errorf("user's languages count should be 1 after Append, but got %v", count)
701 }
702
703 DB.Delete(&language6)
704
705 if count := DB.Model(&user).Association("Languages").Count(); count != 0 {
706 t.Errorf("user's languages count should be 0 after language been deleted, but got %v", count)
707 }
708
709 var languages6 []Language
710 if DB.Model(&user).Association("Languages").Find(&languages6); len(languages6) != 0 {
711 t.Errorf("user's languages count should be 0 when find with Find, but got %v", len(languages6))
712 }
713
714 if count := DB.Unscoped().Model(&user).Association("Languages").Count(); count != 1 {
715 t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", count)
716 }
717
718 var languages61 []Language
719 if DB.Unscoped().Model(&user).Association("Languages").Find(&languages61); len(languages61) != 1 {
720 t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", len(languages61))
721 }
52722 }
53723
54724 func TestRelated(t *testing.T) {
61731 Company: Company{Name: "company1"},
62732 }
63733
64 DB.Save(&user)
734 if err := DB.Save(&user).Error; err != nil {
735 t.Errorf("No error should happen when saving user")
736 }
65737
66738 if user.CreditCard.ID == 0 {
67739 t.Errorf("After user save, credit card should have id")
84756 var emails2 []Email
85757 DB.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2)
86758 if len(emails2) != 1 {
759 t.Errorf("Should have two emails")
760 }
761
762 var emails3 []*Email
763 DB.Model(&user).Related(&emails3)
764 if len(emails3) != 2 {
87765 t.Errorf("Should have two emails")
88766 }
89767
129807 }
130808 }
131809
132 func TestManyToMany(t *testing.T) {
133 DB.Raw("delete from languages")
134 var languages = []Language{{Name: "ZH"}, {Name: "EN"}}
135 user := User{Name: "Many2Many", Languages: languages}
136 DB.Save(&user)
137
138 // Query
139 var newLanguages []Language
140 DB.Model(&user).Related(&newLanguages, "Languages")
141 if len(newLanguages) != len([]string{"ZH", "EN"}) {
142 t.Errorf("Query many to many relations")
143 }
144
145 DB.Model(&user).Association("Languages").Find(&newLanguages)
146 if len(newLanguages) != len([]string{"ZH", "EN"}) {
147 t.Errorf("Should be able to find many to many relations")
148 }
149
150 if DB.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) {
151 t.Errorf("Count should return correct result")
152 }
153
154 // Append
155 DB.Model(&user).Association("Languages").Append(&Language{Name: "DE"})
156 if DB.Where("name = ?", "DE").First(&Language{}).RecordNotFound() {
157 t.Errorf("New record should be saved when append")
158 }
159
160 languageA := Language{Name: "AA"}
161 DB.Save(&languageA)
162 DB.Model(&User{Id: user.Id}).Association("Languages").Append(&languageA)
163
164 languageC := Language{Name: "CC"}
165 DB.Save(&languageC)
166 DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC})
167
168 DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}})
169
170 totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"}
171
172 if DB.Model(&user).Association("Languages").Count() != len(totalLanguages) {
173 t.Errorf("All appended languages should be saved")
174 }
175
176 // Delete
177 user.Languages = []Language{}
178 DB.Model(&user).Association("Languages").Find(&user.Languages)
179
180 var language Language
181 DB.Where("name = ?", "EE").First(&language)
182 DB.Model(&user).Association("Languages").Delete(language, &language)
183
184 if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 {
185 t.Errorf("Relations should be deleted with Delete")
186 }
187 if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() {
188 t.Errorf("Language EE should not be deleted")
189 }
190
191 DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages)
192
193 user2 := User{Name: "Many2Many_User2", Languages: languages}
194 DB.Save(&user2)
195
196 DB.Model(&user).Association("Languages").Delete(languages, &languages)
197 if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 || len(user.Languages) != len(totalLanguages)-3 {
198 t.Errorf("Relations should be deleted with Delete")
199 }
200
201 if DB.Model(&user2).Association("Languages").Count() == 0 {
202 t.Errorf("Other user's relations should not be deleted")
203 }
204
205 // Replace
206 var languageB Language
207 DB.Where("name = ?", "BB").First(&languageB)
208 DB.Model(&user).Association("Languages").Replace(languageB)
209 if len(user.Languages) != 1 || DB.Model(&user).Association("Languages").Count() != 1 {
210 t.Errorf("Relations should be replaced")
211 }
212
213 DB.Model(&user).Association("Languages").Replace()
214 if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 {
215 t.Errorf("Relations should be replaced with empty")
216 }
217
218 DB.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}})
219 if len(user.Languages) != 2 || DB.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) {
220 t.Errorf("Relations should be replaced")
221 }
222
223 // Clear
224 DB.Model(&user).Association("Languages").Clear()
225 if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 {
226 t.Errorf("Relations should be cleared")
227 }
228 }
229
230810 func TestForeignKey(t *testing.T) {
231811 for _, structField := range DB.NewScope(&User{}).GetStructFields() {
232812 for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} {
260840 }
261841 }
262842 }
843
844 func testForeignKey(t *testing.T, source interface{}, sourceFieldName string, target interface{}, targetFieldName string) {
845 if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" {
846 // sqlite does not support ADD CONSTRAINT in ALTER TABLE
847 return
848 }
849 targetScope := DB.NewScope(target)
850 targetTableName := targetScope.TableName()
851 modelScope := DB.NewScope(source)
852 modelField, ok := modelScope.FieldByName(sourceFieldName)
853 if !ok {
854 t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", sourceFieldName))
855 }
856 targetField, ok := targetScope.FieldByName(targetFieldName)
857 if !ok {
858 t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", targetFieldName))
859 }
860 dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName)
861 err := DB.Model(source).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error
862 if err != nil {
863 t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err))
864 }
865 }
866
867 func TestLongForeignKey(t *testing.T) {
868 testForeignKey(t, &NotSoLongTableName{}, "ReallyLongThingID", &ReallyLongTableNameToTestMySQLNameLengthLimit{}, "ID")
869 }
870
871 func TestLongForeignKeyWithShortDest(t *testing.T) {
872 testForeignKey(t, &ReallyLongThingThatReferencesShort{}, "ShortID", &Short{}, "ID")
873 }
874
875 func TestHasManyChildrenWithOneStruct(t *testing.T) {
876 category := Category{
877 Name: "main",
878 Categories: []Category{
879 {Name: "sub1"},
880 {Name: "sub2"},
881 },
882 }
883
884 DB.Save(&category)
885 }
886
887 func TestAutoSaveBelongsToAssociation(t *testing.T) {
888 type Company struct {
889 gorm.Model
890 Name string
891 }
892
893 type User struct {
894 gorm.Model
895 Name string
896 CompanyID uint
897 Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"`
898 }
899
900 DB.Where("name = ?", "auto_save_association").Delete(&Company{})
901 DB.AutoMigrate(&Company{}, &User{})
902
903 DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_association"}})
904
905 if !DB.Where("name = ?", "auto_save_association").First(&Company{}).RecordNotFound() {
906 t.Errorf("Company auto_save_association should not have been saved when autosave is false")
907 }
908
909 // if foreign key is set, this should be saved even if association isn't
910 company := Company{Name: "auto_save_association"}
911 DB.Save(&company)
912
913 company.Name = "auto_save_association_new_name"
914 user := User{Name: "jinzhu", Company: company}
915
916 DB.Save(&user)
917
918 if !DB.Where("name = ?", "auto_save_association_new_name").First(&Company{}).RecordNotFound() {
919 t.Errorf("Company should not have been updated")
920 }
921
922 if DB.Where("id = ? AND company_id = ?", user.ID, company.ID).First(&User{}).RecordNotFound() {
923 t.Errorf("User's foreign key should have been saved")
924 }
925
926 user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_association_2"}}
927 DB.Set("gorm:association_autocreate", true).Save(&user2)
928 if DB.Where("name = ?", "auto_save_association_2").First(&Company{}).RecordNotFound() {
929 t.Errorf("Company auto_save_association_2 should been created when autocreate is true")
930 }
931
932 user2.Company.Name = "auto_save_association_2_newname"
933 DB.Set("gorm:association_autoupdate", true).Save(&user2)
934
935 if DB.Where("name = ?", "auto_save_association_2_newname").First(&Company{}).RecordNotFound() {
936 t.Errorf("Company should been updated")
937 }
938 }
939
940 func TestAutoSaveHasOneAssociation(t *testing.T) {
941 type Company struct {
942 gorm.Model
943 UserID uint
944 Name string
945 }
946
947 type User struct {
948 gorm.Model
949 Name string
950 Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"`
951 }
952
953 DB.Where("name = ?", "auto_save_has_one_association").Delete(&Company{})
954 DB.AutoMigrate(&Company{}, &User{})
955
956 DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_has_one_association"}})
957
958 if !DB.Where("name = ?", "auto_save_has_one_association").First(&Company{}).RecordNotFound() {
959 t.Errorf("Company auto_save_has_one_association should not have been saved when autosave is false")
960 }
961
962 company := Company{Name: "auto_save_has_one_association"}
963 DB.Save(&company)
964
965 company.Name = "auto_save_has_one_association_new_name"
966 user := User{Name: "jinzhu", Company: company}
967
968 DB.Save(&user)
969
970 if !DB.Where("name = ?", "auto_save_has_one_association_new_name").First(&Company{}).RecordNotFound() {
971 t.Errorf("Company should not have been updated")
972 }
973
974 if !DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association", user.ID).First(&Company{}).RecordNotFound() {
975 t.Errorf("Company should not have been updated")
976 }
977
978 if user.Company.UserID == 0 {
979 t.Errorf("UserID should be assigned")
980 }
981
982 company.Name = "auto_save_has_one_association_2_new_name"
983 DB.Set("gorm:association_autoupdate", true).Save(&user)
984
985 if DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association_new_name", user.ID).First(&Company{}).RecordNotFound() {
986 t.Errorf("Company should been updated")
987 }
988
989 user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_has_one_association_2"}}
990 DB.Set("gorm:association_autocreate", true).Save(&user2)
991 if DB.Where("name = ?", "auto_save_has_one_association_2").First(&Company{}).RecordNotFound() {
992 t.Errorf("Company auto_save_has_one_association_2 should been created when autocreate is true")
993 }
994 }
995
996 func TestAutoSaveMany2ManyAssociation(t *testing.T) {
997 type Company struct {
998 gorm.Model
999 Name string
1000 }
1001
1002 type User struct {
1003 gorm.Model
1004 Name string
1005 Companies []Company `gorm:"many2many:user_companies;association_autoupdate:false;association_autocreate:false;"`
1006 }
1007
1008 DB.AutoMigrate(&Company{}, &User{})
1009
1010 DB.Save(&User{Name: "jinzhu", Companies: []Company{{Name: "auto_save_m2m_association"}}})
1011
1012 if !DB.Where("name = ?", "auto_save_m2m_association").First(&Company{}).RecordNotFound() {
1013 t.Errorf("Company auto_save_m2m_association should not have been saved when autosave is false")
1014 }
1015
1016 company := Company{Name: "auto_save_m2m_association"}
1017 DB.Save(&company)
1018
1019 company.Name = "auto_save_m2m_association_new_name"
1020 user := User{Name: "jinzhu", Companies: []Company{company, {Name: "auto_save_m2m_association_new_name_2"}}}
1021
1022 DB.Save(&user)
1023
1024 if !DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() {
1025 t.Errorf("Company should not have been updated")
1026 }
1027
1028 if !DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() {
1029 t.Errorf("Company should not been created")
1030 }
1031
1032 if DB.Model(&user).Association("Companies").Count() != 1 {
1033 t.Errorf("Relationship should been saved")
1034 }
1035
1036 DB.Set("gorm:association_autoupdate", true).Set("gorm:association_autocreate", true).Save(&user)
1037
1038 if DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() {
1039 t.Errorf("Company should been updated")
1040 }
1041
1042 if DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() {
1043 t.Errorf("Company should been created")
1044 }
1045
1046 if DB.Model(&user).Association("Companies").Count() != 2 {
1047 t.Errorf("Relationship should been updated")
1048 }
1049 }
00 package gorm
11
2 import (
3 "fmt"
4 )
5
6 type callback struct {
2 import "log"
3
4 // DefaultCallback default callbacks defined by gorm
5 var DefaultCallback = &Callback{}
6
7 // Callback is a struct that contains all CRUD callbacks
8 // Field `creates` contains callbacks will be call when creating object
9 // Field `updates` contains callbacks will be call when updating object
10 // Field `deletes` contains callbacks will be call when deleting object
11 // Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association...
12 // Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
13 // Field `processors` contains all callback processors, will be used to generate above callbacks in order
14 type Callback struct {
715 creates []*func(scope *Scope)
816 updates []*func(scope *Scope)
917 deletes []*func(scope *Scope)
1018 queries []*func(scope *Scope)
1119 rowQueries []*func(scope *Scope)
12 processors []*callbackProcessor
13 }
14
15 type callbackProcessor struct {
16 name string
17 before string
18 after string
19 replace bool
20 remove bool
21 typ string
22 processor *func(scope *Scope)
23 callback *callback
24 }
25
26 func (c *callback) addProcessor(typ string) *callbackProcessor {
27 cp := &callbackProcessor{typ: typ, callback: c}
28 c.processors = append(c.processors, cp)
29 return cp
30 }
31
32 func (c *callback) clone() *callback {
33 return &callback{
20 processors []*CallbackProcessor
21 }
22
23 // CallbackProcessor contains callback informations
24 type CallbackProcessor struct {
25 name string // current callback's name
26 before string // register current callback before a callback
27 after string // register current callback after a callback
28 replace bool // replace callbacks with same name
29 remove bool // delete callbacks with same name
30 kind string // callback type: create, update, delete, query, row_query
31 processor *func(scope *Scope) // callback handler
32 parent *Callback
33 }
34
35 func (c *Callback) clone() *Callback {
36 return &Callback{
3437 creates: c.creates,
3538 updates: c.updates,
3639 deletes: c.deletes,
3740 queries: c.queries,
41 rowQueries: c.rowQueries,
3842 processors: c.processors,
3943 }
4044 }
4145
42 func (c *callback) Create() *callbackProcessor {
43 return c.addProcessor("create")
44 }
45
46 func (c *callback) Update() *callbackProcessor {
47 return c.addProcessor("update")
48 }
49
50 func (c *callback) Delete() *callbackProcessor {
51 return c.addProcessor("delete")
52 }
53
54 func (c *callback) Query() *callbackProcessor {
55 return c.addProcessor("query")
56 }
57
58 func (c *callback) RowQuery() *callbackProcessor {
59 return c.addProcessor("row_query")
60 }
61
62 func (cp *callbackProcessor) Before(name string) *callbackProcessor {
63 cp.before = name
46 // Create could be used to register callbacks for creating object
47 // db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
48 // // business logic
49 // ...
50 //
51 // // set error if some thing wrong happened, will rollback the creating
52 // scope.Err(errors.New("error"))
53 // })
54 func (c *Callback) Create() *CallbackProcessor {
55 return &CallbackProcessor{kind: "create", parent: c}
56 }
57
58 // Update could be used to register callbacks for updating object, refer `Create` for usage
59 func (c *Callback) Update() *CallbackProcessor {
60 return &CallbackProcessor{kind: "update", parent: c}
61 }
62
63 // Delete could be used to register callbacks for deleting object, refer `Create` for usage
64 func (c *Callback) Delete() *CallbackProcessor {
65 return &CallbackProcessor{kind: "delete", parent: c}
66 }
67
68 // Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
69 // Refer `Create` for usage
70 func (c *Callback) Query() *CallbackProcessor {
71 return &CallbackProcessor{kind: "query", parent: c}
72 }
73
74 // RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
75 func (c *Callback) RowQuery() *CallbackProcessor {
76 return &CallbackProcessor{kind: "row_query", parent: c}
77 }
78
79 // After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
80 func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor {
81 cp.after = callbackName
6482 return cp
6583 }
6684
67 func (cp *callbackProcessor) After(name string) *callbackProcessor {
68 cp.after = name
85 // Before insert a new callback before callback `callbackName`, refer `Callbacks.Create`
86 func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
87 cp.before = callbackName
6988 return cp
7089 }
7190
72 func (cp *callbackProcessor) Register(name string, fc func(scope *Scope)) {
73 cp.name = name
74 cp.processor = &fc
75 cp.callback.sort()
76 }
77
78 func (cp *callbackProcessor) Remove(name string) {
79 fmt.Printf("[info] removing callback `%v` from %v\n", name, fileWithLineNum())
80 cp.name = name
91 // Register a new callback, refer `Callbacks.Create`
92 func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
93 if cp.kind == "row_query" {
94 if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
95 log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
96 cp.before = "gorm:row_query"
97 }
98 }
99
100 cp.name = callbackName
101 cp.processor = &callback
102 cp.parent.processors = append(cp.parent.processors, cp)
103 cp.parent.reorder()
104 }
105
106 // Remove a registered callback
107 // db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
108 func (cp *CallbackProcessor) Remove(callbackName string) {
109 log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
110 cp.name = callbackName
81111 cp.remove = true
82 cp.callback.sort()
83 }
84
85 func (cp *callbackProcessor) Replace(name string, fc func(scope *Scope)) {
86 fmt.Printf("[info] replacing callback `%v` from %v\n", name, fileWithLineNum())
87 cp.name = name
88 cp.processor = &fc
112 cp.parent.processors = append(cp.parent.processors, cp)
113 cp.parent.reorder()
114 }
115
116 // Replace a registered callback with new callback
117 // db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
118 // scope.SetColumn("Created", now)
119 // scope.SetColumn("Updated", now)
120 // })
121 func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
122 log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
123 cp.name = callbackName
124 cp.processor = &callback
89125 cp.replace = true
90 cp.callback.sort()
91 }
92
126 cp.parent.processors = append(cp.parent.processors, cp)
127 cp.parent.reorder()
128 }
129
130 // Get registered callback
131 // db.Callback().Create().Get("gorm:create")
132 func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
133 for _, p := range cp.parent.processors {
134 if p.name == callbackName && p.kind == cp.kind && !cp.remove {
135 return *p.processor
136 }
137 }
138 return nil
139 }
140
141 // getRIndex get right index from string slice
93142 func getRIndex(strs []string, str string) int {
94143 for i := len(strs) - 1; i >= 0; i-- {
95144 if strs[i] == str {
99148 return -1
100149 }
101150
102 func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
103 var sortCallbackProcessor func(c *callbackProcessor)
104 var names, sortedNames = []string{}, []string{}
151 // sortProcessors sort callback processors based on its before, after, remove, replace
152 func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
153 var (
154 allNames, sortedNames []string
155 sortCallbackProcessor func(c *CallbackProcessor)
156 )
105157
106158 for _, cp := range cps {
107 if index := getRIndex(names, cp.name); index > -1 {
108 if !cp.replace && !cp.remove {
109 fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
110 }
111 }
112 names = append(names, cp.name)
113 }
114
115 sortCallbackProcessor = func(c *callbackProcessor) {
116 if getRIndex(sortedNames, c.name) > -1 {
117 return
118 }
119
120 if len(c.before) > 0 {
121 if index := getRIndex(sortedNames, c.before); index > -1 {
122 sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
123 } else if index := getRIndex(names, c.before); index > -1 {
159 // show warning message the callback name already exists
160 if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
161 log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
162 }
163 allNames = append(allNames, cp.name)
164 }
165
166 sortCallbackProcessor = func(c *CallbackProcessor) {
167 if getRIndex(sortedNames, c.name) == -1 { // if not sorted
168 if c.before != "" { // if defined before callback
169 if index := getRIndex(sortedNames, c.before); index != -1 {
170 // if before callback already sorted, append current callback just after it
171 sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
172 } else if index := getRIndex(allNames, c.before); index != -1 {
173 // if before callback exists but haven't sorted, append current callback to last
174 sortedNames = append(sortedNames, c.name)
175 sortCallbackProcessor(cps[index])
176 }
177 }
178
179 if c.after != "" { // if defined after callback
180 if index := getRIndex(sortedNames, c.after); index != -1 {
181 // if after callback already sorted, append current callback just before it
182 sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
183 } else if index := getRIndex(allNames, c.after); index != -1 {
184 // if after callback exists but haven't sorted
185 cp := cps[index]
186 // set after callback's before callback to current callback
187 if cp.before == "" {
188 cp.before = c.name
189 }
190 sortCallbackProcessor(cp)
191 }
192 }
193
194 // if current callback haven't been sorted, append it to last
195 if getRIndex(sortedNames, c.name) == -1 {
124196 sortedNames = append(sortedNames, c.name)
125 sortCallbackProcessor(cps[index])
126 } else {
127 sortedNames = append(sortedNames, c.name)
128 }
129 }
130
131 if len(c.after) > 0 {
132 if index := getRIndex(sortedNames, c.after); index > -1 {
133 sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
134 } else if index := getRIndex(names, c.after); index > -1 {
135 cp := cps[index]
136 if len(cp.before) == 0 {
137 cp.before = c.name
138 }
139 sortCallbackProcessor(cp)
140 } else {
141 sortedNames = append(sortedNames, c.name)
142 }
143 }
144
145 if getRIndex(sortedNames, c.name) == -1 {
146 sortedNames = append(sortedNames, c.name)
197 }
147198 }
148199 }
149200
151202 sortCallbackProcessor(cp)
152203 }
153204
154 var funcs = []*func(scope *Scope){}
155 var sortedFuncs = []*func(scope *Scope){}
205 var sortedFuncs []*func(scope *Scope)
156206 for _, name := range sortedNames {
157 index := getRIndex(names, name)
158 if !cps[index].remove {
207 if index := getRIndex(allNames, name); !cps[index].remove {
159208 sortedFuncs = append(sortedFuncs, cps[index].processor)
160209 }
161210 }
162211
163 for _, cp := range cps {
164 if sindex := getRIndex(sortedNames, cp.name); sindex == -1 {
165 if !cp.remove {
166 funcs = append(funcs, cp.processor)
167 }
168 }
169 }
170
171 return append(sortedFuncs, funcs...)
172 }
173
174 func (c *callback) sort() {
175 var creates, updates, deletes, queries, rowQueries []*callbackProcessor
212 return sortedFuncs
213 }
214
215 // reorder all registered processors, and reset CRUD callbacks
216 func (c *Callback) reorder() {
217 var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
176218
177219 for _, processor := range c.processors {
178 switch processor.typ {
179 case "create":
180 creates = append(creates, processor)
181 case "update":
182 updates = append(updates, processor)
183 case "delete":
184 deletes = append(deletes, processor)
185 case "query":
186 queries = append(queries, processor)
187 case "row_query":
188 rowQueries = append(rowQueries, processor)
220 if processor.name != "" {
221 switch processor.kind {
222 case "create":
223 creates = append(creates, processor)
224 case "update":
225 updates = append(updates, processor)
226 case "delete":
227 deletes = append(deletes, processor)
228 case "query":
229 queries = append(queries, processor)
230 case "row_query":
231 rowQueries = append(rowQueries, processor)
232 }
189233 }
190234 }
191235
195239 c.queries = sortProcessors(queries)
196240 c.rowQueries = sortProcessors(rowQueries)
197241 }
198
199 var DefaultCallback = &callback{processors: []*callbackProcessor{}}
44 "strings"
55 )
66
7 func BeforeCreate(scope *Scope) {
8 scope.CallMethodWithErrorCheck("BeforeSave")
9 scope.CallMethodWithErrorCheck("BeforeCreate")
7 // Define callbacks for creating
8 func init() {
9 DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
10 DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
11 DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
12 DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
13 DefaultCallback.Create().Register("gorm:create", createCallback)
14 DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
15 DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
16 DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
17 DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
1018 }
1119
12 func UpdateTimeStampWhenCreate(scope *Scope) {
20 // beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
21 func beforeCreateCallback(scope *Scope) {
1322 if !scope.HasError() {
14 now := NowFunc()
15 scope.SetColumn("CreatedAt", now)
16 scope.SetColumn("UpdatedAt", now)
23 scope.CallMethod("BeforeSave")
24 }
25 if !scope.HasError() {
26 scope.CallMethod("BeforeCreate")
1727 }
1828 }
1929
20 func Create(scope *Scope) {
21 defer scope.Trace(NowFunc())
30 // updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
31 func updateTimeStampForCreateCallback(scope *Scope) {
32 if !scope.HasError() {
33 now := NowFunc()
2234
35 if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
36 if createdAtField.IsBlank {
37 createdAtField.Set(now)
38 }
39 }
40
41 if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok {
42 if updatedAtField.IsBlank {
43 updatedAtField.Set(now)
44 }
45 }
46 }
47 }
48
49 // createCallback the callback used to insert data into database
50 func createCallback(scope *Scope) {
2351 if !scope.HasError() {
24 // set create sql
25 var sqls, columns []string
26 fields := scope.Fields()
27 for _, field := range fields {
52 defer scope.trace(NowFunc())
53
54 var (
55 columns, placeholders []string
56 blankColumnsWithDefaultValue []string
57 )
58
59 for _, field := range scope.Fields() {
2860 if scope.changeableField(field) {
2961 if field.IsNormal {
30 if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
31 if !field.IsBlank || !field.HasDefaultValue {
32 columns = append(columns, scope.Quote(field.DBName))
33 sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
34 } else if field.HasDefaultValue {
35 scope.InstanceSet("gorm:force_reload_after_create", true)
36 }
62 if field.IsBlank && field.HasDefaultValue {
63 blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
64 scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
65 } else if !field.IsPrimaryKey || !field.IsBlank {
66 columns = append(columns, scope.Quote(field.DBName))
67 placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
3768 }
38 } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
39 for _, dbName := range relationship.ForeignDBNames {
40 if relationField := fields[dbName]; !scope.changeableField(relationField) {
41 columns = append(columns, scope.Quote(relationField.DBName))
42 sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
69 } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
70 for _, foreignKey := range field.Relationship.ForeignDBNames {
71 if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
72 columns = append(columns, scope.Quote(foreignField.DBName))
73 placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
4374 }
4475 }
4576 }
4677 }
4778 }
4879
49 returningKey := "*"
50 primaryField := scope.PrimaryField()
51 if primaryField != nil {
52 returningKey = scope.Quote(primaryField.DBName)
80 var (
81 returningColumn = "*"
82 quotedTableName = scope.QuotedTableName()
83 primaryField = scope.PrimaryField()
84 extraOption string
85 )
86
87 if str, ok := scope.Get("gorm:insert_option"); ok {
88 extraOption = fmt.Sprint(str)
5389 }
5490
91 if primaryField != nil {
92 returningColumn = scope.Quote(primaryField.DBName)
93 }
94
95 lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
96
5597 if len(columns) == 0 {
56 scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v",
57 scope.QuotedTableName(),
58 scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
98 scope.Raw(fmt.Sprintf(
99 "INSERT INTO %v %v%v%v",
100 quotedTableName,
101 scope.Dialect().DefaultValueStr(),
102 addExtraSpaceIfExist(extraOption),
103 addExtraSpaceIfExist(lastInsertIDReturningSuffix),
59104 ))
60105 } else {
61106 scope.Raw(fmt.Sprintf(
62 "INSERT INTO %v (%v) VALUES (%v) %v",
107 "INSERT INTO %v (%v) VALUES (%v)%v%v",
63108 scope.QuotedTableName(),
64109 strings.Join(columns, ","),
65 strings.Join(sqls, ","),
66 scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
110 strings.Join(placeholders, ","),
111 addExtraSpaceIfExist(extraOption),
112 addExtraSpaceIfExist(lastInsertIDReturningSuffix),
67113 ))
68114 }
69115
70116 // execute create sql
71 if scope.Dialect().SupportLastInsertId() {
72 if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
73 id, err := result.LastInsertId()
74 if scope.Err(err) == nil {
75 scope.db.RowsAffected, _ = result.RowsAffected()
76 if primaryField != nil && primaryField.IsBlank {
77 scope.Err(scope.SetColumn(primaryField, id))
117 if lastInsertIDReturningSuffix == "" || primaryField == nil {
118 if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
119 // set rows affected count
120 scope.db.RowsAffected, _ = result.RowsAffected()
121
122 // set primary value to primary field
123 if primaryField != nil && primaryField.IsBlank {
124 if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
125 scope.Err(primaryField.Set(primaryValue))
78126 }
79127 }
80128 }
81129 } else {
82 if primaryField == nil {
83 if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
84 scope.db.RowsAffected, _ = results.RowsAffected()
85 } else {
86 scope.Err(err)
130 if primaryField.Field.CanAddr() {
131 if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
132 primaryField.IsBlank = false
133 scope.db.RowsAffected = 1
87134 }
88135 } else {
89 if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
90 scope.db.RowsAffected = 1
91 } else {
92 scope.Err(err)
93 }
136 scope.Err(ErrUnaddressable)
94137 }
95138 }
96139 }
97140 }
98141
99 func ForceReloadAfterCreate(scope *Scope) {
100 if _, ok := scope.InstanceGet("gorm:force_reload_after_create"); ok {
101 scope.DB().New().First(scope.Value)
142 // forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
143 func forceReloadAfterCreateCallback(scope *Scope) {
144 if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
145 db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
146 for _, field := range scope.Fields() {
147 if field.IsPrimaryKey && !field.IsBlank {
148 db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface())
149 }
150 }
151 db.Scan(scope.Value)
102152 }
103153 }
104154
105 func AfterCreate(scope *Scope) {
106 scope.CallMethodWithErrorCheck("AfterCreate")
107 scope.CallMethodWithErrorCheck("AfterSave")
155 // afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
156 func afterCreateCallback(scope *Scope) {
157 if !scope.HasError() {
158 scope.CallMethod("AfterCreate")
159 }
160 if !scope.HasError() {
161 scope.CallMethod("AfterSave")
162 }
108163 }
109
110 func init() {
111 DefaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction)
112 DefaultCallback.Create().Register("gorm:before_create", BeforeCreate)
113 DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
114 DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate)
115 DefaultCallback.Create().Register("gorm:create", Create)
116 DefaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate)
117 DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
118 DefaultCallback.Create().Register("gorm:after_create", AfterCreate)
119 DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
120 }
00 package gorm
11
2 import "fmt"
2 import (
3 "errors"
4 "fmt"
5 )
36
4 func BeforeDelete(scope *Scope) {
5 scope.CallMethodWithErrorCheck("BeforeDelete")
7 // Define callbacks for deleting
8 func init() {
9 DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback)
10 DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback)
11 DefaultCallback.Delete().Register("gorm:delete", deleteCallback)
12 DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback)
13 DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
614 }
715
8 func Delete(scope *Scope) {
16 // beforeDeleteCallback will invoke `BeforeDelete` method before deleting
17 func beforeDeleteCallback(scope *Scope) {
18 if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
19 scope.Err(errors.New("Missing WHERE clause while deleting"))
20 return
21 }
922 if !scope.HasError() {
10 if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") {
11 scope.Raw(
12 fmt.Sprintf("UPDATE %v SET deleted_at=%v %v",
13 scope.QuotedTableName(),
14 scope.AddToVars(NowFunc()),
15 scope.CombinedConditionSql(),
16 ))
17 } else {
18 scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.QuotedTableName(), scope.CombinedConditionSql()))
19 }
20
21 scope.Exec()
23 scope.CallMethod("BeforeDelete")
2224 }
2325 }
2426
25 func AfterDelete(scope *Scope) {
26 scope.CallMethodWithErrorCheck("AfterDelete")
27 // deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
28 func deleteCallback(scope *Scope) {
29 if !scope.HasError() {
30 var extraOption string
31 if str, ok := scope.Get("gorm:delete_option"); ok {
32 extraOption = fmt.Sprint(str)
33 }
34
35 deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt")
36
37 if !scope.Search.Unscoped && hasDeletedAtField {
38 scope.Raw(fmt.Sprintf(
39 "UPDATE %v SET %v=%v%v%v",
40 scope.QuotedTableName(),
41 scope.Quote(deletedAtField.DBName),
42 scope.AddToVars(NowFunc()),
43 addExtraSpaceIfExist(scope.CombinedConditionSql()),
44 addExtraSpaceIfExist(extraOption),
45 )).Exec()
46 } else {
47 scope.Raw(fmt.Sprintf(
48 "DELETE FROM %v%v%v",
49 scope.QuotedTableName(),
50 addExtraSpaceIfExist(scope.CombinedConditionSql()),
51 addExtraSpaceIfExist(extraOption),
52 )).Exec()
53 }
54 }
2755 }
2856
29 func init() {
30 DefaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction)
31 DefaultCallback.Delete().Register("gorm:before_delete", BeforeDelete)
32 DefaultCallback.Delete().Register("gorm:delete", Delete)
33 DefaultCallback.Delete().Register("gorm:after_delete", AfterDelete)
34 DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
57 // afterDeleteCallback will invoke `AfterDelete` method after deleting
58 func afterDeleteCallback(scope *Scope) {
59 if !scope.HasError() {
60 scope.CallMethod("AfterDelete")
61 }
3562 }
55 "reflect"
66 )
77
8 func Query(scope *Scope) {
9 defer scope.Trace(NowFunc())
8 // Define callbacks for querying
9 func init() {
10 DefaultCallback.Query().Register("gorm:query", queryCallback)
11 DefaultCallback.Query().Register("gorm:preload", preloadCallback)
12 DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
13 }
14
15 // queryCallback used to query data from database
16 func queryCallback(scope *Scope) {
17 if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
18 return
19 }
20
21 defer scope.trace(NowFunc())
1022
1123 var (
12 isSlice bool
13 isPtr bool
14 anyRecordFound bool
15 destType reflect.Type
24 isSlice, isPtr bool
25 resultType reflect.Type
26 results = scope.IndirectValue()
1627 )
1728
1829 if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
19 if primaryKey := scope.PrimaryKey(); primaryKey != "" {
20 scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryKey), orderBy))
30 if primaryField := scope.PrimaryField(); primaryField != nil {
31 scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
2132 }
2233 }
2334
24 var dest = scope.IndirectValue()
2535 if value, ok := scope.Get("gorm:query_destination"); ok {
26 dest = reflect.Indirect(reflect.ValueOf(value))
36 results = indirect(reflect.ValueOf(value))
2737 }
2838
29 if kind := dest.Kind(); kind == reflect.Slice {
39 if kind := results.Kind(); kind == reflect.Slice {
3040 isSlice = true
31 destType = dest.Type().Elem()
32 dest.Set(reflect.MakeSlice(dest.Type(), 0, 0))
41 resultType = results.Type().Elem()
42 results.Set(reflect.MakeSlice(results.Type(), 0, 0))
3343
34 if destType.Kind() == reflect.Ptr {
44 if resultType.Kind() == reflect.Ptr {
3545 isPtr = true
36 destType = destType.Elem()
46 resultType = resultType.Elem()
3747 }
3848 } else if kind != reflect.Struct {
3949 scope.Err(errors.New("unsupported destination, should be slice or struct"))
4050 return
4151 }
4252
43 scope.prepareQuerySql()
53 scope.prepareQuerySQL()
4454
4555 if !scope.HasError() {
46 rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
4756 scope.db.RowsAffected = 0
57 if str, ok := scope.Get("gorm:query_option"); ok {
58 scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
59 }
4860
49 if scope.Err(err) != nil {
50 return
51 }
52 defer rows.Close()
61 if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
62 defer rows.Close()
5363
54 columns, _ := rows.Columns()
55 for rows.Next() {
56 scope.db.RowsAffected++
64 columns, _ := rows.Columns()
65 for rows.Next() {
66 scope.db.RowsAffected++
5767
58 anyRecordFound = true
59 elem := dest
60 if isSlice {
61 elem = reflect.New(destType).Elem()
62 }
68 elem := results
69 if isSlice {
70 elem = reflect.New(resultType).Elem()
71 }
6372
64 var values = make([]interface{}, len(columns))
73 scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())
6574
66 fields := scope.New(elem.Addr().Interface()).Fields()
67
68 for index, column := range columns {
69 if field, ok := fields[column]; ok {
70 if field.Field.Kind() == reflect.Ptr {
71 values[index] = field.Field.Addr().Interface()
75 if isSlice {
76 if isPtr {
77 results.Set(reflect.Append(results, elem.Addr()))
7278 } else {
73 values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
74 }
75 } else {
76 var value interface{}
77 values[index] = &value
78 }
79 }
80
81 scope.Err(rows.Scan(values...))
82
83 for index, column := range columns {
84 value := values[index]
85 if field, ok := fields[column]; ok {
86 if field.Field.Kind() == reflect.Ptr {
87 field.Field.Set(reflect.ValueOf(value).Elem())
88 } else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
89 field.Field.Set(v)
79 results.Set(reflect.Append(results, elem))
9080 }
9181 }
9282 }
9383
94 if isSlice {
95 if isPtr {
96 dest.Set(reflect.Append(dest, elem.Addr()))
97 } else {
98 dest.Set(reflect.Append(dest, elem))
99 }
84 if err := rows.Err(); err != nil {
85 scope.Err(err)
86 } else if scope.db.RowsAffected == 0 && !isSlice {
87 scope.Err(ErrRecordNotFound)
10088 }
101 }
102
103 if !anyRecordFound && !isSlice {
104 scope.Err(RecordNotFound)
10589 }
10690 }
10791 }
10892
109 func AfterQuery(scope *Scope) {
110 scope.CallMethodWithErrorCheck("AfterFind")
93 // afterQueryCallback will invoke `AfterFind` method after querying
94 func afterQueryCallback(scope *Scope) {
95 if !scope.HasError() {
96 scope.CallMethod("AfterFind")
97 }
11198 }
112
113 func init() {
114 DefaultCallback.Query().Register("gorm:query", Query)
115 DefaultCallback.Query().Register("gorm:after_query", AfterQuery)
116 DefaultCallback.Query().Register("gorm:preload", Preload)
117 }
0 package gorm
1
2 import (
3 "errors"
4 "fmt"
5 "reflect"
6 "strconv"
7 "strings"
8 )
9
10 // preloadCallback used to preload associations
11 func preloadCallback(scope *Scope) {
12 if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
13 return
14 }
15
16 if _, ok := scope.Get("gorm:auto_preload"); ok {
17 autoPreload(scope)
18 }
19
20 if scope.Search.preload == nil || scope.HasError() {
21 return
22 }
23
24 var (
25 preloadedMap = map[string]bool{}
26 fields = scope.Fields()
27 )
28
29 for _, preload := range scope.Search.preload {
30 var (
31 preloadFields = strings.Split(preload.schema, ".")
32 currentScope = scope
33 currentFields = fields
34 )
35
36 for idx, preloadField := range preloadFields {
37 var currentPreloadConditions []interface{}
38
39 if currentScope == nil {
40 continue
41 }
42
43 // if not preloaded
44 if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
45
46 // assign search conditions to last preload
47 if idx == len(preloadFields)-1 {
48 currentPreloadConditions = preload.conditions
49 }
50
51 for _, field := range currentFields {
52 if field.Name != preloadField || field.Relationship == nil {
53 continue
54 }
55
56 switch field.Relationship.Kind {
57 case "has_one":
58 currentScope.handleHasOnePreload(field, currentPreloadConditions)
59 case "has_many":
60 currentScope.handleHasManyPreload(field, currentPreloadConditions)
61 case "belongs_to":
62 currentScope.handleBelongsToPreload(field, currentPreloadConditions)
63 case "many_to_many":
64 currentScope.handleManyToManyPreload(field, currentPreloadConditions)
65 default:
66 scope.Err(errors.New("unsupported relation"))
67 }
68
69 preloadedMap[preloadKey] = true
70 break
71 }
72
73 if !preloadedMap[preloadKey] {
74 scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
75 return
76 }
77 }
78
79 // preload next level
80 if idx < len(preloadFields)-1 {
81 currentScope = currentScope.getColumnAsScope(preloadField)
82 if currentScope != nil {
83 currentFields = currentScope.Fields()
84 }
85 }
86 }
87 }
88 }
89
90 func autoPreload(scope *Scope) {
91 for _, field := range scope.Fields() {
92 if field.Relationship == nil {
93 continue
94 }
95
96 if val, ok := field.TagSettings["PRELOAD"]; ok {
97 if preload, err := strconv.ParseBool(val); err != nil {
98 scope.Err(errors.New("invalid preload option"))
99 return
100 } else if !preload {
101 continue
102 }
103 }
104
105 scope.Search.Preload(field.Name)
106 }
107 }
108
109 func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
110 var (
111 preloadDB = scope.NewDB()
112 preloadConditions []interface{}
113 )
114
115 for _, condition := range conditions {
116 if scopes, ok := condition.(func(*DB) *DB); ok {
117 preloadDB = scopes(preloadDB)
118 } else {
119 preloadConditions = append(preloadConditions, condition)
120 }
121 }
122
123 return preloadDB, preloadConditions
124 }
125
126 // handleHasOnePreload used to preload has one associations
127 func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
128 relation := field.Relationship
129
130 // get relations's primary keys
131 primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
132 if len(primaryKeys) == 0 {
133 return
134 }
135
136 // preload conditions
137 preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
138
139 // find relations
140 query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
141 values := toQueryValues(primaryKeys)
142 if relation.PolymorphicType != "" {
143 query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
144 values = append(values, relation.PolymorphicValue)
145 }
146
147 results := makeSlice(field.Struct.Type)
148 scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
149
150 // assign find results
151 var (
152 resultsValue = indirect(reflect.ValueOf(results))
153 indirectScopeValue = scope.IndirectValue()
154 )
155
156 if indirectScopeValue.Kind() == reflect.Slice {
157 for j := 0; j < indirectScopeValue.Len(); j++ {
158 for i := 0; i < resultsValue.Len(); i++ {
159 result := resultsValue.Index(i)
160 foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
161 if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
162 indirectValue.FieldByName(field.Name).Set(result)
163 break
164 }
165 }
166 }
167 } else {
168 for i := 0; i < resultsValue.Len(); i++ {
169 result := resultsValue.Index(i)
170 scope.Err(field.Set(result))
171 }
172 }
173 }
174
175 // handleHasManyPreload used to preload has many associations
176 func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
177 relation := field.Relationship
178
179 // get relations's primary keys
180 primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
181 if len(primaryKeys) == 0 {
182 return
183 }
184
185 // preload conditions
186 preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
187
188 // find relations
189 query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
190 values := toQueryValues(primaryKeys)
191 if relation.PolymorphicType != "" {
192 query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
193 values = append(values, relation.PolymorphicValue)
194 }
195
196 results := makeSlice(field.Struct.Type)
197 scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
198
199 // assign find results
200 var (
201 resultsValue = indirect(reflect.ValueOf(results))
202 indirectScopeValue = scope.IndirectValue()
203 )
204
205 if indirectScopeValue.Kind() == reflect.Slice {
206 preloadMap := make(map[string][]reflect.Value)
207 for i := 0; i < resultsValue.Len(); i++ {
208 result := resultsValue.Index(i)
209 foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
210 preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
211 }
212
213 for j := 0; j < indirectScopeValue.Len(); j++ {
214 object := indirect(indirectScopeValue.Index(j))
215 objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
216 f := object.FieldByName(field.Name)
217 if results, ok := preloadMap[toString(objectRealValue)]; ok {
218 f.Set(reflect.Append(f, results...))
219 } else {
220 f.Set(reflect.MakeSlice(f.Type(), 0, 0))
221 }
222 }
223 } else {
224 scope.Err(field.Set(resultsValue))
225 }
226 }
227
228 // handleBelongsToPreload used to preload belongs to associations
229 func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
230 relation := field.Relationship
231
232 // preload conditions
233 preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
234
235 // get relations's primary keys
236 primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
237 if len(primaryKeys) == 0 {
238 return
239 }
240
241 // find relations
242 results := makeSlice(field.Struct.Type)
243 scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
244
245 // assign find results
246 var (
247 resultsValue = indirect(reflect.ValueOf(results))
248 indirectScopeValue = scope.IndirectValue()
249 )
250
251 for i := 0; i < resultsValue.Len(); i++ {
252 result := resultsValue.Index(i)
253 if indirectScopeValue.Kind() == reflect.Slice {
254 value := getValueFromFields(result, relation.AssociationForeignFieldNames)
255 for j := 0; j < indirectScopeValue.Len(); j++ {
256 object := indirect(indirectScopeValue.Index(j))
257 if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
258 object.FieldByName(field.Name).Set(result)
259 }
260 }
261 } else {
262 scope.Err(field.Set(result))
263 }
264 }
265 }
266
267 // handleManyToManyPreload used to preload many to many associations
268 func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
269 var (
270 relation = field.Relationship
271 joinTableHandler = relation.JoinTableHandler
272 fieldType = field.Struct.Type.Elem()
273 foreignKeyValue interface{}
274 foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
275 linkHash = map[string][]reflect.Value{}
276 isPtr bool
277 )
278
279 if fieldType.Kind() == reflect.Ptr {
280 isPtr = true
281 fieldType = fieldType.Elem()
282 }
283
284 var sourceKeys = []string{}
285 for _, key := range joinTableHandler.SourceForeignKeys() {
286 sourceKeys = append(sourceKeys, key.DBName)
287 }
288
289 // preload conditions
290 preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
291
292 // generate query with join table
293 newScope := scope.New(reflect.New(fieldType).Interface())
294 preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
295
296 if len(preloadDB.search.selects) == 0 {
297 preloadDB = preloadDB.Select("*")
298 }
299
300 preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
301
302 // preload inline conditions
303 if len(preloadConditions) > 0 {
304 preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
305 }
306
307 rows, err := preloadDB.Rows()
308
309 if scope.Err(err) != nil {
310 return
311 }
312 defer rows.Close()
313
314 columns, _ := rows.Columns()
315 for rows.Next() {
316 var (
317 elem = reflect.New(fieldType).Elem()
318 fields = scope.New(elem.Addr().Interface()).Fields()
319 )
320
321 // register foreign keys in join tables
322 var joinTableFields []*Field
323 for _, sourceKey := range sourceKeys {
324 joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
325 }
326
327 scope.scan(rows, columns, append(fields, joinTableFields...))
328
329 scope.New(elem.Addr().Interface()).
330 InstanceSet("gorm:skip_query_callback", true).
331 callCallbacks(scope.db.parent.callbacks.queries)
332
333 var foreignKeys = make([]interface{}, len(sourceKeys))
334 // generate hashed forkey keys in join table
335 for idx, joinTableField := range joinTableFields {
336 if !joinTableField.Field.IsNil() {
337 foreignKeys[idx] = joinTableField.Field.Elem().Interface()
338 }
339 }
340 hashedSourceKeys := toString(foreignKeys)
341
342 if isPtr {
343 linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
344 } else {
345 linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
346 }
347 }
348
349 if err := rows.Err(); err != nil {
350 scope.Err(err)
351 }
352
353 // assign find results
354 var (
355 indirectScopeValue = scope.IndirectValue()
356 fieldsSourceMap = map[string][]reflect.Value{}
357 foreignFieldNames = []string{}
358 )
359
360 for _, dbName := range relation.ForeignFieldNames {
361 if field, ok := scope.FieldByName(dbName); ok {
362 foreignFieldNames = append(foreignFieldNames, field.Name)
363 }
364 }
365
366 if indirectScopeValue.Kind() == reflect.Slice {
367 for j := 0; j < indirectScopeValue.Len(); j++ {
368 object := indirect(indirectScopeValue.Index(j))
369 key := toString(getValueFromFields(object, foreignFieldNames))
370 fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
371 }
372 } else if indirectScopeValue.IsValid() {
373 key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
374 fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
375 }
376 for source, link := range linkHash {
377 for i, field := range fieldsSourceMap[source] {
378 //If not 0 this means Value is a pointer and we already added preloaded models to it
379 if fieldsSourceMap[source][i].Len() != 0 {
380 continue
381 }
382 field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
383 }
384
385 }
386 }
0 package gorm
1
2 import "database/sql"
3
4 // Define callbacks for row query
5 func init() {
6 DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback)
7 }
8
9 type RowQueryResult struct {
10 Row *sql.Row
11 }
12
13 type RowsQueryResult struct {
14 Rows *sql.Rows
15 Error error
16 }
17
18 // queryCallback used to query data from database
19 func rowQueryCallback(scope *Scope) {
20 if result, ok := scope.InstanceGet("row_query_result"); ok {
21 scope.prepareQuerySQL()
22
23 if rowResult, ok := result.(*RowQueryResult); ok {
24 rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
25 } else if rowsResult, ok := result.(*RowsQueryResult); ok {
26 rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
27 }
28 }
29 }
0 package gorm
1
2 import (
3 "reflect"
4 "strings"
5 )
6
7 func beginTransactionCallback(scope *Scope) {
8 scope.Begin()
9 }
10
11 func commitOrRollbackTransactionCallback(scope *Scope) {
12 scope.CommitOrRollback()
13 }
14
15 func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) {
16 checkTruth := func(value interface{}) bool {
17 if v, ok := value.(bool); ok && !v {
18 return false
19 }
20
21 if v, ok := value.(string); ok {
22 v = strings.ToLower(v)
23 if v == "false" || v != "skip" {
24 return false
25 }
26 }
27
28 return true
29 }
30
31 if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
32 if r = field.Relationship; r != nil {
33 autoUpdate, autoCreate, saveReference = true, true, true
34
35 if value, ok := scope.Get("gorm:save_associations"); ok {
36 autoUpdate = checkTruth(value)
37 autoCreate = autoUpdate
38 } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok {
39 autoUpdate = checkTruth(value)
40 autoCreate = autoUpdate
41 }
42
43 if value, ok := scope.Get("gorm:association_autoupdate"); ok {
44 autoUpdate = checkTruth(value)
45 } else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok {
46 autoUpdate = checkTruth(value)
47 }
48
49 if value, ok := scope.Get("gorm:association_autocreate"); ok {
50 autoCreate = checkTruth(value)
51 } else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok {
52 autoCreate = checkTruth(value)
53 }
54
55 if value, ok := scope.Get("gorm:association_save_reference"); ok {
56 saveReference = checkTruth(value)
57 } else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok {
58 saveReference = checkTruth(value)
59 }
60 }
61 }
62
63 return
64 }
65
66 func saveBeforeAssociationsCallback(scope *Scope) {
67 for _, field := range scope.Fields() {
68 autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
69
70 if relationship != nil && relationship.Kind == "belongs_to" {
71 fieldValue := field.Field.Addr().Interface()
72 newScope := scope.New(fieldValue)
73
74 if newScope.PrimaryKeyZero() {
75 if autoCreate {
76 scope.Err(scope.NewDB().Save(fieldValue).Error)
77 }
78 } else if autoUpdate {
79 scope.Err(scope.NewDB().Save(fieldValue).Error)
80 }
81
82 if saveReference {
83 if len(relationship.ForeignFieldNames) != 0 {
84 // set value's foreign key
85 for idx, fieldName := range relationship.ForeignFieldNames {
86 associationForeignName := relationship.AssociationForeignDBNames[idx]
87 if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok {
88 scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface()))
89 }
90 }
91 }
92 }
93 }
94 }
95 }
96
97 func saveAfterAssociationsCallback(scope *Scope) {
98 for _, field := range scope.Fields() {
99 autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
100
101 if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
102 value := field.Field
103
104 switch value.Kind() {
105 case reflect.Slice:
106 for i := 0; i < value.Len(); i++ {
107 newDB := scope.NewDB()
108 elem := value.Index(i).Addr().Interface()
109 newScope := newDB.NewScope(elem)
110
111 if saveReference {
112 if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
113 for idx, fieldName := range relationship.ForeignFieldNames {
114 associationForeignName := relationship.AssociationForeignDBNames[idx]
115 if f, ok := scope.FieldByName(associationForeignName); ok {
116 scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
117 }
118 }
119 }
120
121 if relationship.PolymorphicType != "" {
122 scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
123 }
124 }
125
126 if newScope.PrimaryKeyZero() {
127 if autoCreate {
128 scope.Err(newDB.Save(elem).Error)
129 }
130 } else if autoUpdate {
131 scope.Err(newDB.Save(elem).Error)
132 }
133
134 if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference {
135 if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
136 scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
137 }
138 }
139 }
140 default:
141 elem := value.Addr().Interface()
142 newScope := scope.New(elem)
143
144 if saveReference {
145 if len(relationship.ForeignFieldNames) != 0 {
146 for idx, fieldName := range relationship.ForeignFieldNames {
147 associationForeignName := relationship.AssociationForeignDBNames[idx]
148 if f, ok := scope.FieldByName(associationForeignName); ok {
149 scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
150 }
151 }
152 }
153
154 if relationship.PolymorphicType != "" {
155 scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
156 }
157 }
158
159 if newScope.PrimaryKeyZero() {
160 if autoCreate {
161 scope.Err(scope.NewDB().Save(elem).Error)
162 }
163 } else if autoUpdate {
164 scope.Err(scope.NewDB().Save(elem).Error)
165 }
166 }
167 }
168 }
169 }
+0
-91
callback_shared.go less more
0 package gorm
1
2 import "reflect"
3
4 func BeginTransaction(scope *Scope) {
5 scope.Begin()
6 }
7
8 func CommitOrRollbackTransaction(scope *Scope) {
9 scope.CommitOrRollback()
10 }
11
12 func SaveBeforeAssociations(scope *Scope) {
13 if !scope.shouldSaveAssociations() {
14 return
15 }
16 for _, field := range scope.Fields() {
17 if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
18 if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
19 value := field.Field
20 scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
21 if len(relationship.ForeignFieldNames) != 0 {
22 for idx, fieldName := range relationship.ForeignFieldNames {
23 associationForeignName := relationship.AssociationForeignDBNames[idx]
24 if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok {
25 scope.Err(scope.SetColumn(fieldName, f.Field.Interface()))
26 }
27 }
28 }
29 }
30 }
31 }
32 }
33
34 func SaveAfterAssociations(scope *Scope) {
35 if !scope.shouldSaveAssociations() {
36 return
37 }
38 for _, field := range scope.Fields() {
39 if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
40 if relationship := field.Relationship; relationship != nil &&
41 (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
42 value := field.Field
43
44 switch value.Kind() {
45 case reflect.Slice:
46 for i := 0; i < value.Len(); i++ {
47 newDB := scope.NewDB()
48 elem := value.Index(i).Addr().Interface()
49 newScope := newDB.NewScope(elem)
50
51 if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
52 for idx, fieldName := range relationship.ForeignFieldNames {
53 associationForeignName := relationship.AssociationForeignDBNames[idx]
54 if f, ok := scope.FieldByName(associationForeignName); ok {
55 scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
56 }
57 }
58 }
59
60 if relationship.PolymorphicType != "" {
61 scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName()))
62 }
63
64 scope.Err(newDB.Save(elem).Error)
65
66 if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
67 scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value))
68 }
69 }
70 default:
71 elem := value.Addr().Interface()
72 newScope := scope.New(elem)
73 if len(relationship.ForeignFieldNames) != 0 {
74 for idx, fieldName := range relationship.ForeignFieldNames {
75 associationForeignName := relationship.AssociationForeignDBNames[idx]
76 if f, ok := scope.FieldByName(associationForeignName); ok {
77 scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
78 }
79 }
80 }
81
82 if relationship.PolymorphicType != "" {
83 scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName()))
84 }
85 scope.Err(scope.NewDB().Save(elem).Error)
86 }
87 }
88 }
89 }
90 }
0 package gorm
1
2 import (
3 "reflect"
4 "runtime"
5 "strings"
6 "testing"
7 )
8
9 func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
10 var names []string
11 for _, f := range funcs {
12 fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
13 names = append(names, fnames[len(fnames)-1])
14 }
15 return reflect.DeepEqual(names, fnames)
16 }
17
18 func create(s *Scope) {}
19 func beforeCreate1(s *Scope) {}
20 func beforeCreate2(s *Scope) {}
21 func afterCreate1(s *Scope) {}
22 func afterCreate2(s *Scope) {}
23
24 func TestRegisterCallback(t *testing.T) {
25 var callback = &Callback{}
26
27 callback.Create().Register("before_create1", beforeCreate1)
28 callback.Create().Register("before_create2", beforeCreate2)
29 callback.Create().Register("create", create)
30 callback.Create().Register("after_create1", afterCreate1)
31 callback.Create().Register("after_create2", afterCreate2)
32
33 if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
34 t.Errorf("register callback")
35 }
36 }
37
38 func TestRegisterCallbackWithOrder(t *testing.T) {
39 var callback1 = &Callback{}
40 callback1.Create().Register("before_create1", beforeCreate1)
41 callback1.Create().Register("create", create)
42 callback1.Create().Register("after_create1", afterCreate1)
43 callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
44 if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
45 t.Errorf("register callback with order")
46 }
47
48 var callback2 = &Callback{}
49
50 callback2.Update().Register("create", create)
51 callback2.Update().Before("create").Register("before_create1", beforeCreate1)
52 callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
53 callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
54 callback2.Update().Register("after_create2", afterCreate2)
55
56 if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
57 t.Errorf("register callback with order")
58 }
59 }
60
61 func TestRegisterCallbackWithComplexOrder(t *testing.T) {
62 var callback1 = &Callback{}
63
64 callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
65 callback1.Query().Register("before_create1", beforeCreate1)
66 callback1.Query().Register("after_create1", afterCreate1)
67
68 if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
69 t.Errorf("register callback with order")
70 }
71
72 var callback2 = &Callback{}
73
74 callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
75 callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
76 callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
77 callback2.Delete().Register("after_create1", afterCreate1)
78 callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
79
80 if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
81 t.Errorf("register callback with order")
82 }
83 }
84
85 func replaceCreate(s *Scope) {}
86
87 func TestReplaceCallback(t *testing.T) {
88 var callback = &Callback{}
89
90 callback.Create().Before("after_create1").After("before_create1").Register("create", create)
91 callback.Create().Register("before_create1", beforeCreate1)
92 callback.Create().Register("after_create1", afterCreate1)
93 callback.Create().Replace("create", replaceCreate)
94
95 if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
96 t.Errorf("replace callback")
97 }
98 }
99
100 func TestRemoveCallback(t *testing.T) {
101 var callback = &Callback{}
102
103 callback.Create().Before("after_create1").After("before_create1").Register("create", create)
104 callback.Create().Register("before_create1", beforeCreate1)
105 callback.Create().Register("after_create1", afterCreate1)
106 callback.Create().Remove("create")
107
108 if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
109 t.Errorf("remove callback")
110 }
111 }
+0
-112
callback_test.go less more
0 package gorm
1
2 import (
3 "reflect"
4 "runtime"
5 "strings"
6 "testing"
7 )
8
9 func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
10 var names []string
11 for _, f := range funcs {
12 fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
13 names = append(names, fnames[len(fnames)-1])
14 }
15 return reflect.DeepEqual(names, fnames)
16 }
17
18 func create(s *Scope) {}
19 func beforeCreate1(s *Scope) {}
20 func beforeCreate2(s *Scope) {}
21 func afterCreate1(s *Scope) {}
22 func afterCreate2(s *Scope) {}
23
24 func TestRegisterCallback(t *testing.T) {
25 var callback = &callback{processors: []*callbackProcessor{}}
26
27 callback.Create().Register("before_create1", beforeCreate1)
28 callback.Create().Register("before_create2", beforeCreate2)
29 callback.Create().Register("create", create)
30 callback.Create().Register("after_create1", afterCreate1)
31 callback.Create().Register("after_create2", afterCreate2)
32
33 if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
34 t.Errorf("register callback")
35 }
36 }
37
38 func TestRegisterCallbackWithOrder(t *testing.T) {
39 var callback1 = &callback{processors: []*callbackProcessor{}}
40 callback1.Create().Register("before_create1", beforeCreate1)
41 callback1.Create().Register("create", create)
42 callback1.Create().Register("after_create1", afterCreate1)
43 callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
44 if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
45 t.Errorf("register callback with order")
46 }
47
48 var callback2 = &callback{processors: []*callbackProcessor{}}
49
50 callback2.Update().Register("create", create)
51 callback2.Update().Before("create").Register("before_create1", beforeCreate1)
52 callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
53 callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
54 callback2.Update().Register("after_create2", afterCreate2)
55
56 if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
57 t.Errorf("register callback with order")
58 }
59 }
60
61 func TestRegisterCallbackWithComplexOrder(t *testing.T) {
62 var callback1 = &callback{processors: []*callbackProcessor{}}
63
64 callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
65 callback1.Query().Register("before_create1", beforeCreate1)
66 callback1.Query().Register("after_create1", afterCreate1)
67
68 if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
69 t.Errorf("register callback with order")
70 }
71
72 var callback2 = &callback{processors: []*callbackProcessor{}}
73
74 callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
75 callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
76 callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
77 callback2.Delete().Register("after_create1", afterCreate1)
78 callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
79
80 if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
81 t.Errorf("register callback with order")
82 }
83 }
84
85 func replaceCreate(s *Scope) {}
86
87 func TestReplaceCallback(t *testing.T) {
88 var callback = &callback{processors: []*callbackProcessor{}}
89
90 callback.Create().Before("after_create1").After("before_create1").Register("create", create)
91 callback.Create().Register("before_create1", beforeCreate1)
92 callback.Create().Register("after_create1", afterCreate1)
93 callback.Create().Replace("create", replaceCreate)
94
95 if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
96 t.Errorf("replace callback")
97 }
98 }
99
100 func TestRemoveCallback(t *testing.T) {
101 var callback = &callback{processors: []*callbackProcessor{}}
102
103 callback.Create().Before("after_create1").After("before_create1").Register("create", create)
104 callback.Create().Register("before_create1", beforeCreate1)
105 callback.Create().Register("after_create1", afterCreate1)
106 callback.Create().Remove("create")
107
108 if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
109 t.Errorf("remove callback")
110 }
111 }
00 package gorm
11
22 import (
3 "errors"
34 "fmt"
5 "sort"
46 "strings"
57 )
68
7 func AssignUpdateAttributes(scope *Scope) {
9 // Define callbacks for updating
10 func init() {
11 DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
12 DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
13 DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
14 DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
15 DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
16 DefaultCallback.Update().Register("gorm:update", updateCallback)
17 DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
18 DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
19 DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
20 }
21
22 // assignUpdatingAttributesCallback assign updating attributes to model
23 func assignUpdatingAttributesCallback(scope *Scope) {
824 if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
9 if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
10 protected, ok := scope.Get("gorm:ignore_protected_attrs")
11 _, updateColumn := scope.Get("gorm:update_column")
12 updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool))
13
14 if updateColumn {
15 scope.InstanceSet("gorm:update_attrs", maps)
16 } else if len(updateAttrs) > 0 {
17 scope.InstanceSet("gorm:update_attrs", updateAttrs)
18 } else if !hasUpdate {
19 scope.SkipLeft()
20 return
21 }
25 if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
26 scope.InstanceSet("gorm:update_attrs", updateMaps)
27 } else {
28 scope.SkipLeft()
2229 }
2330 }
2431 }
2532
26 func BeforeUpdate(scope *Scope) {
33 // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
34 func beforeUpdateCallback(scope *Scope) {
35 if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
36 scope.Err(errors.New("Missing WHERE clause while updating"))
37 return
38 }
2739 if _, ok := scope.Get("gorm:update_column"); !ok {
28 scope.CallMethodWithErrorCheck("BeforeSave")
29 scope.CallMethodWithErrorCheck("BeforeUpdate")
40 if !scope.HasError() {
41 scope.CallMethod("BeforeSave")
42 }
43 if !scope.HasError() {
44 scope.CallMethod("BeforeUpdate")
45 }
3046 }
3147 }
3248
33 func UpdateTimeStampWhenUpdate(scope *Scope) {
49 // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
50 func updateTimeStampForUpdateCallback(scope *Scope) {
3451 if _, ok := scope.Get("gorm:update_column"); !ok {
3552 scope.SetColumn("UpdatedAt", NowFunc())
3653 }
3754 }
3855
39 func Update(scope *Scope) {
56 // updateCallback the callback used to update data to database
57 func updateCallback(scope *Scope) {
4058 if !scope.HasError() {
4159 var sqls []string
4260
4361 if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
44 for key, value := range updateAttrs.(map[string]interface{}) {
45 if scope.changeableDBColumn(key) {
46 sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
47 }
62 // Sort the column names so that the generated SQL is the same every time.
63 updateMap := updateAttrs.(map[string]interface{})
64 var columns []string
65 for c := range updateMap {
66 columns = append(columns, c)
67 }
68 sort.Strings(columns)
69
70 for _, column := range columns {
71 value := updateMap[column]
72 sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
4873 }
4974 } else {
50 fields := scope.Fields()
51 for _, field := range fields {
52 if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal {
53 sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
54 } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
55 for _, dbName := range relationship.ForeignDBNames {
56 if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank {
57 sql := fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface()))
58 sqls = append(sqls, sql)
75 for _, field := range scope.Fields() {
76 if scope.changeableField(field) {
77 if !field.IsPrimaryKey && field.IsNormal {
78 sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
79 } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
80 for _, foreignKey := range relationship.ForeignDBNames {
81 if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
82 sqls = append(sqls,
83 fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface())))
84 }
5985 }
6086 }
6187 }
6288 }
6389 }
6490
91 var extraOption string
92 if str, ok := scope.Get("gorm:update_option"); ok {
93 extraOption = fmt.Sprint(str)
94 }
95
6596 if len(sqls) > 0 {
6697 scope.Raw(fmt.Sprintf(
67 "UPDATE %v SET %v %v",
98 "UPDATE %v SET %v%v%v",
6899 scope.QuotedTableName(),
69100 strings.Join(sqls, ", "),
70 scope.CombinedConditionSql(),
71 ))
72 scope.Exec()
101 addExtraSpaceIfExist(scope.CombinedConditionSql()),
102 addExtraSpaceIfExist(extraOption),
103 )).Exec()
73104 }
74105 }
75106 }
76107
77 func AfterUpdate(scope *Scope) {
108 // afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
109 func afterUpdateCallback(scope *Scope) {
78110 if _, ok := scope.Get("gorm:update_column"); !ok {
79 scope.CallMethodWithErrorCheck("AfterUpdate")
80 scope.CallMethodWithErrorCheck("AfterSave")
111 if !scope.HasError() {
112 scope.CallMethod("AfterUpdate")
113 }
114 if !scope.HasError() {
115 scope.CallMethod("AfterSave")
116 }
81117 }
82118 }
83
84 func init() {
85 DefaultCallback.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes)
86 DefaultCallback.Update().Register("gorm:begin_transaction", BeginTransaction)
87 DefaultCallback.Update().Register("gorm:before_update", BeforeUpdate)
88 DefaultCallback.Update().Register("gorm:save_before_associations", SaveBeforeAssociations)
89 DefaultCallback.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate)
90 DefaultCallback.Update().Register("gorm:update", Update)
91 DefaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations)
92 DefaultCallback.Update().Register("gorm:after_update", AfterUpdate)
93 DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
94 }
+0
-117
common_dialect.go less more
0 package gorm
1
2 import (
3 "fmt"
4 "reflect"
5 "time"
6 )
7
8 type commonDialect struct{}
9
10 func (commonDialect) BinVar(i int) string {
11 return "$$" // ?
12 }
13
14 func (commonDialect) SupportLastInsertId() bool {
15 return true
16 }
17
18 func (commonDialect) HasTop() bool {
19 return false
20 }
21
22 func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
23 switch value.Kind() {
24 case reflect.Bool:
25 return "BOOLEAN"
26 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
27 if autoIncrease {
28 return "INTEGER AUTO_INCREMENT"
29 }
30 return "INTEGER"
31 case reflect.Int64, reflect.Uint64:
32 if autoIncrease {
33 return "BIGINT AUTO_INCREMENT"
34 }
35 return "BIGINT"
36 case reflect.Float32, reflect.Float64:
37 return "FLOAT"
38 case reflect.String:
39 if size > 0 && size < 65532 {
40 return fmt.Sprintf("VARCHAR(%d)", size)
41 }
42 return "VARCHAR(65532)"
43 case reflect.Struct:
44 if _, ok := value.Interface().(time.Time); ok {
45 return "TIMESTAMP"
46 }
47 default:
48 if _, ok := value.Interface().([]byte); ok {
49 if size > 0 && size < 65532 {
50 return fmt.Sprintf("BINARY(%d)", size)
51 }
52 return "BINARY(65532)"
53 }
54 }
55 panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
56 }
57
58 func (commonDialect) ReturningStr(tableName, key string) string {
59 return ""
60 }
61
62 func (commonDialect) SelectFromDummyTable() string {
63 return ""
64 }
65
66 func (commonDialect) Quote(key string) string {
67 return fmt.Sprintf(`"%s"`, key)
68 }
69
70 func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
71 var (
72 count int
73 databaseName = c.CurrentDatabase(scope)
74 )
75 c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", databaseName, tableName)
76 return count > 0
77 }
78
79 func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
80 var (
81 count int
82 databaseName = c.CurrentDatabase(scope)
83 )
84 c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
85 return count > 0
86 }
87
88 func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
89 var (
90 count int
91 databaseName = c.CurrentDatabase(scope)
92 )
93 c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName)
94 return count > 0
95 }
96
97 func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
98 scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error)
99 }
100
101 // RawScanInt scans the first column of the first row into the `scan' int pointer.
102 // This function captures raw query errors and propagates them to the original scope.
103 func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) {
104 scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
105 }
106
107 // RawScanString scans the first column of the first row into the `scan' string pointer.
108 // This function captures raw query errors and propagates them to the original scope.
109 func (commonDialect) RawScanString(scope *Scope, scanPtr *string, query string, args ...interface{}) {
110 scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
111 }
112
113 func (commonDialect) CurrentDatabase(scope *Scope) (name string) {
114 scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
115 return
116 }
00 package gorm_test
11
22 import (
3 "os"
34 "reflect"
45 "testing"
56 "time"
7
8 "github.com/jinzhu/now"
69 )
710
811 func TestCreate(t *testing.T) {
912 float := 35.03554004971999
10 user := User{Name: "CreateUser", Age: 18, Birthday: time.Now(), UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float}
13 now := time.Now()
14 user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float}
1115
1216 if !DB.NewRecord(user) || !DB.NewRecord(&user) {
1317 t.Error("User should be new record before create")
2226 }
2327
2428 var newUser User
25 DB.First(&newUser, user.Id)
29 if err := DB.First(&newUser, user.Id).Error; err != nil {
30 t.Errorf("No error should happen, but got %v", err)
31 }
2632
2733 if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) {
2834 t.Errorf("User's PasswordHash should be saved ([]byte)")
3339 }
3440
3541 if newUser.UserNum != Num(111) {
36 t.Errorf("User's UserNum should be saved (custom type)")
42 t.Errorf("User's UserNum should be saved (custom type), but got %v", newUser.UserNum)
3743 }
3844
3945 if newUser.Latitude != float {
5056
5157 DB.Model(user).Update("name", "create_user_new_name")
5258 DB.First(&user, user.Id)
53 if user.CreatedAt != newUser.CreatedAt {
59 if user.CreatedAt.Format(time.RFC3339Nano) != newUser.CreatedAt.Format(time.RFC3339Nano) {
5460 t.Errorf("CreatedAt should not be changed after update")
5561 }
5662 }
5763
64 func TestCreateEmptyStrut(t *testing.T) {
65 type EmptyStruct struct {
66 ID uint
67 }
68 DB.AutoMigrate(&EmptyStruct{})
69
70 if err := DB.Create(&EmptyStruct{}).Error; err != nil {
71 t.Errorf("No error should happen when creating user, but got %v", err)
72 }
73 }
74
75 func TestCreateWithExistingTimestamp(t *testing.T) {
76 user := User{Name: "CreateUserExistingTimestamp"}
77
78 timeA := now.MustParse("2016-01-01")
79 user.CreatedAt = timeA
80 user.UpdatedAt = timeA
81 DB.Save(&user)
82
83 if user.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
84 t.Errorf("CreatedAt should not be changed")
85 }
86
87 if user.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
88 t.Errorf("UpdatedAt should not be changed")
89 }
90
91 var newUser User
92 DB.First(&newUser, user.Id)
93
94 if newUser.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
95 t.Errorf("CreatedAt should not be changed")
96 }
97
98 if newUser.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
99 t.Errorf("UpdatedAt should not be changed")
100 }
101 }
102
103 type AutoIncrementUser struct {
104 User
105 Sequence uint `gorm:"AUTO_INCREMENT"`
106 }
107
108 func TestCreateWithAutoIncrement(t *testing.T) {
109 if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
110 t.Skip("Skipping this because only postgres properly support auto_increment on a non-primary_key column")
111 }
112
113 DB.AutoMigrate(&AutoIncrementUser{})
114
115 user1 := AutoIncrementUser{}
116 user2 := AutoIncrementUser{}
117
118 DB.Create(&user1)
119 DB.Create(&user2)
120
121 if user2.Sequence-user1.Sequence != 1 {
122 t.Errorf("Auto increment should apply on Sequence")
123 }
124 }
125
58126 func TestCreateWithNoGORMPrimayKey(t *testing.T) {
127 if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" {
128 t.Skip("Skipping this because MSSQL will return identity only if the table has an Id column")
129 }
130
59131 jt := JoinTable{From: 1, To: 2}
60132 err := DB.Create(&jt).Error
61133 if err != nil {
86158
87159 // We must fetch the value again, to have the default fields updated
88160 // (We can't do this in the update statements, since sql default can be expressions
89 // And be different from the fields' type (eg. a time.Time fiels has a default value of "now()"
161 // And be different from the fields' type (eg. a time.Time fields has a default value of "now()"
90162 DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an)
91163
92164 if an.Name != "galeone" {
104176 t.Errorf("Should be able to get anonymous scanner")
105177 }
106178
107 if !user2.IsAdmin() {
179 if !user2.Role.IsAdmin() {
108180 t.Errorf("Should be able to get anonymous scanner")
109181 }
110182 }
153225
154226 if queryuser.BillingAddressID.Int64 != 0 || queryuser.ShippingAddressId == 0 ||
155227 queryuser.CreditCard.ID != 0 || len(queryuser.Emails) != 0 {
156 t.Errorf("Should not create omited relationships")
157 }
158 }
228 t.Errorf("Should not create omitted relationships")
229 }
230 }
22 import (
33 "testing"
44 "time"
5
6 "github.com/jinzhu/gorm"
57 )
68
79 type CustomizeColumn struct {
8 ID int64 `gorm:"column:mapped_id; primary_key:yes"`
9 Name string `gorm:"column:mapped_name"`
10 Date time.Time `gorm:"column:mapped_time"`
10 ID int64 `gorm:"column:mapped_id; primary_key:yes"`
11 Name string `gorm:"column:mapped_name"`
12 Date *time.Time `gorm:"column:mapped_time"`
1113 }
1214
1315 // Make sure an ignored field does not interfere with another field's custom
2325 DB.AutoMigrate(&CustomizeColumn{})
2426
2527 scope := DB.NewScope(&CustomizeColumn{})
26 if !scope.Dialect().HasColumn(scope, scope.TableName(), col) {
28 if !scope.Dialect().HasColumn(scope.TableName(), col) {
2729 t.Errorf("CustomizeColumn should have column %s", col)
2830 }
2931
3335 }
3436
3537 expected := "foo"
36 cc := CustomizeColumn{ID: 666, Name: expected, Date: time.Now()}
38 now := time.Now()
39 cc := CustomizeColumn{ID: 666, Name: expected, Date: &now}
3740
3841 if count := DB.Create(&cc).RowsAffected; count != 1 {
3942 t.Error("There should be one record be affected when create record")
6265 t.Errorf("Should not raise error: %s", err)
6366 }
6467 }
68
69 type CustomizePerson struct {
70 IdPerson string `gorm:"column:idPerson;primary_key:true"`
71 Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"`
72 }
73
74 type CustomizeAccount struct {
75 IdAccount string `gorm:"column:idAccount;primary_key:true"`
76 Name string
77 }
78
79 func TestManyToManyWithCustomizedColumn(t *testing.T) {
80 DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount")
81 DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{})
82
83 account := CustomizeAccount{IdAccount: "account", Name: "id1"}
84 person := CustomizePerson{
85 IdPerson: "person",
86 Accounts: []CustomizeAccount{account},
87 }
88
89 if err := DB.Create(&account).Error; err != nil {
90 t.Errorf("no error should happen, but got %v", err)
91 }
92
93 if err := DB.Create(&person).Error; err != nil {
94 t.Errorf("no error should happen, but got %v", err)
95 }
96
97 var person1 CustomizePerson
98 scope := DB.NewScope(nil)
99 if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil {
100 t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err)
101 }
102
103 if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" {
104 t.Errorf("should preload correct accounts")
105 }
106 }
107
108 type CustomizeUser struct {
109 gorm.Model
110 Email string `sql:"column:email_address"`
111 }
112
113 type CustomizeInvitation struct {
114 gorm.Model
115 Address string `sql:"column:invitation"`
116 Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"`
117 }
118
119 func TestOneToOneWithCustomizedColumn(t *testing.T) {
120 DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{})
121 DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{})
122
123 user := CustomizeUser{
124 Email: "hello@example.com",
125 }
126 invitation := CustomizeInvitation{
127 Address: "hello@example.com",
128 }
129
130 DB.Create(&user)
131 DB.Create(&invitation)
132
133 var invitation2 CustomizeInvitation
134 if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil {
135 t.Errorf("no error should happen, but got %v", err)
136 }
137
138 if invitation2.Person.Email != user.Email {
139 t.Errorf("Should preload one to one relation with customize foreign keys")
140 }
141 }
142
143 type PromotionDiscount struct {
144 gorm.Model
145 Name string
146 Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"`
147 Rule *PromotionRule `gorm:"ForeignKey:discount_id"`
148 Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"`
149 }
150
151 type PromotionBenefit struct {
152 gorm.Model
153 Name string
154 PromotionID uint
155 Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"`
156 }
157
158 type PromotionCoupon struct {
159 gorm.Model
160 Code string
161 DiscountID uint
162 Discount PromotionDiscount
163 }
164
165 type PromotionRule struct {
166 gorm.Model
167 Name string
168 Begin *time.Time
169 End *time.Time
170 DiscountID uint
171 Discount *PromotionDiscount
172 }
173
174 func TestOneToManyWithCustomizedColumn(t *testing.T) {
175 DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{})
176 DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{})
177
178 discount := PromotionDiscount{
179 Name: "Happy New Year",
180 Coupons: []*PromotionCoupon{
181 {Code: "newyear1"},
182 {Code: "newyear2"},
183 },
184 }
185
186 if err := DB.Create(&discount).Error; err != nil {
187 t.Errorf("no error should happen but got %v", err)
188 }
189
190 var discount1 PromotionDiscount
191 if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil {
192 t.Errorf("no error should happen but got %v", err)
193 }
194
195 if len(discount.Coupons) != 2 {
196 t.Errorf("should find two coupons")
197 }
198
199 var coupon PromotionCoupon
200 if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil {
201 t.Errorf("no error should happen but got %v", err)
202 }
203
204 if coupon.Discount.Name != "Happy New Year" {
205 t.Errorf("should preload discount from coupon")
206 }
207 }
208
209 func TestHasOneWithPartialCustomizedColumn(t *testing.T) {
210 DB.DropTable(&PromotionDiscount{}, &PromotionRule{})
211 DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{})
212
213 var begin = time.Now()
214 var end = time.Now().Add(24 * time.Hour)
215 discount := PromotionDiscount{
216 Name: "Happy New Year 2",
217 Rule: &PromotionRule{
218 Name: "time_limited",
219 Begin: &begin,
220 End: &end,
221 },
222 }
223
224 if err := DB.Create(&discount).Error; err != nil {
225 t.Errorf("no error should happen but got %v", err)
226 }
227
228 var discount1 PromotionDiscount
229 if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil {
230 t.Errorf("no error should happen but got %v", err)
231 }
232
233 if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) {
234 t.Errorf("Should be able to preload Rule")
235 }
236
237 var rule PromotionRule
238 if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil {
239 t.Errorf("no error should happen but got %v", err)
240 }
241
242 if rule.Discount.Name != "Happy New Year 2" {
243 t.Errorf("should preload discount from rule")
244 }
245 }
246
247 func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
248 DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{})
249 DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{})
250
251 discount := PromotionDiscount{
252 Name: "Happy New Year 3",
253 Benefits: []PromotionBenefit{
254 {Name: "free cod"},
255 {Name: "free shipping"},
256 },
257 }
258
259 if err := DB.Create(&discount).Error; err != nil {
260 t.Errorf("no error should happen but got %v", err)
261 }
262
263 var discount1 PromotionDiscount
264 if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil {
265 t.Errorf("no error should happen but got %v", err)
266 }
267
268 if len(discount.Benefits) != 2 {
269 t.Errorf("should find two benefits")
270 }
271
272 var benefit PromotionBenefit
273 if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil {
274 t.Errorf("no error should happen but got %v", err)
275 }
276
277 if benefit.Discount.Name != "Happy New Year 3" {
278 t.Errorf("should preload discount from coupon")
279 }
280 }
281
282 type SelfReferencingUser struct {
283 gorm.Model
284 Name string
285 Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"`
286 }
287
288 func TestSelfReferencingMany2ManyColumn(t *testing.T) {
289 DB.DropTable(&SelfReferencingUser{}, "UserFriends")
290 DB.AutoMigrate(&SelfReferencingUser{})
291
292 friend1 := SelfReferencingUser{Name: "friend1_m2m"}
293 if err := DB.Create(&friend1).Error; err != nil {
294 t.Errorf("no error should happen, but got %v", err)
295 }
296
297 friend2 := SelfReferencingUser{Name: "friend2_m2m"}
298 if err := DB.Create(&friend2).Error; err != nil {
299 t.Errorf("no error should happen, but got %v", err)
300 }
301
302 user := SelfReferencingUser{
303 Name: "self_m2m",
304 Friends: []*SelfReferencingUser{&friend1, &friend2},
305 }
306
307 if err := DB.Create(&user).Error; err != nil {
308 t.Errorf("no error should happen, but got %v", err)
309 }
310
311 if DB.Model(&user).Association("Friends").Count() != 2 {
312 t.Errorf("Should find created friends correctly")
313 }
314
315 var newUser = SelfReferencingUser{}
316
317 if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil {
318 t.Errorf("no error should happen, but got %v", err)
319 }
320
321 if len(newUser.Friends) != 2 {
322 t.Errorf("Should preload created frineds for self reference m2m")
323 }
324
325 DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"})
326 if DB.Model(&user).Association("Friends").Count() != 3 {
327 t.Errorf("Should find created friends correctly")
328 }
329
330 DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"})
331 if DB.Model(&user).Association("Friends").Count() != 1 {
332 t.Errorf("Should find created friends correctly")
333 }
334
335 friend := SelfReferencingUser{}
336 DB.Model(&newUser).Association("Friends").Find(&friend)
337 if friend.Name != "friend4_m2m" {
338 t.Errorf("Should find created friends correctly")
339 }
340
341 DB.Model(&newUser).Association("Friends").Delete(friend)
342 if DB.Model(&user).Association("Friends").Count() != 0 {
343 t.Errorf("All friends should be deleted")
344 }
345 }
+0
-24
ddl_errors_test.go less more
0 package gorm_test
1
2 import (
3 "testing"
4 )
5
6 func TestDdlErrors(t *testing.T) {
7 var err error
8
9 if err = DB.Close(); err != nil {
10 t.Errorf("Closing DDL test db connection err=%s", err)
11 }
12 defer func() {
13 // Reopen DB connection.
14 if DB, err = OpenTestConnection(); err != nil {
15 t.Fatalf("Failed re-opening db connection: %s", err)
16 }
17 }()
18
19 DB.HasTable("foobarbaz")
20 if DB.Error == nil {
21 t.Errorf("Expected operation on closed db to produce an error, but err was nil")
22 }
23 }
4444 type User struct {
4545 Id int64
4646 Name string
47 DeletedAt time.Time
47 DeletedAt *time.Time
4848 }
4949 DB.AutoMigrate(&User{})
5050
6565 t.Errorf("Can't find permanently deleted record")
6666 }
6767 }
68
69 func TestSoftDeleteWithCustomizedDeletedAtColumnName(t *testing.T) {
70 creditCard := CreditCard{Number: "411111111234567"}
71 DB.Save(&creditCard)
72 DB.Delete(&creditCard)
73
74 if deletedAtField, ok := DB.NewScope(&CreditCard{}).FieldByName("DeletedAt"); !ok || deletedAtField.DBName != "deleted_time" {
75 t.Errorf("CreditCard's DeletedAt's column name should be `deleted_time`")
76 }
77
78 if DB.First(&CreditCard{}, "number = ?", creditCard.Number).Error == nil {
79 t.Errorf("Can't find a soft deleted record")
80 }
81
82 if err := DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).Error; err != nil {
83 t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err)
84 }
85
86 DB.Unscoped().Delete(&creditCard)
87 if !DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).RecordNotFound() {
88 t.Errorf("Can't find permanently deleted record")
89 }
90 }
00 package gorm
11
22 import (
3 "database/sql"
34 "fmt"
45 "reflect"
6 "strconv"
7 "strings"
58 )
69
10 // Dialect interface contains behaviors that differ across SQL database
711 type Dialect interface {
8 BinVar(i int) string
9 SupportLastInsertId() bool
10 HasTop() bool
11 SqlTag(value reflect.Value, size int, autoIncrease bool) string
12 ReturningStr(tableName, key string) string
12 // GetName get dialect's name
13 GetName() string
14
15 // SetDB set db for dialect
16 SetDB(db SQLCommon)
17
18 // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
19 BindVar(i int) string
20 // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
21 Quote(key string) string
22 // DataTypeOf return data's sql type
23 DataTypeOf(field *StructField) string
24
25 // HasIndex check has index or not
26 HasIndex(tableName string, indexName string) bool
27 // HasForeignKey check has foreign key or not
28 HasForeignKey(tableName string, foreignKeyName string) bool
29 // RemoveIndex remove index
30 RemoveIndex(tableName string, indexName string) error
31 // HasTable check has table or not
32 HasTable(tableName string) bool
33 // HasColumn check has column or not
34 HasColumn(tableName string, columnName string) bool
35 // ModifyColumn modify column's type
36 ModifyColumn(tableName string, columnName string, typ string) error
37
38 // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
39 LimitAndOffsetSQL(limit, offset interface{}) string
40 // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
1341 SelectFromDummyTable() string
14 Quote(key string) string
15 HasTable(scope *Scope, tableName string) bool
16 HasColumn(scope *Scope, tableName string, columnName string) bool
17 HasIndex(scope *Scope, tableName string, indexName string) bool
18 RemoveIndex(scope *Scope, indexName string)
19 CurrentDatabase(scope *Scope) string
42 // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
43 LastInsertIDReturningSuffix(tableName, columnName string) string
44 // DefaultValueStr
45 DefaultValueStr() string
46
47 // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
48 BuildKeyName(kind, tableName string, fields ...string) string
49
50 // CurrentDatabase return current database name
51 CurrentDatabase() string
2052 }
2153
22 func NewDialect(driver string) Dialect {
23 var d Dialect
24 switch driver {
25 case "postgres":
26 d = &postgres{}
27 case "foundation":
28 d = &foundation{}
29 case "mysql":
30 d = &mysql{}
31 case "sqlite3":
32 d = &sqlite3{}
33 case "mssql":
34 d = &mssql{}
35 default:
36 fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver)
37 d = &commonDialect{}
54 var dialectsMap = map[string]Dialect{}
55
56 func newDialect(name string, db SQLCommon) Dialect {
57 if value, ok := dialectsMap[name]; ok {
58 dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
59 dialect.SetDB(db)
60 return dialect
3861 }
39 return d
62
63 fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
64 commontDialect := &commonDialect{}
65 commontDialect.SetDB(db)
66 return commontDialect
4067 }
68
69 // RegisterDialect register new dialect
70 func RegisterDialect(name string, dialect Dialect) {
71 dialectsMap[name] = dialect
72 }
73
74 // ParseFieldStructForDialect get field's sql data type
75 var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
76 // Get redirected field type
77 var (
78 reflectType = field.Struct.Type
79 dataType = field.TagSettings["TYPE"]
80 )
81
82 for reflectType.Kind() == reflect.Ptr {
83 reflectType = reflectType.Elem()
84 }
85
86 // Get redirected field value
87 fieldValue = reflect.Indirect(reflect.New(reflectType))
88
89 if gormDataType, ok := fieldValue.Interface().(interface {
90 GormDataType(Dialect) string
91 }); ok {
92 dataType = gormDataType.GormDataType(dialect)
93 }
94
95 // Get scanner's real value
96 if dataType == "" {
97 var getScannerValue func(reflect.Value)
98 getScannerValue = func(value reflect.Value) {
99 fieldValue = value
100 if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct {
101 getScannerValue(fieldValue.Field(0))
102 }
103 }
104 getScannerValue(fieldValue)
105 }
106
107 // Default Size
108 if num, ok := field.TagSettings["SIZE"]; ok {
109 size, _ = strconv.Atoi(num)
110 } else {
111 size = 255
112 }
113
114 // Default type from tag setting
115 additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
116 if value, ok := field.TagSettings["DEFAULT"]; ok {
117 additionalType = additionalType + " DEFAULT " + value
118 }
119
120 return fieldValue, dataType, size, strings.TrimSpace(additionalType)
121 }
122
123 func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) {
124 if strings.Contains(tableName, ".") {
125 splitStrings := strings.SplitN(tableName, ".", 2)
126 return splitStrings[0], splitStrings[1]
127 }
128 return dialect.CurrentDatabase(), tableName
129 }
0 package gorm
1
2 import (
3 "fmt"
4 "reflect"
5 "regexp"
6 "strconv"
7 "strings"
8 "time"
9 )
10
11 // DefaultForeignKeyNamer contains the default foreign key name generator method
12 type DefaultForeignKeyNamer struct {
13 }
14
15 type commonDialect struct {
16 db SQLCommon
17 DefaultForeignKeyNamer
18 }
19
20 func init() {
21 RegisterDialect("common", &commonDialect{})
22 }
23
24 func (commonDialect) GetName() string {
25 return "common"
26 }
27
28 func (s *commonDialect) SetDB(db SQLCommon) {
29 s.db = db
30 }
31
32 func (commonDialect) BindVar(i int) string {
33 return "$$$" // ?
34 }
35
36 func (commonDialect) Quote(key string) string {
37 return fmt.Sprintf(`"%s"`, key)
38 }
39
40 func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool {
41 if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
42 return strings.ToLower(value) != "false"
43 }
44 return field.IsPrimaryKey
45 }
46
47 func (s *commonDialect) DataTypeOf(field *StructField) string {
48 var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
49
50 if sqlType == "" {
51 switch dataValue.Kind() {
52 case reflect.Bool:
53 sqlType = "BOOLEAN"
54 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
55 if s.fieldCanAutoIncrement(field) {
56 sqlType = "INTEGER AUTO_INCREMENT"
57 } else {
58 sqlType = "INTEGER"
59 }
60 case reflect.Int64, reflect.Uint64:
61 if s.fieldCanAutoIncrement(field) {
62 sqlType = "BIGINT AUTO_INCREMENT"
63 } else {
64 sqlType = "BIGINT"
65 }
66 case reflect.Float32, reflect.Float64:
67 sqlType = "FLOAT"
68 case reflect.String:
69 if size > 0 && size < 65532 {
70 sqlType = fmt.Sprintf("VARCHAR(%d)", size)
71 } else {
72 sqlType = "VARCHAR(65532)"
73 }
74 case reflect.Struct:
75 if _, ok := dataValue.Interface().(time.Time); ok {
76 sqlType = "TIMESTAMP"
77 }
78 default:
79 if _, ok := dataValue.Interface().([]byte); ok {
80 if size > 0 && size < 65532 {
81 sqlType = fmt.Sprintf("BINARY(%d)", size)
82 } else {
83 sqlType = "BINARY(65532)"
84 }
85 }
86 }
87 }
88
89 if sqlType == "" {
90 panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
91 }
92
93 if strings.TrimSpace(additionalType) == "" {
94 return sqlType
95 }
96 return fmt.Sprintf("%v %v", sqlType, additionalType)
97 }
98
99 func (s commonDialect) HasIndex(tableName string, indexName string) bool {
100 var count int
101 currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
102 s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
103 return count > 0
104 }
105
106 func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
107 _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
108 return err
109 }
110
111 func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
112 return false
113 }
114
115 func (s commonDialect) HasTable(tableName string) bool {
116 var count int
117 currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
118 s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
119 return count > 0
120 }
121
122 func (s commonDialect) HasColumn(tableName string, columnName string) bool {
123 var count int
124 currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
125 s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
126 return count > 0
127 }
128
129 func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
130 _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
131 return err
132 }
133
134 func (s commonDialect) CurrentDatabase() (name string) {
135 s.db.QueryRow("SELECT DATABASE()").Scan(&name)
136 return
137 }
138
139 func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
140 if limit != nil {
141 if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
142 sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
143 }
144 }
145 if offset != nil {
146 if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
147 sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
148 }
149 }
150 return
151 }
152
153 func (commonDialect) SelectFromDummyTable() string {
154 return ""
155 }
156
157 func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
158 return ""
159 }
160
161 func (commonDialect) DefaultValueStr() string {
162 return "DEFAULT VALUES"
163 }
164
165 // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
166 func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
167 keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
168 keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_")
169 return keyName
170 }
171
172 // IsByteArrayOrSlice returns true of the reflected value is an array or slice
173 func IsByteArrayOrSlice(value reflect.Value) bool {
174 return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
175 }
0 package gorm
1
2 import (
3 "crypto/sha1"
4 "fmt"
5 "reflect"
6 "regexp"
7 "strconv"
8 "strings"
9 "time"
10 "unicode/utf8"
11 )
12
13 type mysql struct {
14 commonDialect
15 }
16
17 func init() {
18 RegisterDialect("mysql", &mysql{})
19 }
20
21 func (mysql) GetName() string {
22 return "mysql"
23 }
24
25 func (mysql) Quote(key string) string {
26 return fmt.Sprintf("`%s`", key)
27 }
28
29 // Get Data Type for MySQL Dialect
30 func (s *mysql) DataTypeOf(field *StructField) string {
31 var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
32
33 // MySQL allows only one auto increment column per table, and it must
34 // be a KEY column.
35 if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
36 if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey {
37 delete(field.TagSettings, "AUTO_INCREMENT")
38 }
39 }
40
41 if sqlType == "" {
42 switch dataValue.Kind() {
43 case reflect.Bool:
44 sqlType = "boolean"
45 case reflect.Int8:
46 if s.fieldCanAutoIncrement(field) {
47 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
48 sqlType = "tinyint AUTO_INCREMENT"
49 } else {
50 sqlType = "tinyint"
51 }
52 case reflect.Int, reflect.Int16, reflect.Int32:
53 if s.fieldCanAutoIncrement(field) {
54 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
55 sqlType = "int AUTO_INCREMENT"
56 } else {
57 sqlType = "int"
58 }
59 case reflect.Uint8:
60 if s.fieldCanAutoIncrement(field) {
61 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
62 sqlType = "tinyint unsigned AUTO_INCREMENT"
63 } else {
64 sqlType = "tinyint unsigned"
65 }
66 case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
67 if s.fieldCanAutoIncrement(field) {
68 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
69 sqlType = "int unsigned AUTO_INCREMENT"
70 } else {
71 sqlType = "int unsigned"
72 }
73 case reflect.Int64:
74 if s.fieldCanAutoIncrement(field) {
75 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
76 sqlType = "bigint AUTO_INCREMENT"
77 } else {
78 sqlType = "bigint"
79 }
80 case reflect.Uint64:
81 if s.fieldCanAutoIncrement(field) {
82 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
83 sqlType = "bigint unsigned AUTO_INCREMENT"
84 } else {
85 sqlType = "bigint unsigned"
86 }
87 case reflect.Float32, reflect.Float64:
88 sqlType = "double"
89 case reflect.String:
90 if size > 0 && size < 65532 {
91 sqlType = fmt.Sprintf("varchar(%d)", size)
92 } else {
93 sqlType = "longtext"
94 }
95 case reflect.Struct:
96 if _, ok := dataValue.Interface().(time.Time); ok {
97 precision := ""
98 if p, ok := field.TagSettings["PRECISION"]; ok {
99 precision = fmt.Sprintf("(%s)", p)
100 }
101
102 if _, ok := field.TagSettings["NOT NULL"]; ok {
103 sqlType = fmt.Sprintf("timestamp%v", precision)
104 } else {
105 sqlType = fmt.Sprintf("timestamp%v NULL", precision)
106 }
107 }
108 default:
109 if IsByteArrayOrSlice(dataValue) {
110 if size > 0 && size < 65532 {
111 sqlType = fmt.Sprintf("varbinary(%d)", size)
112 } else {
113 sqlType = "longblob"
114 }
115 }
116 }
117 }
118
119 if sqlType == "" {
120 panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
121 }
122
123 if strings.TrimSpace(additionalType) == "" {
124 return sqlType
125 }
126 return fmt.Sprintf("%v %v", sqlType, additionalType)
127 }
128
129 func (s mysql) RemoveIndex(tableName string, indexName string) error {
130 _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
131 return err
132 }
133
134 func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error {
135 _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
136 return err
137 }
138
139 func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
140 if limit != nil {
141 if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
142 sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
143
144 if offset != nil {
145 if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
146 sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
147 }
148 }
149 }
150 }
151 return
152 }
153
154 func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
155 var count int
156 currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
157 s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
158 return count > 0
159 }
160
161 func (s mysql) CurrentDatabase() (name string) {
162 s.db.QueryRow("SELECT DATABASE()").Scan(&name)
163 return
164 }
165
166 func (mysql) SelectFromDummyTable() string {
167 return "FROM DUAL"
168 }
169
170 func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string {
171 keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...)
172 if utf8.RuneCountInString(keyName) <= 64 {
173 return keyName
174 }
175 h := sha1.New()
176 h.Write([]byte(keyName))
177 bs := h.Sum(nil)
178
179 // sha1 is 40 characters, keep first 24 characters of destination
180 destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_"))
181 if len(destRunes) > 24 {
182 destRunes = destRunes[:24]
183 }
184
185 return fmt.Sprintf("%s%x", string(destRunes), bs)
186 }
187
188 func (mysql) DefaultValueStr() string {
189 return "VALUES()"
190 }
0 package gorm
1
2 import (
3 "encoding/json"
4 "fmt"
5 "reflect"
6 "strings"
7 "time"
8 )
9
10 type postgres struct {
11 commonDialect
12 }
13
14 func init() {
15 RegisterDialect("postgres", &postgres{})
16 RegisterDialect("cloudsqlpostgres", &postgres{})
17 }
18
19 func (postgres) GetName() string {
20 return "postgres"
21 }
22
23 func (postgres) BindVar(i int) string {
24 return fmt.Sprintf("$%v", i)
25 }
26
27 func (s *postgres) DataTypeOf(field *StructField) string {
28 var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
29
30 if sqlType == "" {
31 switch dataValue.Kind() {
32 case reflect.Bool:
33 sqlType = "boolean"
34 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
35 if s.fieldCanAutoIncrement(field) {
36 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
37 sqlType = "serial"
38 } else {
39 sqlType = "integer"
40 }
41 case reflect.Int64, reflect.Uint32, reflect.Uint64:
42 if s.fieldCanAutoIncrement(field) {
43 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
44 sqlType = "bigserial"
45 } else {
46 sqlType = "bigint"
47 }
48 case reflect.Float32, reflect.Float64:
49 sqlType = "numeric"
50 case reflect.String:
51 if _, ok := field.TagSettings["SIZE"]; !ok {
52 size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
53 }
54
55 if size > 0 && size < 65532 {
56 sqlType = fmt.Sprintf("varchar(%d)", size)
57 } else {
58 sqlType = "text"
59 }
60 case reflect.Struct:
61 if _, ok := dataValue.Interface().(time.Time); ok {
62 sqlType = "timestamp with time zone"
63 }
64 case reflect.Map:
65 if dataValue.Type().Name() == "Hstore" {
66 sqlType = "hstore"
67 }
68 default:
69 if IsByteArrayOrSlice(dataValue) {
70 sqlType = "bytea"
71
72 if isUUID(dataValue) {
73 sqlType = "uuid"
74 }
75
76 if isJSON(dataValue) {
77 sqlType = "jsonb"
78 }
79 }
80 }
81 }
82
83 if sqlType == "" {
84 panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String()))
85 }
86
87 if strings.TrimSpace(additionalType) == "" {
88 return sqlType
89 }
90 return fmt.Sprintf("%v %v", sqlType, additionalType)
91 }
92
93 func (s postgres) HasIndex(tableName string, indexName string) bool {
94 var count int
95 s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
96 return count > 0
97 }
98
99 func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
100 var count int
101 s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
102 return count > 0
103 }
104
105 func (s postgres) HasTable(tableName string) bool {
106 var count int
107 s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
108 return count > 0
109 }
110
111 func (s postgres) HasColumn(tableName string, columnName string) bool {
112 var count int
113 s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
114 return count > 0
115 }
116
117 func (s postgres) CurrentDatabase() (name string) {
118 s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
119 return
120 }
121
122 func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
123 return fmt.Sprintf("RETURNING %v.%v", tableName, key)
124 }
125
126 func (postgres) SupportLastInsertID() bool {
127 return false
128 }
129
130 func isUUID(value reflect.Value) bool {
131 if value.Kind() != reflect.Array || value.Type().Len() != 16 {
132 return false
133 }
134 typename := value.Type().Name()
135 lower := strings.ToLower(typename)
136 return "uuid" == lower || "guid" == lower
137 }
138
139 func isJSON(value reflect.Value) bool {
140 _, ok := value.Interface().(json.RawMessage)
141 return ok
142 }
0 package gorm
1
2 import (
3 "fmt"
4 "reflect"
5 "strings"
6 "time"
7 )
8
9 type sqlite3 struct {
10 commonDialect
11 }
12
13 func init() {
14 RegisterDialect("sqlite3", &sqlite3{})
15 }
16
17 func (sqlite3) GetName() string {
18 return "sqlite3"
19 }
20
21 // Get Data Type for Sqlite Dialect
22 func (s *sqlite3) DataTypeOf(field *StructField) string {
23 var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
24
25 if sqlType == "" {
26 switch dataValue.Kind() {
27 case reflect.Bool:
28 sqlType = "bool"
29 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
30 if s.fieldCanAutoIncrement(field) {
31 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
32 sqlType = "integer primary key autoincrement"
33 } else {
34 sqlType = "integer"
35 }
36 case reflect.Int64, reflect.Uint64:
37 if s.fieldCanAutoIncrement(field) {
38 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
39 sqlType = "integer primary key autoincrement"
40 } else {
41 sqlType = "bigint"
42 }
43 case reflect.Float32, reflect.Float64:
44 sqlType = "real"
45 case reflect.String:
46 if size > 0 && size < 65532 {
47 sqlType = fmt.Sprintf("varchar(%d)", size)
48 } else {
49 sqlType = "text"
50 }
51 case reflect.Struct:
52 if _, ok := dataValue.Interface().(time.Time); ok {
53 sqlType = "datetime"
54 }
55 default:
56 if IsByteArrayOrSlice(dataValue) {
57 sqlType = "blob"
58 }
59 }
60 }
61
62 if sqlType == "" {
63 panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
64 }
65
66 if strings.TrimSpace(additionalType) == "" {
67 return sqlType
68 }
69 return fmt.Sprintf("%v %v", sqlType, additionalType)
70 }
71
72 func (s sqlite3) HasIndex(tableName string, indexName string) bool {
73 var count int
74 s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
75 return count > 0
76 }
77
78 func (s sqlite3) HasTable(tableName string) bool {
79 var count int
80 s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
81 return count > 0
82 }
83
84 func (s sqlite3) HasColumn(tableName string, columnName string) bool {
85 var count int
86 s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
87 return count > 0
88 }
89
90 func (s sqlite3) CurrentDatabase() (name string) {
91 var (
92 ifaces = make([]interface{}, 3)
93 pointers = make([]*string, 3)
94 i int
95 )
96 for i = 0; i < 3; i++ {
97 ifaces[i] = &pointers[i]
98 }
99 if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
100 return
101 }
102 if pointers[1] != nil {
103 name = *pointers[1]
104 }
105 return
106 }
0 package mssql
1
2 import (
3 "fmt"
4 "reflect"
5 "strconv"
6 "strings"
7 "time"
8
9 _ "github.com/denisenkom/go-mssqldb"
10 "github.com/jinzhu/gorm"
11 )
12
13 func setIdentityInsert(scope *gorm.Scope) {
14 if scope.Dialect().GetName() == "mssql" {
15 for _, field := range scope.PrimaryFields() {
16 if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank {
17 scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
18 scope.InstanceSet("mssql:identity_insert_on", true)
19 }
20 }
21 }
22 }
23
24 func turnOffIdentityInsert(scope *gorm.Scope) {
25 if scope.Dialect().GetName() == "mssql" {
26 if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok {
27 scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName()))
28 }
29 }
30 }
31
32 func init() {
33 gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert)
34 gorm.DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert)
35 gorm.RegisterDialect("mssql", &mssql{})
36 }
37
38 type mssql struct {
39 db gorm.SQLCommon
40 gorm.DefaultForeignKeyNamer
41 }
42
43 func (mssql) GetName() string {
44 return "mssql"
45 }
46
47 func (s *mssql) SetDB(db gorm.SQLCommon) {
48 s.db = db
49 }
50
51 func (mssql) BindVar(i int) string {
52 return "$$$" // ?
53 }
54
55 func (mssql) Quote(key string) string {
56 return fmt.Sprintf(`[%s]`, key)
57 }
58
59 func (s *mssql) DataTypeOf(field *gorm.StructField) string {
60 var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s)
61
62 if sqlType == "" {
63 switch dataValue.Kind() {
64 case reflect.Bool:
65 sqlType = "bit"
66 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
67 if s.fieldCanAutoIncrement(field) {
68 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
69 sqlType = "int IDENTITY(1,1)"
70 } else {
71 sqlType = "int"
72 }
73 case reflect.Int64, reflect.Uint64:
74 if s.fieldCanAutoIncrement(field) {
75 field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
76 sqlType = "bigint IDENTITY(1,1)"
77 } else {
78 sqlType = "bigint"
79 }
80 case reflect.Float32, reflect.Float64:
81 sqlType = "float"
82 case reflect.String:
83 if size > 0 && size < 8000 {
84 sqlType = fmt.Sprintf("nvarchar(%d)", size)
85 } else {
86 sqlType = "nvarchar(max)"
87 }
88 case reflect.Struct:
89 if _, ok := dataValue.Interface().(time.Time); ok {
90 sqlType = "datetimeoffset"
91 }
92 default:
93 if gorm.IsByteArrayOrSlice(dataValue) {
94 if size > 0 && size < 8000 {
95 sqlType = fmt.Sprintf("varbinary(%d)", size)
96 } else {
97 sqlType = "varbinary(max)"
98 }
99 }
100 }
101 }
102
103 if sqlType == "" {
104 panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String()))
105 }
106
107 if strings.TrimSpace(additionalType) == "" {
108 return sqlType
109 }
110 return fmt.Sprintf("%v %v", sqlType, additionalType)
111 }
112
113 func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
114 if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
115 return value != "FALSE"
116 }
117 return field.IsPrimaryKey
118 }
119
120 func (s mssql) HasIndex(tableName string, indexName string) bool {
121 var count int
122 s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
123 return count > 0
124 }
125
126 func (s mssql) RemoveIndex(tableName string, indexName string) error {
127 _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
128 return err
129 }
130
131 func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
132 return false
133 }
134
135 func (s mssql) HasTable(tableName string) bool {
136 var count int
137 currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
138 s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
139 return count > 0
140 }
141
142 func (s mssql) HasColumn(tableName string, columnName string) bool {
143 var count int
144 currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
145 s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
146 return count > 0
147 }
148
149 func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error {
150 _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ))
151 return err
152 }
153
154 func (s mssql) CurrentDatabase() (name string) {
155 s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
156 return
157 }
158
159 func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
160 if offset != nil {
161 if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
162 sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset)
163 }
164 }
165 if limit != nil {
166 if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
167 if sql == "" {
168 // add default zero offset
169 sql += " OFFSET 0 ROWS"
170 }
171 sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit)
172 }
173 }
174 return
175 }
176
177 func (mssql) SelectFromDummyTable() string {
178 return ""
179 }
180
181 func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
182 return ""
183 }
184
185 func (mssql) DefaultValueStr() string {
186 return "DEFAULT VALUES"
187 }
188
189 func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
190 if strings.Contains(tableName, ".") {
191 splitStrings := strings.SplitN(tableName, ".", 2)
192 return splitStrings[0], splitStrings[1]
193 }
194 return dialect.CurrentDatabase(), tableName
195 }
0 package mysql
1
2 import _ "github.com/go-sql-driver/mysql"
0 package postgres
1
2 import (
3 "database/sql"
4 "database/sql/driver"
5
6 _ "github.com/lib/pq"
7 "github.com/lib/pq/hstore"
8 "encoding/json"
9 "errors"
10 "fmt"
11 )
12
13 type Hstore map[string]*string
14
15 // Value get value of Hstore
16 func (h Hstore) Value() (driver.Value, error) {
17 hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
18 if len(h) == 0 {
19 return nil, nil
20 }
21
22 for key, value := range h {
23 var s sql.NullString
24 if value != nil {
25 s.String = *value
26 s.Valid = true
27 }
28 hstore.Map[key] = s
29 }
30 return hstore.Value()
31 }
32
33 // Scan scan value into Hstore
34 func (h *Hstore) Scan(value interface{}) error {
35 hstore := hstore.Hstore{}
36
37 if err := hstore.Scan(value); err != nil {
38 return err
39 }
40
41 if len(hstore.Map) == 0 {
42 return nil
43 }
44
45 *h = Hstore{}
46 for k := range hstore.Map {
47 if hstore.Map[k].Valid {
48 s := hstore.Map[k].String
49 (*h)[k] = &s
50 } else {
51 (*h)[k] = nil
52 }
53 }
54
55 return nil
56 }
57
58 // Jsonb Postgresql's JSONB data type
59 type Jsonb struct {
60 json.RawMessage
61 }
62
63 // Value get value of Jsonb
64 func (j Jsonb) Value() (driver.Value, error) {
65 if len(j.RawMessage) == 0 {
66 return nil, nil
67 }
68 return j.MarshalJSON()
69 }
70
71 // Scan scan value into Jsonb
72 func (j *Jsonb) Scan(value interface{}) error {
73 bytes, ok := value.([]byte)
74 if !ok {
75 return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
76 }
77
78 return json.Unmarshal(bytes, j)
79 }
0 package sqlite
1
2 import _ "github.com/mattn/go-sqlite3"
+0
-68
doc/development.md less more
0 # Gorm Development
1
2 ## Architecture
3
4 The most notable component of Gorm is`gorm.DB`, which hold database connection. It could be initialized like this:
5
6 db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
7
8 Gorm has chainable API, `gorm.DB` is the bridge of chains, it save related information and pass it to the next chain.
9
10 Lets use below code to explain how it works:
11
12 db.Where("name = ?", "jinzhu").Find(&users)
13
14 // equivalent code
15 newdb := db.Where("name =?", "jinzhu")
16 newdb.Find(&user)
17
18 `newdb` is `db`'s clone, in addition, it contains search conditions from the `Where` method.
19 `Find` is a query method, it creates a `Scope` instance, and pass it as argument to query callbacks.
20
21 There are four kinds of callbacks corresponds to sql's CURD: create callbacks, update callbacks, query callbacks, delete callbacks.
22
23 ## Callbacks
24
25 ### Register a new callback
26
27 func updateCreated(scope *Scope) {
28 if scope.HasColumn("Created") {
29 scope.SetColumn("Created", NowFunc())
30 }
31 }
32
33 db.Callback().Create().Register("update_created_at", updateCreated)
34 // register a callback for Create process
35
36 ### Delete an existing callback
37
38 db.Callback().Create().Remove("gorm:create")
39 // delete callback `gorm:create` from Create callbacks
40
41 ### Replace an existing callback
42
43 db.Callback().Create().Replace("gorm:create", newCreateFunction)
44 // replace callback `gorm:create` with new function `newCreateFunction` for Create process
45
46 ### Register callback orders
47
48 db.Callback().Create().Before("gorm:create").Register("update_created_at", updateCreated)
49 db.Callback().Create().After("gorm:create").Register("update_created_at", updateCreated)
50 db.Callback().Query().After("gorm:query").Register("my_plugin:after_query", afterQuery)
51 db.Callback().Delete().After("gorm:delete").Register("my_plugin:after_delete", afterDelete)
52 db.Callback().Update().Before("gorm:update").Register("my_plugin:before_update", beforeUpdate)
53 db.Callback().Create().Before("gorm:create").After("gorm:before_create").Register("my_plugin:before_create", beforeCreate)
54
55 ### Callback API
56
57 Gorm is powered by callbacks, so you could refer below links to learn how to write callbacks
58
59 [Create callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go)
60
61 [Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go)
62
63 [Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go)
64
65 [Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go)
66
67 View [https://github.com/jinzhu/gorm/blob/master/scope.go](https://github.com/jinzhu/gorm/blob/master/scope.go) for all available API
0 version: '3'
1
2 services:
3 mysql:
4 image: 'mysql:latest'
5 ports:
6 - 9910:3306
7 environment:
8 - MYSQL_DATABASE=gorm
9 - MYSQL_USER=gorm
10 - MYSQL_PASSWORD=gorm
11 - MYSQL_RANDOM_ROOT_PASSWORD="yes"
12 postgres:
13 image: 'postgres:latest'
14 ports:
15 - 9920:5432
16 environment:
17 - POSTGRES_USER=gorm
18 - POSTGRES_DB=gorm
19 - POSTGRES_PASSWORD=gorm
20 mssql:
21 image: 'mcmoe/mssqldocker:latest'
22 ports:
23 - 9930:1433
24 environment:
25 - ACCEPT_EULA=Y
26 - SA_PASSWORD=LoremIpsum86
27 - MSSQL_DB=gorm
28 - MSSQL_USER=gorm
29 - MSSQL_PASSWORD=LoremIpsum86
77 URL string
88 }
99
10 type Author struct {
11 ID string
12 Name string
13 Email string
14 }
15
1016 type HNPost struct {
1117 BasePost
18 Author `gorm:"embedded_prefix:user_"` // Embedded struct
1219 Upvotes int32
1320 }
1421
1522 type EngadgetPost struct {
1623 BasePost BasePost `gorm:"embedded"`
24 Author Author `gorm:"embedded;embedded_prefix:author_"` // Embedded struct
1725 ImageUrl string
26 }
27
28 func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
29 dialect := DB.NewScope(&EngadgetPost{}).Dialect()
30 engadgetPostScope := DB.NewScope(&EngadgetPost{})
31 if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") {
32 t.Errorf("should has prefix for embedded columns")
33 }
34
35 if len(engadgetPostScope.PrimaryFields()) != 1 {
36 t.Errorf("should have only one primary field with embedded struct, but got %v", len(engadgetPostScope.PrimaryFields()))
37 }
38
39 hnScope := DB.NewScope(&HNPost{})
40 if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") {
41 t.Errorf("should has prefix for embedded columns")
42 }
1843 }
1944
2045 func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
4570 }
4671 }
4772 }
73
74 func TestEmbeddedPointerTypeStruct(t *testing.T) {
75 type HNPost struct {
76 *BasePost
77 Upvotes int32
78 }
79
80 DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}})
81
82 var hnPost HNPost
83 if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil {
84 t.Errorf("No error should happen when find embedded pointer type, but got %v", err)
85 }
86
87 if hnPost.Title != "embedded_pointer_type" {
88 t.Errorf("Should find correct value for embedded pointer type")
89 }
90 }
55 )
66
77 var (
8 RecordNotFound = errors.New("record not found")
9 InvalidSql = errors.New("invalid sql")
10 NoNewAttrs = errors.New("no new attributes")
11 NoValidTransaction = errors.New("no valid transaction")
12 CantStartTransaction = errors.New("can't start transaction")
8 // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct
9 ErrRecordNotFound = errors.New("record not found")
10 // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
11 ErrInvalidSQL = errors.New("invalid SQL")
12 // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
13 ErrInvalidTransaction = errors.New("no valid transaction")
14 // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
15 ErrCantStartTransaction = errors.New("can't start transaction")
16 // ErrUnaddressable unaddressable value
17 ErrUnaddressable = errors.New("using unaddressable value")
1318 )
1419
15 type errorsInterface interface {
16 GetErrors() []error
20 // Errors contains all happened errors
21 type Errors []error
22
23 // IsRecordNotFoundError returns current error has record not found error or not
24 func IsRecordNotFoundError(err error) bool {
25 if errs, ok := err.(Errors); ok {
26 for _, err := range errs {
27 if err == ErrRecordNotFound {
28 return true
29 }
30 }
31 }
32 return err == ErrRecordNotFound
1733 }
1834
19 type Errors struct {
20 errors []error
35 // GetErrors gets all happened errors
36 func (errs Errors) GetErrors() []error {
37 return errs
2138 }
2239
23 func (errs Errors) GetErrors() []error {
24 return errs.errors
40 // Add adds an error
41 func (errs Errors) Add(newErrors ...error) Errors {
42 for _, err := range newErrors {
43 if err == nil {
44 continue
45 }
46
47 if errors, ok := err.(Errors); ok {
48 errs = errs.Add(errors...)
49 } else {
50 ok = true
51 for _, e := range errs {
52 if err == e {
53 ok = false
54 }
55 }
56 if ok {
57 errs = append(errs, err)
58 }
59 }
60 }
61 return errs
2562 }
2663
27 func (errs *Errors) Add(err error) {
28 if errors, ok := err.(errorsInterface); ok {
29 for _, err := range errors.GetErrors() {
30 errs.Add(err)
31 }
32 } else {
33 for _, e := range errs.errors {
34 if err == e {
35 return
36 }
37 }
38 errs.errors = append(errs.errors, err)
39 }
40 }
41
64 // Error format happened errors
4265 func (errs Errors) Error() string {
4366 var errors = []string{}
44 for _, e := range errs.errors {
67 for _, e := range errs {
4568 errors = append(errors, e.Error())
4669 }
4770 return strings.Join(errors, "; ")
0 package gorm_test
1
2 import (
3 "errors"
4 "testing"
5
6 "github.com/jinzhu/gorm"
7 )
8
9 func TestErrorsCanBeUsedOutsideGorm(t *testing.T) {
10 errs := []error{errors.New("First"), errors.New("Second")}
11
12 gErrs := gorm.Errors(errs)
13 gErrs = gErrs.Add(errors.New("Third"))
14 gErrs = gErrs.Add(gErrs)
15
16 if gErrs.Error() != "First; Second; Third" {
17 t.Fatalf("Gave wrong error, got %s", gErrs.Error())
18 }
19 }
22 import (
33 "database/sql"
44 "errors"
5 "fmt"
56 "reflect"
67 )
78
9 // Field model field definition
810 type Field struct {
911 *StructField
1012 IsBlank bool
1113 Field reflect.Value
1214 }
1315
14 func (field *Field) Set(value interface{}) error {
16 // Set set a value to the field
17 func (field *Field) Set(value interface{}) (err error) {
1518 if !field.Field.IsValid() {
1619 return errors.New("field value not valid")
1720 }
1821
1922 if !field.Field.CanAddr() {
20 return errors.New("unaddressable value")
23 return ErrUnaddressable
2124 }
2225
23 if rvalue, ok := value.(reflect.Value); ok {
24 value = rvalue.Interface()
26 reflectValue, ok := value.(reflect.Value)
27 if !ok {
28 reflectValue = reflect.ValueOf(value)
2529 }
2630
27 if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
28 if v, ok := value.(reflect.Value); ok {
29 if err := scanner.Scan(v.Interface()); err != nil {
30 return err
31 fieldValue := field.Field
32 if reflectValue.IsValid() {
33 if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
34 fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
35 } else {
36 if fieldValue.Kind() == reflect.Ptr {
37 if fieldValue.IsNil() {
38 fieldValue.Set(reflect.New(field.Struct.Type.Elem()))
39 }
40 fieldValue = fieldValue.Elem()
3141 }
32 } else {
33 if err := scanner.Scan(value); err != nil {
34 return err
42
43 if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
44 fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
45 } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
46 err = scanner.Scan(reflectValue.Interface())
47 } else {
48 err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
3549 }
3650 }
3751 } else {
38 reflectValue, ok := value.(reflect.Value)
39 if !ok {
40 reflectValue = reflect.ValueOf(value)
41 }
42
43 if reflectValue.Type().ConvertibleTo(field.Field.Type()) {
44 field.Field.Set(reflectValue.Convert(field.Field.Type()))
45 } else {
46 return errors.New("could not convert argument")
47 }
52 field.Field.Set(reflect.Zero(field.Field.Type()))
4853 }
4954
5055 field.IsBlank = isBlank(field.Field)
51 return nil
56 return err
5257 }
53
54 // Fields get value's fields
55 func (scope *Scope) Fields() map[string]*Field {
56 if scope.fields == nil {
57 fields := map[string]*Field{}
58 modelStruct := scope.GetModelStruct()
59
60 indirectValue := scope.IndirectValue()
61 isStruct := indirectValue.Kind() == reflect.Struct
62 for _, structField := range modelStruct.StructFields {
63 if isStruct {
64 fields[structField.DBName] = getField(indirectValue, structField)
65 } else {
66 fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
67 }
68 }
69
70 if modelStruct.cached {
71 scope.fields = fields
72 }
73 return fields
74 }
75 return scope.fields
76 }
77
78 func getField(indirectValue reflect.Value, structField *StructField) *Field {
79 field := &Field{StructField: structField}
80 for _, name := range structField.Names {
81 indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
82 }
83 field.Field = indirectValue
84 field.IsBlank = isBlank(indirectValue)
85 return field
86 }
1010 Name string
1111 Children []CalculateFieldChild
1212 Category CalculateFieldCategory
13 EmbeddedField
14 }
15
16 type EmbeddedField struct {
17 EmbeddedName string `sql:"NOT NULL;DEFAULT:'hello'"`
1318 }
1419
1520 type CalculateFieldChild struct {
2631
2732 func TestCalculateField(t *testing.T) {
2833 var field CalculateField
29 fields := DB.NewScope(&field).Fields()
30 if fields["children"].Relationship == nil || fields["category"].Relationship == nil {
34 var scope = DB.NewScope(&field)
35 if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil {
3136 t.Errorf("Should calculate fields correctly for the first time")
3237 }
38
39 if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil {
40 t.Errorf("Should calculate fields correctly for the first time")
41 }
42
43 if field, ok := scope.FieldByName("embedded_name"); !ok {
44 t.Errorf("should find embedded field")
45 } else if _, ok := field.TagSettings["NOT NULL"]; !ok {
46 t.Errorf("should find embedded field's tag settings")
47 }
3348 }
+0
-83
foundation.go less more
0 package gorm
1
2 import (
3 "fmt"
4 "reflect"
5 "time"
6 )
7
8 type foundation struct {
9 commonDialect
10 }
11
12 func (foundation) BinVar(i int) string {
13 return fmt.Sprintf("$%v", i)
14 }
15
16 func (foundation) SupportLastInsertId() bool {
17 return false
18 }
19
20 func (foundation) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
21 switch value.Kind() {
22 case reflect.Bool:
23 return "boolean"
24 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
25 if autoIncrease {
26 return "serial"
27 }
28 return "int"
29 case reflect.Int64, reflect.Uint64:
30 if autoIncrease {
31 return "bigserial"
32 }
33 return "bigint"
34 case reflect.Float32, reflect.Float64:
35 return "double"
36 case reflect.String:
37 if size > 0 && size < 65532 {
38 return fmt.Sprintf("varchar(%d)", size)
39 }
40 return "clob"
41 case reflect.Struct:
42 if _, ok := value.Interface().(time.Time); ok {
43 return "datetime"
44 }
45 default:
46 if _, ok := value.Interface().([]byte); ok {
47 return "blob"
48 }
49 }
50 panic(fmt.Sprintf("invalid sql type %s (%s) for foundation", value.Type().Name(), value.Kind().String()))
51 }
52
53 func (s foundation) ReturningStr(tableName, key string) string {
54 return fmt.Sprintf("RETURNING %v.%v", tableName, key)
55 }
56
57 func (s foundation) HasTable(scope *Scope, tableName string) bool {
58 var count int
59 s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName)
60 return count > 0
61 }
62
63 func (s foundation) HasColumn(scope *Scope, tableName string, columnName string) bool {
64 var count int
65 s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName)
66 return count > 0
67 }
68
69 func (s foundation) RemoveIndex(scope *Scope, indexName string) {
70 scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", s.Quote(indexName)))
71 }
72
73 func (s foundation) HasIndex(scope *Scope, tableName string, indexName string) bool {
74 var count int
75 s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName)
76 return count > 0
77 }
78
79 func (s foundation) CurrentDatabase(scope *Scope) (name string) {
80 s.RawScanString(scope, &name, "SELECT CURRENT_SCHEMA")
81 return
82 }
images/logger.png less more
Binary diff not shown
11
22 import "database/sql"
33
4 type sqlCommon interface {
4 // SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
5 type SQLCommon interface {
56 Exec(query string, args ...interface{}) (sql.Result, error)
67 Prepare(query string) (*sql.Stmt, error)
78 Query(query string, args ...interface{}) (*sql.Rows, error)
66 "strings"
77 )
88
9 // JoinTableHandlerInterface is an interface for how to handle many2many relations
910 type JoinTableHandlerInterface interface {
11 // initialize join table handler
1012 Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
13 // Table return join table's table name
1114 Table(db *DB) string
15 // Add create relationship in join table for source and destination
1216 Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
17 // Delete delete relationship in join table for sources
1318 Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
19 // JoinWith query with `Join` conditions
1420 JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
21 // SourceForeignKeys return source foreign keys
1522 SourceForeignKeys() []JoinTableForeignKey
23 // DestinationForeignKeys return destination foreign keys
1624 DestinationForeignKeys() []JoinTableForeignKey
1725 }
1826
27 // JoinTableForeignKey join table foreign key struct
1928 type JoinTableForeignKey struct {
2029 DBName string
2130 AssociationDBName string
2231 }
2332
33 // JoinTableSource is a struct that contains model type and foreign keys
2434 type JoinTableSource struct {
2535 ModelType reflect.Type
2636 ForeignKeys []JoinTableForeignKey
2737 }
2838
39 // JoinTableHandler default join table handler
2940 type JoinTableHandler struct {
3041 TableName string `sql:"-"`
3142 Source JoinTableSource `sql:"-"`
3243 Destination JoinTableSource `sql:"-"`
3344 }
3445
46 // SourceForeignKeys return source foreign keys
3547 func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
3648 return s.Source.ForeignKeys
3749 }
3850
51 // DestinationForeignKeys return destination foreign keys
3952 func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
4053 return s.Destination.ForeignKeys
4154 }
4255
56 // Setup initialize a default join table handler
4357 func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
4458 s.TableName = tableName
4559
4660 s.Source = JoinTableSource{ModelType: source}
61 s.Source.ForeignKeys = []JoinTableForeignKey{}
4762 for idx, dbName := range relationship.ForeignFieldNames {
4863 s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
4964 DBName: relationship.ForeignDBNames[idx],
5267 }
5368
5469 s.Destination = JoinTableSource{ModelType: destination}
70 s.Destination.ForeignKeys = []JoinTableForeignKey{}
5571 for idx, dbName := range relationship.AssociationForeignFieldNames {
5672 s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
5773 DBName: relationship.AssociationForeignDBNames[idx],
6076 }
6177 }
6278
79 // Table return join table's table name
6380 func (s JoinTableHandler) Table(db *DB) string {
64 return s.TableName
65 }
66
67 func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
68 values := map[string]interface{}{}
69
81 return DefaultTableNameHandler(db, s.TableName)
82 }
83
84 func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) {
7085 for _, source := range sources {
7186 scope := db.NewScope(source)
7287 modelType := scope.GetModelStruct().ModelType
7388
74 if s.Source.ModelType == modelType {
75 for _, foreignKey := range s.Source.ForeignKeys {
76 values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
89 for _, joinTableSource := range joinTableSources {
90 if joinTableSource.ModelType == modelType {
91 for _, foreignKey := range joinTableSource.ForeignKeys {
92 if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
93 conditionMap[foreignKey.DBName] = field.Field.Interface()
94 }
95 }
96 break
7797 }
78 } else if s.Destination.ModelType == modelType {
79 for _, foreignKey := range s.Destination.ForeignKeys {
80 values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
81 }
82 }
83 }
84 return values
85 }
86
87 func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error {
88 scope := db.NewScope("")
89 searchMap := s.GetSearchMap(db, source1, source2)
98 }
99 }
100 }
101
102 // Add create relationship in join table for source and destination
103 func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
104 var (
105 scope = db.NewScope("")
106 conditionMap = map[string]interface{}{}
107 )
108
109 // Update condition map for source
110 s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source)
111
112 // Update condition map for destination
113 s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination)
90114
91115 var assignColumns, binVars, conditions []string
92116 var values []interface{}
93 for key, value := range searchMap {
94 assignColumns = append(assignColumns, key)
117 for key, value := range conditionMap {
118 assignColumns = append(assignColumns, scope.Quote(key))
95119 binVars = append(binVars, `?`)
96120 conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
97121 values = append(values, value)
101125 values = append(values, value)
102126 }
103127
104 quotedTable := handler.Table(db)
128 quotedTable := scope.Quote(handler.Table(db))
105129 sql := fmt.Sprintf(
106130 "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
107131 quotedTable,
115139 return db.Exec(sql, values...).Error
116140 }
117141
142 // Delete delete relationship in join table for sources
118143 func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
119 var conditions []string
120 var values []interface{}
121
122 for key, value := range s.GetSearchMap(db, sources...) {
123 conditions = append(conditions, fmt.Sprintf("%v = ?", key))
144 var (
145 scope = db.NewScope(nil)
146 conditions []string
147 values []interface{}
148 conditionMap = map[string]interface{}{}
149 )
150
151 s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...)
152
153 for key, value := range conditionMap {
154 conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
124155 values = append(values, value)
125156 }
126157
127158 return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
128159 }
129160
161 // JoinWith query with `Join` conditions
130162 func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
131 quotedTable := handler.Table(db)
132
133 scope := db.NewScope(source)
134 modelType := scope.GetModelStruct().ModelType
135 var joinConditions []string
136 var queryConditions []string
137 var values []interface{}
138 if s.Source.ModelType == modelType {
163 var (
164 scope = db.NewScope(source)
165 tableName = handler.Table(db)
166 quotedTableName = scope.Quote(tableName)
167 joinConditions []string
168 values []interface{}
169 )
170
171 if s.Source.ModelType == scope.GetModelStruct().ModelType {
139172 destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
140173 for _, foreignKey := range s.Destination.ForeignKeys {
141 joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
174 joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
142175 }
143176
144177 var foreignDBNames []string
146179
147180 for _, foreignKey := range s.Source.ForeignKeys {
148181 foreignDBNames = append(foreignDBNames, foreignKey.DBName)
149 foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name)
150 }
151
152 foreignFieldValues := scope.getColumnAsArray(foreignFieldNames)
153
154 condString := fmt.Sprintf("%v in (%v)", toQueryCondition(scope, foreignDBNames), toQueryMarks(foreignFieldValues))
155
156 keys := scope.getColumnAsArray(foreignFieldNames)
157 values = append(values, toQueryValues(keys))
158
159 queryConditions = append(queryConditions, condString)
160
161 return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))).
182 if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
183 foreignFieldNames = append(foreignFieldNames, field.Name)
184 }
185 }
186
187 foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
188
189 var condString string
190 if len(foreignFieldValues) > 0 {
191 var quotedForeignDBNames []string
192 for _, dbName := range foreignDBNames {
193 quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName)
194 }
195
196 condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
197
198 keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
199 values = append(values, toQueryValues(keys))
200 } else {
201 condString = fmt.Sprintf("1 <> 1")
202 }
203
204 return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
162205 Where(condString, toQueryValues(foreignFieldValues)...)
163 } else {
164 db.Error = errors.New("wrong source type for join table handler")
165 return db
166 }
167 }
206 }
207
208 db.Error = errors.New("wrong source type for join table handler")
209 return db
210 }
11
22 import (
33 "fmt"
4 "strconv"
45 "testing"
56 "time"
67
1718 gorm.JoinTableHandler
1819 PersonID int
1920 AddressID int
20 DeletedAt time.Time
21 DeletedAt *time.Time
2122 CreatedAt time.Time
2223 }
2324
2425 func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error {
25 return db.Where(map[string]interface{}{
26 "person_id": db.NewScope(foreignValue).PrimaryKeyValue(),
27 "address_id": db.NewScope(associationValue).PrimaryKeyValue(),
28 }).Assign(map[string]interface{}{
29 "person_id": foreignValue,
30 "address_id": associationValue,
26 foreignPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(foreignValue).PrimaryKeyValue()))
27 associationPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(associationValue).PrimaryKeyValue()))
28 if result := db.Unscoped().Model(&PersonAddress{}).Where(map[string]interface{}{
29 "person_id": foreignPrimaryKey,
30 "address_id": associationPrimaryKey,
31 }).Update(map[string]interface{}{
32 "person_id": foreignPrimaryKey,
33 "address_id": associationPrimaryKey,
3134 "deleted_at": gorm.Expr("NULL"),
32 }).FirstOrCreate(&PersonAddress{}).Error
35 }).RowsAffected; result == 0 {
36 return db.Create(&PersonAddress{
37 PersonID: foreignPrimaryKey,
38 AddressID: associationPrimaryKey,
39 }).Error
40 }
41
42 return nil
3343 }
3444
3545 func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error {
3848
3949 func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB {
4050 table := pa.Table(db)
41 return db.Table(table).Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
51 return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
4252 }
4353
4454 func TestJoinTable(t *testing.T) {
6979 t.Errorf("Should deleted all addresses")
7080 }
7181 }
82
83 func TestEmbeddedMany2ManyRelationship(t *testing.T) {
84 type EmbeddedPerson struct {
85 ID int
86 Name string
87 Addresses []*Address `gorm:"many2many:person_addresses;"`
88 }
89
90 type NewPerson struct {
91 EmbeddedPerson
92 ExternalID uint
93 }
94 DB.Exec("drop table person_addresses;")
95 DB.AutoMigrate(&NewPerson{})
96
97 address1 := &Address{Address1: "address 1"}
98 address2 := &Address{Address1: "address 2"}
99 person := &NewPerson{ExternalID: 100, EmbeddedPerson: EmbeddedPerson{Name: "person", Addresses: []*Address{address1, address2}}}
100 if err := DB.Save(person).Error; err != nil {
101 t.Errorf("no error should return when save embedded many2many relationship, but got %v", err)
102 }
103
104 if err := DB.Model(person).Association("Addresses").Delete(address1).Error; err != nil {
105 t.Errorf("no error should return when delete embedded many2many relationship, but got %v", err)
106 }
107
108 association := DB.Model(person).Association("Addresses")
109 if count := association.Count(); count != 1 || association.Error != nil {
110 t.Errorf("Should found one address, but got %v, error is %v", count, association.Error)
111 }
112
113 if association.Clear(); association.Count() != 0 {
114 t.Errorf("Should deleted all addresses")
115 }
116 }
66 "os"
77 "reflect"
88 "regexp"
9 "strconv"
910 "time"
11 "unicode"
1012 )
13
14 var (
15 defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
16 sqlRegexp = regexp.MustCompile(`\?`)
17 numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`)
18 )
19
20 func isPrintable(s string) bool {
21 for _, r := range s {
22 if !unicode.IsPrint(r) {
23 return false
24 }
25 }
26 return true
27 }
28
29 var LogFormatter = func(values ...interface{}) (messages []interface{}) {
30 if len(values) > 1 {
31 var (
32 sql string
33 formattedValues []string
34 level = values[0]
35 currentTime = "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m"
36 source = fmt.Sprintf("\033[35m(%v)\033[0m", values[1])
37 )
38
39 messages = []interface{}{source, currentTime}
40
41 if level == "sql" {
42 // duration
43 messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
44 // sql
45
46 for _, value := range values[4].([]interface{}) {
47 indirectValue := reflect.Indirect(reflect.ValueOf(value))
48 if indirectValue.IsValid() {
49 value = indirectValue.Interface()
50 if t, ok := value.(time.Time); ok {
51 formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05")))
52 } else if b, ok := value.([]byte); ok {
53 if str := string(b); isPrintable(str) {
54 formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
55 } else {
56 formattedValues = append(formattedValues, "'<binary>'")
57 }
58 } else if r, ok := value.(driver.Valuer); ok {
59 if value, err := r.Value(); err == nil && value != nil {
60 formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
61 } else {
62 formattedValues = append(formattedValues, "NULL")
63 }
64 } else {
65 formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
66 }
67 } else {
68 formattedValues = append(formattedValues, "NULL")
69 }
70 }
71
72 // differentiate between $n placeholders or else treat like ?
73 if numericPlaceHolderRegexp.MatchString(values[3].(string)) {
74 sql = values[3].(string)
75 for index, value := range formattedValues {
76 placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1)
77 sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1")
78 }
79 } else {
80 formattedValuesLength := len(formattedValues)
81 for index, value := range sqlRegexp.Split(values[3].(string), -1) {
82 sql += value
83 if index < formattedValuesLength {
84 sql += formattedValues[index]
85 }
86 }
87 }
88
89 messages = append(messages, sql)
90 messages = append(messages, fmt.Sprintf(" \n\033[36;31m[%v]\033[0m ", strconv.FormatInt(values[5].(int64), 10)+" rows affected or returned "))
91 } else {
92 messages = append(messages, "\033[31;1m")
93 messages = append(messages, values[2:]...)
94 messages = append(messages, "\033[0m")
95 }
96 }
97
98 return
99 }
11100
12101 type logger interface {
13102 Print(v ...interface{})
14103 }
15104
105 // LogWriter log writer interface
16106 type LogWriter interface {
17107 Println(v ...interface{})
18108 }
19109
110 // Logger default logger
20111 type Logger struct {
21112 LogWriter
22113 }
23114
24 var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
25
26 // Format log
27 var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
28
115 // Print format & print log
29116 func (logger Logger) Print(values ...interface{}) {
30 if len(values) > 1 {
31 level := values[0]
32 currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m"
33 source := fmt.Sprintf("\033[35m(%v)\033[0m", values[1])
34 messages := []interface{}{source, currentTime}
35
36 if level == "sql" {
37 // duration
38 messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
39 // sql
40 var formatedValues []interface{}
41 for _, value := range values[4].([]interface{}) {
42 indirectValue := reflect.Indirect(reflect.ValueOf(value))
43 if indirectValue.IsValid() {
44 value = indirectValue.Interface()
45 if t, ok := value.(time.Time); ok {
46 formatedValues = append(formatedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339)))
47 } else if b, ok := value.([]byte); ok {
48 formatedValues = append(formatedValues, fmt.Sprintf("'%v'", string(b)))
49 } else if r, ok := value.(driver.Valuer); ok {
50 if value, err := r.Value(); err == nil && value != nil {
51 formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
52 } else {
53 formatedValues = append(formatedValues, "NULL")
54 }
55 } else {
56 formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
57 }
58 } else {
59 formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
60 }
61 }
62 messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(values[3].(string), "%v"), formatedValues...))
63 } else {
64 messages = append(messages, "\033[31;1m")
65 messages = append(messages, values[2:]...)
66 messages = append(messages, "\033[0m")
67 }
68 logger.Println(messages...)
69 }
117 logger.Println(LogFormatter(values...)...)
70118 }
+435
-219
main.go less more
88 "time"
99 )
1010
11 // NowFunc returns current time, this function is exported in order to be able
12 // to give the flexibility to the developer to customize it according to their
13 // needs
14 //
15 // e.g: return time.Now().UTC()
16 //
17 var NowFunc = func() time.Time {
18 return time.Now()
19 }
20
11 // DB contains information for current db connection
2112 type DB struct {
22 Value interface{}
23 Error error
24 RowsAffected int64
25 callback *callback
26 db sqlCommon
27 parent *DB
28 search *search
13 Value interface{}
14 Error error
15 RowsAffected int64
16
17 // single db
18 db SQLCommon
19 blockGlobalUpdate bool
2920 logMode int
3021 logger logger
31 dialect Dialect
32 singularTable bool
33 source string
22 search *search
3423 values map[string]interface{}
35 joinTableHandlers map[string]JoinTableHandler
36 }
37
38 func Open(dialect string, args ...interface{}) (DB, error) {
39 var db DB
40 var err error
41
24
25 // global db
26 parent *DB
27 callbacks *Callback
28 dialect Dialect
29 singularTable bool
30 }
31
32 // Open initialize a new db connection, need to import driver first, e.g:
33 //
34 // import _ "github.com/go-sql-driver/mysql"
35 // func main() {
36 // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
37 // }
38 // GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with
39 // import _ "github.com/jinzhu/gorm/dialects/mysql"
40 // // import _ "github.com/jinzhu/gorm/dialects/postgres"
41 // // import _ "github.com/jinzhu/gorm/dialects/sqlite"
42 // // import _ "github.com/jinzhu/gorm/dialects/mssql"
43 func Open(dialect string, args ...interface{}) (db *DB, err error) {
4244 if len(args) == 0 {
4345 err = errors.New("invalid database source")
44 } else {
45 var source string
46 var dbSql sqlCommon
47
48 switch value := args[0].(type) {
49 case string:
50 var driver = dialect
51 if len(args) == 1 {
52 source = value
53 } else if len(args) >= 2 {
54 driver = value
55 source = args[1].(string)
56 }
57 if driver == "foundation" {
58 driver = "postgres" // FoundationDB speaks a postgres-compatible protocol.
59 }
60 dbSql, err = sql.Open(driver, source)
61 case sqlCommon:
62 source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
63 dbSql = value
64 }
65
66 db = DB{
67 dialect: NewDialect(dialect),
68 logger: defaultLogger,
69 callback: DefaultCallback,
70 source: source,
71 values: map[string]interface{}{},
72 db: dbSql,
73 }
74 db.parent = &db
75
76 if err == nil {
77 err = db.DB().Ping() // Send a ping to make sure the database connection is alive.
78 }
79 }
80
81 return db, err
82 }
83
84 func (s *DB) Close() error {
85 return s.parent.db.(*sql.DB).Close()
86 }
87
88 func (s *DB) DB() *sql.DB {
89 return s.db.(*sql.DB)
90 }
91
46 return nil, err
47 }
48 var source string
49 var dbSQL SQLCommon
50
51 switch value := args[0].(type) {
52 case string:
53 var driver = dialect
54 if len(args) == 1 {
55 source = value
56 } else if len(args) >= 2 {
57 driver = value
58 source = args[1].(string)
59 }
60 dbSQL, err = sql.Open(driver, source)
61 case SQLCommon:
62 dbSQL = value
63 }
64
65 db = &DB{
66 db: dbSQL,
67 logger: defaultLogger,
68 values: map[string]interface{}{},
69 callbacks: DefaultCallback,
70 dialect: newDialect(dialect, dbSQL),
71 }
72 db.parent = db
73 if err != nil {
74 return
75 }
76 // Send a ping to make sure the database connection is alive.
77 if d, ok := dbSQL.(*sql.DB); ok {
78 if err = d.Ping(); err != nil {
79 d.Close()
80 }
81 }
82 return
83 }
84
85 // New clone a new db connection without search conditions
9286 func (s *DB) New() *DB {
9387 clone := s.clone()
9488 clone.search = nil
9690 return clone
9791 }
9892
99 // NewScope create scope for callbacks, including DB's search information
100 func (db *DB) NewScope(value interface{}) *Scope {
101 dbClone := db.clone()
102 dbClone.Value = value
103 return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
104 }
105
106 // CommonDB Return the underlying sql.DB or sql.Tx instance.
107 // Use of this method is discouraged. It's mainly intended to allow
108 // coexistence with legacy non-GORM code.
109 func (s *DB) CommonDB() sqlCommon {
93 type closer interface {
94 Close() error
95 }
96
97 // Close close current db connection. If database connection is not an io.Closer, returns an error.
98 func (s *DB) Close() error {
99 if db, ok := s.parent.db.(closer); ok {
100 return db.Close()
101 }
102 return errors.New("can't close current db")
103 }
104
105 // DB get `*sql.DB` from current connection
106 // If the underlying database connection is not a *sql.DB, returns nil
107 func (s *DB) DB() *sql.DB {
108 db, _ := s.db.(*sql.DB)
109 return db
110 }
111
112 // CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
113 func (s *DB) CommonDB() SQLCommon {
110114 return s.db
111115 }
112116
113 func (s *DB) Callback() *callback {
114 s.parent.callback = s.parent.callback.clone()
115 return s.parent.callback
116 }
117
118 func (s *DB) SetLogger(l logger) {
119 s.logger = l
120 }
121
117 // Dialect get dialect
118 func (s *DB) Dialect() Dialect {
119 return s.parent.dialect
120 }
121
122 // Callback return `Callbacks` container, you could add/change/delete callbacks with it
123 // db.Callback().Create().Register("update_created_at", updateCreated)
124 // Refer https://jinzhu.github.io/gorm/development.html#callbacks
125 func (s *DB) Callback() *Callback {
126 s.parent.callbacks = s.parent.callbacks.clone()
127 return s.parent.callbacks
128 }
129
130 // SetLogger replace default logger
131 func (s *DB) SetLogger(log logger) {
132 s.logger = log
133 }
134
135 // LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
122136 func (s *DB) LogMode(enable bool) *DB {
123137 if enable {
124138 s.logMode = 2
128142 return s
129143 }
130144
145 // BlockGlobalUpdate if true, generates an error on update/delete without where clause.
146 // This is to prevent eventual error with empty objects updates/deletions
147 func (s *DB) BlockGlobalUpdate(enable bool) *DB {
148 s.blockGlobalUpdate = enable
149 return s
150 }
151
152 // HasBlockGlobalUpdate return state of block
153 func (s *DB) HasBlockGlobalUpdate() bool {
154 return s.blockGlobalUpdate
155 }
156
157 // SingularTable use singular table by default
131158 func (s *DB) SingularTable(enable bool) {
132159 modelStructsMap = newModelStructsMap()
133160 s.parent.singularTable = enable
134161 }
135162
163 // NewScope create a scope for current operation
164 func (s *DB) NewScope(value interface{}) *Scope {
165 dbClone := s.clone()
166 dbClone.Value = value
167 return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
168 }
169
170 // QueryExpr returns the query as expr object
171 func (s *DB) QueryExpr() *expr {
172 scope := s.NewScope(s.Value)
173 scope.InstanceSet("skip_bindvar", true)
174 scope.prepareQuerySQL()
175
176 return Expr(scope.SQL, scope.SQLVars...)
177 }
178
179 // Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query
136180 func (s *DB) Where(query interface{}, args ...interface{}) *DB {
137181 return s.clone().search.Where(query, args...).db
138182 }
139183
184 // Or filter records that match before conditions or this one, similar to `Where`
140185 func (s *DB) Or(query interface{}, args ...interface{}) *DB {
141186 return s.clone().search.Or(query, args...).db
142187 }
143188
189 // Not filter records that don't match current conditions, similar to `Where`
144190 func (s *DB) Not(query interface{}, args ...interface{}) *DB {
145191 return s.clone().search.Not(query, args...).db
146192 }
147193
148 func (s *DB) Limit(value interface{}) *DB {
149 return s.clone().search.Limit(value).db
150 }
151
152 func (s *DB) Offset(value interface{}) *DB {
153 return s.clone().search.Offset(value).db
154 }
155
156 func (s *DB) Order(value string, reorder ...bool) *DB {
194 // Limit specify the number of records to be retrieved
195 func (s *DB) Limit(limit interface{}) *DB {
196 return s.clone().search.Limit(limit).db
197 }
198
199 // Offset specify the number of records to skip before starting to return the records
200 func (s *DB) Offset(offset interface{}) *DB {
201 return s.clone().search.Offset(offset).db
202 }
203
204 // Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions
205 // db.Order("name DESC")
206 // db.Order("name DESC", true) // reorder
207 // db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
208 func (s *DB) Order(value interface{}, reorder ...bool) *DB {
157209 return s.clone().search.Order(value, reorder...).db
158210 }
159211
212 // Select specify fields that you want to retrieve from database when querying, by default, will select all fields;
213 // When creating/updating, specify fields that you want to save to database
160214 func (s *DB) Select(query interface{}, args ...interface{}) *DB {
161215 return s.clone().search.Select(query, args...).db
162216 }
163217
218 // Omit specify fields that you want to ignore when saving to database for creating, updating
164219 func (s *DB) Omit(columns ...string) *DB {
165220 return s.clone().search.Omit(columns...).db
166221 }
167222
223 // Group specify the group method on the find
168224 func (s *DB) Group(query string) *DB {
169225 return s.clone().search.Group(query).db
170226 }
171227
172 func (s *DB) Having(query string, values ...interface{}) *DB {
228 // Having specify HAVING conditions for GROUP BY
229 func (s *DB) Having(query interface{}, values ...interface{}) *DB {
173230 return s.clone().search.Having(query, values...).db
174231 }
175232
176 func (s *DB) Joins(query string) *DB {
177 return s.clone().search.Joins(query).db
178 }
179
233 // Joins specify Joins conditions
234 // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
235 func (s *DB) Joins(query string, args ...interface{}) *DB {
236 return s.clone().search.Joins(query, args...).db
237 }
238
239 // Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically
240 // func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
241 // return db.Where("amount > ?", 1000)
242 // }
243 //
244 // func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
245 // return func (db *gorm.DB) *gorm.DB {
246 // return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
247 // }
248 // }
249 //
250 // db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
251 // Refer https://jinzhu.github.io/gorm/crud.html#scopes
180252 func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
181253 for _, f := range funcs {
182254 s = f(s)
184256 return s
185257 }
186258
259 // Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/crud.html#soft-delete
187260 func (s *DB) Unscoped() *DB {
188261 return s.clone().search.unscoped().db
189262 }
190263
264 // Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate
191265 func (s *DB) Attrs(attrs ...interface{}) *DB {
192266 return s.clone().search.Attrs(attrs...).db
193267 }
194268
269 // Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate
195270 func (s *DB) Assign(attrs ...interface{}) *DB {
196271 return s.clone().search.Assign(attrs...).db
197272 }
198273
274 // First find first record that match given conditions, order by primary key
199275 func (s *DB) First(out interface{}, where ...interface{}) *DB {
200 newScope := s.clone().NewScope(out)
276 newScope := s.NewScope(out)
201277 newScope.Search.Limit(1)
202278 return newScope.Set("gorm:order_by_primary_key", "ASC").
203 inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
204 }
205
279 inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
280 }
281
282 // Take return a record that match given conditions, the order will depend on the database implementation
283 func (s *DB) Take(out interface{}, where ...interface{}) *DB {
284 newScope := s.NewScope(out)
285 newScope.Search.Limit(1)
286 return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
287 }
288
289 // Last find last record that match given conditions, order by primary key
206290 func (s *DB) Last(out interface{}, where ...interface{}) *DB {
207 newScope := s.clone().NewScope(out)
291 newScope := s.NewScope(out)
208292 newScope.Search.Limit(1)
209293 return newScope.Set("gorm:order_by_primary_key", "DESC").
210 inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
211 }
212
294 inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
295 }
296
297 // Find find records that match given conditions
213298 func (s *DB) Find(out interface{}, where ...interface{}) *DB {
214 return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
215 }
216
299 return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
300 }
301
302 // Scan scan value to a struct
217303 func (s *DB) Scan(dest interface{}) *DB {
218 return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db
219 }
220
304 return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
305 }
306
307 // Row return `*sql.Row` with given conditions
221308 func (s *DB) Row() *sql.Row {
222309 return s.NewScope(s.Value).row()
223310 }
224311
312 // Rows return `*sql.Rows` with given conditions
225313 func (s *DB) Rows() (*sql.Rows, error) {
226314 return s.NewScope(s.Value).rows()
227315 }
228316
317 // ScanRows scan `*sql.Rows` to give struct
318 func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error {
319 var (
320 scope = s.NewScope(result)
321 clone = scope.db
322 columns, err = rows.Columns()
323 )
324
325 if clone.AddError(err) == nil {
326 scope.scan(rows, columns, scope.Fields())
327 }
328
329 return clone.Error
330 }
331
332 // Pluck used to query single column from a model as a map
333 // var ages []int64
334 // db.Find(&users).Pluck("age", &ages)
229335 func (s *DB) Pluck(column string, value interface{}) *DB {
230336 return s.NewScope(s.Value).pluck(column, value).db
231337 }
232338
339 // Count get how many records for a model
233340 func (s *DB) Count(value interface{}) *DB {
234341 return s.NewScope(s.Value).count(value).db
235342 }
236343
344 // Related get related associations
237345 func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
238 return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
239 }
240
346 return s.NewScope(s.Value).related(value, foreignKeys...).db
347 }
348
349 // FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)
350 // https://jinzhu.github.io/gorm/crud.html#firstorinit
241351 func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
242352 c := s.clone()
243353 if result := c.First(out, where...); result.Error != nil {
246356 }
247357 c.NewScope(out).inlineCondition(where...).initialize()
248358 } else {
249 c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs), false)
359 c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs)
250360 }
251361 return c
252362 }
253363
364 // FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions)
365 // https://jinzhu.github.io/gorm/crud.html#firstorcreate
254366 func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
255367 c := s.clone()
256 if result := c.First(out, where...); result.Error != nil {
368 if result := s.First(out, where...); result.Error != nil {
257369 if !result.RecordNotFound() {
258370 return result
259371 }
260 c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callback.creates).db.Error)
372 return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db
261373 } else if len(c.search.assignAttrs) > 0 {
262 c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callback.updates).db.Error)
374 return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db
263375 }
264376 return c
265377 }
266378
379 // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
267380 func (s *DB) Update(attrs ...interface{}) *DB {
268381 return s.Updates(toSearchableMap(attrs...), true)
269382 }
270383
384 // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
271385 func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
272 return s.clone().NewScope(s.Value).
386 return s.NewScope(s.Value).
273387 Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
274388 InstanceSet("gorm:update_interface", values).
275 callCallbacks(s.parent.callback.updates).db
276 }
277
389 callCallbacks(s.parent.callbacks.updates).db
390 }
391
392 // UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
278393 func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
279394 return s.UpdateColumns(toSearchableMap(attrs...))
280395 }
281396
397 // UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
282398 func (s *DB) UpdateColumns(values interface{}) *DB {
283 return s.clone().NewScope(s.Value).
399 return s.NewScope(s.Value).
284400 Set("gorm:update_column", true).
285401 Set("gorm:save_associations", false).
286402 InstanceSet("gorm:update_interface", values).
287 callCallbacks(s.parent.callback.updates).db
288 }
289
403 callCallbacks(s.parent.callbacks.updates).db
404 }
405
406 // Save update value in database, if the value doesn't have primary key, will insert it
290407 func (s *DB) Save(value interface{}) *DB {
291 scope := s.clone().NewScope(value)
292 if scope.PrimaryKeyZero() {
293 return scope.callCallbacks(s.parent.callback.creates).db
294 }
295 return scope.callCallbacks(s.parent.callback.updates).db
296 }
297
408 scope := s.NewScope(value)
409 if !scope.PrimaryKeyZero() {
410 newDB := scope.callCallbacks(s.parent.callbacks.updates).db
411 if newDB.Error == nil && newDB.RowsAffected == 0 {
412 return s.New().FirstOrCreate(value)
413 }
414 return newDB
415 }
416 return scope.callCallbacks(s.parent.callbacks.creates).db
417 }
418
419 // Create insert the value into database
298420 func (s *DB) Create(value interface{}) *DB {
299 scope := s.clone().NewScope(value)
300 return scope.callCallbacks(s.parent.callback.creates).db
301 }
302
421 scope := s.NewScope(value)
422 return scope.callCallbacks(s.parent.callbacks.creates).db
423 }
424
425 // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
303426 func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
304 return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callback.deletes).db
305 }
306
427 return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
428 }
429
430 // Raw use raw sql as conditions, won't run it unless invoked by other methods
431 // db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result)
307432 func (s *DB) Raw(sql string, values ...interface{}) *DB {
308433 return s.clone().search.Raw(true).Where(sql, values...).db
309434 }
310435
436 // Exec execute raw sql
311437 func (s *DB) Exec(sql string, values ...interface{}) *DB {
312 scope := s.clone().NewScope(nil)
313 generatedSql := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
314 generatedSql = strings.TrimSuffix(strings.TrimPrefix(generatedSql, "("), ")")
315 scope.Raw(generatedSql)
438 scope := s.NewScope(nil)
439 generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true)
440 generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")")
441 scope.Raw(generatedSQL)
316442 return scope.Exec().db
317443 }
318444
445 // Model specify the model you would like to run db operations
446 // // update all users's name to `hello`
447 // db.Model(&User{}).Update("name", "hello")
448 // // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
449 // db.Model(&user).Update("name", "hello")
319450 func (s *DB) Model(value interface{}) *DB {
320451 c := s.clone()
321452 c.Value = value
322453 return c
323454 }
324455
456 // Table specify the table you would like to run db operations
325457 func (s *DB) Table(name string) *DB {
326458 clone := s.clone()
327459 clone.search.Table(name)
329461 return clone
330462 }
331463
464 // Debug start debug mode
332465 func (s *DB) Debug() *DB {
333466 return s.clone().LogMode(true)
334467 }
335468
469 // Begin begin a transaction
336470 func (s *DB) Begin() *DB {
337471 c := s.clone()
338 if db, ok := c.db.(sqlDb); ok {
472 if db, ok := c.db.(sqlDb); ok && db != nil {
339473 tx, err := db.Begin()
340 c.db = interface{}(tx).(sqlCommon)
474 c.db = interface{}(tx).(SQLCommon)
341475 c.AddError(err)
342476 } else {
343 c.AddError(CantStartTransaction)
477 c.AddError(ErrCantStartTransaction)
344478 }
345479 return c
346480 }
347481
482 // Commit commit a transaction
348483 func (s *DB) Commit() *DB {
349 if db, ok := s.db.(sqlTx); ok {
484 if db, ok := s.db.(sqlTx); ok && db != nil {
350485 s.AddError(db.Commit())
351486 } else {
352 s.AddError(NoValidTransaction)
487 s.AddError(ErrInvalidTransaction)
353488 }
354489 return s
355490 }
356491
492 // Rollback rollback a transaction
357493 func (s *DB) Rollback() *DB {
358 if db, ok := s.db.(sqlTx); ok {
494 if db, ok := s.db.(sqlTx); ok && db != nil {
359495 s.AddError(db.Rollback())
360496 } else {
361 s.AddError(NoValidTransaction)
497 s.AddError(ErrInvalidTransaction)
362498 }
363499 return s
364500 }
365501
502 // NewRecord check if value's primary key is blank
366503 func (s *DB) NewRecord(value interface{}) bool {
367 return s.clone().NewScope(value).PrimaryKeyZero()
368 }
369
504 return s.NewScope(value).PrimaryKeyZero()
505 }
506
507 // RecordNotFound check if returning ErrRecordNotFound error
370508 func (s *DB) RecordNotFound() bool {
371 return s.Error == RecordNotFound
372 }
373
374 // Migrations
375 func (s *DB) CreateTable(values ...interface{}) *DB {
376 db := s.clone()
377 for _, value := range values {
378 db = db.NewScope(value).createTable().db
509 for _, err := range s.GetErrors() {
510 if err == ErrRecordNotFound {
511 return true
512 }
513 }
514 return false
515 }
516
517 // CreateTable create table for models
518 func (s *DB) CreateTable(models ...interface{}) *DB {
519 db := s.Unscoped()
520 for _, model := range models {
521 db = db.NewScope(model).createTable().db
379522 }
380523 return db
381524 }
382525
526 // DropTable drop table for models
383527 func (s *DB) DropTable(values ...interface{}) *DB {
384528 db := s.clone()
385529 for _, value := range values {
530 if tableName, ok := value.(string); ok {
531 db = db.Table(tableName)
532 }
533
386534 db = db.NewScope(value).dropTable().db
387535 }
388536 return db
389537 }
390538
539 // DropTableIfExists drop table if it is exist
391540 func (s *DB) DropTableIfExists(values ...interface{}) *DB {
392541 db := s.clone()
393542 for _, value := range values {
394 db = db.NewScope(value).dropTableIfExists().db
543 if s.HasTable(value) {
544 db.AddError(s.DropTable(value).Error)
545 }
395546 }
396547 return db
397548 }
398549
550 // HasTable check has table or not
399551 func (s *DB) HasTable(value interface{}) bool {
400 scope := s.clone().NewScope(value)
401 tableName := scope.TableName()
402 has := scope.Dialect().HasTable(scope, tableName)
552 var (
553 scope = s.NewScope(value)
554 tableName string
555 )
556
557 if name, ok := value.(string); ok {
558 tableName = name
559 } else {
560 tableName = scope.TableName()
561 }
562
563 has := scope.Dialect().HasTable(tableName)
403564 s.AddError(scope.db.Error)
404565 return has
405566 }
406567
568 // AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data
407569 func (s *DB) AutoMigrate(values ...interface{}) *DB {
408 db := s.clone()
570 db := s.Unscoped()
409571 for _, value := range values {
410 db = db.NewScope(value).NeedPtr().autoMigrate().db
572 db = db.NewScope(value).autoMigrate().db
411573 }
412574 return db
413575 }
414576
577 // ModifyColumn modify column to type
415578 func (s *DB) ModifyColumn(column string, typ string) *DB {
416 scope := s.clone().NewScope(s.Value)
579 scope := s.NewScope(s.Value)
417580 scope.modifyColumn(column, typ)
418581 return scope.db
419582 }
420583
584 // DropColumn drop a column
421585 func (s *DB) DropColumn(column string) *DB {
422 scope := s.clone().NewScope(s.Value)
586 scope := s.NewScope(s.Value)
423587 scope.dropColumn(column)
424588 return scope.db
425589 }
426590
427 func (s *DB) AddIndex(indexName string, column ...string) *DB {
428 scope := s.clone().NewScope(s.Value)
429 scope.addIndex(false, indexName, column...)
591 // AddIndex add index for columns with given name
592 func (s *DB) AddIndex(indexName string, columns ...string) *DB {
593 scope := s.Unscoped().NewScope(s.Value)
594 scope.addIndex(false, indexName, columns...)
430595 return scope.db
431596 }
432597
433 func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB {
434 scope := s.clone().NewScope(s.Value)
435 scope.addIndex(true, indexName, column...)
598 // AddUniqueIndex add unique index for columns with given name
599 func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB {
600 scope := s.Unscoped().NewScope(s.Value)
601 scope.addIndex(true, indexName, columns...)
436602 return scope.db
437603 }
438604
605 // RemoveIndex remove index with name
439606 func (s *DB) RemoveIndex(indexName string) *DB {
440 scope := s.clone().NewScope(s.Value)
607 scope := s.NewScope(s.Value)
441608 scope.removeIndex(indexName)
442609 return scope.db
443610 }
444611
445 func (s *DB) CurrentDatabase() string {
446 var (
447 scope = s.clone().NewScope(s.Value)
448 name = s.dialect.CurrentDatabase(scope)
449 )
450 return name
451 }
452
453 /*
454 Add foreign key to the given scope
455
456 Example:
457 db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
458 */
612 // AddForeignKey Add foreign key to the given scope, e.g:
613 // db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
459614 func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
460 scope := s.clone().NewScope(s.Value)
615 scope := s.NewScope(s.Value)
461616 scope.addForeignKey(field, dest, onDelete, onUpdate)
462617 return scope.db
463618 }
464619
620 // RemoveForeignKey Remove foreign key from the given scope, e.g:
621 // db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)")
622 func (s *DB) RemoveForeignKey(field string, dest string) *DB {
623 scope := s.clone().NewScope(s.Value)
624 scope.removeForeignKey(field, dest)
625 return scope.db
626 }
627
628 // Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode
465629 func (s *DB) Association(column string) *Association {
466630 var err error
467 scope := s.clone().NewScope(s.Value)
631 var scope = s.Set("gorm:association:source", s.Value).NewScope(s.Value)
468632
469633 if primaryField := scope.PrimaryField(); primaryField.IsBlank {
470634 err = errors.New("primary key can't be nil")
473637 if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
474638 err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
475639 } else {
476 return &Association{Scope: scope, Column: column, Field: field}
640 return &Association{scope: scope, column: column, field: field}
477641 }
478642 } else {
479643 err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
483647 return &Association{Error: err}
484648 }
485649
650 // Preload preload associations with given conditions
651 // db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
486652 func (s *DB) Preload(column string, conditions ...interface{}) *DB {
487653 return s.clone().search.Preload(column, conditions...).db
488654 }
489655
490 // Set set value by name
656 // Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting
491657 func (s *DB) Set(name string, value interface{}) *DB {
492658 return s.clone().InstantSet(name, value)
493659 }
494660
661 // InstantSet instant set setting, will affect current db
495662 func (s *DB) InstantSet(name string, value interface{}) *DB {
496663 s.values[name] = value
497664 return s
498665 }
499666
500 // Get get value by name
667 // Get get setting by name
501668 func (s *DB) Get(name string) (value interface{}, ok bool) {
502669 value, ok = s.values[name]
503670 return
504671 }
505672
673 // SetJoinTableHandler set a model's join table handler for a relation
506674 func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
507675 scope := s.NewScope(source)
508676 for _, field := range scope.GetModelStruct().StructFields {
509677 if field.Name == column || field.DBName == column {
510 if many2many := parseTagSetting(field.Tag.Get("gorm"))["MANY2MANY"]; many2many != "" {
678 if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
511679 source := (&Scope{Value: source}).GetModelStruct().ModelType
512680 destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
513681 handler.Setup(field.Relationship, many2many, source, destination)
514682 field.Relationship.JoinTableHandler = handler
515 if table := handler.Table(s); scope.Dialect().HasTable(scope, table) {
683 if table := handler.Table(s); scope.Dialect().HasTable(table) {
516684 s.Table(table).AutoMigrate(handler)
517685 }
518686 }
520688 }
521689 }
522690
691 // AddError add error to the db
523692 func (s *DB) AddError(err error) error {
524693 if err != nil {
525 if err != RecordNotFound {
694 if err != ErrRecordNotFound {
526695 if s.logMode == 0 {
527696 go s.print(fileWithLineNum(), err)
528697 } else {
529698 s.log(err)
530699 }
531700
532 errors := Errors{errors: s.GetErrors()}
533 errors.Add(err)
534 if len(errors.GetErrors()) > 1 {
701 errors := Errors(s.GetErrors())
702 errors = errors.Add(err)
703 if len(errors) > 1 {
535704 err = errors
536705 }
537706 }
541710 return err
542711 }
543712
544 func (s *DB) GetErrors() (errors []error) {
545 if errs, ok := s.Error.(errorsInterface); ok {
546 return errs.GetErrors()
713 // GetErrors get happened errors from the db
714 func (s *DB) GetErrors() []error {
715 if errs, ok := s.Error.(Errors); ok {
716 return errs
547717 } else if s.Error != nil {
548718 return []error{s.Error}
549719 }
550 return
551 }
720 return []error{}
721 }
722
723 ////////////////////////////////////////////////////////////////////////////////
724 // Private Methods For DB
725 ////////////////////////////////////////////////////////////////////////////////
726
727 func (s *DB) clone() *DB {
728 db := &DB{
729 db: s.db,
730 parent: s.parent,
731 logger: s.logger,
732 logMode: s.logMode,
733 values: map[string]interface{}{},
734 Value: s.Value,
735 Error: s.Error,
736 blockGlobalUpdate: s.blockGlobalUpdate,
737 }
738
739 for key, value := range s.values {
740 db.values[key] = value
741 }
742
743 if s.search == nil {
744 db.search = &search{limit: -1, offset: -1}
745 } else {
746 db.search = s.search.clone()
747 }
748
749 db.search.db = db
750 return db
751 }
752
753 func (s *DB) print(v ...interface{}) {
754 s.logger.Print(v...)
755 }
756
757 func (s *DB) log(v ...interface{}) {
758 if s != nil && s.logMode == 2 {
759 s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
760 }
761 }
762
763 func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
764 if s.logMode == 2 {
765 s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected)
766 }
767 }
+0
-36
main_private.go less more
0 package gorm
1
2 import "time"
3
4 func (s *DB) clone() *DB {
5 db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error}
6
7 for key, value := range s.values {
8 db.values[key] = value
9 }
10
11 if s.search == nil {
12 db.search = &search{}
13 } else {
14 db.search = s.search.clone()
15 }
16
17 db.search.db = &db
18 return &db
19 }
20
21 func (s *DB) print(v ...interface{}) {
22 s.logger.(logger).Print(v...)
23 }
24
25 func (s *DB) log(v ...interface{}) {
26 if s != nil && s.logMode == 2 {
27 s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
28 }
29 }
30
31 func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
32 if s.logMode == 2 {
33 s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars)
34 }
35 }
33 "database/sql"
44 "database/sql/driver"
55 "fmt"
6 "os"
7 "path/filepath"
8 "reflect"
69 "strconv"
7
8 _ "github.com/denisenkom/go-mssqldb"
9 testdb "github.com/erikstmartin/go-testdb"
10 _ "github.com/go-sql-driver/mysql"
11 "github.com/jinzhu/gorm"
12 "github.com/jinzhu/now"
13 _ "github.com/lib/pq"
14 _ "github.com/mattn/go-sqlite3"
15
16 "os"
1710 "testing"
1811 "time"
12
13 "github.com/erikstmartin/go-testdb"
14 "github.com/jinzhu/gorm"
15 _ "github.com/jinzhu/gorm/dialects/mssql"
16 _ "github.com/jinzhu/gorm/dialects/mysql"
17 "github.com/jinzhu/gorm/dialects/postgres"
18 _ "github.com/jinzhu/gorm/dialects/sqlite"
19 "github.com/jinzhu/now"
1920 )
2021
2122 var (
22 DB gorm.DB
23 DB *gorm.DB
2324 t1, t2, t3, t4, t5 time.Time
2425 )
2526
3031 panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err))
3132 }
3233
33 // DB.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)})
34 // DB.SetLogger(log.New(os.Stdout, "\r\n", 0))
35 // DB.LogMode(true)
36 DB.LogMode(false)
37
38 DB.DB().SetMaxIdleConns(10)
39
4034 runMigration()
4135 }
4236
43 func OpenTestConnection() (db gorm.DB, err error) {
37 func OpenTestConnection() (db *gorm.DB, err error) {
38 dbDSN := os.Getenv("GORM_DSN")
4439 switch os.Getenv("GORM_DIALECT") {
4540 case "mysql":
46 // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm';
47 // CREATE DATABASE gorm;
48 // GRANT ALL ON gorm.* TO 'gorm'@'localhost';
4941 fmt.Println("testing mysql...")
50 db, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True")
42 if dbDSN == "" {
43 dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"
44 }
45 db, err = gorm.Open("mysql", dbDSN)
5146 case "postgres":
5247 fmt.Println("testing postgres...")
53 db, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable")
54 case "foundation":
55 fmt.Println("testing foundation...")
56 db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable")
48 if dbDSN == "" {
49 dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
50 }
51 db, err = gorm.Open("postgres", dbDSN)
5752 case "mssql":
53 // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';
54 // CREATE DATABASE gorm;
55 // USE gorm;
56 // CREATE USER gorm FROM LOGIN gorm;
57 // sp_changedbowner 'gorm';
5858 fmt.Println("testing mssql...")
59 db, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433")
59 if dbDSN == "" {
60 dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
61 }
62 db, err = gorm.Open("mssql", dbDSN)
6063 default:
6164 fmt.Println("testing sqlite3...")
62 db, err = gorm.Open("sqlite3", "/tmp/gorm.db")
63 }
65 db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db"))
66 }
67
68 // db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)})
69 // db.SetLogger(log.New(os.Stdout, "\r\n", 0))
70 if debug := os.Getenv("DEBUG"); debug == "true" {
71 db.LogMode(true)
72 } else if debug == "false" {
73 db.LogMode(false)
74 }
75
76 db.DB().SetMaxIdleConns(10)
77
6478 return
6579 }
6680
6983 ID string `gorm:"primary_key"`
7084 Name string
7185 }
86 DB.DropTable(&UUIDStruct{})
7287 DB.AutoMigrate(&UUIDStruct{})
7388
7489 data := UUIDStruct{ID: "uuid", Name: "hello"}
75 if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" {
90 if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.Name != "hello" {
91 t.Errorf("string primary key should not be populated")
92 }
93
94 data = UUIDStruct{ID: "uuid", Name: "hello world"}
95 if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.Name != "hello world" {
7696 t.Errorf("string primary key should not be populated")
7797 }
7898 }
113133 DB.Create(getPreparedUser("pluck_user3", "pluck_user"))
114134
115135 if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil {
116 t.Errorf("No errors should happen if set table for pluck", err.Error())
136 t.Error("No errors should happen if set table for pluck", err)
117137 }
118138
119139 var users []User
163183 Stuff string
164184 }
165185 DB.DropTable(&Foo{})
186
187 // Table should not exist at this point, HasTable should return false
188 if ok := DB.HasTable("foos"); ok {
189 t.Errorf("Table should not exist, but does")
190 }
166191 if ok := DB.HasTable(&Foo{}); ok {
167192 t.Errorf("Table should not exist, but does")
168193 }
194
195 // We create the table
169196 if err := DB.CreateTable(&Foo{}).Error; err != nil {
170197 t.Errorf("Table should be created")
198 }
199
200 // And now it should exits, and HasTable should return true
201 if ok := DB.HasTable("foos"); !ok {
202 t.Errorf("Table should exist, but HasTable informs it does not")
171203 }
172204 if ok := DB.HasTable(&Foo{}); !ok {
173205 t.Errorf("Table should exist, but HasTable informs it does not")
227259 DB.SingularTable(false)
228260 }
229261
230 func TestSqlNullValue(t *testing.T) {
262 func TestNullValues(t *testing.T) {
231263 DB.DropTable(&NullValue{})
232264 DB.AutoMigrate(&NullValue{})
233265
234 if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello", Valid: true},
266 if err := DB.Save(&NullValue{
267 Name: sql.NullString{String: "hello", Valid: true},
268 Gender: &sql.NullString{String: "M", Valid: true},
235269 Age: sql.NullInt64{Int64: 18, Valid: true},
236270 Male: sql.NullBool{Bool: true, Valid: true},
237271 Height: sql.NullFloat64{Float64: 100.11, Valid: true},
243277 var nv NullValue
244278 DB.First(&nv, "name = ?", "hello")
245279
246 if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
280 if nv.Name.String != "hello" || nv.Gender.String != "M" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
247281 t.Errorf("Should be able to fetch null value")
248282 }
249283
250 if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello-2", Valid: true},
284 if err := DB.Save(&NullValue{
285 Name: sql.NullString{String: "hello-2", Valid: true},
286 Gender: &sql.NullString{String: "F", Valid: true},
251287 Age: sql.NullInt64{Int64: 18, Valid: false},
252288 Male: sql.NullBool{Bool: true, Valid: true},
253289 Height: sql.NullFloat64{Float64: 100.11, Valid: true},
258294
259295 var nv2 NullValue
260296 DB.First(&nv2, "name = ?", "hello-2")
261 if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
297 if nv2.Name.String != "hello-2" || nv2.Gender.String != "F" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
262298 t.Errorf("Should be able to fetch null value")
263299 }
264300
265 if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello-3", Valid: false},
301 if err := DB.Save(&NullValue{
302 Name: sql.NullString{String: "hello-3", Valid: false},
303 Gender: &sql.NullString{String: "M", Valid: true},
266304 Age: sql.NullInt64{Int64: 18, Valid: false},
267305 Male: sql.NullBool{Bool: true, Valid: true},
268306 Height: sql.NullFloat64{Float64: 100.11, Valid: true},
272310 }
273311 }
274312
313 func TestNullValuesWithFirstOrCreate(t *testing.T) {
314 var nv1 = NullValue{
315 Name: sql.NullString{String: "first_or_create", Valid: true},
316 Gender: &sql.NullString{String: "M", Valid: true},
317 }
318
319 var nv2 NullValue
320 result := DB.Where(nv1).FirstOrCreate(&nv2)
321
322 if result.RowsAffected != 1 {
323 t.Errorf("RowsAffected should be 1 after create some record")
324 }
325
326 if result.Error != nil {
327 t.Errorf("Should not raise any error, but got %v", result.Error)
328 }
329
330 if nv2.Name.String != "first_or_create" || nv2.Gender.String != "M" {
331 t.Errorf("first or create with nullvalues")
332 }
333
334 if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil {
335 t.Errorf("Should not raise any error, but got %v", err)
336 }
337
338 if nv2.Age.Int64 != 18 {
339 t.Errorf("should update age to 18")
340 }
341 }
342
275343 func TestTransaction(t *testing.T) {
276344 tx := DB.Begin()
277345 u := User{Name: "transcation"}
311379 }
312380
313381 func TestRow(t *testing.T) {
314 user1 := User{Name: "RowUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
315 user2 := User{Name: "RowUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
316 user3 := User{Name: "RowUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
382 user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")}
383 user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}
384 user3 := User{Name: "RowUser3", Age: 20, Birthday: parseTime("2020-1-1")}
317385 DB.Save(&user1).Save(&user2).Save(&user3)
318386
319387 row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row()
325393 }
326394
327395 func TestRows(t *testing.T) {
328 user1 := User{Name: "RowsUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
329 user2 := User{Name: "RowsUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
330 user3 := User{Name: "RowsUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
396 user1 := User{Name: "RowsUser1", Age: 1, Birthday: parseTime("2000-1-1")}
397 user2 := User{Name: "RowsUser2", Age: 10, Birthday: parseTime("2010-1-1")}
398 user3 := User{Name: "RowsUser3", Age: 20, Birthday: parseTime("2020-1-1")}
331399 DB.Save(&user1).Save(&user2).Save(&user3)
332400
333401 rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
334402 if err != nil {
335 t.Errorf("Not error should happen, but got")
403 t.Errorf("Not error should happen, got %v", err)
336404 }
337405
338406 count := 0
342410 rows.Scan(&name, &age)
343411 count++
344412 }
413
345414 if count != 2 {
346 t.Errorf("Should found two records with name 3")
415 t.Errorf("Should found two records")
416 }
417 }
418
419 func TestScanRows(t *testing.T) {
420 user1 := User{Name: "ScanRowsUser1", Age: 1, Birthday: parseTime("2000-1-1")}
421 user2 := User{Name: "ScanRowsUser2", Age: 10, Birthday: parseTime("2010-1-1")}
422 user3 := User{Name: "ScanRowsUser3", Age: 20, Birthday: parseTime("2020-1-1")}
423 DB.Save(&user1).Save(&user2).Save(&user3)
424
425 rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
426 if err != nil {
427 t.Errorf("Not error should happen, got %v", err)
428 }
429
430 type Result struct {
431 Name string
432 Age int
433 }
434
435 var results []Result
436 for rows.Next() {
437 var result Result
438 if err := DB.ScanRows(rows, &result); err != nil {
439 t.Errorf("should get no error, but got %v", err)
440 }
441 results = append(results, result)
442 }
443
444 if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
445 t.Errorf("Should find expected results")
347446 }
348447 }
349448
350449 func TestScan(t *testing.T) {
351 user1 := User{Name: "ScanUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
352 user2 := User{Name: "ScanUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
353 user3 := User{Name: "ScanUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
450 user1 := User{Name: "ScanUser1", Age: 1, Birthday: parseTime("2000-1-1")}
451 user2 := User{Name: "ScanUser2", Age: 10, Birthday: parseTime("2010-1-1")}
452 user3 := User{Name: "ScanUser3", Age: 20, Birthday: parseTime("2020-1-1")}
354453 DB.Save(&user1).Save(&user2).Save(&user3)
355454
356455 type result struct {
364463 t.Errorf("Scan into struct should work")
365464 }
366465
367 var doubleAgeRes result
368 DB.Table("users").Select("age + age as age").Where("name = ?", user3.Name).Scan(&doubleAgeRes)
466 var doubleAgeRes = &result{}
467 if err := DB.Table("users").Select("age + age as age").Where("name = ?", user3.Name).Scan(&doubleAgeRes).Error; err != nil {
468 t.Errorf("Scan to pointer of pointer")
469 }
369470 if doubleAgeRes.Age != res.Age*2 {
370471 t.Errorf("Scan double age as age")
371472 }
378479 }
379480
380481 func TestRaw(t *testing.T) {
381 user1 := User{Name: "ExecRawSqlUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
382 user2 := User{Name: "ExecRawSqlUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
383 user3 := User{Name: "ExecRawSqlUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
482 user1 := User{Name: "ExecRawSqlUser1", Age: 1, Birthday: parseTime("2000-1-1")}
483 user2 := User{Name: "ExecRawSqlUser2", Age: 10, Birthday: parseTime("2010-1-1")}
484 user3 := User{Name: "ExecRawSqlUser3", Age: 20, Birthday: parseTime("2020-1-1")}
384485 DB.Save(&user1).Save(&user2).Save(&user3)
385486
386487 type result struct {
404505 }
405506
406507 DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
407 if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.RecordNotFound {
508 if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound {
408509 t.Error("Raw sql to update records")
409510 }
410511 }
425526
426527 func TestJoins(t *testing.T) {
427528 var user = User{
428 Name: "joins",
429 Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
430 }
431 DB.Save(&user)
432
433 var result User
434 DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").First(&result)
435 if result.Name != "joins" || result.Id != user.Id {
436 t.Errorf("Should find all two emails with Join")
529 Name: "joins",
530 CreditCard: CreditCard{Number: "411111111111"},
531 Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
532 }
533 DB.Save(&user)
534
535 var users1 []User
536 DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Find(&users1)
537 if len(users1) != 2 {
538 t.Errorf("should find two users using left join")
539 }
540
541 var users2 []User
542 DB.Joins("left join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Where("name = ?", "joins").First(&users2)
543 if len(users2) != 1 {
544 t.Errorf("should find one users using left join with conditions")
545 }
546
547 var users3 []User
548 DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where("name = ?", "joins").First(&users3)
549 if len(users3) != 1 {
550 t.Errorf("should find one users using multiple left join conditions")
551 }
552
553 var users4 []User
554 DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "422222222222").Where("name = ?", "joins").First(&users4)
555 if len(users4) != 0 {
556 t.Errorf("should find no user when searching with unexisting credit card")
557 }
558
559 var users5 []User
560 db5 := DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where(User{Id: 1}).Where(Email{Id: 1}).Not(Email{Id: 10}).First(&users5)
561 if db5.Error != nil {
562 t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error())
437563 }
438564 }
439565
450576 DB.Save(&user)
451577
452578 var results []result
453 DB.Table("users").Select("name, email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results)
579 DB.Table("users").Select("name, emails.email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results)
454580 if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" {
455581 t.Errorf("Should find all two emails with Join select")
456582 }
478604 }
479605 }
480606
607 func TestQueryBuilderSubselectInWhere(t *testing.T) {
608 user := User{Name: "query_expr_select_ruser1", Email: "root@user1.com", Age: 32}
609 DB.Save(&user)
610 user = User{Name: "query_expr_select_ruser2", Email: "nobody@user2.com", Age: 16}
611 DB.Save(&user)
612 user = User{Name: "query_expr_select_ruser3", Email: "root@user3.com", Age: 64}
613 DB.Save(&user)
614 user = User{Name: "query_expr_select_ruser4", Email: "somebody@user3.com", Age: 128}
615 DB.Save(&user)
616
617 var users []User
618 DB.Select("*").Where("name IN (?)", DB.
619 Select("name").Table("users").Where("name LIKE ?", "query_expr_select%").QueryExpr()).Find(&users)
620
621 if len(users) != 4 {
622 t.Errorf("Four users should be found, instead found %d", len(users))
623 }
624
625 DB.Select("*").Where("name LIKE ?", "query_expr_select%").Where("age >= (?)", DB.
626 Select("AVG(age)").Table("users").Where("name LIKE ?", "query_expr_select%").QueryExpr()).Find(&users)
627
628 if len(users) != 2 {
629 t.Errorf("Two users should be found, instead found %d", len(users))
630 }
631 }
632
633 func TestQueryBuilderRawQueryWithSubquery(t *testing.T) {
634 user := User{Name: "subquery_test_user1", Age: 10}
635 DB.Save(&user)
636 user = User{Name: "subquery_test_user2", Age: 11}
637 DB.Save(&user)
638 user = User{Name: "subquery_test_user3", Age: 12}
639 DB.Save(&user)
640
641 var count int
642 err := DB.Raw("select count(*) from (?) tmp",
643 DB.Table("users").
644 Select("name").
645 Where("age >= ? and name in (?)", 10, []string{"subquery_test_user1", "subquery_test_user2"}).
646 Group("name").
647 QueryExpr(),
648 ).Count(&count).Error
649
650 if err != nil {
651 t.Errorf("Expected to get no errors, but got %v", err)
652 }
653 if count != 2 {
654 t.Errorf("Row count must be 2, instead got %d", count)
655 }
656
657 err = DB.Raw("select count(*) from (?) tmp",
658 DB.Table("users").
659 Select("name").
660 Where("name LIKE ?", "subquery_test%").
661 Not("age <= ?", 10).Not("name in (?)", []string{"subquery_test_user1", "subquery_test_user2"}).
662 Group("name").
663 QueryExpr(),
664 ).Count(&count).Error
665
666 if err != nil {
667 t.Errorf("Expected to get no errors, but got %v", err)
668 }
669 if count != 1 {
670 t.Errorf("Row count must be 1, instead got %d", count)
671 }
672 }
673
674 func TestQueryBuilderSubselectInHaving(t *testing.T) {
675 user := User{Name: "query_expr_having_ruser1", Email: "root@user1.com", Age: 64}
676 DB.Save(&user)
677 user = User{Name: "query_expr_having_ruser2", Email: "root@user2.com", Age: 128}
678 DB.Save(&user)
679 user = User{Name: "query_expr_having_ruser3", Email: "root@user1.com", Age: 64}
680 DB.Save(&user)
681 user = User{Name: "query_expr_having_ruser4", Email: "root@user2.com", Age: 128}
682 DB.Save(&user)
683
684 var users []User
685 DB.Select("AVG(age) as avgage").Where("name LIKE ?", "query_expr_having_%").Group("email").Having("AVG(age) > (?)", DB.
686 Select("AVG(age)").Where("name LIKE ?", "query_expr_having_%").Table("users").QueryExpr()).Find(&users)
687
688 if len(users) != 1 {
689 t.Errorf("Two user group should be found, instead found %d", len(users))
690 }
691 }
692
481693 func DialectHasTzSupport() bool {
482694 // NB: mssql and FoundationDB do not support time zones.
483 if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" {
695 if dialect := os.Getenv("GORM_DIALECT"); dialect == "foundation" {
484696 return false
485697 }
486698 return true
495707
496708 for index, vtime := range times {
497709 name := "time_with_zone_" + strconv.Itoa(index)
498 user := User{Name: name, Birthday: vtime}
710 user := User{Name: name, Birthday: &vtime}
499711
500712 if !DialectHasTzSupport() {
501713 // If our driver dialect doesn't support TZ's, just use UTC for everything here.
502 user.Birthday = vtime.UTC()
714 utcBirthday := user.Birthday.UTC()
715 user.Birthday = &utcBirthday
503716 }
504717
505718 DB.Save(&user)
513726 DB.First(&findUser, "name = ?", name)
514727 foundBirthday = findUser.Birthday.UTC().Format(format)
515728 if foundBirthday != expectedBirthday {
516 t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v or %+v", name, expectedBirthday, foundBirthday)
729 t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday)
517730 }
518731
519732 if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
529742 func TestHstore(t *testing.T) {
530743 type Details struct {
531744 Id int64
532 Bulk gorm.Hstore
745 Bulk postgres.Hstore
533746 }
534747
535748 if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
615828 }
616829
617830 var user User
618 if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound {
831 if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound {
619832 t.Errorf("Should have found existing record")
833 }
834 }
835
836 func TestDdlErrors(t *testing.T) {
837 var err error
838
839 if err = DB.Close(); err != nil {
840 t.Errorf("Closing DDL test db connection err=%s", err)
841 }
842 defer func() {
843 // Reopen DB connection.
844 if DB, err = OpenTestConnection(); err != nil {
845 t.Fatalf("Failed re-opening db connection: %s", err)
846 }
847 }()
848
849 if err := DB.Find(&User{}).Error; err == nil {
850 t.Errorf("Expected operation on closed db to produce an error, but err was nil")
851 }
852 }
853
854 func TestOpenWithOneParameter(t *testing.T) {
855 db, err := gorm.Open("dialect")
856 if db != nil {
857 t.Error("Open with one parameter returned non nil for db")
858 }
859 if err == nil {
860 t.Error("Open with one parameter returned err as nil")
861 }
862 }
863
864 func TestBlockGlobalUpdate(t *testing.T) {
865 db := DB.New()
866 db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
867
868 err := db.Model(&Toy{}).Update("OwnerType", "Human").Error
869 if err != nil {
870 t.Error("Unexpected error on global update")
871 }
872
873 err = db.Delete(&Toy{}).Error
874 if err != nil {
875 t.Error("Unexpected error on global delete")
876 }
877
878 db.BlockGlobalUpdate(true)
879
880 db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
881
882 err = db.Model(&Toy{}).Update("OwnerType", "Human").Error
883 if err == nil {
884 t.Error("Expected error on global update")
885 }
886
887 err = db.Model(&Toy{}).Where(&Toy{OwnerType: "Martian"}).Update("OwnerType", "Astronaut").Error
888 if err != nil {
889 t.Error("Unxpected error on conditional update")
890 }
891
892 err = db.Delete(&Toy{}).Error
893 if err == nil {
894 t.Error("Expected error on global delete")
895 }
896 err = db.Where(&Toy{OwnerType: "Martian"}).Delete(&Toy{}).Error
897 if err != nil {
898 t.Error("Unexpected error on conditional delete")
620899 }
621900 }
622901
624903 b.N = 2000
625904 for x := 0; x < b.N; x++ {
626905 e := strconv.Itoa(x) + "benchmark@example.org"
627 email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
906 now := time.Now()
907 email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now}
628908 // Insert
629909 DB.Save(&email)
630910 // Query
631 DB.First(&BigEmail{}, "email = ?", e)
911 DB.First(&EmailWithIdx{}, "email = ?", e)
632912 // Update
633913 DB.Model(&email).UpdateColumn("email", "new-"+e)
634914 // Delete
648928 for x := 0; x < b.N; x++ {
649929 var id int64
650930 e := strconv.Itoa(x) + "benchmark@example.org"
651 email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
931 now := time.Now()
932 email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now}
652933 // Insert
653934 DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
654935 // Query
660941 DB.Exec(deleteSql, id)
661942 }
662943 }
944
945 func parseTime(str string) *time.Time {
946 t := now.New(time.Now().UTC()).MustParse(str)
947 return &t
948 }
00 package gorm_test
11
22 import (
3 "database/sql"
4 "database/sql/driver"
5 "errors"
36 "fmt"
7 "os"
8 "reflect"
9 "strconv"
410 "testing"
511 "time"
12
13 "github.com/jinzhu/gorm"
614 )
15
16 type User struct {
17 Id int64
18 Age int64
19 UserNum Num
20 Name string `sql:"size:255"`
21 Email string
22 Birthday *time.Time // Time
23 CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
24 UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
25 Emails []Email // Embedded structs
26 BillingAddress Address // Embedded struct
27 BillingAddressID sql.NullInt64 // Embedded struct's foreign key
28 ShippingAddress Address // Embedded struct
29 ShippingAddressId int64 // Embedded struct's foreign key
30 CreditCard CreditCard
31 Latitude float64
32 Languages []Language `gorm:"many2many:user_languages;"`
33 CompanyID *int
34 Company Company
35 Role Role
36 Password EncryptedData
37 PasswordHash []byte
38 IgnoreMe int64 `sql:"-"`
39 IgnoreStringSlice []string `sql:"-"`
40 Ignored struct{ Name string } `sql:"-"`
41 IgnoredPointer *User `sql:"-"`
42 }
43
44 type NotSoLongTableName struct {
45 Id int64
46 ReallyLongThingID int64
47 ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit
48 }
49
50 type ReallyLongTableNameToTestMySQLNameLengthLimit struct {
51 Id int64
52 }
53
54 type ReallyLongThingThatReferencesShort struct {
55 Id int64
56 ShortID int64
57 Short Short
58 }
59
60 type Short struct {
61 Id int64
62 }
63
64 type CreditCard struct {
65 ID int8
66 Number string
67 UserId sql.NullInt64
68 CreatedAt time.Time `sql:"not null"`
69 UpdatedAt time.Time
70 DeletedAt *time.Time `sql:"column:deleted_time"`
71 }
72
73 type Email struct {
74 Id int16
75 UserId int
76 Email string `sql:"type:varchar(100);"`
77 CreatedAt time.Time
78 UpdatedAt time.Time
79 }
80
81 type Address struct {
82 ID int
83 Address1 string
84 Address2 string
85 Post string
86 CreatedAt time.Time
87 UpdatedAt time.Time
88 DeletedAt *time.Time
89 }
90
91 type Language struct {
92 gorm.Model
93 Name string
94 Users []User `gorm:"many2many:user_languages;"`
95 }
96
97 type Product struct {
98 Id int64
99 Code string
100 Price int64
101 CreatedAt time.Time
102 UpdatedAt time.Time
103 AfterFindCallTimes int64
104 BeforeCreateCallTimes int64
105 AfterCreateCallTimes int64
106 BeforeUpdateCallTimes int64
107 AfterUpdateCallTimes int64
108 BeforeSaveCallTimes int64
109 AfterSaveCallTimes int64
110 BeforeDeleteCallTimes int64
111 AfterDeleteCallTimes int64
112 }
113
114 type Company struct {
115 Id int64
116 Name string
117 Owner *User `sql:"-"`
118 }
119
120 type EncryptedData []byte
121
122 func (data *EncryptedData) Scan(value interface{}) error {
123 if b, ok := value.([]byte); ok {
124 if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' {
125 return errors.New("Too short")
126 }
127
128 *data = b[3:]
129 return nil
130 }
131
132 return errors.New("Bytes expected")
133 }
134
135 func (data EncryptedData) Value() (driver.Value, error) {
136 if len(data) > 0 && data[0] == 'x' {
137 //needed to test failures
138 return nil, errors.New("Should not start with 'x'")
139 }
140
141 //prepend asterisks
142 return append([]byte("***"), data...), nil
143 }
144
145 type Role struct {
146 Name string `gorm:"size:256"`
147 }
148
149 func (role *Role) Scan(value interface{}) error {
150 if b, ok := value.([]uint8); ok {
151 role.Name = string(b)
152 } else {
153 role.Name = value.(string)
154 }
155 return nil
156 }
157
158 func (role Role) Value() (driver.Value, error) {
159 return role.Name, nil
160 }
161
162 func (role Role) IsAdmin() bool {
163 return role.Name == "admin"
164 }
165
166 type Num int64
167
168 func (i *Num) Scan(src interface{}) error {
169 switch s := src.(type) {
170 case []byte:
171 n, _ := strconv.Atoi(string(s))
172 *i = Num(n)
173 case int64:
174 *i = Num(s)
175 default:
176 return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
177 }
178 return nil
179 }
180
181 type Animal struct {
182 Counter uint64 `gorm:"primary_key:yes"`
183 Name string `sql:"DEFAULT:'galeone'"`
184 From string //test reserved sql keyword as field name
185 Age time.Time `sql:"DEFAULT:current_timestamp"`
186 unexported string // unexported value
187 CreatedAt time.Time
188 UpdatedAt time.Time
189 }
190
191 type JoinTable struct {
192 From uint64
193 To uint64
194 Time time.Time `sql:"default: null"`
195 }
196
197 type Post struct {
198 Id int64
199 CategoryId sql.NullInt64
200 MainCategoryId int64
201 Title string
202 Body string
203 Comments []*Comment
204 Category Category
205 MainCategory Category
206 }
207
208 type Category struct {
209 gorm.Model
210 Name string
211
212 Categories []Category
213 CategoryID *uint
214 }
215
216 type Comment struct {
217 gorm.Model
218 PostId int64
219 Content string
220 Post Post
221 }
222
223 // Scanner
224 type NullValue struct {
225 Id int64
226 Name sql.NullString `sql:"not null"`
227 Gender *sql.NullString `sql:"not null"`
228 Age sql.NullInt64
229 Male sql.NullBool
230 Height sql.NullFloat64
231 AddedAt NullTime
232 }
233
234 type NullTime struct {
235 Time time.Time
236 Valid bool
237 }
238
239 func (nt *NullTime) Scan(value interface{}) error {
240 if value == nil {
241 nt.Valid = false
242 return nil
243 }
244 nt.Time, nt.Valid = value.(time.Time), true
245 return nil
246 }
247
248 func (nt NullTime) Value() (driver.Value, error) {
249 if !nt.Valid {
250 return nil, nil
251 }
252 return nt.Time, nil
253 }
254
255 func getPreparedUser(name string, role string) *User {
256 var company Company
257 DB.Where(Company{Name: role}).FirstOrCreate(&company)
258
259 return &User{
260 Name: name,
261 Age: 20,
262 Role: Role{role},
263 BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
264 ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
265 CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
266 Emails: []Email{
267 {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
268 },
269 Company: company,
270 Languages: []Language{
271 {Name: fmt.Sprintf("lang_1_%v", name)},
272 {Name: fmt.Sprintf("lang_2_%v", name)},
273 },
274 }
275 }
7276
8277 func runMigration() {
9278 if err := DB.DropTableIfExists(&User{}).Error; err != nil {
14283 DB.Exec(fmt.Sprintf("drop table %v;", table))
15284 }
16285
17 values := []interface{}{&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}}
286 values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}}
18287 for _, value := range values {
19288 DB.DropTable(value)
20289 }
21
22290 if err := DB.AutoMigrate(values...).Error; err != nil {
23291 panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
24292 }
30298 }
31299
32300 scope := DB.NewScope(&Email{})
33 if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
301 if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
34302 t.Errorf("Email should have index idx_email_email")
35303 }
36304
38306 t.Errorf("Got error when tried to remove index: %+v", err)
39307 }
40308
41 if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
309 if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
42310 t.Errorf("Email's index idx_email_email should be deleted")
43311 }
44312
46314 t.Errorf("Got error when tried to create index: %+v", err)
47315 }
48316
49 if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
317 if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
50318 t.Errorf("Email should have index idx_email_email_and_user_id")
51319 }
52320
54322 t.Errorf("Got error when tried to remove index: %+v", err)
55323 }
56324
57 if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
325 if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
58326 t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
59327 }
60328
62330 t.Errorf("Got error when tried to create index: %+v", err)
63331 }
64332
65 if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
333 if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
66334 t.Errorf("Email should have index idx_email_email_and_user_id")
67335 }
68336
70338 t.Errorf("Should get to create duplicate record when having unique index")
71339 }
72340
341 var user = User{Name: "sample_user"}
342 DB.Save(&user)
343 if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil {
344 t.Errorf("Should get no error when append two emails for user")
345 }
346
347 if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil {
348 t.Errorf("Should get no duplicated email error when insert duplicated emails for a user")
349 }
350
73351 if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
74352 t.Errorf("Got error when tried to remove index: %+v", err)
75353 }
76354
77 if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
355 if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
78356 t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
79357 }
80358
83361 }
84362 }
85363
86 type BigEmail struct {
364 type EmailWithIdx struct {
87365 Id int64
88366 UserId int64
89 Email string `sql:"index:idx_email_agent"`
90 UserAgent string `sql:"index:idx_email_agent"`
91 RegisteredAt time.Time `sql:"unique_index"`
367 Email string `sql:"index:idx_email_agent"`
368 UserAgent string `sql:"index:idx_email_agent"`
369 RegisteredAt *time.Time `sql:"unique_index"`
92370 CreatedAt time.Time
93371 UpdatedAt time.Time
94372 }
95373
96 func (b BigEmail) TableName() string {
97 return "emails"
98 }
99
100374 func TestAutoMigration(t *testing.T) {
101375 DB.AutoMigrate(&Address{})
102 if err := DB.Table("emails").AutoMigrate(&BigEmail{}).Error; err != nil {
376 DB.DropTable(&EmailWithIdx{})
377 if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
103378 t.Errorf("Auto Migrate should not raise any error")
104379 }
105380
106 DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()})
107
108 scope := DB.NewScope(&BigEmail{})
109 if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_agent") {
110 t.Errorf("Failed to create index")
111 }
112
113 if !scope.Dialect().HasIndex(scope, scope.TableName(), "uix_emails_registered_at") {
114 t.Errorf("Failed to create index")
115 }
116
117 var bigemail BigEmail
381 now := time.Now()
382 DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now})
383
384 scope := DB.NewScope(&EmailWithIdx{})
385 if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
386 t.Errorf("Failed to create index")
387 }
388
389 if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") {
390 t.Errorf("Failed to create index")
391 }
392
393 var bigemail EmailWithIdx
118394 DB.First(&bigemail, "user_agent = ?", "pc")
119395 if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() {
120396 t.Error("Big Emails should be saved and fetched correctly")
121397 }
122398 }
399
400 type MultipleIndexes struct {
401 ID int64
402 UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`
403 Name string `sql:"unique_index:uix_multipleindexes_user_name"`
404 Email string `sql:"unique_index:,uix_multipleindexes_user_email"`
405 Other string `sql:"index:,idx_multipleindexes_user_other"`
406 }
407
408 func TestMultipleIndexes(t *testing.T) {
409 if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil {
410 fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err)
411 }
412
413 DB.AutoMigrate(&MultipleIndexes{})
414 if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
415 t.Errorf("Auto Migrate should not raise any error")
416 }
417
418 DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"})
419
420 scope := DB.NewScope(&MultipleIndexes{})
421 if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") {
422 t.Errorf("Failed to create index")
423 }
424
425 if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") {
426 t.Errorf("Failed to create index")
427 }
428
429 if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") {
430 t.Errorf("Failed to create index")
431 }
432
433 if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") {
434 t.Errorf("Failed to create index")
435 }
436
437 if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") {
438 t.Errorf("Failed to create index")
439 }
440
441 var mutipleIndexes MultipleIndexes
442 DB.First(&mutipleIndexes, "name = ?", "jinzhu")
443 if mutipleIndexes.Email != "jinzhu@example.org" || mutipleIndexes.Name != "jinzhu" {
444 t.Error("MutipleIndexes should be saved and fetched correctly")
445 }
446
447 // Check unique constraints
448 if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil {
449 t.Error("MultipleIndexes unique index failed")
450 }
451
452 if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "foo@example.org", Other: "foo"}).Error; err != nil {
453 t.Error("MultipleIndexes unique index failed")
454 }
455
456 if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil {
457 t.Error("MultipleIndexes unique index failed")
458 }
459
460 if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "foo2@example.org", Other: "foo"}).Error; err != nil {
461 t.Error("MultipleIndexes unique index failed")
462 }
463 }
464
465 func TestModifyColumnType(t *testing.T) {
466 if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" {
467 t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type")
468 }
469
470 type ModifyColumnType struct {
471 gorm.Model
472 Name1 string `gorm:"length:100"`
473 Name2 string `gorm:"length:200"`
474 }
475 DB.DropTable(&ModifyColumnType{})
476 DB.CreateTable(&ModifyColumnType{})
477
478 name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2")
479 name2Type := DB.Dialect().DataTypeOf(name2Field.StructField)
480
481 if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil {
482 t.Errorf("No error should happen when ModifyColumn, but got %v", err)
483 }
484 }
11
22 import "time"
33
4 // Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models
5 // type User struct {
6 // gorm.Model
7 // }
48 type Model struct {
59 ID uint `gorm:"primary_key"`
610 CreatedAt time.Time
11
22 import (
33 "database/sql"
4 "fmt"
4 "errors"
55 "go/ast"
66 "reflect"
7 "strconv"
87 "strings"
98 "sync"
109 "time"
1110
12 "github.com/qor/inflection"
11 "github.com/jinzhu/inflection"
1312 )
1413
14 // DefaultTableNameHandler default table name handler
1515 var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
1616 return defaultTableName
1717 }
3939
4040 var modelStructsMap = newModelStructsMap()
4141
42 // ModelStruct model definition
4243 type ModelStruct struct {
4344 PrimaryFields []*StructField
4445 StructFields []*StructField
4546 ModelType reflect.Type
4647 defaultTableName string
47 cached bool
48 }
49
50 func (s ModelStruct) TableName(db *DB) string {
48 }
49
50 // TableName get model's table name
51 func (s *ModelStruct) TableName(db *DB) string {
52 if s.defaultTableName == "" && db != nil && s.ModelType != nil {
53 // Set default table name
54 if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok {
55 s.defaultTableName = tabler.TableName()
56 } else {
57 tableName := ToDBName(s.ModelType.Name())
58 if db == nil || !db.parent.singularTable {
59 tableName = inflection.Plural(tableName)
60 }
61 s.defaultTableName = tableName
62 }
63 }
64
5165 return DefaultTableNameHandler(db, s.defaultTableName)
5266 }
5367
68 // StructField model field's struct definition
5469 type StructField struct {
5570 DBName string
5671 Name string
6176 IsScanner bool
6277 HasDefaultValue bool
6378 Tag reflect.StructTag
79 TagSettings map[string]string
6480 Struct reflect.StructField
6581 IsForeignKey bool
6682 Relationship *Relationship
6783 }
6884
6985 func (structField *StructField) clone() *StructField {
70 return &StructField{
86 clone := &StructField{
7187 DBName: structField.DBName,
7288 Name: structField.Name,
7389 Names: structField.Names,
7793 IsScanner: structField.IsScanner,
7894 HasDefaultValue: structField.HasDefaultValue,
7995 Tag: structField.Tag,
96 TagSettings: map[string]string{},
8097 Struct: structField.Struct,
8198 IsForeignKey: structField.IsForeignKey,
82 Relationship: structField.Relationship,
83 }
84 }
85
99 }
100
101 if structField.Relationship != nil {
102 relationship := *structField.Relationship
103 clone.Relationship = &relationship
104 }
105
106 for key, value := range structField.TagSettings {
107 clone.TagSettings[key] = value
108 }
109
110 return clone
111 }
112
113 // Relationship described the relationship between models
86114 type Relationship struct {
87 Kind string
88 PolymorphicType string
89 PolymorphicDBName string
90 ForeignFieldNames []string
91 ForeignDBNames []string
92 AssociationForeignFieldNames []string
93 AssociationForeignStructFieldNames []string
94 AssociationForeignDBNames []string
95 JoinTableHandler JoinTableHandlerInterface
96 }
97
115 Kind string
116 PolymorphicType string
117 PolymorphicDBName string
118 PolymorphicValue string
119 ForeignFieldNames []string
120 ForeignDBNames []string
121 AssociationForeignFieldNames []string
122 AssociationForeignDBNames []string
123 JoinTableHandler JoinTableHandlerInterface
124 }
125
126 func getForeignField(column string, fields []*StructField) *StructField {
127 for _, field := range fields {
128 if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) {
129 return field
130 }
131 }
132 return nil
133 }
134
135 // GetModelStruct get value's model struct, relationships based on struct and tag definition
98136 func (scope *Scope) GetModelStruct() *ModelStruct {
99137 var modelStruct ModelStruct
100
101 reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value))
102 if !reflectValue.IsValid() {
138 // Scope value can't be nil
139 if scope.Value == nil {
103140 return &modelStruct
104141 }
105142
106 if reflectValue.Kind() == reflect.Slice {
107 reflectValue = reflect.Indirect(reflect.New(reflectValue.Type().Elem()))
108 }
109
110 scopeType := reflectValue.Type()
111
112 if scopeType.Kind() == reflect.Ptr {
113 scopeType = scopeType.Elem()
114 }
115
116 if value := modelStructsMap.Get(scopeType); value != nil {
143 reflectType := reflect.ValueOf(scope.Value).Type()
144 for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr {
145 reflectType = reflectType.Elem()
146 }
147
148 // Scope value need to be a struct
149 if reflectType.Kind() != reflect.Struct {
150 return &modelStruct
151 }
152
153 // Get Cached model struct
154 if value := modelStructsMap.Get(reflectType); value != nil {
117155 return value
118156 }
119157
120 modelStruct.ModelType = scopeType
121 if scopeType.Kind() != reflect.Struct {
122 return &modelStruct
123 }
124
125 if tabler, ok := reflect.New(scopeType).Interface().(interface {
126 TableName() string
127 }); ok {
128 modelStruct.defaultTableName = tabler.TableName()
129 } else {
130 name := ToDBName(scopeType.Name())
131 if scope.db == nil || !scope.db.parent.singularTable {
132 name = inflection.Plural(name)
133 }
134
135 modelStruct.defaultTableName = name
136 }
158 modelStruct.ModelType = reflectType
137159
138160 // Get all fields
139 fields := []*StructField{}
140 for i := 0; i < scopeType.NumField(); i++ {
141 if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) {
161 for i := 0; i < reflectType.NumField(); i++ {
162 if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) {
142163 field := &StructField{
143 Struct: fieldStruct,
144 Name: fieldStruct.Name,
145 Names: []string{fieldStruct.Name},
146 Tag: fieldStruct.Tag,
164 Struct: fieldStruct,
165 Name: fieldStruct.Name,
166 Names: []string{fieldStruct.Name},
167 Tag: fieldStruct.Tag,
168 TagSettings: parseTagSetting(fieldStruct.Tag),
147169 }
148170
149 if fieldStruct.Tag.Get("sql") == "-" {
171 // is ignored field
172 if _, ok := field.TagSettings["-"]; ok {
150173 field.IsIgnored = true
151174 } else {
152 sqlSettings := parseTagSetting(field.Tag.Get("sql"))
153 gormSettings := parseTagSetting(field.Tag.Get("gorm"))
154 if _, ok := gormSettings["PRIMARY_KEY"]; ok {
175 if _, ok := field.TagSettings["PRIMARY_KEY"]; ok {
155176 field.IsPrimaryKey = true
156177 modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
157178 }
158179
159 if _, ok := sqlSettings["DEFAULT"]; ok {
180 if _, ok := field.TagSettings["DEFAULT"]; ok {
160181 field.HasDefaultValue = true
161182 }
162183
163 if value, ok := gormSettings["COLUMN"]; ok {
164 field.DBName = value
165 } else {
166 field.DBName = ToDBName(fieldStruct.Name)
184 if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsPrimaryKey {
185 field.HasDefaultValue = true
167186 }
168 }
169 fields = append(fields, field)
170 }
171 }
172
173 var finished = make(chan bool)
174 go func(finished chan bool) {
175 for _, field := range fields {
176 if !field.IsIgnored {
177 fieldStruct := field.Struct
187
178188 indirectType := fieldStruct.Type
179 if indirectType.Kind() == reflect.Ptr {
189 for indirectType.Kind() == reflect.Ptr {
180190 indirectType = indirectType.Elem()
181191 }
182192
183 if _, isScanner := reflect.New(indirectType).Interface().(sql.Scanner); isScanner {
193 fieldValue := reflect.New(indirectType).Interface()
194 if _, isScanner := fieldValue.(sql.Scanner); isScanner {
195 // is scanner
184196 field.IsScanner, field.IsNormal = true, true
185 }
186
187 if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime {
188 field.IsNormal = true
189 }
190
191 if !field.IsNormal {
192 gormSettings := parseTagSetting(field.Tag.Get("gorm"))
193 toScope := scope.New(reflect.New(fieldStruct.Type).Interface())
194
195 getForeignField := func(column string, fields []*StructField) *StructField {
196 for _, field := range fields {
197 if field.Name == column || field.DBName == ToDBName(column) {
198 return field
199 }
200 }
201 return nil
202 }
203
204 var relationship = &Relationship{}
205
206 if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" {
207 if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil {
208 if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil {
209 relationship.ForeignFieldNames = []string{polymorphicField.Name}
210 relationship.ForeignDBNames = []string{polymorphicField.DBName}
211 relationship.AssociationForeignFieldNames = []string{scope.PrimaryField().Name}
212 relationship.AssociationForeignDBNames = []string{scope.PrimaryField().DBName}
213 relationship.PolymorphicType = polymorphicType.Name
214 relationship.PolymorphicDBName = polymorphicType.DBName
215 polymorphicType.IsForeignKey = true
216 polymorphicField.IsForeignKey = true
197 if indirectType.Kind() == reflect.Struct {
198 for i := 0; i < indirectType.NumField(); i++ {
199 for key, value := range parseTagSetting(indirectType.Field(i).Tag) {
200 if _, ok := field.TagSettings[key]; !ok {
201 field.TagSettings[key] = value
202 }
217203 }
218204 }
219205 }
220
221 var foreignKeys []string
222 if foreignKey, ok := gormSettings["FOREIGNKEY"]; ok {
223 foreignKeys = append(foreignKeys, foreignKey)
206 } else if _, isTime := fieldValue.(*time.Time); isTime {
207 // is time
208 field.IsNormal = true
209 } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
210 // is embedded struct
211 for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields {
212 subField = subField.clone()
213 subField.Names = append([]string{fieldStruct.Name}, subField.Names...)
214 if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok {
215 subField.DBName = prefix + subField.DBName
216 }
217
218 if subField.IsPrimaryKey {
219 if _, ok := subField.TagSettings["PRIMARY_KEY"]; ok {
220 modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField)
221 } else {
222 subField.IsPrimaryKey = false
223 }
224 }
225
226 if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil {
227 if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok {
228 newJoinTableHandler := &JoinTableHandler{}
229 newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType)
230 subField.Relationship.JoinTableHandler = newJoinTableHandler
231 }
232 }
233
234 modelStruct.StructFields = append(modelStruct.StructFields, subField)
224235 }
236 continue
237 } else {
238 // build relationships
225239 switch indirectType.Kind() {
226240 case reflect.Slice:
227 elemType := indirectType.Elem()
228 if elemType.Kind() == reflect.Ptr {
229 elemType = elemType.Elem()
230 }
231
232 if elemType.Kind() == reflect.Struct {
233 if many2many := gormSettings["MANY2MANY"]; many2many != "" {
234 relationship.Kind = "many_to_many"
235
236 // foreign keys
241 defer func(field *StructField) {
242 var (
243 relationship = &Relationship{}
244 toScope = scope.New(reflect.New(field.Struct.Type).Interface())
245 foreignKeys []string
246 associationForeignKeys []string
247 elemType = field.Struct.Type
248 )
249
250 if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
251 foreignKeys = strings.Split(foreignKey, ",")
252 }
253
254 if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
255 associationForeignKeys = strings.Split(foreignKey, ",")
256 } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
257 associationForeignKeys = strings.Split(foreignKey, ",")
258 }
259
260 for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr {
261 elemType = elemType.Elem()
262 }
263
264 if elemType.Kind() == reflect.Struct {
265 if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
266 relationship.Kind = "many_to_many"
267
268 { // Foreign Keys for Source
269 joinTableDBNames := []string{}
270
271 if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
272 joinTableDBNames = strings.Split(foreignKey, ",")
273 }
274
275 // if no foreign keys defined with tag
276 if len(foreignKeys) == 0 {
277 for _, field := range modelStruct.PrimaryFields {
278 foreignKeys = append(foreignKeys, field.DBName)
279 }
280 }
281
282 for idx, foreignKey := range foreignKeys {
283 if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
284 // source foreign keys (db names)
285 relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName)
286
287 // setup join table foreign keys for source
288 if len(joinTableDBNames) > idx {
289 // if defined join table's foreign key
290 relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
291 } else {
292 defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
293 relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
294 }
295 }
296 }
297 }
298
299 { // Foreign Keys for Association (Destination)
300 associationJoinTableDBNames := []string{}
301
302 if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
303 associationJoinTableDBNames = strings.Split(foreignKey, ",")
304 }
305
306 // if no association foreign keys defined with tag
307 if len(associationForeignKeys) == 0 {
308 for _, field := range toScope.PrimaryFields() {
309 associationForeignKeys = append(associationForeignKeys, field.DBName)
310 }
311 }
312
313 for idx, name := range associationForeignKeys {
314 if field, ok := toScope.FieldByName(name); ok {
315 // association foreign keys (db names)
316 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
317
318 // setup join table foreign keys for association
319 if len(associationJoinTableDBNames) > idx {
320 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
321 } else {
322 // join table foreign keys for association
323 joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
324 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
325 }
326 }
327 }
328 }
329
330 joinTableHandler := JoinTableHandler{}
331 joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
332 relationship.JoinTableHandler = &joinTableHandler
333 field.Relationship = relationship
334 } else {
335 // User has many comments, associationType is User, comment use UserID as foreign key
336 var associationType = reflectType.Name()
337 var toFields = toScope.GetStructFields()
338 relationship.Kind = "has_many"
339
340 if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
341 // Dog has many toys, tag polymorphic is Owner, then associationType is Owner
342 // Toy use OwnerID, OwnerType ('dogs') as foreign key
343 if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
344 associationType = polymorphic
345 relationship.PolymorphicType = polymorphicType.Name
346 relationship.PolymorphicDBName = polymorphicType.DBName
347 // if Dog has multiple set of toys set name of the set (instead of default 'dogs')
348 if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
349 relationship.PolymorphicValue = value
350 } else {
351 relationship.PolymorphicValue = scope.TableName()
352 }
353 polymorphicType.IsForeignKey = true
354 }
355 }
356
357 // if no foreign keys defined with tag
358 if len(foreignKeys) == 0 {
359 // if no association foreign keys defined with tag
360 if len(associationForeignKeys) == 0 {
361 for _, field := range modelStruct.PrimaryFields {
362 foreignKeys = append(foreignKeys, associationType+field.Name)
363 associationForeignKeys = append(associationForeignKeys, field.Name)
364 }
365 } else {
366 // generate foreign keys from defined association foreign keys
367 for _, scopeFieldName := range associationForeignKeys {
368 if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil {
369 foreignKeys = append(foreignKeys, associationType+foreignField.Name)
370 associationForeignKeys = append(associationForeignKeys, foreignField.Name)
371 }
372 }
373 }
374 } else {
375 // generate association foreign keys from foreign keys
376 if len(associationForeignKeys) == 0 {
377 for _, foreignKey := range foreignKeys {
378 if strings.HasPrefix(foreignKey, associationType) {
379 associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
380 if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
381 associationForeignKeys = append(associationForeignKeys, associationForeignKey)
382 }
383 }
384 }
385 if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
386 associationForeignKeys = []string{scope.PrimaryKey()}
387 }
388 } else if len(foreignKeys) != len(associationForeignKeys) {
389 scope.Err(errors.New("invalid foreign keys, should have same length"))
390 return
391 }
392 }
393
394 for idx, foreignKey := range foreignKeys {
395 if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
396 if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil {
397 // source foreign keys
398 foreignField.IsForeignKey = true
399 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
400 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)
401
402 // association foreign keys
403 relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
404 relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
405 }
406 }
407 }
408
409 if len(relationship.ForeignFieldNames) != 0 {
410 field.Relationship = relationship
411 }
412 }
413 } else {
414 field.IsNormal = true
415 }
416 }(field)
417 case reflect.Struct:
418 defer func(field *StructField) {
419 var (
420 // user has one profile, associationType is User, profile use UserID as foreign key
421 // user belongs to profile, associationType is Profile, user use ProfileID as foreign key
422 associationType = reflectType.Name()
423 relationship = &Relationship{}
424 toScope = scope.New(reflect.New(field.Struct.Type).Interface())
425 toFields = toScope.GetStructFields()
426 tagForeignKeys []string
427 tagAssociationForeignKeys []string
428 )
429
430 if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
431 tagForeignKeys = strings.Split(foreignKey, ",")
432 }
433
434 if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
435 tagAssociationForeignKeys = strings.Split(foreignKey, ",")
436 } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
437 tagAssociationForeignKeys = strings.Split(foreignKey, ",")
438 }
439
440 if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
441 // Cat has one toy, tag polymorphic is Owner, then associationType is Owner
442 // Toy use OwnerID, OwnerType ('cats') as foreign key
443 if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
444 associationType = polymorphic
445 relationship.PolymorphicType = polymorphicType.Name
446 relationship.PolymorphicDBName = polymorphicType.DBName
447 // if Cat has several different types of toys set name for each (instead of default 'cats')
448 if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
449 relationship.PolymorphicValue = value
450 } else {
451 relationship.PolymorphicValue = scope.TableName()
452 }
453 polymorphicType.IsForeignKey = true
454 }
455 }
456
457 // Has One
458 {
459 var foreignKeys = tagForeignKeys
460 var associationForeignKeys = tagAssociationForeignKeys
461 // if no foreign keys defined with tag
237462 if len(foreignKeys) == 0 {
238 for _, field := range scope.PrimaryFields() {
239 foreignKeys = append(foreignKeys, field.DBName)
240 }
241 }
242
243 for _, foreignKey := range foreignKeys {
244 if field, ok := scope.FieldByName(foreignKey); ok {
245 relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName)
246 joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName
247 relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
248 }
249 }
250
251 // association foreign keys
252 var associationForeignKeys []string
253 if foreignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
254 associationForeignKeys = []string{gormSettings["ASSOCIATIONFOREIGNKEY"]}
463 // if no association foreign keys defined with tag
464 if len(associationForeignKeys) == 0 {
465 for _, primaryField := range modelStruct.PrimaryFields {
466 foreignKeys = append(foreignKeys, associationType+primaryField.Name)
467 associationForeignKeys = append(associationForeignKeys, primaryField.Name)
468 }
469 } else {
470 // generate foreign keys form association foreign keys
471 for _, associationForeignKey := range tagAssociationForeignKeys {
472 if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
473 foreignKeys = append(foreignKeys, associationType+foreignField.Name)
474 associationForeignKeys = append(associationForeignKeys, foreignField.Name)
475 }
476 }
477 }
255478 } else {
256 for _, field := range toScope.PrimaryFields() {
257 associationForeignKeys = append(associationForeignKeys, field.DBName)
258 }
259 }
260
261 for _, name := range associationForeignKeys {
262 if field, ok := toScope.FieldByName(name); ok {
263 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
264 relationship.AssociationForeignStructFieldNames = append(relationship.AssociationForeignFieldNames, field.Name)
265 joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
266 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
267 }
268 }
269
270 joinTableHandler := JoinTableHandler{}
271 joinTableHandler.Setup(relationship, many2many, scopeType, elemType)
272 relationship.JoinTableHandler = &joinTableHandler
273 field.Relationship = relationship
274 } else {
275 relationship.Kind = "has_many"
276
277 if len(foreignKeys) == 0 {
278 for _, field := range scope.PrimaryFields() {
279 if foreignField := getForeignField(scopeType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil {
280 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name)
281 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName)
479 // generate association foreign keys from foreign keys
480 if len(associationForeignKeys) == 0 {
481 for _, foreignKey := range foreignKeys {
482 if strings.HasPrefix(foreignKey, associationType) {
483 associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
484 if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
485 associationForeignKeys = append(associationForeignKeys, associationForeignKey)
486 }
487 }
488 }
489 if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
490 associationForeignKeys = []string{scope.PrimaryKey()}
491 }
492 } else if len(foreignKeys) != len(associationForeignKeys) {
493 scope.Err(errors.New("invalid foreign keys, should have same length"))
494 return
495 }
496 }
497
498 for idx, foreignKey := range foreignKeys {
499 if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
500 if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil {
501 foreignField.IsForeignKey = true
502 // source foreign keys
503 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name)
504 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName)
505
506 // association foreign keys
282507 relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
283508 relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
284 foreignField.IsForeignKey = true
285 }
286 }
287 } else {
288 for _, foreignKey := range foreignKeys {
289 if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
290 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name)
291 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName)
292 relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
293 relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
294 foreignField.IsForeignKey = true
295 }
296 }
297 }
298
299 if len(relationship.ForeignFieldNames) != 0 {
300 field.Relationship = relationship
301 }
302 }
303 } else {
304 field.IsNormal = true
305 }
306 case reflect.Struct:
307 if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
308 for _, toField := range toScope.GetStructFields() {
309 toField = toField.clone()
310 toField.Names = append([]string{fieldStruct.Name}, toField.Names...)
311 modelStruct.StructFields = append(modelStruct.StructFields, toField)
312 if toField.IsPrimaryKey {
313 modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField)
314 }
315 }
316 continue
317 } else {
318 if len(foreignKeys) == 0 {
319 for _, f := range scope.PrimaryFields() {
320 if foreignField := getForeignField(modelStruct.ModelType.Name()+f.Name, toScope.GetStructFields()); foreignField != nil {
321 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name)
322 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName)
323 relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
324 relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
325 foreignField.IsForeignKey = true
326 }
327 }
328 } else {
329 for _, foreignKey := range foreignKeys {
330 if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
331 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name)
332 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName)
333 relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
334 relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
335 foreignField.IsForeignKey = true
509 }
336510 }
337511 }
338512 }
341515 relationship.Kind = "has_one"
342516 field.Relationship = relationship
343517 } else {
518 var foreignKeys = tagForeignKeys
519 var associationForeignKeys = tagAssociationForeignKeys
520
344521 if len(foreignKeys) == 0 {
345 for _, f := range toScope.PrimaryFields() {
346 if foreignField := getForeignField(field.Name+f.Name, fields); foreignField != nil {
347 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name)
348 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName)
522 // generate foreign keys & association foreign keys
523 if len(associationForeignKeys) == 0 {
524 for _, primaryField := range toScope.PrimaryFields() {
525 foreignKeys = append(foreignKeys, field.Name+primaryField.Name)
526 associationForeignKeys = append(associationForeignKeys, primaryField.Name)
527 }
528 } else {
529 // generate foreign keys with association foreign keys
530 for _, associationForeignKey := range associationForeignKeys {
531 if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil {
532 foreignKeys = append(foreignKeys, field.Name+foreignField.Name)
533 associationForeignKeys = append(associationForeignKeys, foreignField.Name)
534 }
535 }
536 }
537 } else {
538 // generate foreign keys & association foreign keys
539 if len(associationForeignKeys) == 0 {
540 for _, foreignKey := range foreignKeys {
541 if strings.HasPrefix(foreignKey, field.Name) {
542 associationForeignKey := strings.TrimPrefix(foreignKey, field.Name)
543 if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil {
544 associationForeignKeys = append(associationForeignKeys, associationForeignKey)
545 }
546 }
547 }
548 if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
549 associationForeignKeys = []string{toScope.PrimaryKey()}
550 }
551 } else if len(foreignKeys) != len(associationForeignKeys) {
552 scope.Err(errors.New("invalid foreign keys, should have same length"))
553 return
554 }
555 }
556
557 for idx, foreignKey := range foreignKeys {
558 if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
559 if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil {
560 foreignField.IsForeignKey = true
561
562 // association foreign keys
563 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
564 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)
565
566 // source foreign keys
349567 relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
350568 relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
351 foreignField.IsForeignKey = true
352 }
353 }
354 } else {
355 for _, foreignKey := range foreignKeys {
356 if foreignField := getForeignField(foreignKey, fields); foreignField != nil {
357 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, toScope.PrimaryField().Name)
358 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, toScope.PrimaryField().DBName)
359 relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
360 relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
361 foreignField.IsForeignKey = true
362569 }
363570 }
364571 }
368575 field.Relationship = relationship
369576 }
370577 }
371 }
578 }(field)
372579 default:
373580 field.IsNormal = true
374581 }
375582 }
376
377 if field.IsNormal {
378 if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" {
379 field.IsPrimaryKey = true
380 modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
381 }
382 }
383583 }
584
585 // Even it is ignored, also possible to decode db value into the field
586 if value, ok := field.TagSettings["COLUMN"]; ok {
587 field.DBName = value
588 } else {
589 field.DBName = ToDBName(fieldStruct.Name)
590 }
591
384592 modelStruct.StructFields = append(modelStruct.StructFields, field)
385593 }
386 finished <- true
387 }(finished)
388
389 modelStructsMap.Set(scopeType, &modelStruct)
390
391 <-finished
392 modelStruct.cached = true
594 }
595
596 if len(modelStruct.PrimaryFields) == 0 {
597 if field := getForeignField("id", modelStruct.StructFields); field != nil {
598 field.IsPrimaryKey = true
599 modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
600 }
601 }
602
603 modelStructsMap.Set(reflectType, &modelStruct)
393604
394605 return &modelStruct
395606 }
396607
608 // GetStructFields get model's field structs
397609 func (scope *Scope) GetStructFields() (fields []*StructField) {
398610 return scope.GetModelStruct().StructFields
399611 }
400612
401 func (scope *Scope) generateSqlTag(field *StructField) string {
402 var sqlType string
403 structType := field.Struct.Type
404 if structType.Kind() == reflect.Ptr {
405 structType = structType.Elem()
406 }
407 reflectValue := reflect.Indirect(reflect.New(structType))
408 sqlSettings := parseTagSetting(field.Tag.Get("sql"))
409
410 if value, ok := sqlSettings["TYPE"]; ok {
411 sqlType = value
412 }
413
414 additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"]
415 if value, ok := sqlSettings["DEFAULT"]; ok {
416 additionalType = additionalType + " DEFAULT " + value
417 }
418
419 if field.IsScanner {
420 var getScannerValue func(reflect.Value)
421 getScannerValue = func(value reflect.Value) {
422 reflectValue = value
423 if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct {
424 getScannerValue(reflectValue.Field(0))
613 func parseTagSetting(tags reflect.StructTag) map[string]string {
614 setting := map[string]string{}
615 for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
616 tags := strings.Split(str, ";")
617 for _, value := range tags {
618 v := strings.Split(value, ":")
619 k := strings.TrimSpace(strings.ToUpper(v[0]))
620 if len(v) >= 2 {
621 setting[k] = strings.Join(v[1:], ":")
622 } else {
623 setting[k] = k
425624 }
426625 }
427 getScannerValue(reflectValue)
428 }
429
430 if sqlType == "" {
431 var size = 255
432
433 if value, ok := sqlSettings["SIZE"]; ok {
434 size, _ = strconv.Atoi(value)
435 }
436
437 _, autoIncrease := sqlSettings["AUTO_INCREMENT"]
438 if field.IsPrimaryKey {
439 autoIncrease = true
440 }
441
442 sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
443 }
444
445 if strings.TrimSpace(additionalType) == "" {
446 return sqlType
447 } else {
448 return fmt.Sprintf("%v %v", sqlType, additionalType)
449 }
450 }
451
452 func parseTagSetting(str string) map[string]string {
453 tags := strings.Split(str, ";")
454 setting := map[string]string{}
455 for _, value := range tags {
456 v := strings.Split(value, ":")
457 k := strings.TrimSpace(strings.ToUpper(v[0]))
458 if len(v) >= 2 {
459 setting[k] = strings.Join(v[1:], ":")
460 } else {
461 setting[k] = k
462 }
463626 }
464627 return setting
465628 }
+0
-80
mssql.go less more
0 package gorm
1
2 import (
3 "fmt"
4 "reflect"
5 "time"
6 )
7
8 type mssql struct {
9 commonDialect
10 }
11
12 func (mssql) HasTop() bool {
13 return true
14 }
15
16 func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
17 switch value.Kind() {
18 case reflect.Bool:
19 return "bit"
20 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
21 if autoIncrease {
22 return "int IDENTITY(1,1)"
23 }
24 return "int"
25 case reflect.Int64, reflect.Uint64:
26 if autoIncrease {
27 return "bigint IDENTITY(1,1)"
28 }
29 return "bigint"
30 case reflect.Float32, reflect.Float64:
31 return "float"
32 case reflect.String:
33 if size > 0 && size < 65532 {
34 return fmt.Sprintf("nvarchar(%d)", size)
35 }
36 return "text"
37 case reflect.Struct:
38 if _, ok := value.Interface().(time.Time); ok {
39 return "datetime2"
40 }
41 default:
42 if _, ok := value.Interface().([]byte); ok {
43 if size > 0 && size < 65532 {
44 return fmt.Sprintf("varchar(%d)", size)
45 }
46 return "text"
47 }
48 }
49 panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
50 }
51
52 func (s mssql) HasTable(scope *Scope, tableName string) bool {
53 var (
54 count int
55 databaseName = s.CurrentDatabase(scope)
56 )
57 s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName)
58 return count > 0
59 }
60
61 func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
62 var (
63 count int
64 databaseName = s.CurrentDatabase(scope)
65 )
66 s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
67 return count > 0
68 }
69
70 func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
71 var count int
72 s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName)
73 return count > 0
74 }
75
76 func (s mssql) CurrentDatabase(scope *Scope) (name string) {
77 s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]")
78 return
79 }
00 package gorm_test
11
22 import (
3 "fmt"
43 "os"
4 "reflect"
5 "sort"
56 "testing"
67 )
78
89 type Blog struct {
9 ID uint `gorm:"primary_key"`
10 Locale string `gorm:"primary_key"`
11 Subject string
12 Body string
13 Tags []Tag `gorm:"many2many:blog_tags;"`
10 ID uint `gorm:"primary_key"`
11 Locale string `gorm:"primary_key"`
12 Subject string
13 Body string
14 Tags []Tag `gorm:"many2many:blog_tags;"`
15 SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;AssociationForeignKey:id"`
16 LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;AssociationForeignKey:id"`
1417 }
1518
1619 type Tag struct {
1720 ID uint `gorm:"primary_key"`
1821 Locale string `gorm:"primary_key"`
1922 Value string
23 Blogs []*Blog `gorm:"many2many:blogs_tags"`
24 }
25
26 func compareTags(tags []Tag, contents []string) bool {
27 var tagContents []string
28 for _, tag := range tags {
29 tagContents = append(tagContents, tag.Value)
30 }
31 sort.Strings(tagContents)
32 sort.Strings(contents)
33 return reflect.DeepEqual(tagContents, contents)
2034 }
2135
2236 func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
23 if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" {
24 DB.Exec(fmt.Sprintf("drop table blog_tags;"))
25 DB.AutoMigrate(&Blog{}, &Tag{})
37 if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" {
38 DB.DropTable(&Blog{}, &Tag{})
39 DB.DropTable("blog_tags")
40 DB.CreateTable(&Blog{}, &Tag{})
2641 blog := Blog{
2742 Locale: "ZH",
2843 Subject: "subject",
3449 }
3550
3651 DB.Save(&blog)
37 DB.Model(&blog).Association("Tags").Append([]Tag{{Locale: "ZH", Value: "tag3"}})
52 if !compareTags(blog.Tags, []string{"tag1", "tag2"}) {
53 t.Errorf("Blog should has two tags")
54 }
55
56 // Append
57 var tag3 = &Tag{Locale: "ZH", Value: "tag3"}
58 DB.Model(&blog).Association("Tags").Append([]*Tag{tag3})
59 if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) {
60 t.Errorf("Blog should has three tags after Append")
61 }
62
63 if DB.Model(&blog).Association("Tags").Count() != 3 {
64 t.Errorf("Blog should has three tags after Append")
65 }
3866
3967 var tags []Tag
4068 DB.Model(&blog).Related(&tags, "Tags")
41 if len(tags) != 3 {
42 t.Errorf("should found 3 tags with blog")
69 if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) {
70 t.Errorf("Should find 3 tags with Related")
71 }
72
73 var blog1 Blog
74 DB.Preload("Tags").Find(&blog1)
75 if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) {
76 t.Errorf("Preload many2many relations")
77 }
78
79 // Replace
80 var tag5 = &Tag{Locale: "ZH", Value: "tag5"}
81 var tag6 = &Tag{Locale: "ZH", Value: "tag6"}
82 DB.Model(&blog).Association("Tags").Replace(tag5, tag6)
83 var tags2 []Tag
84 DB.Model(&blog).Related(&tags2, "Tags")
85 if !compareTags(tags2, []string{"tag5", "tag6"}) {
86 t.Errorf("Should find 2 tags after Replace")
87 }
88
89 if DB.Model(&blog).Association("Tags").Count() != 2 {
90 t.Errorf("Blog should has three tags after Replace")
91 }
92
93 // Delete
94 DB.Model(&blog).Association("Tags").Delete(tag5)
95 var tags3 []Tag
96 DB.Model(&blog).Related(&tags3, "Tags")
97 if !compareTags(tags3, []string{"tag6"}) {
98 t.Errorf("Should find 1 tags after Delete")
99 }
100
101 if DB.Model(&blog).Association("Tags").Count() != 1 {
102 t.Errorf("Blog should has three tags after Delete")
103 }
104
105 DB.Model(&blog).Association("Tags").Delete(tag3)
106 var tags4 []Tag
107 DB.Model(&blog).Related(&tags4, "Tags")
108 if !compareTags(tags4, []string{"tag6"}) {
109 t.Errorf("Tag should not be deleted when Delete with a unrelated tag")
110 }
111
112 // Clear
113 DB.Model(&blog).Association("Tags").Clear()
114 if DB.Model(&blog).Association("Tags").Count() != 0 {
115 t.Errorf("All tags should be cleared")
43116 }
44117 }
45118 }
119
120 func TestManyToManyWithCustomizedForeignKeys(t *testing.T) {
121 if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" {
122 DB.DropTable(&Blog{}, &Tag{})
123 DB.DropTable("shared_blog_tags")
124 DB.CreateTable(&Blog{}, &Tag{})
125 blog := Blog{
126 Locale: "ZH",
127 Subject: "subject",
128 Body: "body",
129 SharedTags: []Tag{
130 {Locale: "ZH", Value: "tag1"},
131 {Locale: "ZH", Value: "tag2"},
132 },
133 }
134 DB.Save(&blog)
135
136 blog2 := Blog{
137 ID: blog.ID,
138 Locale: "EN",
139 }
140 DB.Create(&blog2)
141
142 if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) {
143 t.Errorf("Blog should has two tags")
144 }
145
146 // Append
147 var tag3 = &Tag{Locale: "ZH", Value: "tag3"}
148 DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3})
149 if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) {
150 t.Errorf("Blog should has three tags after Append")
151 }
152
153 if DB.Model(&blog).Association("SharedTags").Count() != 3 {
154 t.Errorf("Blog should has three tags after Append")
155 }
156
157 if DB.Model(&blog2).Association("SharedTags").Count() != 3 {
158 t.Errorf("Blog should has three tags after Append")
159 }
160
161 var tags []Tag
162 DB.Model(&blog).Related(&tags, "SharedTags")
163 if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) {
164 t.Errorf("Should find 3 tags with Related")
165 }
166
167 DB.Model(&blog2).Related(&tags, "SharedTags")
168 if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) {
169 t.Errorf("Should find 3 tags with Related")
170 }
171
172 var blog1 Blog
173 DB.Preload("SharedTags").Find(&blog1)
174 if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) {
175 t.Errorf("Preload many2many relations")
176 }
177
178 var tag4 = &Tag{Locale: "ZH", Value: "tag4"}
179 DB.Model(&blog2).Association("SharedTags").Append(tag4)
180
181 DB.Model(&blog).Related(&tags, "SharedTags")
182 if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) {
183 t.Errorf("Should find 3 tags with Related")
184 }
185
186 DB.Model(&blog2).Related(&tags, "SharedTags")
187 if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) {
188 t.Errorf("Should find 3 tags with Related")
189 }
190
191 // Replace
192 var tag5 = &Tag{Locale: "ZH", Value: "tag5"}
193 var tag6 = &Tag{Locale: "ZH", Value: "tag6"}
194 DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6)
195 var tags2 []Tag
196 DB.Model(&blog).Related(&tags2, "SharedTags")
197 if !compareTags(tags2, []string{"tag5", "tag6"}) {
198 t.Errorf("Should find 2 tags after Replace")
199 }
200
201 DB.Model(&blog2).Related(&tags2, "SharedTags")
202 if !compareTags(tags2, []string{"tag5", "tag6"}) {
203 t.Errorf("Should find 2 tags after Replace")
204 }
205
206 if DB.Model(&blog).Association("SharedTags").Count() != 2 {
207 t.Errorf("Blog should has three tags after Replace")
208 }
209
210 // Delete
211 DB.Model(&blog).Association("SharedTags").Delete(tag5)
212 var tags3 []Tag
213 DB.Model(&blog).Related(&tags3, "SharedTags")
214 if !compareTags(tags3, []string{"tag6"}) {
215 t.Errorf("Should find 1 tags after Delete")
216 }
217
218 if DB.Model(&blog).Association("SharedTags").Count() != 1 {
219 t.Errorf("Blog should has three tags after Delete")
220 }
221
222 DB.Model(&blog2).Association("SharedTags").Delete(tag3)
223 var tags4 []Tag
224 DB.Model(&blog).Related(&tags4, "SharedTags")
225 if !compareTags(tags4, []string{"tag6"}) {
226 t.Errorf("Tag should not be deleted when Delete with a unrelated tag")
227 }
228
229 // Clear
230 DB.Model(&blog2).Association("SharedTags").Clear()
231 if DB.Model(&blog).Association("SharedTags").Count() != 0 {
232 t.Errorf("All tags should be cleared")
233 }
234 }
235 }
236
237 func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
238 if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" {
239 DB.DropTable(&Blog{}, &Tag{})
240 DB.DropTable("locale_blog_tags")
241 DB.CreateTable(&Blog{}, &Tag{})
242 blog := Blog{
243 Locale: "ZH",
244 Subject: "subject",
245 Body: "body",
246 LocaleTags: []Tag{
247 {Locale: "ZH", Value: "tag1"},
248 {Locale: "ZH", Value: "tag2"},
249 },
250 }
251 DB.Save(&blog)
252
253 blog2 := Blog{
254 ID: blog.ID,
255 Locale: "EN",
256 }
257 DB.Create(&blog2)
258
259 // Append
260 var tag3 = &Tag{Locale: "ZH", Value: "tag3"}
261 DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3})
262 if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) {
263 t.Errorf("Blog should has three tags after Append")
264 }
265
266 if DB.Model(&blog).Association("LocaleTags").Count() != 3 {
267 t.Errorf("Blog should has three tags after Append")
268 }
269
270 if DB.Model(&blog2).Association("LocaleTags").Count() != 0 {
271 t.Errorf("EN Blog should has 0 tags after ZH Blog Append")
272 }
273
274 var tags []Tag
275 DB.Model(&blog).Related(&tags, "LocaleTags")
276 if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) {
277 t.Errorf("Should find 3 tags with Related")
278 }
279
280 DB.Model(&blog2).Related(&tags, "LocaleTags")
281 if len(tags) != 0 {
282 t.Errorf("Should find 0 tags with Related for EN Blog")
283 }
284
285 var blog1 Blog
286 DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID)
287 if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) {
288 t.Errorf("Preload many2many relations")
289 }
290
291 var tag4 = &Tag{Locale: "ZH", Value: "tag4"}
292 DB.Model(&blog2).Association("LocaleTags").Append(tag4)
293
294 DB.Model(&blog).Related(&tags, "LocaleTags")
295 if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) {
296 t.Errorf("Should find 3 tags with Related for EN Blog")
297 }
298
299 DB.Model(&blog2).Related(&tags, "LocaleTags")
300 if !compareTags(tags, []string{"tag4"}) {
301 t.Errorf("Should find 1 tags with Related for EN Blog")
302 }
303
304 // Replace
305 var tag5 = &Tag{Locale: "ZH", Value: "tag5"}
306 var tag6 = &Tag{Locale: "ZH", Value: "tag6"}
307 DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6)
308
309 var tags2 []Tag
310 DB.Model(&blog).Related(&tags2, "LocaleTags")
311 if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) {
312 t.Errorf("CN Blog's tags should not be changed after EN Blog Replace")
313 }
314
315 var blog11 Blog
316 DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale)
317 if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) {
318 t.Errorf("CN Blog's tags should not be changed after EN Blog Replace")
319 }
320
321 DB.Model(&blog2).Related(&tags2, "LocaleTags")
322 if !compareTags(tags2, []string{"tag5", "tag6"}) {
323 t.Errorf("Should find 2 tags after Replace")
324 }
325
326 var blog21 Blog
327 DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale)
328 if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) {
329 t.Errorf("EN Blog's tags should be changed after Replace")
330 }
331
332 if DB.Model(&blog).Association("LocaleTags").Count() != 3 {
333 t.Errorf("ZH Blog should has three tags after Replace")
334 }
335
336 if DB.Model(&blog2).Association("LocaleTags").Count() != 2 {
337 t.Errorf("EN Blog should has two tags after Replace")
338 }
339
340 // Delete
341 DB.Model(&blog).Association("LocaleTags").Delete(tag5)
342
343 if DB.Model(&blog).Association("LocaleTags").Count() != 3 {
344 t.Errorf("ZH Blog should has three tags after Delete with EN's tag")
345 }
346
347 if DB.Model(&blog2).Association("LocaleTags").Count() != 2 {
348 t.Errorf("EN Blog should has two tags after ZH Blog Delete with EN's tag")
349 }
350
351 DB.Model(&blog2).Association("LocaleTags").Delete(tag5)
352
353 if DB.Model(&blog).Association("LocaleTags").Count() != 3 {
354 t.Errorf("ZH Blog should has three tags after EN Blog Delete with EN's tag")
355 }
356
357 if DB.Model(&blog2).Association("LocaleTags").Count() != 1 {
358 t.Errorf("EN Blog should has 1 tags after EN Blog Delete with EN's tag")
359 }
360
361 // Clear
362 DB.Model(&blog2).Association("LocaleTags").Clear()
363 if DB.Model(&blog).Association("LocaleTags").Count() != 3 {
364 t.Errorf("ZH Blog's tags should not be cleared when clear EN Blog's tags")
365 }
366
367 if DB.Model(&blog2).Association("LocaleTags").Count() != 0 {
368 t.Errorf("EN Blog's tags should be cleared when clear EN Blog's tags")
369 }
370
371 DB.Model(&blog).Association("LocaleTags").Clear()
372 if DB.Model(&blog).Association("LocaleTags").Count() != 0 {
373 t.Errorf("ZH Blog's tags should be cleared when clear ZH Blog's tags")
374 }
375
376 if DB.Model(&blog2).Association("LocaleTags").Count() != 0 {
377 t.Errorf("EN Blog's tags should be cleared")
378 }
379 }
380 }
+0
-70
mysql.go less more
0 package gorm
1
2 import (
3 "fmt"
4 "reflect"
5 "time"
6 )
7
8 type mysql struct {
9 commonDialect
10 }
11
12 func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
13 switch value.Kind() {
14 case reflect.Bool:
15 return "boolean"
16 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
17 if autoIncrease {
18 return "int AUTO_INCREMENT"
19 }
20 return "int"
21 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
22 if autoIncrease {
23 return "int unsigned AUTO_INCREMENT"
24 }
25 return "int unsigned"
26 case reflect.Int64:
27 if autoIncrease {
28 return "bigint AUTO_INCREMENT"
29 }
30 return "bigint"
31 case reflect.Uint64:
32 if autoIncrease {
33 return "bigint unsigned AUTO_INCREMENT"
34 }
35 return "bigint unsigned"
36 case reflect.Float32, reflect.Float64:
37 return "double"
38 case reflect.String:
39 if size > 0 && size < 65532 {
40 return fmt.Sprintf("varchar(%d)", size)
41 }
42 return "longtext"
43 case reflect.Struct:
44 if _, ok := value.Interface().(time.Time); ok {
45 return "timestamp NULL"
46 }
47 default:
48 if _, ok := value.Interface().([]byte); ok {
49 if size > 0 && size < 65532 {
50 return fmt.Sprintf("varbinary(%d)", size)
51 }
52 return "longblob"
53 }
54 }
55 panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
56 }
57
58 func (mysql) Quote(key string) string {
59 return fmt.Sprintf("`%s`", key)
60 }
61
62 func (mysql) SelectFromDummyTable() string {
63 return "FROM DUAL"
64 }
65
66 func (s mysql) CurrentDatabase(scope *Scope) (name string) {
67 s.RawScanString(scope, &name, "SELECT DATABASE()")
68 return
69 }
3838
3939 var nilPointerStruct = PointerStruct{}
4040 if err := DB.Create(&nilPointerStruct).Error; err != nil {
41 t.Errorf("Failed to save nil pointer struct", err)
41 t.Error("Failed to save nil pointer struct", err)
4242 }
4343
4444 var pointerStruct2 PointerStruct
4545 if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
46 t.Errorf("Failed to query saved nil pointer struct", err)
46 t.Error("Failed to query saved nil pointer struct", err)
4747 }
4848
4949 var normalStruct2 NormalStruct
5050 if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
51 t.Errorf("Failed to query saved nil pointer struct", err)
51 t.Error("Failed to query saved nil pointer struct", err)
5252 }
5353
5454 var partialNilPointerStruct1 = PointerStruct{Num: &num}
5555 if err := DB.Create(&partialNilPointerStruct1).Error; err != nil {
56 t.Errorf("Failed to save partial nil pointer struct", err)
56 t.Error("Failed to save partial nil pointer struct", err)
5757 }
5858
5959 var pointerStruct3 PointerStruct
6060 if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num {
61 t.Errorf("Failed to query saved partial nil pointer struct", err)
61 t.Error("Failed to query saved partial nil pointer struct", err)
6262 }
6363
6464 var normalStruct3 NormalStruct
6565 if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num {
66 t.Errorf("Failed to query saved partial pointer struct", err)
66 t.Error("Failed to query saved partial pointer struct", err)
6767 }
6868
6969 var partialNilPointerStruct2 = PointerStruct{Name: &name}
7070 if err := DB.Create(&partialNilPointerStruct2).Error; err != nil {
71 t.Errorf("Failed to save partial nil pointer struct", err)
71 t.Error("Failed to save partial nil pointer struct", err)
7272 }
7373
7474 var pointerStruct4 PointerStruct
7575 if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name {
76 t.Errorf("Failed to query saved partial nil pointer struct", err)
76 t.Error("Failed to query saved partial nil pointer struct", err)
7777 }
7878
7979 var normalStruct4 NormalStruct
8080 if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name {
81 t.Errorf("Failed to query saved partial pointer struct", err)
81 t.Error("Failed to query saved partial pointer struct", err)
8282 }
8383 }
00 package gorm_test
11
2 import "testing"
2 import (
3 "reflect"
4 "sort"
5 "testing"
6 )
37
48 type Cat struct {
59 Id int
1115 Id int
1216 Name string
1317 Toys []Toy `gorm:"polymorphic:Owner;"`
18 }
19
20 type Hamster struct {
21 Id int
22 Name string
23 PreferredToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_preferred"`
24 OtherToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_other"`
1425 }
1526
1627 type Toy struct {
2031 OwnerType string
2132 }
2233
34 var compareToys = func(toys []Toy, contents []string) bool {
35 var toyContents []string
36 for _, toy := range toys {
37 toyContents = append(toyContents, toy.Name)
38 }
39 sort.Strings(toyContents)
40 sort.Strings(contents)
41 return reflect.DeepEqual(toyContents, contents)
42 }
43
2344 func TestPolymorphic(t *testing.T) {
24 DB.AutoMigrate(&Cat{})
25 DB.AutoMigrate(&Dog{})
26 DB.AutoMigrate(&Toy{})
27
28 cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat nip"}}
29 dog := Dog{Name: "Pluto", Toys: []Toy{Toy{Name: "orange ball"}, Toy{Name: "yellow ball"}}}
45 cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat toy"}}
46 dog := Dog{Name: "Pluto", Toys: []Toy{{Name: "dog toy 1"}, {Name: "dog toy 2"}}}
3047 DB.Save(&cat).Save(&dog)
3148
49 if DB.Model(&cat).Association("Toy").Count() != 1 {
50 t.Errorf("Cat's toys count should be 1")
51 }
52
53 if DB.Model(&dog).Association("Toys").Count() != 2 {
54 t.Errorf("Dog's toys count should be 2")
55 }
56
57 // Query
3258 var catToys []Toy
3359 if DB.Model(&cat).Related(&catToys, "Toy").RecordNotFound() {
3460 t.Errorf("Did not find any has one polymorphic association")
4571 t.Errorf("Should have found all polymorphic has many associations")
4672 }
4773
48 if DB.Model(&cat).Association("Toy").Count() != 1 {
49 t.Errorf("Should return one polymorphic has one association")
74 var catToy Toy
75 DB.Model(&cat).Association("Toy").Find(&catToy)
76 if catToy.Name != cat.Toy.Name {
77 t.Errorf("Should find has one polymorphic association")
78 }
79
80 var dogToys1 []Toy
81 DB.Model(&dog).Association("Toys").Find(&dogToys1)
82 if !compareToys(dogToys1, []string{"dog toy 1", "dog toy 2"}) {
83 t.Errorf("Should find has many polymorphic association")
84 }
85
86 // Append
87 DB.Model(&cat).Association("Toy").Append(&Toy{
88 Name: "cat toy 2",
89 })
90
91 var catToy2 Toy
92 DB.Model(&cat).Association("Toy").Find(&catToy2)
93 if catToy2.Name != "cat toy 2" {
94 t.Errorf("Should update has one polymorphic association with Append")
95 }
96
97 if DB.Model(&cat).Association("Toy").Count() != 1 {
98 t.Errorf("Cat's toys count should be 1 after Append")
5099 }
51100
52101 if DB.Model(&dog).Association("Toys").Count() != 2 {
53102 t.Errorf("Should return two polymorphic has many associations")
54103 }
55 }
104
105 DB.Model(&dog).Association("Toys").Append(&Toy{
106 Name: "dog toy 3",
107 })
108
109 var dogToys2 []Toy
110 DB.Model(&dog).Association("Toys").Find(&dogToys2)
111 if !compareToys(dogToys2, []string{"dog toy 1", "dog toy 2", "dog toy 3"}) {
112 t.Errorf("Dog's toys should be updated with Append")
113 }
114
115 if DB.Model(&dog).Association("Toys").Count() != 3 {
116 t.Errorf("Should return three polymorphic has many associations")
117 }
118
119 // Replace
120 DB.Model(&cat).Association("Toy").Replace(&Toy{
121 Name: "cat toy 3",
122 })
123
124 var catToy3 Toy
125 DB.Model(&cat).Association("Toy").Find(&catToy3)
126 if catToy3.Name != "cat toy 3" {
127 t.Errorf("Should update has one polymorphic association with Replace")
128 }
129
130 if DB.Model(&cat).Association("Toy").Count() != 1 {
131 t.Errorf("Cat's toys count should be 1 after Replace")
132 }
133
134 if DB.Model(&dog).Association("Toys").Count() != 3 {
135 t.Errorf("Should return three polymorphic has many associations")
136 }
137
138 DB.Model(&dog).Association("Toys").Replace(&Toy{
139 Name: "dog toy 4",
140 }, []Toy{
141 {Name: "dog toy 5"}, {Name: "dog toy 6"}, {Name: "dog toy 7"},
142 })
143
144 var dogToys3 []Toy
145 DB.Model(&dog).Association("Toys").Find(&dogToys3)
146 if !compareToys(dogToys3, []string{"dog toy 4", "dog toy 5", "dog toy 6", "dog toy 7"}) {
147 t.Errorf("Dog's toys should be updated with Replace")
148 }
149
150 if DB.Model(&dog).Association("Toys").Count() != 4 {
151 t.Errorf("Should return three polymorphic has many associations")
152 }
153
154 // Delete
155 DB.Model(&cat).Association("Toy").Delete(&catToy2)
156
157 var catToy4 Toy
158 DB.Model(&cat).Association("Toy").Find(&catToy4)
159 if catToy4.Name != "cat toy 3" {
160 t.Errorf("Should not update has one polymorphic association when Delete a unrelated Toy")
161 }
162
163 if DB.Model(&cat).Association("Toy").Count() != 1 {
164 t.Errorf("Cat's toys count should be 1")
165 }
166
167 if DB.Model(&dog).Association("Toys").Count() != 4 {
168 t.Errorf("Dog's toys count should be 4")
169 }
170
171 DB.Model(&cat).Association("Toy").Delete(&catToy3)
172
173 if !DB.Model(&cat).Related(&Toy{}, "Toy").RecordNotFound() {
174 t.Errorf("Toy should be deleted with Delete")
175 }
176
177 if DB.Model(&cat).Association("Toy").Count() != 0 {
178 t.Errorf("Cat's toys count should be 0 after Delete")
179 }
180
181 if DB.Model(&dog).Association("Toys").Count() != 4 {
182 t.Errorf("Dog's toys count should not be changed when delete cat's toy")
183 }
184
185 DB.Model(&dog).Association("Toys").Delete(&dogToys2)
186
187 if DB.Model(&dog).Association("Toys").Count() != 4 {
188 t.Errorf("Dog's toys count should not be changed when delete unrelated toys")
189 }
190
191 DB.Model(&dog).Association("Toys").Delete(&dogToys3)
192
193 if DB.Model(&dog).Association("Toys").Count() != 0 {
194 t.Errorf("Dog's toys count should be deleted with Delete")
195 }
196
197 // Clear
198 DB.Model(&cat).Association("Toy").Append(&Toy{
199 Name: "cat toy 2",
200 })
201
202 if DB.Model(&cat).Association("Toy").Count() != 1 {
203 t.Errorf("Cat's toys should be added with Append")
204 }
205
206 DB.Model(&cat).Association("Toy").Clear()
207
208 if DB.Model(&cat).Association("Toy").Count() != 0 {
209 t.Errorf("Cat's toys should be cleared with Clear")
210 }
211
212 DB.Model(&dog).Association("Toys").Append(&Toy{
213 Name: "dog toy 8",
214 })
215
216 if DB.Model(&dog).Association("Toys").Count() != 1 {
217 t.Errorf("Dog's toys should be added with Append")
218 }
219
220 DB.Model(&dog).Association("Toys").Clear()
221
222 if DB.Model(&dog).Association("Toys").Count() != 0 {
223 t.Errorf("Dog's toys should be cleared with Clear")
224 }
225 }
226
227 func TestNamedPolymorphic(t *testing.T) {
228 hamster := Hamster{Name: "Mr. Hammond", PreferredToy: Toy{Name: "bike"}, OtherToy: Toy{Name: "treadmill"}}
229 DB.Save(&hamster)
230
231 hamster2 := Hamster{}
232 DB.Preload("PreferredToy").Preload("OtherToy").Find(&hamster2, hamster.Id)
233 if hamster2.PreferredToy.Id != hamster.PreferredToy.Id || hamster2.PreferredToy.Name != hamster.PreferredToy.Name {
234 t.Errorf("Hamster's preferred toy couldn't be preloaded")
235 }
236 if hamster2.OtherToy.Id != hamster.OtherToy.Id || hamster2.OtherToy.Name != hamster.OtherToy.Name {
237 t.Errorf("Hamster's other toy couldn't be preloaded")
238 }
239
240 // clear to omit Toy.Id in count
241 hamster2.PreferredToy = Toy{}
242 hamster2.OtherToy = Toy{}
243
244 if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 {
245 t.Errorf("Hamster's preferred toy count should be 1")
246 }
247
248 if DB.Model(&hamster2).Association("OtherToy").Count() != 1 {
249 t.Errorf("Hamster's other toy count should be 1")
250 }
251
252 // Query
253 var hamsterToys []Toy
254 if DB.Model(&hamster).Related(&hamsterToys, "PreferredToy").RecordNotFound() {
255 t.Errorf("Did not find any has one polymorphic association")
256 } else if len(hamsterToys) != 1 {
257 t.Errorf("Should have found only one polymorphic has one association")
258 } else if hamsterToys[0].Name != hamster.PreferredToy.Name {
259 t.Errorf("Should have found the proper has one polymorphic association")
260 }
261
262 if DB.Model(&hamster).Related(&hamsterToys, "OtherToy").RecordNotFound() {
263 t.Errorf("Did not find any has one polymorphic association")
264 } else if len(hamsterToys) != 1 {
265 t.Errorf("Should have found only one polymorphic has one association")
266 } else if hamsterToys[0].Name != hamster.OtherToy.Name {
267 t.Errorf("Should have found the proper has one polymorphic association")
268 }
269
270 hamsterToy := Toy{}
271 DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy)
272 if hamsterToy.Name != hamster.PreferredToy.Name {
273 t.Errorf("Should find has one polymorphic association")
274 }
275 hamsterToy = Toy{}
276 DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy)
277 if hamsterToy.Name != hamster.OtherToy.Name {
278 t.Errorf("Should find has one polymorphic association")
279 }
280
281 // Append
282 DB.Model(&hamster).Association("PreferredToy").Append(&Toy{
283 Name: "bike 2",
284 })
285 DB.Model(&hamster).Association("OtherToy").Append(&Toy{
286 Name: "treadmill 2",
287 })
288
289 hamsterToy = Toy{}
290 DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy)
291 if hamsterToy.Name != "bike 2" {
292 t.Errorf("Should update has one polymorphic association with Append")
293 }
294
295 hamsterToy = Toy{}
296 DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy)
297 if hamsterToy.Name != "treadmill 2" {
298 t.Errorf("Should update has one polymorphic association with Append")
299 }
300
301 if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 {
302 t.Errorf("Hamster's toys count should be 1 after Append")
303 }
304
305 if DB.Model(&hamster2).Association("OtherToy").Count() != 1 {
306 t.Errorf("Hamster's toys count should be 1 after Append")
307 }
308
309 // Replace
310 DB.Model(&hamster).Association("PreferredToy").Replace(&Toy{
311 Name: "bike 3",
312 })
313 DB.Model(&hamster).Association("OtherToy").Replace(&Toy{
314 Name: "treadmill 3",
315 })
316
317 hamsterToy = Toy{}
318 DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy)
319 if hamsterToy.Name != "bike 3" {
320 t.Errorf("Should update has one polymorphic association with Replace")
321 }
322
323 hamsterToy = Toy{}
324 DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy)
325 if hamsterToy.Name != "treadmill 3" {
326 t.Errorf("Should update has one polymorphic association with Replace")
327 }
328
329 if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 {
330 t.Errorf("hamster's toys count should be 1 after Replace")
331 }
332
333 if DB.Model(&hamster2).Association("OtherToy").Count() != 1 {
334 t.Errorf("hamster's toys count should be 1 after Replace")
335 }
336
337 // Clear
338 DB.Model(&hamster).Association("PreferredToy").Append(&Toy{
339 Name: "bike 2",
340 })
341 DB.Model(&hamster).Association("OtherToy").Append(&Toy{
342 Name: "treadmill 2",
343 })
344
345 if DB.Model(&hamster).Association("PreferredToy").Count() != 1 {
346 t.Errorf("Hamster's toys should be added with Append")
347 }
348 if DB.Model(&hamster).Association("OtherToy").Count() != 1 {
349 t.Errorf("Hamster's toys should be added with Append")
350 }
351
352 DB.Model(&hamster).Association("PreferredToy").Clear()
353
354 if DB.Model(&hamster2).Association("PreferredToy").Count() != 0 {
355 t.Errorf("Hamster's preferred toy should be cleared with Clear")
356 }
357 if DB.Model(&hamster2).Association("OtherToy").Count() != 1 {
358 t.Errorf("Hamster's other toy should be still available")
359 }
360
361 DB.Model(&hamster).Association("OtherToy").Clear()
362 if DB.Model(&hamster).Association("OtherToy").Count() != 0 {
363 t.Errorf("Hamster's other toy should be cleared with Clear")
364 }
365 }
+0
-136
postgres.go less more
0 package gorm
1
2 import (
3 "database/sql"
4 "database/sql/driver"
5 "fmt"
6 "reflect"
7 "time"
8
9 "github.com/lib/pq/hstore"
10 )
11
12 type postgres struct {
13 commonDialect
14 }
15
16 func (postgres) BinVar(i int) string {
17 return fmt.Sprintf("$%v", i)
18 }
19
20 func (postgres) SupportLastInsertId() bool {
21 return false
22 }
23
24 func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
25 switch value.Kind() {
26 case reflect.Bool:
27 return "boolean"
28 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
29 if autoIncrease {
30 return "serial"
31 }
32 return "integer"
33 case reflect.Int64, reflect.Uint64:
34 if autoIncrease {
35 return "bigserial"
36 }
37 return "bigint"
38 case reflect.Float32, reflect.Float64:
39 return "numeric"
40 case reflect.String:
41 if size > 0 && size < 65532 {
42 return fmt.Sprintf("varchar(%d)", size)
43 }
44 return "text"
45 case reflect.Struct:
46 if _, ok := value.Interface().(time.Time); ok {
47 return "timestamp with time zone"
48 }
49 case reflect.Map:
50 if value.Type() == hstoreType {
51 return "hstore"
52 }
53 default:
54 if _, ok := value.Interface().([]byte); ok {
55 return "bytea"
56 }
57 }
58 panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
59 }
60
61 func (s postgres) ReturningStr(tableName, key string) string {
62 return fmt.Sprintf("RETURNING %v.%v", tableName, key)
63 }
64
65 func (s postgres) HasTable(scope *Scope, tableName string) bool {
66 var count int
67 s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName)
68 return count > 0
69 }
70
71 func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) bool {
72 var count int
73 s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName)
74 return count > 0
75 }
76
77 func (postgres) RemoveIndex(scope *Scope, indexName string) {
78 scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
79 }
80
81 func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
82 var count int
83 s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName)
84 return count > 0
85 }
86
87 func (s postgres) CurrentDatabase(scope *Scope) (name string) {
88 s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()")
89 return
90 }
91
92 var hstoreType = reflect.TypeOf(Hstore{})
93
94 type Hstore map[string]*string
95
96 func (h Hstore) Value() (driver.Value, error) {
97 hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
98 if len(h) == 0 {
99 return nil, nil
100 }
101
102 for key, value := range h {
103 var s sql.NullString
104 if value != nil {
105 s.String = *value
106 s.Valid = true
107 }
108 hstore.Map[key] = s
109 }
110 return hstore.Value()
111 }
112
113 func (h *Hstore) Scan(value interface{}) error {
114 hstore := hstore.Hstore{}
115
116 if err := hstore.Scan(value); err != nil {
117 return err
118 }
119
120 if len(hstore.Map) == 0 {
121 return nil
122 }
123
124 *h = Hstore{}
125 for k := range hstore.Map {
126 if hstore.Map[k].Valid {
127 s := hstore.Map[k].String
128 (*h)[k] = &s
129 } else {
130 (*h)[k] = nil
131 }
132 }
133
134 return nil
135 }
+0
-360
preload.go less more
0 package gorm
1
2 import (
3 "database/sql/driver"
4 "errors"
5 "fmt"
6 "reflect"
7 "strings"
8 )
9
10 func getRealValue(value reflect.Value, columns []string) (results []interface{}) {
11 for _, column := range columns {
12 if reflect.Indirect(value).FieldByName(column).IsValid() {
13 result := reflect.Indirect(value).FieldByName(column).Interface()
14 if r, ok := result.(driver.Valuer); ok {
15 result, _ = r.Value()
16 }
17 results = append(results, result)
18 }
19 }
20 return
21 }
22
23 func equalAsString(a interface{}, b interface{}) bool {
24 return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b)
25 }
26
27 func Preload(scope *Scope) {
28 if scope.Search.preload == nil {
29 return
30 }
31
32 preloadMap := map[string]bool{}
33 fields := scope.Fields()
34 for _, preload := range scope.Search.preload {
35 schema, conditions := preload.schema, preload.conditions
36 keys := strings.Split(schema, ".")
37 currentScope := scope
38 currentFields := fields
39 originalConditions := conditions
40 conditions = []interface{}{}
41 for i, key := range keys {
42 var found bool
43 if preloadMap[strings.Join(keys[:i+1], ".")] {
44 goto nextLoop
45 }
46
47 if i == len(keys)-1 {
48 conditions = originalConditions
49 }
50
51 for _, field := range currentFields {
52 if field.Name != key || field.Relationship == nil {
53 continue
54 }
55
56 found = true
57 switch field.Relationship.Kind {
58 case "has_one":
59 currentScope.handleHasOnePreload(field, conditions)
60 case "has_many":
61 currentScope.handleHasManyPreload(field, conditions)
62 case "belongs_to":
63 currentScope.handleBelongsToPreload(field, conditions)
64 case "many_to_many":
65 currentScope.handleHasManyToManyPreload(field, conditions)
66 default:
67 currentScope.Err(errors.New("not supported relation"))
68 }
69 break
70 }
71
72 if !found {
73 value := reflect.ValueOf(currentScope.Value)
74 if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface {
75 value = value.Index(0).Elem()
76 }
77 scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type()))
78 return
79 }
80
81 preloadMap[strings.Join(keys[:i+1], ".")] = true
82
83 nextLoop:
84 if i < len(keys)-1 {
85 currentScope = currentScope.getColumnsAsScope(key)
86 currentFields = currentScope.Fields()
87 }
88 }
89 }
90
91 }
92
93 func makeSlice(typ reflect.Type) interface{} {
94 if typ.Kind() == reflect.Slice {
95 typ = typ.Elem()
96 }
97 sliceType := reflect.SliceOf(typ)
98 slice := reflect.New(sliceType)
99 slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
100 return slice.Interface()
101 }
102
103 func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
104 relation := field.Relationship
105
106 primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
107 if len(primaryKeys) == 0 {
108 return
109 }
110
111 results := makeSlice(field.Struct.Type)
112 scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
113 resultValues := reflect.Indirect(reflect.ValueOf(results))
114
115 for i := 0; i < resultValues.Len(); i++ {
116 result := resultValues.Index(i)
117 if scope.IndirectValue().Kind() == reflect.Slice {
118 value := getRealValue(result, relation.ForeignFieldNames)
119 objects := scope.IndirectValue()
120 for j := 0; j < objects.Len(); j++ {
121 if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) {
122 reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
123 break
124 }
125 }
126 } else {
127 if err := scope.SetColumn(field, result); err != nil {
128 scope.Err(err)
129 return
130 }
131 }
132 }
133 }
134
135 func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
136 relation := field.Relationship
137 primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
138 if len(primaryKeys) == 0 {
139 return
140 }
141
142 results := makeSlice(field.Struct.Type)
143 scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
144 resultValues := reflect.Indirect(reflect.ValueOf(results))
145
146 if scope.IndirectValue().Kind() == reflect.Slice {
147 for i := 0; i < resultValues.Len(); i++ {
148 result := resultValues.Index(i)
149 value := getRealValue(result, relation.ForeignFieldNames)
150 objects := scope.IndirectValue()
151 for j := 0; j < objects.Len(); j++ {
152 object := reflect.Indirect(objects.Index(j))
153 if equalAsString(getRealValue(object, relation.AssociationForeignFieldNames), value) {
154 f := object.FieldByName(field.Name)
155 f.Set(reflect.Append(f, result))
156 break
157 }
158 }
159 }
160 } else {
161 scope.SetColumn(field, resultValues)
162 }
163 }
164
165 func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
166 relation := field.Relationship
167 primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames)
168 if len(primaryKeys) == 0 {
169 return
170 }
171
172 results := makeSlice(field.Struct.Type)
173 scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
174 resultValues := reflect.Indirect(reflect.ValueOf(results))
175
176 for i := 0; i < resultValues.Len(); i++ {
177 result := resultValues.Index(i)
178 if scope.IndirectValue().Kind() == reflect.Slice {
179 value := getRealValue(result, relation.AssociationForeignFieldNames)
180 objects := scope.IndirectValue()
181 for j := 0; j < objects.Len(); j++ {
182 object := reflect.Indirect(objects.Index(j))
183 if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) {
184 object.FieldByName(field.Name).Set(result)
185 }
186 }
187 } else {
188 scope.SetColumn(field, result)
189 }
190 }
191 }
192
193 func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interface{}) {
194 relation := field.Relationship
195
196 joinTableHandler := relation.JoinTableHandler
197 destType := field.StructField.Struct.Type.Elem()
198 var isPtr bool
199 if destType.Kind() == reflect.Ptr {
200 isPtr = true
201 destType = destType.Elem()
202 }
203
204 var sourceKeys []string
205 var linkHash = make(map[string][]reflect.Value)
206
207 for _, key := range joinTableHandler.SourceForeignKeys() {
208 sourceKeys = append(sourceKeys, key.DBName)
209 }
210
211 db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*")
212 preloadJoinDB := joinTableHandler.JoinWith(joinTableHandler, db, scope.Value)
213 if len(conditions) > 0 {
214 preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...)
215 }
216 rows, err := preloadJoinDB.Rows()
217
218 if scope.Err(err) != nil {
219 return
220 }
221 defer rows.Close()
222
223 columns, _ := rows.Columns()
224 for rows.Next() {
225 elem := reflect.New(destType).Elem()
226 var values = make([]interface{}, len(columns))
227
228 fields := scope.New(elem.Addr().Interface()).Fields()
229
230 for index, column := range columns {
231 if field, ok := fields[column]; ok {
232 if field.Field.Kind() == reflect.Ptr {
233 values[index] = field.Field.Addr().Interface()
234 } else {
235 values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
236 }
237 } else {
238 var i interface{}
239 values[index] = &i
240 }
241 }
242
243 scope.Err(rows.Scan(values...))
244
245 var sourceKey []interface{}
246
247 for index, column := range columns {
248 value := values[index]
249 if field, ok := fields[column]; ok {
250 if field.Field.Kind() == reflect.Ptr {
251 field.Field.Set(reflect.ValueOf(value).Elem())
252 } else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
253 field.Field.Set(v)
254 }
255 } else if strInSlice(column, sourceKeys) {
256 sourceKey = append(sourceKey, *(value.(*interface{})))
257 }
258 }
259
260 if len(sourceKey) != 0 {
261 if isPtr {
262 linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem.Addr())
263 } else {
264 linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem)
265 }
266 }
267 }
268
269 var associationForeignStructFieldNames []string
270 for _, dbName := range relation.AssociationForeignFieldNames {
271 if field, ok := scope.FieldByName(dbName); ok {
272 associationForeignStructFieldNames = append(associationForeignStructFieldNames, field.Name)
273 }
274 }
275
276 if scope.IndirectValue().Kind() == reflect.Slice {
277 objects := scope.IndirectValue()
278 for j := 0; j < objects.Len(); j++ {
279 object := reflect.Indirect(objects.Index(j))
280 source := getRealValue(object, associationForeignStructFieldNames)
281 field := object.FieldByName(field.Name)
282 for _, link := range linkHash[toString(source)] {
283 field.Set(reflect.Append(field, link))
284 }
285 }
286 } else {
287 object := scope.IndirectValue()
288 source := getRealValue(object, associationForeignStructFieldNames)
289 field := object.FieldByName(field.Name)
290 for _, link := range linkHash[toString(source)] {
291 field.Set(reflect.Append(field, link))
292 }
293 }
294 }
295
296 func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) {
297 values := scope.IndirectValue()
298 switch values.Kind() {
299 case reflect.Slice:
300 for i := 0; i < values.Len(); i++ {
301 var result []interface{}
302 for _, column := range columns {
303 result = append(result, reflect.Indirect(values.Index(i)).FieldByName(column).Interface())
304 }
305 results = append(results, result)
306 }
307 case reflect.Struct:
308 var result []interface{}
309 for _, column := range columns {
310 result = append(result, values.FieldByName(column).Interface())
311 }
312 return [][]interface{}{result}
313 }
314 return
315 }
316
317 func (scope *Scope) getColumnsAsScope(column string) *Scope {
318 values := scope.IndirectValue()
319 switch values.Kind() {
320 case reflect.Slice:
321 modelType := values.Type().Elem()
322 if modelType.Kind() == reflect.Ptr {
323 modelType = modelType.Elem()
324 }
325 fieldStruct, _ := modelType.FieldByName(column)
326 var columns reflect.Value
327 if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr {
328 columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem()
329 } else {
330 columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem()
331 }
332 for i := 0; i < values.Len(); i++ {
333 column := reflect.Indirect(values.Index(i)).FieldByName(column)
334 if column.Kind() == reflect.Ptr {
335 column = column.Elem()
336 }
337 if column.Kind() == reflect.Slice {
338 for i := 0; i < column.Len(); i++ {
339 elem := column.Index(i)
340 if elem.CanAddr() {
341 columns = reflect.Append(columns, elem.Addr())
342 }
343 }
344 } else {
345 if column.CanAddr() {
346 columns = reflect.Append(columns, column.Addr())
347 }
348 }
349 }
350 return scope.New(columns.Interface())
351 case reflect.Struct:
352 field := values.FieldByName(column)
353 if !field.CanAddr() {
354 return nil
355 }
356 return scope.New(field.Addr().Interface())
357 }
358 return nil
359 }
00 package gorm_test
11
22 import (
3 "database/sql"
34 "encoding/json"
45 "os"
56 "reflect"
67 "testing"
8
9 "github.com/jinzhu/gorm"
710 )
811
912 func getPreloadUser(name string) *User {
8689 }
8790 } else if len(user.Emails) != 0 {
8891 t.Errorf("should not preload any emails for other users when with condition")
89 }
92 } else if user.Emails == nil {
93 t.Errorf("should return an empty slice to indicate zero results")
94 }
95 }
96 }
97
98 func TestAutoPreload(t *testing.T) {
99 user1 := getPreloadUser("auto_user1")
100 DB.Save(user1)
101
102 preloadDB := DB.Set("gorm:auto_preload", true).Where("role = ?", "Preload")
103 var user User
104 preloadDB.Find(&user)
105 checkUserHasPreloadData(user, t)
106
107 user2 := getPreloadUser("auto_user2")
108 DB.Save(user2)
109
110 var users []User
111 preloadDB.Find(&users)
112
113 for _, user := range users {
114 checkUserHasPreloadData(user, t)
115 }
116
117 var users2 []*User
118 preloadDB.Find(&users2)
119
120 for _, user := range users2 {
121 checkUserHasPreloadData(*user, t)
90122 }
91123 }
92124
112144 DB.DropTableIfExists(&Level2{})
113145 DB.DropTableIfExists(&Level1{})
114146 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
115 panic(err)
147 t.Error(err)
116148 }
117149
118150 want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}}
119151 if err := DB.Create(&want).Error; err != nil {
120 panic(err)
152 t.Error(err)
121153 }
122154
123155 var got Level3
124156 if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil {
125 panic(err)
126 }
127
128 if !reflect.DeepEqual(got, want) {
129 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
157 t.Error(err)
158 }
159
160 if !reflect.DeepEqual(got, want) {
161 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
162 }
163
164 if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
165 t.Error(err)
130166 }
131167 }
132168
152188 DB.DropTableIfExists(&Level2{})
153189 DB.DropTableIfExists(&Level1{})
154190 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
155 panic(err)
191 t.Error(err)
156192 }
157193
158194 want := Level3{
159195 Level2s: []Level2{
160196 {
161197 Level1s: []*Level1{
162 &Level1{Value: "value1"},
163 &Level1{Value: "value2"},
198 {Value: "value1"},
199 {Value: "value2"},
164200 },
165201 },
166202 {
167203 Level1s: []*Level1{
168 &Level1{Value: "value3"},
204 {Value: "value3"},
169205 },
170206 },
171207 },
172208 }
173209 if err := DB.Create(&want).Error; err != nil {
174 panic(err)
210 t.Error(err)
175211 }
176212
177213 var got Level3
178214 if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil {
179 panic(err)
215 t.Error(err)
180216 }
181217
182218 if !reflect.DeepEqual(got, want) {
206242 DB.DropTableIfExists(&Level2{})
207243 DB.DropTableIfExists(&Level1{})
208244 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
209 panic(err)
245 t.Error(err)
210246 }
211247
212248 want := Level3{
216252 },
217253 }
218254 if err := DB.Create(&want).Error; err != nil {
219 panic(err)
255 t.Error(err)
220256 }
221257
222258 var got Level3
223259 if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil {
224 panic(err)
260 t.Error(err)
225261 }
226262
227263 if !reflect.DeepEqual(got, want) {
251287 DB.DropTableIfExists(&Level2{})
252288 DB.DropTableIfExists(&Level1{})
253289 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
254 panic(err)
290 t.Error(err)
255291 }
256292
257293 want := Level3{
258294 Level2: Level2{
259295 Level1s: []Level1{
260 Level1{Value: "value1"},
261 Level1{Value: "value2"},
296 {Value: "value1"},
297 {Value: "value2"},
262298 },
263299 },
264300 }
265301 if err := DB.Create(&want).Error; err != nil {
266 panic(err)
302 t.Error(err)
267303 }
268304
269305 var got Level3
270306 if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil {
271 panic(err)
307 t.Error(err)
272308 }
273309
274310 if !reflect.DeepEqual(got, want) {
299335 DB.DropTableIfExists(&Level2{})
300336 DB.DropTableIfExists(&Level1{})
301337 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
302 panic(err)
338 t.Error(err)
303339 }
304340
305341 want := make([]Level3, 2)
306342 want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}}
307343 if err := DB.Create(&want[0]).Error; err != nil {
308 panic(err)
344 t.Error(err)
309345 }
310346 want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}}
311347 if err := DB.Create(&want[1]).Error; err != nil {
312 panic(err)
348 t.Error(err)
313349 }
314350
315351 var got []Level3
316352 if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil {
317 panic(err)
353 t.Error(err)
318354 }
319355
320356 if !reflect.DeepEqual(got, want) {
344380 DB.DropTableIfExists(&Level2{})
345381 DB.DropTableIfExists(&Level1{})
346382 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
347 panic(err)
383 t.Error(err)
348384 }
349385
350386 want := make([]Level3, 2)
364400 },
365401 }
366402 if err := DB.Create(&want[0]).Error; err != nil {
367 panic(err)
403 t.Error(err)
368404 }
369405
370406 want[1] = Level3{
383419 },
384420 }
385421 if err := DB.Create(&want[1]).Error; err != nil {
386 panic(err)
422 t.Error(err)
387423 }
388424
389425 var got []Level3
390426 if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil {
391 panic(err)
427 t.Error(err)
392428 }
393429
394430 if !reflect.DeepEqual(got, want) {
418454 DB.DropTableIfExists(&Level2{})
419455 DB.DropTableIfExists(&Level1{})
420456 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
421 panic(err)
457 t.Error(err)
422458 }
423459
424460 want := make([]Level3, 2)
429465 },
430466 }
431467 if err := DB.Create(&want[0]).Error; err != nil {
432 panic(err)
468 t.Error(err)
433469 }
434470
435471 want[1] = Level3{
439475 },
440476 }
441477 if err := DB.Create(&want[1]).Error; err != nil {
442 panic(err)
478 t.Error(err)
443479 }
444480
445481 var got []Level3
446482 if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil {
447 panic(err)
483 t.Error(err)
448484 }
449485
450486 if !reflect.DeepEqual(got, want) {
474510 DB.DropTableIfExists(&Level2{})
475511 DB.DropTableIfExists(&Level1{})
476512 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
477 panic(err)
513 t.Error(err)
478514 }
479515
480516 want := make([]Level3, 2)
481517 want[0] = Level3{
482518 Level2: Level2{
483519 Level1s: []Level1{
484 Level1{Value: "value1"},
485 Level1{Value: "value2"},
520 {Value: "value1"},
521 {Value: "value2"},
486522 },
487523 },
488524 }
489525 if err := DB.Create(&want[0]).Error; err != nil {
490 panic(err)
526 t.Error(err)
491527 }
492528 want[1] = Level3{
493529 Level2: Level2{
494530 Level1s: []Level1{
495 Level1{Value: "value3"},
496 Level1{Value: "value4"},
531 {Value: "value3"},
532 {Value: "value4"},
497533 },
498534 },
499535 }
500536 if err := DB.Create(&want[1]).Error; err != nil {
501 panic(err)
537 t.Error(err)
502538 }
503539
504540 var got []Level3
505541 if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil {
506 panic(err)
542 t.Error(err)
507543 }
508544
509545 if !reflect.DeepEqual(got, want) {
548584 DB.DropTableIfExists(&Level1{})
549585 DB.DropTableIfExists(&Level0{})
550586 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil {
551 panic(err)
587 t.Error(err)
552588 }
553589
554590 want := make([]Level3, 2)
555591 want[0] = Level3{
556592 Level2: Level2{
557593 Level1s: []Level1{
558 Level1{Value: "value1"},
559 Level1{Value: "value2"},
594 {Value: "value1"},
595 {Value: "value2"},
560596 },
561597 },
562598 Level2_1: Level2_1{
563599 Level1s: []Level1{
564 Level1{
600 {
565601 Value: "value1-1",
566602 Level0s: []Level0{{Value: "Level0-1"}},
567603 },
568 Level1{
604 {
569605 Value: "value2-2",
570606 Level0s: []Level0{{Value: "Level0-2"}},
571607 },
573609 },
574610 }
575611 if err := DB.Create(&want[0]).Error; err != nil {
576 panic(err)
612 t.Error(err)
577613 }
578614 want[1] = Level3{
579615 Level2: Level2{
580616 Level1s: []Level1{
581 Level1{Value: "value3"},
582 Level1{Value: "value4"},
617 {Value: "value3"},
618 {Value: "value4"},
583619 },
584620 },
585621 Level2_1: Level2_1{
586622 Level1s: []Level1{
587 Level1{Value: "value3-3"},
588 Level1{Value: "value4-4"},
623 {
624 Value: "value3-3",
625 Level0s: []Level0{},
626 },
627 {
628 Value: "value4-4",
629 Level0s: []Level0{},
630 },
589631 },
590632 },
591633 }
592634 if err := DB.Create(&want[1]).Error; err != nil {
593 panic(err)
635 t.Error(err)
594636 }
595637
596638 var got []Level3
597639 if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil {
598 panic(err)
640 t.Error(err)
641 }
642
643 if !reflect.DeepEqual(got, want) {
644 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
645 }
646 }
647
648 type LevelA1 struct {
649 ID uint
650 Value string
651 }
652
653 type LevelA2 struct {
654 ID uint
655 Value string
656 LevelA3s []*LevelA3
657 }
658
659 type LevelA3 struct {
660 ID uint
661 Value string
662 LevelA1ID sql.NullInt64
663 LevelA1 *LevelA1
664 LevelA2ID sql.NullInt64
665 LevelA2 *LevelA2
666 }
667
668 func TestNestedPreload10(t *testing.T) {
669 DB.DropTableIfExists(&LevelA3{})
670 DB.DropTableIfExists(&LevelA2{})
671 DB.DropTableIfExists(&LevelA1{})
672
673 if err := DB.AutoMigrate(&LevelA1{}, &LevelA2{}, &LevelA3{}).Error; err != nil {
674 t.Error(err)
675 }
676
677 levelA1 := &LevelA1{Value: "foo"}
678 if err := DB.Save(levelA1).Error; err != nil {
679 t.Error(err)
680 }
681
682 want := []*LevelA2{
683 {
684 Value: "bar",
685 LevelA3s: []*LevelA3{
686 {
687 Value: "qux",
688 LevelA1: levelA1,
689 },
690 },
691 },
692 {
693 Value: "bar 2",
694 LevelA3s: []*LevelA3{},
695 },
696 }
697 for _, levelA2 := range want {
698 if err := DB.Save(levelA2).Error; err != nil {
699 t.Error(err)
700 }
701 }
702
703 var got []*LevelA2
704 if err := DB.Preload("LevelA3s.LevelA1").Find(&got).Error; err != nil {
705 t.Error(err)
706 }
707
708 if !reflect.DeepEqual(got, want) {
709 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
710 }
711 }
712
713 type LevelB1 struct {
714 ID uint
715 Value string
716 LevelB3s []*LevelB3
717 }
718
719 type LevelB2 struct {
720 ID uint
721 Value string
722 }
723
724 type LevelB3 struct {
725 ID uint
726 Value string
727 LevelB1ID sql.NullInt64
728 LevelB1 *LevelB1
729 LevelB2s []*LevelB2 `gorm:"many2many:levelb1_levelb3_levelb2s"`
730 }
731
732 func TestNestedPreload11(t *testing.T) {
733 DB.DropTableIfExists(&LevelB2{})
734 DB.DropTableIfExists(&LevelB3{})
735 DB.DropTableIfExists(&LevelB1{})
736 if err := DB.AutoMigrate(&LevelB1{}, &LevelB2{}, &LevelB3{}).Error; err != nil {
737 t.Error(err)
738 }
739
740 levelB1 := &LevelB1{Value: "foo"}
741 if err := DB.Create(levelB1).Error; err != nil {
742 t.Error(err)
743 }
744
745 levelB3 := &LevelB3{
746 Value: "bar",
747 LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)},
748 }
749 if err := DB.Create(levelB3).Error; err != nil {
750 t.Error(err)
751 }
752 levelB1.LevelB3s = []*LevelB3{levelB3}
753
754 want := []*LevelB1{levelB1}
755 var got []*LevelB1
756 if err := DB.Preload("LevelB3s.LevelB2s").Find(&got).Error; err != nil {
757 t.Error(err)
758 }
759
760 if !reflect.DeepEqual(got, want) {
761 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
762 }
763 }
764
765 type LevelC1 struct {
766 ID uint
767 Value string
768 LevelC2ID uint
769 }
770
771 type LevelC2 struct {
772 ID uint
773 Value string
774 LevelC1 LevelC1
775 }
776
777 type LevelC3 struct {
778 ID uint
779 Value string
780 LevelC2ID uint
781 LevelC2 LevelC2
782 }
783
784 func TestNestedPreload12(t *testing.T) {
785 DB.DropTableIfExists(&LevelC2{})
786 DB.DropTableIfExists(&LevelC3{})
787 DB.DropTableIfExists(&LevelC1{})
788 if err := DB.AutoMigrate(&LevelC1{}, &LevelC2{}, &LevelC3{}).Error; err != nil {
789 t.Error(err)
790 }
791
792 level2 := LevelC2{
793 Value: "c2",
794 LevelC1: LevelC1{
795 Value: "c1",
796 },
797 }
798 DB.Create(&level2)
799
800 want := []LevelC3{
801 {
802 Value: "c3-1",
803 LevelC2: level2,
804 }, {
805 Value: "c3-2",
806 LevelC2: level2,
807 },
808 }
809
810 for i := range want {
811 if err := DB.Create(&want[i]).Error; err != nil {
812 t.Error(err)
813 }
814 }
815
816 var got []LevelC3
817 if err := DB.Preload("LevelC2").Preload("LevelC2.LevelC1").Find(&got).Error; err != nil {
818 t.Error(err)
599819 }
600820
601821 if !reflect.DeepEqual(got, want) {
604824 }
605825
606826 func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
607 if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" {
827 if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" || dialect == "mssql" {
608828 return
609829 }
610830
624844
625845 DB.DropTableIfExists(&Level2{})
626846 DB.DropTableIfExists(&Level1{})
627 DB.Table("levels").DropTableIfExists("levels")
847 DB.DropTableIfExists("levels")
628848
629849 if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
630 panic(err)
850 t.Error(err)
631851 }
632852
633853 want := Level2{Value: "Bob", LanguageCode: "ru", Level1s: []Level1{
635855 {Value: "en", LanguageCode: "en"},
636856 }}
637857 if err := DB.Save(&want).Error; err != nil {
638 panic(err)
858 t.Error(err)
639859 }
640860
641861 want2 := Level2{Value: "Tom", LanguageCode: "zh", Level1s: []Level1{
643863 {Value: "de", LanguageCode: "de"},
644864 }}
645865 if err := DB.Save(&want2).Error; err != nil {
646 panic(err)
866 t.Error(err)
647867 }
648868
649869 var got Level2
650870 if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
651 panic(err)
871 t.Error(err)
652872 }
653873
654874 if !reflect.DeepEqual(got, want) {
657877
658878 var got2 Level2
659879 if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
660 panic(err)
880 t.Error(err)
661881 }
662882
663883 if !reflect.DeepEqual(got2, want2) {
666886
667887 var got3 []Level2
668888 if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
669 panic(err)
889 t.Error(err)
670890 }
671891
672892 if !reflect.DeepEqual(got3, []Level2{got, got2}) {
675895
676896 var got4 []Level2
677897 if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
678 panic(err)
898 t.Error(err)
679899 }
680900
681901 var ruLevel1 Level1
688908 if !reflect.DeepEqual(got4, []Level2{got, got2}) {
689909 t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2}))
690910 }
691 }
692
693 func TestManyToManyPreloadForPointer(t *testing.T) {
694 type (
695 Level1 struct {
696 ID uint `gorm:"primary_key;"`
911
912 if err := DB.Preload("Level1s").Find(&got4, "value IN (?)", []string{"non-existing"}).Error; err != nil {
913 t.Error(err)
914 }
915 }
916
917 func TestManyToManyPreloadForNestedPointer(t *testing.T) {
918 type (
919 Level1 struct {
920 ID uint
697921 Value string
698922 }
699923 Level2 struct {
700 ID uint `gorm:"primary_key;"`
924 ID uint
701925 Value string
702926 Level1s []*Level1 `gorm:"many2many:levels;"`
703927 }
704 )
705
706 DB.DropTableIfExists(&Level2{})
707 DB.DropTableIfExists(&Level1{})
708 DB.Table("levels").DropTableIfExists("levels")
928 Level3 struct {
929 ID uint
930 Value string
931 Level2ID sql.NullInt64
932 Level2 *Level2
933 }
934 )
935
936 DB.DropTableIfExists(&Level3{})
937 DB.DropTableIfExists(&Level2{})
938 DB.DropTableIfExists(&Level1{})
939 DB.DropTableIfExists("levels")
940
941 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
942 t.Error(err)
943 }
944
945 want := Level3{
946 Value: "Bob",
947 Level2: &Level2{
948 Value: "Foo",
949 Level1s: []*Level1{
950 {Value: "ru"},
951 {Value: "en"},
952 },
953 },
954 }
955 if err := DB.Save(&want).Error; err != nil {
956 t.Error(err)
957 }
958
959 want2 := Level3{
960 Value: "Tom",
961 Level2: &Level2{
962 Value: "Bar",
963 Level1s: []*Level1{
964 {Value: "zh"},
965 {Value: "de"},
966 },
967 },
968 }
969 if err := DB.Save(&want2).Error; err != nil {
970 t.Error(err)
971 }
972
973 var got Level3
974 if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
975 t.Error(err)
976 }
977
978 if !reflect.DeepEqual(got, want) {
979 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
980 }
981
982 var got2 Level3
983 if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
984 t.Error(err)
985 }
986
987 if !reflect.DeepEqual(got2, want2) {
988 t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2))
989 }
990
991 var got3 []Level3
992 if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
993 t.Error(err)
994 }
995
996 if !reflect.DeepEqual(got3, []Level3{got, got2}) {
997 t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level3{got, got2}))
998 }
999
1000 var got4 []Level3
1001 if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
1002 t.Error(err)
1003 }
1004
1005 var got5 Level3
1006 DB.Preload("Level2.Level1s").Find(&got5, "value = ?", "bogus")
1007
1008 var ruLevel1 Level1
1009 var zhLevel1 Level1
1010 DB.First(&ruLevel1, "value = ?", "ru")
1011 DB.First(&zhLevel1, "value = ?", "zh")
1012
1013 got.Level2.Level1s = []*Level1{&ruLevel1}
1014 got2.Level2.Level1s = []*Level1{&zhLevel1}
1015 if !reflect.DeepEqual(got4, []Level3{got, got2}) {
1016 t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level3{got, got2}))
1017 }
1018 }
1019
1020 func TestNestedManyToManyPreload(t *testing.T) {
1021 type (
1022 Level1 struct {
1023 ID uint
1024 Value string
1025 }
1026 Level2 struct {
1027 ID uint
1028 Value string
1029 Level1s []*Level1 `gorm:"many2many:level1_level2;"`
1030 }
1031 Level3 struct {
1032 ID uint
1033 Value string
1034 Level2s []Level2 `gorm:"many2many:level2_level3;"`
1035 }
1036 )
1037
1038 DB.DropTableIfExists(&Level1{})
1039 DB.DropTableIfExists(&Level2{})
1040 DB.DropTableIfExists(&Level3{})
1041 DB.DropTableIfExists("level1_level2")
1042 DB.DropTableIfExists("level2_level3")
1043
1044 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
1045 t.Error(err)
1046 }
1047
1048 want := Level3{
1049 Value: "Level3",
1050 Level2s: []Level2{
1051 {
1052 Value: "Bob",
1053 Level1s: []*Level1{
1054 {Value: "ru"},
1055 {Value: "en"},
1056 },
1057 }, {
1058 Value: "Tom",
1059 Level1s: []*Level1{
1060 {Value: "zh"},
1061 {Value: "de"},
1062 },
1063 },
1064 },
1065 }
1066
1067 if err := DB.Save(&want).Error; err != nil {
1068 t.Error(err)
1069 }
1070
1071 var got Level3
1072 if err := DB.Preload("Level2s").Preload("Level2s.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil {
1073 t.Error(err)
1074 }
1075
1076 if !reflect.DeepEqual(got, want) {
1077 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
1078 }
1079
1080 if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
1081 t.Error(err)
1082 }
1083 }
1084
1085 func TestNestedManyToManyPreload2(t *testing.T) {
1086 type (
1087 Level1 struct {
1088 ID uint
1089 Value string
1090 }
1091 Level2 struct {
1092 ID uint
1093 Value string
1094 Level1s []*Level1 `gorm:"many2many:level1_level2;"`
1095 }
1096 Level3 struct {
1097 ID uint
1098 Value string
1099 Level2ID sql.NullInt64
1100 Level2 *Level2
1101 }
1102 )
1103
1104 DB.DropTableIfExists(&Level1{})
1105 DB.DropTableIfExists(&Level2{})
1106 DB.DropTableIfExists(&Level3{})
1107 DB.DropTableIfExists("level1_level2")
1108
1109 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
1110 t.Error(err)
1111 }
1112
1113 want := Level3{
1114 Value: "Level3",
1115 Level2: &Level2{
1116 Value: "Bob",
1117 Level1s: []*Level1{
1118 {Value: "ru"},
1119 {Value: "en"},
1120 },
1121 },
1122 }
1123
1124 if err := DB.Save(&want).Error; err != nil {
1125 t.Error(err)
1126 }
1127
1128 var got Level3
1129 if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil {
1130 t.Error(err)
1131 }
1132
1133 if !reflect.DeepEqual(got, want) {
1134 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
1135 }
1136
1137 if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
1138 t.Error(err)
1139 }
1140 }
1141
1142 func TestNestedManyToManyPreload3(t *testing.T) {
1143 type (
1144 Level1 struct {
1145 ID uint
1146 Value string
1147 }
1148 Level2 struct {
1149 ID uint
1150 Value string
1151 Level1s []*Level1 `gorm:"many2many:level1_level2;"`
1152 }
1153 Level3 struct {
1154 ID uint
1155 Value string
1156 Level2ID sql.NullInt64
1157 Level2 *Level2
1158 }
1159 )
1160
1161 DB.DropTableIfExists(&Level1{})
1162 DB.DropTableIfExists(&Level2{})
1163 DB.DropTableIfExists(&Level3{})
1164 DB.DropTableIfExists("level1_level2")
1165
1166 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
1167 t.Error(err)
1168 }
1169
1170 level1Zh := &Level1{Value: "zh"}
1171 level1Ru := &Level1{Value: "ru"}
1172 level1En := &Level1{Value: "en"}
1173
1174 level21 := &Level2{
1175 Value: "Level2-1",
1176 Level1s: []*Level1{level1Zh, level1Ru},
1177 }
1178
1179 level22 := &Level2{
1180 Value: "Level2-2",
1181 Level1s: []*Level1{level1Zh, level1En},
1182 }
1183
1184 wants := []*Level3{
1185 {
1186 Value: "Level3-1",
1187 Level2: level21,
1188 },
1189 {
1190 Value: "Level3-2",
1191 Level2: level22,
1192 },
1193 {
1194 Value: "Level3-3",
1195 Level2: level21,
1196 },
1197 }
1198
1199 for _, want := range wants {
1200 if err := DB.Save(&want).Error; err != nil {
1201 t.Error(err)
1202 }
1203 }
1204
1205 var gots []*Level3
1206 if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB {
1207 return db.Order("level1.id ASC")
1208 }).Find(&gots).Error; err != nil {
1209 t.Error(err)
1210 }
1211
1212 if !reflect.DeepEqual(gots, wants) {
1213 t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants))
1214 }
1215 }
1216
1217 func TestNestedManyToManyPreload3ForStruct(t *testing.T) {
1218 type (
1219 Level1 struct {
1220 ID uint
1221 Value string
1222 }
1223 Level2 struct {
1224 ID uint
1225 Value string
1226 Level1s []Level1 `gorm:"many2many:level1_level2;"`
1227 }
1228 Level3 struct {
1229 ID uint
1230 Value string
1231 Level2ID sql.NullInt64
1232 Level2 Level2
1233 }
1234 )
1235
1236 DB.DropTableIfExists(&Level1{})
1237 DB.DropTableIfExists(&Level2{})
1238 DB.DropTableIfExists(&Level3{})
1239 DB.DropTableIfExists("level1_level2")
1240
1241 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
1242 t.Error(err)
1243 }
1244
1245 level1Zh := Level1{Value: "zh"}
1246 level1Ru := Level1{Value: "ru"}
1247 level1En := Level1{Value: "en"}
1248
1249 level21 := Level2{
1250 Value: "Level2-1",
1251 Level1s: []Level1{level1Zh, level1Ru},
1252 }
1253
1254 level22 := Level2{
1255 Value: "Level2-2",
1256 Level1s: []Level1{level1Zh, level1En},
1257 }
1258
1259 wants := []*Level3{
1260 {
1261 Value: "Level3-1",
1262 Level2: level21,
1263 },
1264 {
1265 Value: "Level3-2",
1266 Level2: level22,
1267 },
1268 {
1269 Value: "Level3-3",
1270 Level2: level21,
1271 },
1272 }
1273
1274 for _, want := range wants {
1275 if err := DB.Save(&want).Error; err != nil {
1276 t.Error(err)
1277 }
1278 }
1279
1280 var gots []*Level3
1281 if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB {
1282 return db.Order("level1.id ASC")
1283 }).Find(&gots).Error; err != nil {
1284 t.Error(err)
1285 }
1286
1287 if !reflect.DeepEqual(gots, wants) {
1288 t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants))
1289 }
1290 }
1291
1292 func TestNestedManyToManyPreload4(t *testing.T) {
1293 type (
1294 Level4 struct {
1295 ID uint
1296 Value string
1297 Level3ID uint
1298 }
1299 Level3 struct {
1300 ID uint
1301 Value string
1302 Level4s []*Level4
1303 }
1304 Level2 struct {
1305 ID uint
1306 Value string
1307 Level3s []*Level3 `gorm:"many2many:level2_level3;"`
1308 }
1309 Level1 struct {
1310 ID uint
1311 Value string
1312 Level2s []*Level2 `gorm:"many2many:level1_level2;"`
1313 }
1314 )
1315
1316 DB.DropTableIfExists(&Level1{})
1317 DB.DropTableIfExists(&Level2{})
1318 DB.DropTableIfExists(&Level3{})
1319 DB.DropTableIfExists(&Level4{})
1320 DB.DropTableIfExists("level1_level2")
1321 DB.DropTableIfExists("level2_level3")
1322
1323 dummy := Level1{
1324 Value: "Level1",
1325 Level2s: []*Level2{{
1326 Value: "Level2",
1327 Level3s: []*Level3{{
1328 Value: "Level3",
1329 Level4s: []*Level4{{
1330 Value: "Level4",
1331 }},
1332 }},
1333 }},
1334 }
1335
1336 if err := DB.AutoMigrate(&Level4{}, &Level3{}, &Level2{}, &Level1{}).Error; err != nil {
1337 t.Error(err)
1338 }
1339
1340 if err := DB.Save(&dummy).Error; err != nil {
1341 t.Error(err)
1342 }
1343
1344 var level1 Level1
1345 if err := DB.Preload("Level2s").Preload("Level2s.Level3s").Preload("Level2s.Level3s.Level4s").First(&level1).Error; err != nil {
1346 t.Error(err)
1347 }
1348 }
1349
1350 func TestManyToManyPreloadForPointer(t *testing.T) {
1351 type (
1352 Level1 struct {
1353 ID uint
1354 Value string
1355 }
1356 Level2 struct {
1357 ID uint
1358 Value string
1359 Level1s []*Level1 `gorm:"many2many:levels;"`
1360 }
1361 )
1362
1363 DB.DropTableIfExists(&Level2{})
1364 DB.DropTableIfExists(&Level1{})
1365 DB.DropTableIfExists("levels")
7091366
7101367 if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
711 panic(err)
1368 t.Error(err)
7121369 }
7131370
7141371 want := Level2{Value: "Bob", Level1s: []*Level1{
7161373 {Value: "en"},
7171374 }}
7181375 if err := DB.Save(&want).Error; err != nil {
719 panic(err)
1376 t.Error(err)
7201377 }
7211378
7221379 want2 := Level2{Value: "Tom", Level1s: []*Level1{
7241381 {Value: "de"},
7251382 }}
7261383 if err := DB.Save(&want2).Error; err != nil {
727 panic(err)
1384 t.Error(err)
7281385 }
7291386
7301387 var got Level2
7311388 if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
732 panic(err)
1389 t.Error(err)
7331390 }
7341391
7351392 if !reflect.DeepEqual(got, want) {
7381395
7391396 var got2 Level2
7401397 if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
741 panic(err)
1398 t.Error(err)
7421399 }
7431400
7441401 if !reflect.DeepEqual(got2, want2) {
7471404
7481405 var got3 []Level2
7491406 if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
750 panic(err)
1407 t.Error(err)
7511408 }
7521409
7531410 if !reflect.DeepEqual(got3, []Level2{got, got2}) {
7561413
7571414 var got4 []Level2
7581415 if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
759 panic(err)
760 }
1416 t.Error(err)
1417 }
1418
1419 var got5 Level2
1420 DB.Preload("Level1s").First(&got5, "value = ?", "bogus")
7611421
7621422 var ruLevel1 Level1
7631423 var zhLevel1 Level1
7741434 func TestNilPointerSlice(t *testing.T) {
7751435 type (
7761436 Level3 struct {
777 ID uint `gorm:"primary_key;"`
1437 ID uint
7781438 Value string
7791439 }
7801440 Level2 struct {
781 ID uint `gorm:"primary_key;"`
1441 ID uint
7821442 Value string
7831443 Level3ID uint
7841444 Level3 *Level3
7851445 }
7861446 Level1 struct {
787 ID uint `gorm:"primary_key;"`
1447 ID uint
7881448 Value string
7891449 Level2ID uint
7901450 Level2 *Level2
7961456 DB.DropTableIfExists(&Level1{})
7971457
7981458 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
799 panic(err)
800 }
801
802 want := Level1{Value: "Bob", Level2: &Level2{
803 Value: "en",
804 Level3: &Level3{
805 Value: "native",
806 },
807 }}
1459 t.Error(err)
1460 }
1461
1462 want := Level1{
1463 Value: "Bob",
1464 Level2: &Level2{
1465 Value: "en",
1466 Level3: &Level3{
1467 Value: "native",
1468 },
1469 },
1470 }
8081471 if err := DB.Save(&want).Error; err != nil {
809 panic(err)
810 }
811
812 want2 := Level1{Value: "Tom", Level2: nil}
1472 t.Error(err)
1473 }
1474
1475 want2 := Level1{
1476 Value: "Tom",
1477 Level2: nil,
1478 }
8131479 if err := DB.Save(&want2).Error; err != nil {
814 panic(err)
1480 t.Error(err)
8151481 }
8161482
8171483 var got []Level1
8181484 if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).Error; err != nil {
819 panic(err)
1485 t.Error(err)
8201486 }
8211487
8221488 if len(got) != 2 {
823 t.Fatalf("got %v items, expected 2", len(got))
1489 t.Errorf("got %v items, expected 2", len(got))
8241490 }
8251491
8261492 if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) {
8291495
8301496 if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) {
8311497 t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want2))
1498 }
1499 }
1500
1501 func TestNilPointerSlice2(t *testing.T) {
1502 type (
1503 Level4 struct {
1504 ID uint
1505 }
1506 Level3 struct {
1507 ID uint
1508 Level4ID sql.NullInt64 `sql:"index"`
1509 Level4 *Level4
1510 }
1511 Level2 struct {
1512 ID uint
1513 Level3s []*Level3 `gorm:"many2many:level2_level3s"`
1514 }
1515 Level1 struct {
1516 ID uint
1517 Level2ID sql.NullInt64 `sql:"index"`
1518 Level2 *Level2
1519 }
1520 )
1521
1522 DB.DropTableIfExists(new(Level4))
1523 DB.DropTableIfExists(new(Level3))
1524 DB.DropTableIfExists(new(Level2))
1525 DB.DropTableIfExists(new(Level1))
1526
1527 if err := DB.AutoMigrate(new(Level4), new(Level3), new(Level2), new(Level1)).Error; err != nil {
1528 t.Error(err)
1529 }
1530
1531 want := new(Level1)
1532 if err := DB.Save(want).Error; err != nil {
1533 t.Error(err)
1534 }
1535
1536 got := new(Level1)
1537 err := DB.Preload("Level2.Level3s.Level4").Last(&got).Error
1538 if err != nil {
1539 t.Error(err)
1540 }
1541
1542 if !reflect.DeepEqual(got, want) {
1543 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
1544 }
1545 }
1546
1547 func TestPrefixedPreloadDuplication(t *testing.T) {
1548 type (
1549 Level4 struct {
1550 ID uint
1551 Name string
1552 Level3ID uint
1553 }
1554 Level3 struct {
1555 ID uint
1556 Name string
1557 Level4s []*Level4
1558 }
1559 Level2 struct {
1560 ID uint
1561 Name string
1562 Level3ID sql.NullInt64 `sql:"index"`
1563 Level3 *Level3
1564 }
1565 Level1 struct {
1566 ID uint
1567 Name string
1568 Level2ID sql.NullInt64 `sql:"index"`
1569 Level2 *Level2
1570 }
1571 )
1572
1573 DB.DropTableIfExists(new(Level3))
1574 DB.DropTableIfExists(new(Level4))
1575 DB.DropTableIfExists(new(Level2))
1576 DB.DropTableIfExists(new(Level1))
1577
1578 if err := DB.AutoMigrate(new(Level3), new(Level4), new(Level2), new(Level1)).Error; err != nil {
1579 t.Error(err)
1580 }
1581
1582 lvl := &Level3{}
1583 if err := DB.Save(lvl).Error; err != nil {
1584 t.Error(err)
1585 }
1586
1587 sublvl1 := &Level4{Level3ID: lvl.ID}
1588 if err := DB.Save(sublvl1).Error; err != nil {
1589 t.Error(err)
1590 }
1591 sublvl2 := &Level4{Level3ID: lvl.ID}
1592 if err := DB.Save(sublvl2).Error; err != nil {
1593 t.Error(err)
1594 }
1595
1596 lvl.Level4s = []*Level4{sublvl1, sublvl2}
1597
1598 want1 := Level1{
1599 Level2: &Level2{
1600 Level3: lvl,
1601 },
1602 }
1603 if err := DB.Save(&want1).Error; err != nil {
1604 t.Error(err)
1605 }
1606
1607 want2 := Level1{
1608 Level2: &Level2{
1609 Level3: lvl,
1610 },
1611 }
1612 if err := DB.Save(&want2).Error; err != nil {
1613 t.Error(err)
1614 }
1615
1616 want := []Level1{want1, want2}
1617
1618 var got []Level1
1619 err := DB.Preload("Level2.Level3.Level4s").Find(&got).Error
1620 if err != nil {
1621 t.Error(err)
1622 }
1623
1624 if !reflect.DeepEqual(got, want) {
1625 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
1626 }
1627 }
1628
1629 func TestPreloadManyToManyCallbacks(t *testing.T) {
1630 type (
1631 Level2 struct {
1632 ID uint
1633 Name string
1634 }
1635 Level1 struct {
1636 ID uint
1637 Name string
1638 Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"`
1639 }
1640 )
1641
1642 DB.DropTableIfExists("level1_level2s")
1643 DB.DropTableIfExists(new(Level1))
1644 DB.DropTableIfExists(new(Level2))
1645
1646 if err := DB.AutoMigrate(new(Level1), new(Level2)).Error; err != nil {
1647 t.Error(err)
1648 }
1649
1650 lvl := Level1{
1651 Name: "l1",
1652 Level2s: []Level2{
1653 Level2{Name: "l2-1"}, Level2{Name: "l2-2"},
1654 },
1655 }
1656 DB.Save(&lvl)
1657
1658 called := 0
1659
1660 DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) {
1661 called = called + 1
1662 })
1663
1664 DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID)
1665
1666 if called != 3 {
1667 t.Errorf("Wanted callback to be called 3 times but got %d", called)
8321668 }
8331669 }
8341670
44 "reflect"
55
66 "github.com/jinzhu/gorm"
7 "github.com/jinzhu/now"
87
98 "testing"
109 "time"
1817 DB.First(&user1)
1918 DB.Order("id").Limit(1).Find(&user2)
2019
21 DB.Last(&user3)
20 ptrOfUser3 := &user3
21 DB.Last(&ptrOfUser3)
2222 DB.Order("id desc").Limit(1).Find(&user4)
2323 if user1.Id != user2.Id || user3.Id != user4.Id {
2424 t.Errorf("First and Last should by order by primary key")
3030 t.Errorf("Find first record as slice")
3131 }
3232
33 if DB.Joins("left join emails on emails.user_id = users.id").First(&User{}).Error != nil {
33 var user User
34 if DB.Joins("left join emails on emails.user_id = users.id").First(&user).Error != nil {
3435 t.Errorf("Should not raise any error when order with Join table")
36 }
37
38 if user.Email != "" {
39 t.Errorf("User's Email should be blank as no one set it")
3540 }
3641 }
3742
4752 DB.Order("counter desc").Limit(1).Find(&animal4)
4853 if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter {
4954 t.Errorf("First and Last should work correctly")
55 }
56 }
57
58 func TestFirstAndLastWithRaw(t *testing.T) {
59 user1 := User{Name: "user", Emails: []Email{{Email: "user1@example.com"}}}
60 user2 := User{Name: "user", Emails: []Email{{Email: "user2@example.com"}}}
61 DB.Save(&user1)
62 DB.Save(&user2)
63
64 var user3, user4 User
65 DB.Raw("select * from users WHERE name = ?", "user").First(&user3)
66 if user3.Id != user1.Id {
67 t.Errorf("Find first record with raw")
68 }
69
70 DB.Raw("select * from users WHERE name = ?", "user").Last(&user4)
71 if user4.Id != user2.Id {
72 t.Errorf("Find last record with raw")
5073 }
5174 }
5275
6386 }
6487 }
6588
89 func TestCustomizedTypePrimaryKey(t *testing.T) {
90 type ID uint
91 type CustomizedTypePrimaryKey struct {
92 ID ID
93 Name string
94 }
95
96 DB.AutoMigrate(&CustomizedTypePrimaryKey{})
97
98 p1 := CustomizedTypePrimaryKey{Name: "p1"}
99 p2 := CustomizedTypePrimaryKey{Name: "p2"}
100 p3 := CustomizedTypePrimaryKey{Name: "p3"}
101 DB.Create(&p1)
102 DB.Create(&p2)
103 DB.Create(&p3)
104
105 var p CustomizedTypePrimaryKey
106
107 if err := DB.First(&p, p2.ID).Error; err == nil {
108 t.Errorf("Should return error for invalid query condition")
109 }
110
111 if err := DB.First(&p, "id = ?", p2.ID).Error; err != nil {
112 t.Errorf("No error should happen when querying with customized type for primary key, got err %v", err)
113 }
114
115 if p.Name != "p2" {
116 t.Errorf("Should find correct value when querying with customized type for primary key")
117 }
118 }
119
120 func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) {
121 type AddressByZipCode struct {
122 ZipCode string `gorm:"primary_key"`
123 Address string
124 }
125
126 DB.AutoMigrate(&AddressByZipCode{})
127 DB.Create(&AddressByZipCode{ZipCode: "00501", Address: "Holtsville"})
128
129 var address AddressByZipCode
130 DB.First(&address, "00501")
131 if address.ZipCode != "00501" {
132 t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode)
133 }
134 }
135
66136 func TestFindAsSliceOfPointers(t *testing.T) {
67137 DB.Save(&User{Name: "user"})
68138
78148 }
79149
80150 func TestSearchWithPlainSQL(t *testing.T) {
81 user1 := User{Name: "PlainSqlUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
82 user2 := User{Name: "PlainSqlUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
83 user3 := User{Name: "PlainSqlUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
151 user1 := User{Name: "PlainSqlUser1", Age: 1, Birthday: parseTime("2000-1-1")}
152 user2 := User{Name: "PlainSqlUser2", Age: 10, Birthday: parseTime("2010-1-1")}
153 user3 := User{Name: "PlainSqlUser3", Age: 20, Birthday: parseTime("2020-1-1")}
84154 DB.Save(&user1).Save(&user2).Save(&user3)
85155 scopedb := DB.Where("name LIKE ?", "%PlainSqlUser%")
86156
108178 t.Errorf("Should found 2 users age != 20, but got %v", len(users))
109179 }
110180
111 scopedb.Where("birthday > ?", now.MustParse("2000-1-1")).Find(&users)
181 scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users)
112182 if len(users) != 2 {
113183 t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users))
114184 }
138208 t.Errorf("Should found 1 users, but got %v", len(users))
139209 }
140210
211 if err := DB.Where("id IN (?)", []string{}).Find(&users).Error; err != nil {
212 t.Error("no error should happen when query with empty slice, but got: ", err)
213 }
214
215 if err := DB.Not("id IN (?)", []string{}).Find(&users).Error; err != nil {
216 t.Error("no error should happen when query with empty slice, but got: ", err)
217 }
218
141219 if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() {
142220 t.Errorf("Should not get RecordNotFound error when looking for none existing records")
143221 }
144222 }
145223
224 func TestSearchWithTwoDimensionalArray(t *testing.T) {
225 var users []User
226 user1 := User{Name: "2DSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")}
227 user2 := User{Name: "2DSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")}
228 user3 := User{Name: "2DSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")}
229 DB.Create(&user1)
230 DB.Create(&user2)
231 DB.Create(&user3)
232
233 if dialect := DB.Dialect().GetName(); dialect == "mysql" || dialect == "postgres" {
234 if err := DB.Where("(name, age) IN (?)", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil {
235 t.Errorf("No error should happen when query with 2D array, but got %v", err)
236
237 if len(users) != 2 {
238 t.Errorf("Should find 2 users with 2D array, but got %v", len(users))
239 }
240 }
241 }
242
243 if dialect := DB.Dialect().GetName(); dialect == "mssql" {
244 if err := DB.Joins("JOIN (VALUES ?) AS x (col1, col2) ON x.col1 = name AND x.col2 = age", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil {
245 t.Errorf("No error should happen when query with 2D array, but got %v", err)
246
247 if len(users) != 2 {
248 t.Errorf("Should find 2 users with 2D array, but got %v", len(users))
249 }
250 }
251 }
252 }
253
146254 func TestSearchWithStruct(t *testing.T) {
147 user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
148 user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
149 user3 := User{Name: "StructSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
255 user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")}
256 user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")}
257 user3 := User{Name: "StructSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")}
150258 DB.Save(&user1).Save(&user2).Save(&user3)
151259
152260 if DB.Where(user1.Id).First(&User{}).RecordNotFound() {
174282 }
175283
176284 DB.First(&user, User{Name: user1.Name})
177 if user.Id == 0 || user.Name != user.Name {
285 if user.Id == 0 || user.Name != user1.Name {
178286 t.Errorf("Search first record with inline struct")
179287 }
180288
190298 }
191299
192300 func TestSearchWithMap(t *testing.T) {
193 user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
194 user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
195 user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
196 DB.Save(&user1).Save(&user2).Save(&user3)
301 companyID := 1
302 user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")}
303 user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")}
304 user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")}
305 user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: parseTime("2020-1-1"), CompanyID: &companyID}
306 DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4)
197307
198308 var user User
199309 DB.First(&user, map[string]interface{}{"name": user1.Name})
217327 if len(users) != 1 {
218328 t.Errorf("Search all records with inline map")
219329 }
330
331 DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": nil})
332 if len(users) != 0 {
333 t.Errorf("Search all records with inline map containing null value finding 0 records")
334 }
335
336 DB.Find(&users, map[string]interface{}{"name": user1.Name, "company_id": nil})
337 if len(users) != 1 {
338 t.Errorf("Search all records with inline map containing null value finding 1 record")
339 }
340
341 DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": companyID})
342 if len(users) != 1 {
343 t.Errorf("Search all records with inline multiple value map")
344 }
220345 }
221346
222347 func TestSearchWithEmptyChain(t *testing.T) {
223 user1 := User{Name: "ChainSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
224 user2 := User{Name: "ChainearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
225 user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
348 user1 := User{Name: "ChainSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")}
349 user2 := User{Name: "ChainearchUser2", Age: 10, Birthday: parseTime("2010-1-1")}
350 user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: parseTime("2020-1-1")}
226351 DB.Save(&user1).Save(&user2).Save(&user3)
227352
228353 if DB.Where("").Where("").First(&User{}).Error != nil {
259384 user3 := User{Name: "OrderPluckUser3", Age: 20}
260385 DB.Save(&user1).Save(&user2).Save(&user3)
261386 scopedb := DB.Model(&User{}).Where("name like ?", "%OrderPluckUser%")
387
388 var user User
389 scopedb.Order(gorm.Expr("case when name = ? then 0 else 1 end", "OrderPluckUser2")).First(&user)
390 if user.Name != "OrderPluckUser2" {
391 t.Errorf("Order with sql expression")
392 }
262393
263394 var ages []int64
264395 scopedb.Order("age desc").Pluck("age", &ages)
289420 t.Errorf("Order with multiple orders")
290421 }
291422
423 var ages6 []int64
424 if err := scopedb.Order("").Pluck("age", &ages6).Error; err != nil {
425 t.Errorf("An empty string as order clause produces invalid queries")
426 }
427
292428 DB.Model(User{}).Select("name, age").Find(&[]User{})
293429 }
294430
313449 DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)})
314450 }
315451 var users1, users2, users3, users4 []User
316 DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
452 DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
317453
318454 if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
319455 t.Errorf("Offset should work")
354490 if count1 != 1 || count2 != 3 {
355491 t.Errorf("Multiple count in chain")
356492 }
493
494 var count3 int
495 if err := DB.Model(&User{}).Where("name in (?)", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil {
496 t.Errorf("Not error should happen, but got %v", err)
497 }
498
499 if count3 != 2 {
500 t.Errorf("Should get correct count, but got %v", count3)
501 }
357502 }
358503
359504 func TestNot(t *testing.T) {
360505 DB.Create(getPreparedUser("user1", "not"))
361506 DB.Create(getPreparedUser("user2", "not"))
362507 DB.Create(getPreparedUser("user3", "not"))
363 DB.Create(getPreparedUser("user4", "not"))
508
509 user4 := getPreparedUser("user4", "not")
510 user4.Company = Company{}
511 DB.Create(user4)
512
364513 DB := DB.Where("role = ?", "not")
365514
366 var users1, users2, users3, users4, users5, users6, users7, users8 []User
515 var users1, users2, users3, users4, users5, users6, users7, users8, users9 []User
367516 if DB.Find(&users1).RowsAffected != 4 {
368517 t.Errorf("should find 4 not users")
369518 }
406555 t.Errorf("Should find all users's name not equal 3")
407556 }
408557
409 DB.Not("name", []string{"user3"}).Find(&users7)
410 if len(users1)-len(users7) != int(name3Count) {
558 DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7)
559 if len(users1)-len(users7) != 2 { // not user3 or user4
560 t.Errorf("Should find all user's name not equal to 3 who do not have company id")
561 }
562
563 DB.Not("name", []string{"user3"}).Find(&users8)
564 if len(users1)-len(users8) != int(name3Count) {
411565 t.Errorf("Should find all users's name not equal 3")
412566 }
413567
414568 var name2Count int64
415569 DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
416 DB.Not("name", []string{"user3", "user2"}).Find(&users8)
417 if len(users1)-len(users8) != (int(name3Count) + int(name2Count)) {
570 DB.Not("name", []string{"user3", "user2"}).Find(&users9)
571 if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) {
418572 t.Errorf("Should find all users's name not equal 3")
419573 }
420574 }
566720 t.Errorf("Should only contains one column")
567721 }
568722 }
723
724 rows.Close()
569725 }
570726
571727 func TestSelectWithArrayInput(t *testing.T) {
579735 }
580736 }
581737
582 func TestCurrentDatabase(t *testing.T) {
583 databaseName := DB.CurrentDatabase()
584 if err := DB.Error; err != nil {
585 t.Errorf("Problem getting current db name: %s", err)
586 }
587 if databaseName == "" {
588 t.Errorf("Current db name returned empty; this should never happen!")
589 }
590 t.Logf("Got current db name: %v", databaseName)
591 }
738 func TestPluckWithSelect(t *testing.T) {
739 var (
740 user = User{Name: "matematik7_pluck_with_select", Age: 25}
741 combinedName = fmt.Sprintf("%v%v", user.Name, user.Age)
742 combineUserAgeSQL = fmt.Sprintf("concat(%v, %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age"))
743 )
744
745 if dialect := DB.Dialect().GetName(); dialect == "sqlite3" {
746 combineUserAgeSQL = fmt.Sprintf("(%v || %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age"))
747 }
748
749 DB.Save(&user)
750
751 selectStr := combineUserAgeSQL + " as user_age"
752 var userAges []string
753 err := DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error
754 if err != nil {
755 t.Error(err)
756 }
757
758 if len(userAges) != 1 || userAges[0] != combinedName {
759 t.Errorf("Should correctly pluck with select, got: %s", userAges)
760 }
761
762 selectStr = combineUserAgeSQL + fmt.Sprintf(" as %v", DB.Dialect().Quote("user_age"))
763 userAges = userAges[:0]
764 err = DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error
765 if err != nil {
766 t.Error(err)
767 }
768
769 if len(userAges) != 1 || userAges[0] != combinedName {
770 t.Errorf("Should correctly pluck with select, got: %s", userAges)
771 }
772 }
0 package gorm_test
1
2 import (
3 "database/sql/driver"
4 "encoding/json"
5 "errors"
6 "testing"
7
8 "github.com/jinzhu/gorm"
9 )
10
11 func TestScannableSlices(t *testing.T) {
12 if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil {
13 t.Errorf("Should create table with slice values correctly: %s", err)
14 }
15
16 r1 := RecordWithSlice{
17 Strings: ExampleStringSlice{"a", "b", "c"},
18 Structs: ExampleStructSlice{
19 {"name1", "value1"},
20 {"name2", "value2"},
21 },
22 }
23
24 if err := DB.Save(&r1).Error; err != nil {
25 t.Errorf("Should save record with slice values")
26 }
27
28 var r2 RecordWithSlice
29
30 if err := DB.Find(&r2).Error; err != nil {
31 t.Errorf("Should fetch record with slice values")
32 }
33
34 if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" {
35 t.Errorf("Should have serialised and deserialised a string array")
36 }
37
38 if len(r2.Structs) != 2 || r2.Structs[0].Name != "name1" || r2.Structs[0].Value != "value1" || r2.Structs[1].Name != "name2" || r2.Structs[1].Value != "value2" {
39 t.Errorf("Should have serialised and deserialised a struct array")
40 }
41 }
42
43 type RecordWithSlice struct {
44 ID uint64
45 Strings ExampleStringSlice `sql:"type:text"`
46 Structs ExampleStructSlice `sql:"type:text"`
47 }
48
49 type ExampleStringSlice []string
50
51 func (l ExampleStringSlice) Value() (driver.Value, error) {
52 bytes, err := json.Marshal(l)
53 return string(bytes), err
54 }
55
56 func (l *ExampleStringSlice) Scan(input interface{}) error {
57 switch value := input.(type) {
58 case string:
59 return json.Unmarshal([]byte(value), l)
60 case []byte:
61 return json.Unmarshal(value, l)
62 default:
63 return errors.New("not supported")
64 }
65 }
66
67 type ExampleStruct struct {
68 Name string
69 Value string
70 }
71
72 type ExampleStructSlice []ExampleStruct
73
74 func (l ExampleStructSlice) Value() (driver.Value, error) {
75 bytes, err := json.Marshal(l)
76 return string(bytes), err
77 }
78
79 func (l *ExampleStructSlice) Scan(input interface{}) error {
80 switch value := input.(type) {
81 case string:
82 return json.Unmarshal([]byte(value), l)
83 case []byte:
84 return json.Unmarshal(value, l)
85 default:
86 return errors.New("not supported")
87 }
88 }
89
90 type ScannerDataType struct {
91 Street string `sql:"TYPE:varchar(24)"`
92 }
93
94 func (ScannerDataType) Value() (driver.Value, error) {
95 return nil, nil
96 }
97
98 func (*ScannerDataType) Scan(input interface{}) error {
99 return nil
100 }
101
102 type ScannerDataTypeTestStruct struct {
103 Field1 int
104 ScannerDataType *ScannerDataType `sql:"TYPE:json"`
105 }
106
107 type ScannerDataType2 struct {
108 Street string `sql:"TYPE:varchar(24)"`
109 }
110
111 func (ScannerDataType2) Value() (driver.Value, error) {
112 return nil, nil
113 }
114
115 func (*ScannerDataType2) Scan(input interface{}) error {
116 return nil
117 }
118
119 type ScannerDataTypeTestStruct2 struct {
120 Field1 int
121 ScannerDataType *ScannerDataType2
122 }
123
124 func TestScannerDataType(t *testing.T) {
125 scope := gorm.Scope{Value: &ScannerDataTypeTestStruct{}}
126 if field, ok := scope.FieldByName("ScannerDataType"); ok {
127 if DB.Dialect().DataTypeOf(field.StructField) != "json" {
128 t.Errorf("data type for scanner is wrong")
129 }
130 }
131
132 scope = gorm.Scope{Value: &ScannerDataTypeTestStruct2{}}
133 if field, ok := scope.FieldByName("ScannerDataType"); ok {
134 if DB.Dialect().DataTypeOf(field.StructField) != "varchar(24)" {
135 t.Errorf("data type for scanner is wrong")
136 }
137 }
138 }
+1218
-300
scope.go less more
00 package gorm
11
22 import (
3 "bytes"
4 "database/sql"
5 "database/sql/driver"
36 "errors"
47 "fmt"
8 "reflect"
59 "regexp"
610 "strings"
711 "time"
8
9 "reflect"
1012 )
1113
14 // Scope contain current operation's information when you perform any operation on the database
1215 type Scope struct {
1316 Search *search
1417 Value interface{}
15 Sql string
16 SqlVars []interface{}
18 SQL string
19 SQLVars []interface{}
1720 db *DB
18 indirectValue *reflect.Value
19 instanceId string
21 instanceID string
2022 primaryKeyField *Field
2123 skipLeft bool
22 fields map[string]*Field
24 fields *[]*Field
2325 selectAttrs *[]string
2426 }
2527
28 // IndirectValue return scope's reflect value's indirect value
2629 func (scope *Scope) IndirectValue() reflect.Value {
27 if scope.indirectValue == nil {
28 value := reflect.Indirect(reflect.ValueOf(scope.Value))
29 if value.Kind() == reflect.Ptr {
30 value = value.Elem()
31 }
32 scope.indirectValue = &value
33 }
34 return *scope.indirectValue
35 }
36
37 func (scope *Scope) NeedPtr() *Scope {
38 reflectKind := reflect.ValueOf(scope.Value).Kind()
39 if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) {
40 err := fmt.Errorf("%v %v\n", fileWithLineNum(), "using unaddressable value")
41 scope.Err(err)
42 fmt.Printf(err.Error())
43 }
44 return scope
30 return indirect(reflect.ValueOf(scope.Value))
4531 }
4632
4733 // New create a new Scope without search information
4834 func (scope *Scope) New(value interface{}) *Scope {
4935 return &Scope{db: scope.NewDB(), Search: &search{}, Value: value}
36 }
37
38 ////////////////////////////////////////////////////////////////////////////////
39 // Scope DB
40 ////////////////////////////////////////////////////////////////////////////////
41
42 // DB return scope's DB connection
43 func (scope *Scope) DB() *DB {
44 return scope.db
5045 }
5146
5247 // NewDB create a new DB without search information
6055 return nil
6156 }
6257
63 func (scope *Scope) DB() *DB {
64 return scope.db
65 }
66
67 // SqlDB return *sql.DB
68 func (scope *Scope) SqlDB() sqlCommon {
58 // SQLDB return *sql.DB
59 func (scope *Scope) SQLDB() SQLCommon {
6960 return scope.db.db
7061 }
7162
72 // SkipLeft skip remaining callbacks
73 func (scope *Scope) SkipLeft() {
74 scope.skipLeft = true
75 }
76
77 // Quote used to quote database column name according to database dialect
63 // Dialect get dialect
64 func (scope *Scope) Dialect() Dialect {
65 return scope.db.parent.dialect
66 }
67
68 // Quote used to quote string to escape them for database
7869 func (scope *Scope) Quote(str string) string {
7970 if strings.Index(str, ".") != -1 {
8071 newStrs := []string{}
8273 newStrs = append(newStrs, scope.Dialect().Quote(str))
8374 }
8475 return strings.Join(newStrs, ".")
85 } else {
86 return scope.Dialect().Quote(str)
87 }
88 }
89
90 func (scope *Scope) QuoteIfPossible(str string) string {
91 if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) {
92 return scope.Quote(str)
93 }
94 return str
95 }
96
97 // Dialect get dialect
98 func (scope *Scope) Dialect() Dialect {
99 return scope.db.parent.dialect
100 }
101
102 // Err write error
76 }
77
78 return scope.Dialect().Quote(str)
79 }
80
81 // Err add error to Scope
10382 func (scope *Scope) Err(err error) error {
10483 if err != nil {
10584 scope.db.AddError(err)
10786 return err
10887 }
10988
89 // HasError check if there are any error
90 func (scope *Scope) HasError() bool {
91 return scope.db.Error != nil
92 }
93
11094 // Log print log message
11195 func (scope *Scope) Log(v ...interface{}) {
11296 scope.db.log(v...)
11397 }
11498
115 // HasError check if there are any error
116 func (scope *Scope) HasError() bool {
117 return scope.db.Error != nil
118 }
119
120 func (scope *Scope) PrimaryFields() []*Field {
121 var fields = []*Field{}
122 for _, field := range scope.GetModelStruct().PrimaryFields {
123 fields = append(fields, scope.Fields()[field.DBName])
99 // SkipLeft skip remaining callbacks
100 func (scope *Scope) SkipLeft() {
101 scope.skipLeft = true
102 }
103
104 // Fields get value's fields
105 func (scope *Scope) Fields() []*Field {
106 if scope.fields == nil {
107 var (
108 fields []*Field
109 indirectScopeValue = scope.IndirectValue()
110 isStruct = indirectScopeValue.Kind() == reflect.Struct
111 )
112
113 for _, structField := range scope.GetModelStruct().StructFields {
114 if isStruct {
115 fieldValue := indirectScopeValue
116 for _, name := range structField.Names {
117 if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
118 fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
119 }
120 fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
121 }
122 fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)})
123 } else {
124 fields = append(fields, &Field{StructField: structField, IsBlank: true})
125 }
126 }
127 scope.fields = &fields
128 }
129
130 return *scope.fields
131 }
132
133 // FieldByName find `gorm.Field` with field name or db name
134 func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
135 var (
136 dbName = ToDBName(name)
137 mostMatchedField *Field
138 )
139
140 for _, field := range scope.Fields() {
141 if field.Name == name || field.DBName == name {
142 return field, true
143 }
144 if field.DBName == dbName {
145 mostMatchedField = field
146 }
147 }
148 return mostMatchedField, mostMatchedField != nil
149 }
150
151 // PrimaryFields return scope's primary fields
152 func (scope *Scope) PrimaryFields() (fields []*Field) {
153 for _, field := range scope.Fields() {
154 if field.IsPrimaryKey {
155 fields = append(fields, field)
156 }
124157 }
125158 return fields
126159 }
127160
161 // PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one
128162 func (scope *Scope) PrimaryField() *Field {
129163 if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
130164 if len(primaryFields) > 1 {
131 if field, ok := scope.Fields()["id"]; ok {
165 if field, ok := scope.FieldByName("id"); ok {
132166 return field
133167 }
134168 }
135 return scope.Fields()[primaryFields[0].DBName]
169 return scope.PrimaryFields()[0]
136170 }
137171 return nil
138172 }
139173
140 // PrimaryKey get the primary key's column name
174 // PrimaryKey get main primary field's db name
141175 func (scope *Scope) PrimaryKey() string {
142176 if field := scope.PrimaryField(); field != nil {
143177 return field.DBName
145179 return ""
146180 }
147181
148 // PrimaryKeyZero check the primary key is blank or not
182 // PrimaryKeyZero check main primary field's value is blank or not
149183 func (scope *Scope) PrimaryKeyZero() bool {
150184 field := scope.PrimaryField()
151185 return field == nil || field.IsBlank
169203 return false
170204 }
171205
172 // SetColumn to set the column's value
206 // SetColumn to set the column's value, column could be field or field's name/dbname
173207 func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
208 var updateAttrs = map[string]interface{}{}
209 if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
210 updateAttrs = attrs.(map[string]interface{})
211 defer scope.InstanceSet("gorm:update_attrs", updateAttrs)
212 }
213
174214 if field, ok := column.(*Field); ok {
215 updateAttrs[field.DBName] = value
175216 return field.Set(value)
176217 } else if name, ok := column.(string); ok {
177
178 if field, ok := scope.Fields()[name]; ok {
179 return field.Set(value)
180 }
181
182 dbName := ToDBName(name)
183 if field, ok := scope.Fields()[dbName]; ok {
184 return field.Set(value)
185 }
186
187 if field, ok := scope.FieldByName(name); ok {
188 return field.Set(value)
218 var (
219 dbName = ToDBName(name)
220 mostMatchedField *Field
221 )
222 for _, field := range scope.Fields() {
223 if field.DBName == value {
224 updateAttrs[field.DBName] = value
225 return field.Set(value)
226 }
227 if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) {
228 mostMatchedField = field
229 }
230 }
231
232 if mostMatchedField != nil {
233 updateAttrs[mostMatchedField.DBName] = value
234 return mostMatchedField.Set(value)
189235 }
190236 }
191237 return errors.New("could not convert column to field")
192238 }
193239
194 func (scope *Scope) CallMethod(name string, checkError bool) {
195 if scope.Value == nil || (checkError && scope.HasError()) {
240 // CallMethod call scope value's method, if it is a slice, will call its element's method one by one
241 func (scope *Scope) CallMethod(methodName string) {
242 if scope.Value == nil {
196243 return
197244 }
198245
199 call := func(value interface{}) {
200 if fm := reflect.ValueOf(value).MethodByName(name); fm.IsValid() {
201 switch f := fm.Interface().(type) {
202 case func():
203 f()
204 case func(s *Scope):
205 f(scope)
206 case func(s *DB):
207 newDB := scope.NewDB()
208 f(newDB)
209 scope.Err(newDB.Error)
210 case func() error:
211 scope.Err(f())
212 case func(s *Scope) error:
213 scope.Err(f(scope))
214 case func(s *DB) error:
215 newDB := scope.NewDB()
216 scope.Err(f(newDB))
217 scope.Err(newDB.Error)
218 default:
219 scope.Err(fmt.Errorf("unsupported function %v", name))
220 }
221 }
222 }
223
224 if values := scope.IndirectValue(); values.Kind() == reflect.Slice {
225 for i := 0; i < values.Len(); i++ {
226 value := values.Index(i).Addr().Interface()
227 if values.Index(i).Kind() == reflect.Ptr {
228 value = values.Index(i).Interface()
229 }
230 call(value)
246 if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice {
247 for i := 0; i < indirectScopeValue.Len(); i++ {
248 scope.callMethod(methodName, indirectScopeValue.Index(i))
231249 }
232250 } else {
233 if scope.IndirectValue().CanAddr() {
234 call(scope.IndirectValue().Addr().Interface())
235 } else {
236 call(scope.IndirectValue().Interface())
237 }
238 }
239 }
240
241 func (scope *Scope) CallMethodWithErrorCheck(name string) {
242 scope.CallMethod(name, true)
243 }
244
245 // AddToVars add value as sql's vars, gorm will escape them
251 scope.callMethod(methodName, indirectScopeValue)
252 }
253 }
254
255 // AddToVars add value as sql's vars, used to prevent SQL injection
246256 func (scope *Scope) AddToVars(value interface{}) string {
257 _, skipBindVar := scope.InstanceGet("skip_bindvar")
258
247259 if expr, ok := value.(*expr); ok {
248260 exp := expr.expr
249261 for _, arg := range expr.args {
250 exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
262 if skipBindVar {
263 scope.AddToVars(arg)
264 } else {
265 exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
266 }
251267 }
252268 return exp
253 } else {
254 scope.SqlVars = append(scope.SqlVars, value)
255 return scope.Dialect().BinVar(len(scope.SqlVars))
256 }
257 }
258
259 type tabler interface {
260 TableName() string
261 }
262
263 type dbTabler interface {
264 TableName(*DB) string
265 }
266
267 // TableName get table name
268 func (scope *Scope) TableName() string {
269 if scope.Search != nil && len(scope.Search.tableName) > 0 {
270 return scope.Search.tableName
271 }
272
273 if tabler, ok := scope.Value.(tabler); ok {
274 return tabler.TableName()
275 }
276
277 if tabler, ok := scope.Value.(dbTabler); ok {
278 return tabler.TableName(scope.db)
279 }
280
281 return scope.GetModelStruct().TableName(scope.db.Model(scope.Value))
282 }
283
284 func (scope *Scope) QuotedTableName() (name string) {
285 if scope.Search != nil && len(scope.Search.tableName) > 0 {
286 if strings.Index(scope.Search.tableName, " ") != -1 {
287 return scope.Search.tableName
288 }
289 return scope.Quote(scope.Search.tableName)
290 } else {
291 return scope.Quote(scope.TableName())
292 }
293 }
294
295 // CombinedConditionSql get combined condition sql
296 func (scope *Scope) CombinedConditionSql() string {
297 return scope.joinsSql() + scope.whereSql() + scope.groupSql() +
298 scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql()
299 }
300
301 func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
302 for _, field := range scope.Fields() {
303 if field.Name == name || field.DBName == name {
304 return field, true
305 }
306 }
307 return nil, false
308 }
309
310 // Raw set sql
311 func (scope *Scope) Raw(sql string) *Scope {
312 scope.Sql = strings.Replace(sql, "$$", "?", -1)
313 return scope
314 }
315
316 // Exec invoke sql
317 func (scope *Scope) Exec() *Scope {
318 defer scope.Trace(NowFunc())
319
320 if !scope.HasError() {
321 if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
322 if count, err := result.RowsAffected(); scope.Err(err) == nil {
323 scope.db.RowsAffected = count
324 }
325 }
326 }
327 return scope
328 }
329
330 // Set set value by name
331 func (scope *Scope) Set(name string, value interface{}) *Scope {
332 scope.db.InstantSet(name, value)
333 return scope
334 }
335
336 // Get get value by name
337 func (scope *Scope) Get(name string) (interface{}, bool) {
338 return scope.db.Get(name)
339 }
340
341 // InstanceId get InstanceId for scope
342 func (scope *Scope) InstanceId() string {
343 if scope.instanceId == "" {
344 scope.instanceId = fmt.Sprintf("%v%v", &scope, &scope.db)
345 }
346 return scope.instanceId
347 }
348
349 func (scope *Scope) InstanceSet(name string, value interface{}) *Scope {
350 return scope.Set(name+scope.InstanceId(), value)
351 }
352
353 func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
354 return scope.Get(name + scope.InstanceId())
355 }
356
357 // Trace print sql log
358 func (scope *Scope) Trace(t time.Time) {
359 if len(scope.Sql) > 0 {
360 scope.db.slog(scope.Sql, t, scope.SqlVars...)
361 }
362 }
363
364 // Begin start a transaction
365 func (scope *Scope) Begin() *Scope {
366 if db, ok := scope.SqlDB().(sqlDb); ok {
367 if tx, err := db.Begin(); err == nil {
368 scope.db.db = interface{}(tx).(sqlCommon)
369 scope.InstanceSet("gorm:started_transaction", true)
370 }
371 }
372 return scope
373 }
374
375 // CommitOrRollback commit current transaction if there is no error, otherwise rollback it
376 func (scope *Scope) CommitOrRollback() *Scope {
377 if _, ok := scope.InstanceGet("gorm:started_transaction"); ok {
378 if db, ok := scope.db.db.(sqlTx); ok {
379 if scope.HasError() {
380 db.Rollback()
381 } else {
382 db.Commit()
383 }
384 scope.db.db = scope.db.parent.db
385 }
386 }
387 return scope
388 }
389
269 }
270
271 scope.SQLVars = append(scope.SQLVars, value)
272
273 if skipBindVar {
274 return "?"
275 }
276 return scope.Dialect().BindVar(len(scope.SQLVars))
277 }
278
279 // SelectAttrs return selected attributes
390280 func (scope *Scope) SelectAttrs() []string {
391281 if scope.selectAttrs == nil {
392282 attrs := []string{}
406296 return *scope.selectAttrs
407297 }
408298
299 // OmitAttrs return omitted attributes
409300 func (scope *Scope) OmitAttrs() []string {
410301 return scope.Search.omits
411302 }
412303
413 func (scope *Scope) changeableDBColumn(column string) bool {
414 selectAttrs := scope.SelectAttrs()
415 omitAttrs := scope.OmitAttrs()
416
417 if len(selectAttrs) > 0 {
418 for _, attr := range selectAttrs {
419 if column == ToDBName(attr) {
420 return true
421 }
422 }
423 return false
424 }
425
426 for _, attr := range omitAttrs {
427 if column == ToDBName(attr) {
428 return false
429 }
430 }
431 return true
304 type tabler interface {
305 TableName() string
306 }
307
308 type dbTabler interface {
309 TableName(*DB) string
310 }
311
312 // TableName return table name
313 func (scope *Scope) TableName() string {
314 if scope.Search != nil && len(scope.Search.tableName) > 0 {
315 return scope.Search.tableName
316 }
317
318 if tabler, ok := scope.Value.(tabler); ok {
319 return tabler.TableName()
320 }
321
322 if tabler, ok := scope.Value.(dbTabler); ok {
323 return tabler.TableName(scope.db)
324 }
325
326 return scope.GetModelStruct().TableName(scope.db.Model(scope.Value))
327 }
328
329 // QuotedTableName return quoted table name
330 func (scope *Scope) QuotedTableName() (name string) {
331 if scope.Search != nil && len(scope.Search.tableName) > 0 {
332 if strings.Index(scope.Search.tableName, " ") != -1 {
333 return scope.Search.tableName
334 }
335 return scope.Quote(scope.Search.tableName)
336 }
337
338 return scope.Quote(scope.TableName())
339 }
340
341 // CombinedConditionSql return combined condition sql
342 func (scope *Scope) CombinedConditionSql() string {
343 joinSQL := scope.joinsSQL()
344 whereSQL := scope.whereSQL()
345 if scope.Search.raw {
346 whereSQL = strings.TrimSuffix(strings.TrimPrefix(whereSQL, "WHERE ("), ")")
347 }
348 return joinSQL + whereSQL + scope.groupSQL() +
349 scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL()
350 }
351
352 // Raw set raw sql
353 func (scope *Scope) Raw(sql string) *Scope {
354 scope.SQL = strings.Replace(sql, "$$$", "?", -1)
355 return scope
356 }
357
358 // Exec perform generated SQL
359 func (scope *Scope) Exec() *Scope {
360 defer scope.trace(NowFunc())
361
362 if !scope.HasError() {
363 if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
364 if count, err := result.RowsAffected(); scope.Err(err) == nil {
365 scope.db.RowsAffected = count
366 }
367 }
368 }
369 return scope
370 }
371
372 // Set set value by name
373 func (scope *Scope) Set(name string, value interface{}) *Scope {
374 scope.db.InstantSet(name, value)
375 return scope
376 }
377
378 // Get get setting by name
379 func (scope *Scope) Get(name string) (interface{}, bool) {
380 return scope.db.Get(name)
381 }
382
383 // InstanceID get InstanceID for scope
384 func (scope *Scope) InstanceID() string {
385 if scope.instanceID == "" {
386 scope.instanceID = fmt.Sprintf("%v%v", &scope, &scope.db)
387 }
388 return scope.instanceID
389 }
390
391 // InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback
392 func (scope *Scope) InstanceSet(name string, value interface{}) *Scope {
393 return scope.Set(name+scope.InstanceID(), value)
394 }
395
396 // InstanceGet get instance setting from current operation
397 func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
398 return scope.Get(name + scope.InstanceID())
399 }
400
401 // Begin start a transaction
402 func (scope *Scope) Begin() *Scope {
403 if db, ok := scope.SQLDB().(sqlDb); ok {
404 if tx, err := db.Begin(); err == nil {
405 scope.db.db = interface{}(tx).(SQLCommon)
406 scope.InstanceSet("gorm:started_transaction", true)
407 }
408 }
409 return scope
410 }
411
412 // CommitOrRollback commit current transaction if no error happened, otherwise will rollback it
413 func (scope *Scope) CommitOrRollback() *Scope {
414 if _, ok := scope.InstanceGet("gorm:started_transaction"); ok {
415 if db, ok := scope.db.db.(sqlTx); ok {
416 if scope.HasError() {
417 db.Rollback()
418 } else {
419 scope.Err(db.Commit())
420 }
421 scope.db.db = scope.db.parent.db
422 }
423 }
424 return scope
425 }
426
427 ////////////////////////////////////////////////////////////////////////////////
428 // Private Methods For *gorm.Scope
429 ////////////////////////////////////////////////////////////////////////////////
430
431 func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
432 // Only get address from non-pointer
433 if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr {
434 reflectValue = reflectValue.Addr()
435 }
436
437 if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() {
438 switch method := methodValue.Interface().(type) {
439 case func():
440 method()
441 case func(*Scope):
442 method(scope)
443 case func(*DB):
444 newDB := scope.NewDB()
445 method(newDB)
446 scope.Err(newDB.Error)
447 case func() error:
448 scope.Err(method())
449 case func(*Scope) error:
450 scope.Err(method(scope))
451 case func(*DB) error:
452 newDB := scope.NewDB()
453 scope.Err(method(newDB))
454 scope.Err(newDB.Error)
455 default:
456 scope.Err(fmt.Errorf("unsupported function %v", methodName))
457 }
458 }
459 }
460
461 var (
462 columnRegexp = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name`
463 isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number
464 comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ")
465 countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$")
466 )
467
468 func (scope *Scope) quoteIfPossible(str string) string {
469 if columnRegexp.MatchString(str) {
470 return scope.Quote(str)
471 }
472 return str
473 }
474
475 func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
476 var (
477 ignored interface{}
478 values = make([]interface{}, len(columns))
479 selectFields []*Field
480 selectedColumnsMap = map[string]int{}
481 resetFields = map[int]*Field{}
482 )
483
484 for index, column := range columns {
485 values[index] = &ignored
486
487 selectFields = fields
488 if idx, ok := selectedColumnsMap[column]; ok {
489 selectFields = selectFields[idx+1:]
490 }
491
492 for fieldIndex, field := range selectFields {
493 if field.DBName == column {
494 if field.Field.Kind() == reflect.Ptr {
495 values[index] = field.Field.Addr().Interface()
496 } else {
497 reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
498 reflectValue.Elem().Set(field.Field.Addr())
499 values[index] = reflectValue.Interface()
500 resetFields[index] = field
501 }
502
503 selectedColumnsMap[column] = fieldIndex
504
505 if field.IsNormal {
506 break
507 }
508 }
509 }
510 }
511
512 scope.Err(rows.Scan(values...))
513
514 for index, field := range resetFields {
515 if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
516 field.Field.Set(v)
517 }
518 }
519 }
520
521 func (scope *Scope) primaryCondition(value interface{}) string {
522 return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value)
523 }
524
525 func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) {
526 var (
527 quotedTableName = scope.QuotedTableName()
528 quotedPrimaryKey = scope.Quote(scope.PrimaryKey())
529 equalSQL = "="
530 inSQL = "IN"
531 )
532
533 // If building not conditions
534 if !include {
535 equalSQL = "<>"
536 inSQL = "NOT IN"
537 }
538
539 switch value := clause["query"].(type) {
540 case sql.NullInt64:
541 return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64)
542 case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
543 return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value)
544 case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
545 if !include && reflect.ValueOf(value).Len() == 0 {
546 return
547 }
548 str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL)
549 clause["args"] = []interface{}{value}
550 case string:
551 if isNumberRegexp.MatchString(value) {
552 return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value))
553 }
554
555 if value != "" {
556 if !include {
557 if comparisonRegexp.MatchString(value) {
558 str = fmt.Sprintf("NOT (%v)", value)
559 } else {
560 str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value))
561 }
562 } else {
563 str = fmt.Sprintf("(%v)", value)
564 }
565 }
566 case map[string]interface{}:
567 var sqls []string
568 for key, value := range value {
569 if value != nil {
570 sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value)))
571 } else {
572 if !include {
573 sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key)))
574 } else {
575 sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key)))
576 }
577 }
578 }
579 return strings.Join(sqls, " AND ")
580 case interface{}:
581 var sqls []string
582 newScope := scope.New(value)
583
584 if len(newScope.Fields()) == 0 {
585 scope.Err(fmt.Errorf("invalid query condition: %v", value))
586 return
587 }
588
589 for _, field := range newScope.Fields() {
590 if !field.IsIgnored && !field.IsBlank {
591 sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
592 }
593 }
594 return strings.Join(sqls, " AND ")
595 default:
596 scope.Err(fmt.Errorf("invalid query condition: %v", value))
597 return
598 }
599
600 replacements := []string{}
601 args := clause["args"].([]interface{})
602 for _, arg := range args {
603 var err error
604 switch reflect.ValueOf(arg).Kind() {
605 case reflect.Slice: // For where("id in (?)", []int64{1,2})
606 if scanner, ok := interface{}(arg).(driver.Valuer); ok {
607 arg, err = scanner.Value()
608 replacements = append(replacements, scope.AddToVars(arg))
609 } else if b, ok := arg.([]byte); ok {
610 replacements = append(replacements, scope.AddToVars(b))
611 } else if as, ok := arg.([][]interface{}); ok {
612 var tempMarks []string
613 for _, a := range as {
614 var arrayMarks []string
615 for _, v := range a {
616 arrayMarks = append(arrayMarks, scope.AddToVars(v))
617 }
618
619 if len(arrayMarks) > 0 {
620 tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ",")))
621 }
622 }
623
624 if len(tempMarks) > 0 {
625 replacements = append(replacements, strings.Join(tempMarks, ","))
626 }
627 } else if values := reflect.ValueOf(arg); values.Len() > 0 {
628 var tempMarks []string
629 for i := 0; i < values.Len(); i++ {
630 tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
631 }
632 replacements = append(replacements, strings.Join(tempMarks, ","))
633 } else {
634 replacements = append(replacements, scope.AddToVars(Expr("NULL")))
635 }
636 default:
637 if valuer, ok := interface{}(arg).(driver.Valuer); ok {
638 arg, err = valuer.Value()
639 }
640
641 replacements = append(replacements, scope.AddToVars(arg))
642 }
643
644 if err != nil {
645 scope.Err(err)
646 }
647 }
648
649 buff := bytes.NewBuffer([]byte{})
650 i := 0
651 for _, s := range str {
652 if s == '?' {
653 buff.WriteString(replacements[i])
654 i++
655 } else {
656 buff.WriteRune(s)
657 }
658 }
659
660 str = buff.String()
661
662 return
663 }
664
665 func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) {
666 switch value := clause["query"].(type) {
667 case string:
668 str = value
669 case []string:
670 str = strings.Join(value, ", ")
671 }
672
673 args := clause["args"].([]interface{})
674 replacements := []string{}
675 for _, arg := range args {
676 switch reflect.ValueOf(arg).Kind() {
677 case reflect.Slice:
678 values := reflect.ValueOf(arg)
679 var tempMarks []string
680 for i := 0; i < values.Len(); i++ {
681 tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
682 }
683 replacements = append(replacements, strings.Join(tempMarks, ","))
684 default:
685 if valuer, ok := interface{}(arg).(driver.Valuer); ok {
686 arg, _ = valuer.Value()
687 }
688 replacements = append(replacements, scope.AddToVars(arg))
689 }
690 }
691
692 buff := bytes.NewBuffer([]byte{})
693 i := 0
694 for pos := range str {
695 if str[pos] == '?' {
696 buff.WriteString(replacements[i])
697 i++
698 } else {
699 buff.WriteByte(str[pos])
700 }
701 }
702
703 str = buff.String()
704
705 return
706 }
707
708 func (scope *Scope) whereSQL() (sql string) {
709 var (
710 quotedTableName = scope.QuotedTableName()
711 deletedAtField, hasDeletedAtField = scope.FieldByName("DeletedAt")
712 primaryConditions, andConditions, orConditions []string
713 )
714
715 if !scope.Search.Unscoped && hasDeletedAtField {
716 sql := fmt.Sprintf("%v.%v IS NULL", quotedTableName, scope.Quote(deletedAtField.DBName))
717 primaryConditions = append(primaryConditions, sql)
718 }
719
720 if !scope.PrimaryKeyZero() {
721 for _, field := range scope.PrimaryFields() {
722 sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
723 primaryConditions = append(primaryConditions, sql)
724 }
725 }
726
727 for _, clause := range scope.Search.whereConditions {
728 if sql := scope.buildCondition(clause, true); sql != "" {
729 andConditions = append(andConditions, sql)
730 }
731 }
732
733 for _, clause := range scope.Search.orConditions {
734 if sql := scope.buildCondition(clause, true); sql != "" {
735 orConditions = append(orConditions, sql)
736 }
737 }
738
739 for _, clause := range scope.Search.notConditions {
740 if sql := scope.buildCondition(clause, false); sql != "" {
741 andConditions = append(andConditions, sql)
742 }
743 }
744
745 orSQL := strings.Join(orConditions, " OR ")
746 combinedSQL := strings.Join(andConditions, " AND ")
747 if len(combinedSQL) > 0 {
748 if len(orSQL) > 0 {
749 combinedSQL = combinedSQL + " OR " + orSQL
750 }
751 } else {
752 combinedSQL = orSQL
753 }
754
755 if len(primaryConditions) > 0 {
756 sql = "WHERE " + strings.Join(primaryConditions, " AND ")
757 if len(combinedSQL) > 0 {
758 sql = sql + " AND (" + combinedSQL + ")"
759 }
760 } else if len(combinedSQL) > 0 {
761 sql = "WHERE " + combinedSQL
762 }
763 return
764 }
765
766 func (scope *Scope) selectSQL() string {
767 if len(scope.Search.selects) == 0 {
768 if len(scope.Search.joinConditions) > 0 {
769 return fmt.Sprintf("%v.*", scope.QuotedTableName())
770 }
771 return "*"
772 }
773 return scope.buildSelectQuery(scope.Search.selects)
774 }
775
776 func (scope *Scope) orderSQL() string {
777 if len(scope.Search.orders) == 0 || scope.Search.ignoreOrderQuery {
778 return ""
779 }
780
781 var orders []string
782 for _, order := range scope.Search.orders {
783 if str, ok := order.(string); ok {
784 orders = append(orders, scope.quoteIfPossible(str))
785 } else if expr, ok := order.(*expr); ok {
786 exp := expr.expr
787 for _, arg := range expr.args {
788 exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
789 }
790 orders = append(orders, exp)
791 }
792 }
793 return " ORDER BY " + strings.Join(orders, ",")
794 }
795
796 func (scope *Scope) limitAndOffsetSQL() string {
797 return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
798 }
799
800 func (scope *Scope) groupSQL() string {
801 if len(scope.Search.group) == 0 {
802 return ""
803 }
804 return " GROUP BY " + scope.Search.group
805 }
806
807 func (scope *Scope) havingSQL() string {
808 if len(scope.Search.havingConditions) == 0 {
809 return ""
810 }
811
812 var andConditions []string
813 for _, clause := range scope.Search.havingConditions {
814 if sql := scope.buildCondition(clause, true); sql != "" {
815 andConditions = append(andConditions, sql)
816 }
817 }
818
819 combinedSQL := strings.Join(andConditions, " AND ")
820 if len(combinedSQL) == 0 {
821 return ""
822 }
823
824 return " HAVING " + combinedSQL
825 }
826
827 func (scope *Scope) joinsSQL() string {
828 var joinConditions []string
829 for _, clause := range scope.Search.joinConditions {
830 if sql := scope.buildCondition(clause, true); sql != "" {
831 joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))
832 }
833 }
834
835 return strings.Join(joinConditions, " ") + " "
836 }
837
838 func (scope *Scope) prepareQuerySQL() {
839 if scope.Search.raw {
840 scope.Raw(scope.CombinedConditionSql())
841 } else {
842 scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
843 }
844 return
845 }
846
847 func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
848 if len(values) > 0 {
849 scope.Search.Where(values[0], values[1:]...)
850 }
851 return scope
852 }
853
854 func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
855 for _, f := range funcs {
856 (*f)(scope)
857 if scope.skipLeft {
858 break
859 }
860 }
861 return scope
862 }
863
864 func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string]interface{} {
865 var attrs = map[string]interface{}{}
866
867 switch value := values.(type) {
868 case map[string]interface{}:
869 return value
870 case []interface{}:
871 for _, v := range value {
872 for key, value := range convertInterfaceToMap(v, withIgnoredField) {
873 attrs[key] = value
874 }
875 }
876 case interface{}:
877 reflectValue := reflect.ValueOf(values)
878
879 switch reflectValue.Kind() {
880 case reflect.Map:
881 for _, key := range reflectValue.MapKeys() {
882 attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
883 }
884 default:
885 for _, field := range (&Scope{Value: values}).Fields() {
886 if !field.IsBlank && (withIgnoredField || !field.IsIgnored) {
887 attrs[field.DBName] = field.Field.Interface()
888 }
889 }
890 }
891 }
892 return attrs
893 }
894
895 func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) {
896 if scope.IndirectValue().Kind() != reflect.Struct {
897 return convertInterfaceToMap(value, false), true
898 }
899
900 results = map[string]interface{}{}
901
902 for key, value := range convertInterfaceToMap(value, true) {
903 if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
904 if _, ok := value.(*expr); ok {
905 hasUpdate = true
906 results[field.DBName] = value
907 } else {
908 err := field.Set(value)
909 if field.IsNormal {
910 hasUpdate = true
911 if err == ErrUnaddressable {
912 results[field.DBName] = value
913 } else {
914 results[field.DBName] = field.Field.Interface()
915 }
916 }
917 }
918 }
919 }
920 return
921 }
922
923 func (scope *Scope) row() *sql.Row {
924 defer scope.trace(NowFunc())
925
926 result := &RowQueryResult{}
927 scope.InstanceSet("row_query_result", result)
928 scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
929
930 return result.Row
931 }
932
933 func (scope *Scope) rows() (*sql.Rows, error) {
934 defer scope.trace(NowFunc())
935
936 result := &RowsQueryResult{}
937 scope.InstanceSet("row_query_result", result)
938 scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
939
940 return result.Rows, result.Error
941 }
942
943 func (scope *Scope) initialize() *Scope {
944 for _, clause := range scope.Search.whereConditions {
945 scope.updatedAttrsWithValues(clause["query"])
946 }
947 scope.updatedAttrsWithValues(scope.Search.initAttrs)
948 scope.updatedAttrsWithValues(scope.Search.assignAttrs)
949 return scope
950 }
951
952 func (scope *Scope) isQueryForColumn(query interface{}, column string) bool {
953 queryStr := strings.ToLower(fmt.Sprint(query))
954 if queryStr == column {
955 return true
956 }
957
958 if strings.HasSuffix(queryStr, "as "+column) {
959 return true
960 }
961
962 if strings.HasSuffix(queryStr, "as "+scope.Quote(column)) {
963 return true
964 }
965
966 return false
967 }
968
969 func (scope *Scope) pluck(column string, value interface{}) *Scope {
970 dest := reflect.Indirect(reflect.ValueOf(value))
971 if dest.Kind() != reflect.Slice {
972 scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
973 return scope
974 }
975
976 if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) {
977 scope.Search.Select(column)
978 }
979
980 rows, err := scope.rows()
981 if scope.Err(err) == nil {
982 defer rows.Close()
983 for rows.Next() {
984 elem := reflect.New(dest.Type().Elem()).Interface()
985 scope.Err(rows.Scan(elem))
986 dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem()))
987 }
988
989 if err := rows.Err(); err != nil {
990 scope.Err(err)
991 }
992 }
993 return scope
994 }
995
996 func (scope *Scope) count(value interface{}) *Scope {
997 if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) {
998 if len(scope.Search.group) != 0 {
999 scope.Search.Select("count(*) FROM ( SELECT count(*) as name ")
1000 scope.Search.group += " ) AS count_table"
1001 } else {
1002 scope.Search.Select("count(*)")
1003 }
1004 }
1005 scope.Search.ignoreOrderQuery = true
1006 scope.Err(scope.row().Scan(value))
1007 return scope
1008 }
1009
1010 func (scope *Scope) typeName() string {
1011 typ := scope.IndirectValue().Type()
1012
1013 for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr {
1014 typ = typ.Elem()
1015 }
1016
1017 return typ.Name()
1018 }
1019
1020 // trace print sql log
1021 func (scope *Scope) trace(t time.Time) {
1022 if len(scope.SQL) > 0 {
1023 scope.db.slog(scope.SQL, t, scope.SQLVars...)
1024 }
4321025 }
4331026
4341027 func (scope *Scope) changeableField(field *Field) bool {
435 selectAttrs := scope.SelectAttrs()
436 omitAttrs := scope.OmitAttrs()
437
438 if len(selectAttrs) > 0 {
1028 if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 {
4391029 for _, attr := range selectAttrs {
4401030 if field.Name == attr || field.DBName == attr {
4411031 return true
4441034 return false
4451035 }
4461036
447 for _, attr := range omitAttrs {
1037 for _, attr := range scope.OmitAttrs() {
4481038 if field.Name == attr || field.DBName == attr {
4491039 return false
4501040 }
4511041 }
4521042
453 return !field.IsIgnored
454 }
455
456 func (scope *Scope) shouldSaveAssociations() bool {
457 saveAssociations, ok := scope.Get("gorm:save_associations")
458 if ok && !saveAssociations.(bool) {
459 return false
460 }
461 return true && !scope.HasError()
462 }
1043 return true
1044 }
1045
1046 func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
1047 toScope := scope.db.NewScope(value)
1048 tx := scope.db.Set("gorm:association:source", scope.Value)
1049
1050 for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
1051 fromField, _ := scope.FieldByName(foreignKey)
1052 toField, _ := toScope.FieldByName(foreignKey)
1053
1054 if fromField != nil {
1055 if relationship := fromField.Relationship; relationship != nil {
1056 if relationship.Kind == "many_to_many" {
1057 joinTableHandler := relationship.JoinTableHandler
1058 scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Find(value).Error)
1059 } else if relationship.Kind == "belongs_to" {
1060 for idx, foreignKey := range relationship.ForeignDBNames {
1061 if field, ok := scope.FieldByName(foreignKey); ok {
1062 tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface())
1063 }
1064 }
1065 scope.Err(tx.Find(value).Error)
1066 } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
1067 for idx, foreignKey := range relationship.ForeignDBNames {
1068 if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
1069 tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
1070 }
1071 }
1072
1073 if relationship.PolymorphicType != "" {
1074 tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue)
1075 }
1076 scope.Err(tx.Find(value).Error)
1077 }
1078 } else {
1079 sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
1080 scope.Err(tx.Where(sql, fromField.Field.Interface()).Find(value).Error)
1081 }
1082 return scope
1083 } else if toField != nil {
1084 sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
1085 scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Find(value).Error)
1086 return scope
1087 }
1088 }
1089
1090 scope.Err(fmt.Errorf("invalid association %v", foreignKeys))
1091 return scope
1092 }
1093
1094 // getTableOptions return the table options string or an empty string if the table options does not exist
1095 func (scope *Scope) getTableOptions() string {
1096 tableOptions, ok := scope.Get("gorm:table_options")
1097 if !ok {
1098 return ""
1099 }
1100 return " " + tableOptions.(string)
1101 }
1102
1103 func (scope *Scope) createJoinTable(field *StructField) {
1104 if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
1105 joinTableHandler := relationship.JoinTableHandler
1106 joinTable := joinTableHandler.Table(scope.db)
1107 if !scope.Dialect().HasTable(joinTable) {
1108 toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
1109
1110 var sqlTypes, primaryKeys []string
1111 for idx, fieldName := range relationship.ForeignFieldNames {
1112 if field, ok := scope.FieldByName(fieldName); ok {
1113 foreignKeyStruct := field.clone()
1114 foreignKeyStruct.IsPrimaryKey = false
1115 foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
1116 delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT")
1117 sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
1118 primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
1119 }
1120 }
1121
1122 for idx, fieldName := range relationship.AssociationForeignFieldNames {
1123 if field, ok := toScope.FieldByName(fieldName); ok {
1124 foreignKeyStruct := field.clone()
1125 foreignKeyStruct.IsPrimaryKey = false
1126 foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
1127 delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT")
1128 sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
1129 primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
1130 }
1131 }
1132
1133 scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v))%s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error)
1134 }
1135 scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
1136 }
1137 }
1138
1139 func (scope *Scope) createTable() *Scope {
1140 var tags []string
1141 var primaryKeys []string
1142 var primaryKeyInColumnType = false
1143 for _, field := range scope.GetModelStruct().StructFields {
1144 if field.IsNormal {
1145 sqlTag := scope.Dialect().DataTypeOf(field)
1146
1147 // Check if the primary key constraint was specified as
1148 // part of the column type. If so, we can only support
1149 // one column as the primary key.
1150 if strings.Contains(strings.ToLower(sqlTag), "primary key") {
1151 primaryKeyInColumnType = true
1152 }
1153
1154 tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
1155 }
1156
1157 if field.IsPrimaryKey {
1158 primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
1159 }
1160 scope.createJoinTable(field)
1161 }
1162
1163 var primaryKeyStr string
1164 if len(primaryKeys) > 0 && !primaryKeyInColumnType {
1165 primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
1166 }
1167
1168 scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
1169
1170 scope.autoIndex()
1171 return scope
1172 }
1173
1174 func (scope *Scope) dropTable() *Scope {
1175 scope.Raw(fmt.Sprintf("DROP TABLE %v%s", scope.QuotedTableName(), scope.getTableOptions())).Exec()
1176 return scope
1177 }
1178
1179 func (scope *Scope) modifyColumn(column string, typ string) {
1180 scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ))
1181 }
1182
1183 func (scope *Scope) dropColumn(column string) {
1184 scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec()
1185 }
1186
1187 func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
1188 if scope.Dialect().HasIndex(scope.TableName(), indexName) {
1189 return
1190 }
1191
1192 var columns []string
1193 for _, name := range column {
1194 columns = append(columns, scope.quoteIfPossible(name))
1195 }
1196
1197 sqlCreate := "CREATE INDEX"
1198 if unique {
1199 sqlCreate = "CREATE UNIQUE INDEX"
1200 }
1201
1202 scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
1203 }
1204
1205 func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
1206 // Compatible with old generated key
1207 keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
1208
1209 if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
1210 return
1211 }
1212 var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
1213 scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
1214 }
1215
1216 func (scope *Scope) removeForeignKey(field string, dest string) {
1217 keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest)
1218
1219 if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
1220 return
1221 }
1222 var query = `ALTER TABLE %s DROP CONSTRAINT %s;`
1223 scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
1224 }
1225
1226 func (scope *Scope) removeIndex(indexName string) {
1227 scope.Dialect().RemoveIndex(scope.TableName(), indexName)
1228 }
1229
1230 func (scope *Scope) autoMigrate() *Scope {
1231 tableName := scope.TableName()
1232 quotedTableName := scope.QuotedTableName()
1233
1234 if !scope.Dialect().HasTable(tableName) {
1235 scope.createTable()
1236 } else {
1237 for _, field := range scope.GetModelStruct().StructFields {
1238 if !scope.Dialect().HasColumn(tableName, field.DBName) {
1239 if field.IsNormal {
1240 sqlTag := scope.Dialect().DataTypeOf(field)
1241 scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
1242 }
1243 }
1244 scope.createJoinTable(field)
1245 }
1246 scope.autoIndex()
1247 }
1248 return scope
1249 }
1250
1251 func (scope *Scope) autoIndex() *Scope {
1252 var indexes = map[string][]string{}
1253 var uniqueIndexes = map[string][]string{}
1254
1255 for _, field := range scope.GetStructFields() {
1256 if name, ok := field.TagSettings["INDEX"]; ok {
1257 names := strings.Split(name, ",")
1258
1259 for _, name := range names {
1260 if name == "INDEX" || name == "" {
1261 name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName)
1262 }
1263 indexes[name] = append(indexes[name], field.DBName)
1264 }
1265 }
1266
1267 if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok {
1268 names := strings.Split(name, ",")
1269
1270 for _, name := range names {
1271 if name == "UNIQUE_INDEX" || name == "" {
1272 name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName)
1273 }
1274 uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
1275 }
1276 }
1277 }
1278
1279 for name, columns := range indexes {
1280 if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddIndex(name, columns...); db.Error != nil {
1281 scope.db.AddError(db.Error)
1282 }
1283 }
1284
1285 for name, columns := range uniqueIndexes {
1286 if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil {
1287 scope.db.AddError(db.Error)
1288 }
1289 }
1290
1291 return scope
1292 }
1293
1294 func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
1295 for _, value := range values {
1296 indirectValue := indirect(reflect.ValueOf(value))
1297
1298 switch indirectValue.Kind() {
1299 case reflect.Slice:
1300 for i := 0; i < indirectValue.Len(); i++ {
1301 var result []interface{}
1302 var object = indirect(indirectValue.Index(i))
1303 var hasValue = false
1304 for _, column := range columns {
1305 field := object.FieldByName(column)
1306 if hasValue || !isBlank(field) {
1307 hasValue = true
1308 }
1309 result = append(result, field.Interface())
1310 }
1311
1312 if hasValue {
1313 results = append(results, result)
1314 }
1315 }
1316 case reflect.Struct:
1317 var result []interface{}
1318 var hasValue = false
1319 for _, column := range columns {
1320 field := indirectValue.FieldByName(column)
1321 if hasValue || !isBlank(field) {
1322 hasValue = true
1323 }
1324 result = append(result, field.Interface())
1325 }
1326
1327 if hasValue {
1328 results = append(results, result)
1329 }
1330 }
1331 }
1332
1333 return
1334 }
1335
1336 func (scope *Scope) getColumnAsScope(column string) *Scope {
1337 indirectScopeValue := scope.IndirectValue()
1338
1339 switch indirectScopeValue.Kind() {
1340 case reflect.Slice:
1341 if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok {
1342 fieldType := fieldStruct.Type
1343 if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr {
1344 fieldType = fieldType.Elem()
1345 }
1346
1347 resultsMap := map[interface{}]bool{}
1348 results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
1349
1350 for i := 0; i < indirectScopeValue.Len(); i++ {
1351 result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column))
1352
1353 if result.Kind() == reflect.Slice {
1354 for j := 0; j < result.Len(); j++ {
1355 if elem := result.Index(j); elem.CanAddr() && resultsMap[elem.Addr()] != true {
1356 resultsMap[elem.Addr()] = true
1357 results = reflect.Append(results, elem.Addr())
1358 }
1359 }
1360 } else if result.CanAddr() && resultsMap[result.Addr()] != true {
1361 resultsMap[result.Addr()] = true
1362 results = reflect.Append(results, result.Addr())
1363 }
1364 }
1365 return scope.New(results.Interface())
1366 }
1367 case reflect.Struct:
1368 if field := indirectScopeValue.FieldByName(column); field.CanAddr() {
1369 return scope.New(field.Addr().Interface())
1370 }
1371 }
1372 return nil
1373 }
1374
1375 func (scope *Scope) hasConditions() bool {
1376 return !scope.PrimaryKeyZero() ||
1377 len(scope.Search.whereConditions) > 0 ||
1378 len(scope.Search.orConditions) > 0 ||
1379 len(scope.Search.notConditions) > 0
1380 }
+0
-654
scope_private.go less more
0 package gorm
1
2 import (
3 "database/sql"
4 "database/sql/driver"
5 "fmt"
6 "reflect"
7 "regexp"
8 "strconv"
9 "strings"
10 )
11
12 func (scope *Scope) primaryCondition(value interface{}) string {
13 return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value)
14 }
15
16 func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) {
17 switch value := clause["query"].(type) {
18 case string:
19 // if string is number
20 if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
21 id, _ := strconv.Atoi(value)
22 return scope.primaryCondition(scope.AddToVars(id))
23 } else if value != "" {
24 str = fmt.Sprintf("(%v)", value)
25 }
26 case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
27 return scope.primaryCondition(scope.AddToVars(value))
28 case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
29 str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey()))
30 clause["args"] = []interface{}{value}
31 case map[string]interface{}:
32 var sqls []string
33 for key, value := range value {
34 sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value)))
35 }
36 return strings.Join(sqls, " AND ")
37 case interface{}:
38 var sqls []string
39 for _, field := range scope.New(value).Fields() {
40 if !field.IsIgnored && !field.IsBlank {
41 sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
42 }
43 }
44 return strings.Join(sqls, " AND ")
45 }
46
47 args := clause["args"].([]interface{})
48 for _, arg := range args {
49 switch reflect.ValueOf(arg).Kind() {
50 case reflect.Slice: // For where("id in (?)", []int64{1,2})
51 values := reflect.ValueOf(arg)
52 var tempMarks []string
53 for i := 0; i < values.Len(); i++ {
54 tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
55 }
56 str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
57 default:
58 if valuer, ok := interface{}(arg).(driver.Valuer); ok {
59 arg, _ = valuer.Value()
60 }
61
62 str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
63 }
64 }
65 return
66 }
67
68 func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
69 var notEqualSql string
70 var primaryKey = scope.PrimaryKey()
71
72 switch value := clause["query"].(type) {
73 case string:
74 // is number
75 if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
76 id, _ := strconv.Atoi(value)
77 return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
78 } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) {
79 str = fmt.Sprintf(" NOT (%v) ", value)
80 notEqualSql = fmt.Sprintf("NOT (%v)", value)
81 } else {
82 str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
83 notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
84 }
85 case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
86 return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
87 case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
88 if reflect.ValueOf(value).Len() > 0 {
89 str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey))
90 clause["args"] = []interface{}{value}
91 }
92 return ""
93 case map[string]interface{}:
94 var sqls []string
95 for key, value := range value {
96 sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value)))
97 }
98 return strings.Join(sqls, " AND ")
99 case interface{}:
100 var sqls []string
101 for _, field := range scope.New(value).Fields() {
102 if !field.IsBlank {
103 sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
104 }
105 }
106 return strings.Join(sqls, " AND ")
107 }
108
109 args := clause["args"].([]interface{})
110 for _, arg := range args {
111 switch reflect.ValueOf(arg).Kind() {
112 case reflect.Slice: // For where("id in (?)", []int64{1,2})
113 values := reflect.ValueOf(arg)
114 var tempMarks []string
115 for i := 0; i < values.Len(); i++ {
116 tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
117 }
118 str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
119 default:
120 if scanner, ok := interface{}(arg).(driver.Valuer); ok {
121 arg, _ = scanner.Value()
122 }
123 str = strings.Replace(notEqualSql, "?", scope.AddToVars(arg), 1)
124 }
125 }
126 return
127 }
128
129 func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) {
130 switch value := clause["query"].(type) {
131 case string:
132 str = value
133 case []string:
134 str = strings.Join(value, ", ")
135 }
136
137 args := clause["args"].([]interface{})
138 for _, arg := range args {
139 switch reflect.ValueOf(arg).Kind() {
140 case reflect.Slice:
141 values := reflect.ValueOf(arg)
142 var tempMarks []string
143 for i := 0; i < values.Len(); i++ {
144 tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
145 }
146 str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
147 default:
148 if valuer, ok := interface{}(arg).(driver.Valuer); ok {
149 arg, _ = valuer.Value()
150 }
151 str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
152 }
153 }
154 return
155 }
156
157 func (scope *Scope) whereSql() (sql string) {
158 var primaryConditions, andConditions, orConditions []string
159
160 if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil {
161 sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
162 primaryConditions = append(primaryConditions, sql)
163 }
164
165 if !scope.PrimaryKeyZero() {
166 for _, field := range scope.PrimaryFields() {
167 sql := fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
168 primaryConditions = append(primaryConditions, sql)
169 }
170 }
171
172 for _, clause := range scope.Search.whereConditions {
173 if sql := scope.buildWhereCondition(clause); sql != "" {
174 andConditions = append(andConditions, sql)
175 }
176 }
177
178 for _, clause := range scope.Search.orConditions {
179 if sql := scope.buildWhereCondition(clause); sql != "" {
180 orConditions = append(orConditions, sql)
181 }
182 }
183
184 for _, clause := range scope.Search.notConditions {
185 if sql := scope.buildNotCondition(clause); sql != "" {
186 andConditions = append(andConditions, sql)
187 }
188 }
189
190 orSql := strings.Join(orConditions, " OR ")
191 combinedSql := strings.Join(andConditions, " AND ")
192 if len(combinedSql) > 0 {
193 if len(orSql) > 0 {
194 combinedSql = combinedSql + " OR " + orSql
195 }
196 } else {
197 combinedSql = orSql
198 }
199
200 if len(primaryConditions) > 0 {
201 sql = "WHERE " + strings.Join(primaryConditions, " AND ")
202 if len(combinedSql) > 0 {
203 sql = sql + " AND (" + combinedSql + ")"
204 }
205 } else if len(combinedSql) > 0 {
206 sql = "WHERE " + combinedSql
207 }
208 return
209 }
210
211 var hasCountRegexp = regexp.MustCompile(`(?i)count(.+)`)
212
213 func (scope *Scope) selectSql() string {
214 if len(scope.Search.selects) == 0 {
215 if scope.Search.joins != "" {
216 return fmt.Sprintf("%v.*", scope.QuotedTableName())
217 }
218 return "*"
219 }
220 sql := scope.buildSelectQuery(scope.Search.selects)
221 scope.Search.countingQuery = (len(scope.Search.group) == 0) && hasCountRegexp.MatchString(sql)
222 return sql
223 }
224
225 func (scope *Scope) orderSql() string {
226 if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
227 return ""
228 }
229 return " ORDER BY " + strings.Join(scope.Search.orders, ",")
230 }
231
232 func (scope *Scope) limitSql() string {
233 if !scope.Dialect().HasTop() {
234 if len(scope.Search.limit) == 0 {
235 return ""
236 }
237 return " LIMIT " + scope.Search.limit
238 }
239
240 return ""
241 }
242
243 func (scope *Scope) topSql() string {
244 if scope.Dialect().HasTop() && len(scope.Search.offset) == 0 {
245 if len(scope.Search.limit) == 0 {
246 return ""
247 }
248 return " TOP(" + scope.Search.limit + ")"
249 }
250
251 return ""
252 }
253
254 func (scope *Scope) offsetSql() string {
255 if len(scope.Search.offset) == 0 {
256 return ""
257 }
258
259 if scope.Dialect().HasTop() {
260 sql := " OFFSET " + scope.Search.offset + " ROW "
261 if len(scope.Search.limit) > 0 {
262 sql += "FETCH NEXT " + scope.Search.limit + " ROWS ONLY"
263 }
264 return sql
265 }
266 return " OFFSET " + scope.Search.offset
267 }
268
269 func (scope *Scope) groupSql() string {
270 if len(scope.Search.group) == 0 {
271 return ""
272 }
273 return " GROUP BY " + scope.Search.group
274 }
275
276 func (scope *Scope) havingSql() string {
277 if scope.Search.havingConditions == nil {
278 return ""
279 }
280
281 var andConditions []string
282
283 for _, clause := range scope.Search.havingConditions {
284 if sql := scope.buildWhereCondition(clause); sql != "" {
285 andConditions = append(andConditions, sql)
286 }
287 }
288
289 combinedSql := strings.Join(andConditions, " AND ")
290 if len(combinedSql) == 0 {
291 return ""
292 }
293
294 return " HAVING " + combinedSql
295 }
296
297 func (scope *Scope) joinsSql() string {
298 return scope.Search.joins + " "
299 }
300
301 func (scope *Scope) prepareQuerySql() {
302 if scope.Search.raw {
303 scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
304 } else {
305 scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql()))
306 }
307 return
308 }
309
310 func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
311 if len(values) > 0 {
312 scope.Search.Where(values[0], values[1:]...)
313 }
314 return scope
315 }
316
317 func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
318 for _, f := range funcs {
319 (*f)(scope)
320 if scope.skipLeft {
321 break
322 }
323 }
324 return scope
325 }
326
327 func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) {
328 if !scope.IndirectValue().CanAddr() {
329 return values, true
330 }
331
332 var hasExpr bool
333 fields := scope.Fields()
334 for key, value := range values {
335 if field, ok := fields[ToDBName(key)]; ok && field.Field.IsValid() {
336 if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
337 if _, ok := value.(*expr); ok {
338 hasExpr = true
339 } else if !equalAsString(field.Field.Interface(), value) {
340 hasUpdate = true
341 field.Set(value)
342 }
343 }
344 }
345 }
346 if hasExpr {
347 var updateMap = map[string]interface{}{}
348 for key, value := range fields {
349 if v, ok := values[key]; ok {
350 updateMap[key] = v
351 } else {
352 updateMap[key] = value.Field.Interface()
353 }
354 }
355 return updateMap, true
356 }
357 return
358 }
359
360 func (scope *Scope) row() *sql.Row {
361 defer scope.Trace(NowFunc())
362 scope.callCallbacks(scope.db.parent.callback.rowQueries)
363 scope.prepareQuerySql()
364 return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...)
365 }
366
367 func (scope *Scope) rows() (*sql.Rows, error) {
368 defer scope.Trace(NowFunc())
369 scope.callCallbacks(scope.db.parent.callback.rowQueries)
370 scope.prepareQuerySql()
371 return scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
372 }
373
374 func (scope *Scope) initialize() *Scope {
375 for _, clause := range scope.Search.whereConditions {
376 scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false)
377 }
378 scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false)
379 scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false)
380 return scope
381 }
382
383 func (scope *Scope) pluck(column string, value interface{}) *Scope {
384 dest := reflect.Indirect(reflect.ValueOf(value))
385 scope.Search.Select(column)
386 if dest.Kind() != reflect.Slice {
387 scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
388 return scope
389 }
390
391 rows, err := scope.rows()
392 if scope.Err(err) == nil {
393 defer rows.Close()
394 for rows.Next() {
395 elem := reflect.New(dest.Type().Elem()).Interface()
396 scope.Err(rows.Scan(elem))
397 dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem()))
398 }
399 }
400 return scope
401 }
402
403 func (scope *Scope) count(value interface{}) *Scope {
404 scope.Search.Select("count(*)")
405 scope.Err(scope.row().Scan(value))
406 return scope
407 }
408
409 func (scope *Scope) typeName() string {
410 value := scope.IndirectValue()
411 if value.Kind() == reflect.Slice {
412 return value.Type().Elem().Name()
413 }
414
415 return value.Type().Name()
416 }
417
418 func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
419 toScope := scope.db.NewScope(value)
420 fromFields := scope.Fields()
421 toFields := toScope.Fields()
422 for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
423 var fromField, toField *Field
424 if field, ok := scope.FieldByName(foreignKey); ok {
425 fromField = field
426 } else {
427 fromField = fromFields[ToDBName(foreignKey)]
428 }
429 if field, ok := toScope.FieldByName(foreignKey); ok {
430 toField = field
431 } else {
432 toField = toFields[ToDBName(foreignKey)]
433 }
434
435 if fromField != nil {
436 if relationship := fromField.Relationship; relationship != nil {
437 if relationship.Kind == "many_to_many" {
438 joinTableHandler := relationship.JoinTableHandler
439 scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error)
440 } else if relationship.Kind == "belongs_to" {
441 query := toScope.db
442 for idx, foreignKey := range relationship.ForeignDBNames {
443 if field, ok := scope.FieldByName(foreignKey); ok {
444 query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface())
445 }
446 }
447 scope.Err(query.Find(value).Error)
448 } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
449 query := toScope.db
450 for idx, foreignKey := range relationship.ForeignDBNames {
451 if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
452 query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
453 }
454 }
455
456 if relationship.PolymorphicType != "" {
457 query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
458 }
459 scope.Err(query.Find(value).Error)
460 }
461 } else {
462 sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
463 scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error)
464 }
465 return scope
466 } else if toField != nil {
467 sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
468 scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error)
469 return scope
470 }
471 }
472
473 scope.Err(fmt.Errorf("invalid association %v", foreignKeys))
474 return scope
475 }
476
477 /**
478 Return the table options string or an empty string if the table options does not exist
479 */
480 func (scope *Scope) getTableOptions() string {
481 tableOptions, ok := scope.Get("gorm:table_options")
482 if !ok {
483 return ""
484 }
485 return tableOptions.(string)
486 }
487
488 func (scope *Scope) createJoinTable(field *StructField) {
489 if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
490 joinTableHandler := relationship.JoinTableHandler
491 joinTable := joinTableHandler.Table(scope.db)
492 if !scope.Dialect().HasTable(scope, joinTable) {
493 toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
494
495 var sqlTypes []string
496 for idx, fieldName := range relationship.ForeignFieldNames {
497 if field, ok := scope.Fields()[fieldName]; ok {
498 value := reflect.Indirect(reflect.New(field.Struct.Type))
499 primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false)
500 sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType)
501 }
502 }
503
504 for idx, fieldName := range relationship.AssociationForeignFieldNames {
505 if field, ok := toScope.Fields()[fieldName]; ok {
506 value := reflect.Indirect(reflect.New(field.Struct.Type))
507 primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false)
508 sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType)
509 }
510 }
511 scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), scope.getTableOptions())).Error)
512 }
513 scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
514 }
515 }
516
517 func (scope *Scope) createTable() *Scope {
518 var tags []string
519 var primaryKeys []string
520 var primaryKeyInColumnType bool = false
521 for _, field := range scope.GetStructFields() {
522 if field.IsNormal {
523 sqlTag := scope.generateSqlTag(field)
524
525 // Check if the primary key constraint was specified as
526 // part of the column type. If so, we can only support
527 // one column as the primary key.
528 if strings.Contains(strings.ToLower(sqlTag), "primary key") {
529 primaryKeyInColumnType = true
530 }
531
532 tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
533 }
534
535 if field.IsPrimaryKey {
536 primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
537 }
538 scope.createJoinTable(field)
539 }
540
541 var primaryKeyStr string
542 if len(primaryKeys) > 0 && !primaryKeyInColumnType {
543 primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
544 }
545 scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
546 return scope
547 }
548
549 func (scope *Scope) dropTable() *Scope {
550 scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
551 return scope
552 }
553
554 func (scope *Scope) dropTableIfExists() *Scope {
555 if scope.Dialect().HasTable(scope, scope.TableName()) {
556 scope.dropTable()
557 }
558 return scope
559 }
560
561 func (scope *Scope) modifyColumn(column string, typ string) {
562 scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
563 }
564
565 func (scope *Scope) dropColumn(column string) {
566 scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec()
567 }
568
569 func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
570 if scope.Dialect().HasIndex(scope, scope.TableName(), indexName) {
571 return
572 }
573
574 var columns []string
575 for _, name := range column {
576 columns = append(columns, scope.QuoteIfPossible(name))
577 }
578
579 sqlCreate := "CREATE INDEX"
580 if unique {
581 sqlCreate = "CREATE UNIQUE INDEX"
582 }
583
584 scope.Search.Unscoped = true
585 scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSql())).Exec()
586 scope.Search.Unscoped = false
587 }
588
589 func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
590 var table = scope.TableName()
591 var keyName = fmt.Sprintf("%s_%s_%s_foreign", table, field, regexp.MustCompile("[^a-zA-Z]").ReplaceAllString(dest, "_"))
592 keyName = regexp.MustCompile("_+").ReplaceAllString(keyName, "_")
593 var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
594 scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.QuoteIfPossible(keyName), scope.QuoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
595 }
596
597 func (scope *Scope) removeIndex(indexName string) {
598 scope.Dialect().RemoveIndex(scope, indexName)
599 }
600
601 func (scope *Scope) autoMigrate() *Scope {
602 tableName := scope.TableName()
603 quotedTableName := scope.QuotedTableName()
604
605 if !scope.Dialect().HasTable(scope, tableName) {
606 scope.createTable()
607 } else {
608 for _, field := range scope.GetStructFields() {
609 if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
610 if field.IsNormal {
611 sqlTag := scope.generateSqlTag(field)
612 scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
613 }
614 }
615 scope.createJoinTable(field)
616 }
617 }
618
619 scope.autoIndex()
620 return scope
621 }
622
623 func (scope *Scope) autoIndex() *Scope {
624 var indexes = map[string][]string{}
625 var uniqueIndexes = map[string][]string{}
626
627 for _, field := range scope.GetStructFields() {
628 sqlSettings := parseTagSetting(field.Tag.Get("sql"))
629 if name, ok := sqlSettings["INDEX"]; ok {
630 if name == "INDEX" {
631 name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName)
632 }
633 indexes[name] = append(indexes[name], field.DBName)
634 }
635
636 if name, ok := sqlSettings["UNIQUE_INDEX"]; ok {
637 if name == "UNIQUE_INDEX" {
638 name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName)
639 }
640 uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
641 }
642 }
643
644 for name, columns := range indexes {
645 scope.addIndex(false, name, columns...)
646 }
647
648 for name, columns := range uniqueIndexes {
649 scope.addIndex(true, name, columns...)
650 }
651
652 return scope
653 }
00 package gorm_test
11
22 import (
3 "encoding/hex"
4 "math/rand"
5 "strings"
6 "testing"
7
38 "github.com/jinzhu/gorm"
4 "testing"
59 )
610
711 func NameIn1And2(d *gorm.DB) *gorm.DB {
4044 t.Errorf("Should found two users's name in 1, 3")
4145 }
4246 }
47
48 func randName() string {
49 data := make([]byte, 8)
50 rand.Read(data)
51
52 return "n-" + hex.EncodeToString(data)
53 }
54
55 func TestValuer(t *testing.T) {
56 name := randName()
57
58 origUser := User{Name: name, Age: 1, Password: EncryptedData("pass1"), PasswordHash: []byte("abc")}
59 if err := DB.Save(&origUser).Error; err != nil {
60 t.Errorf("No error should happen when saving user, but got %v", err)
61 }
62
63 var user2 User
64 if err := DB.Where("name = ? AND password = ? AND password_hash = ?", name, EncryptedData("pass1"), []byte("abc")).First(&user2).Error; err != nil {
65 t.Errorf("No error should happen when querying user with valuer, but got %v", err)
66 }
67 }
68
69 func TestFailedValuer(t *testing.T) {
70 name := randName()
71
72 err := DB.Exec("INSERT INTO users(name, password) VALUES(?, ?)", name, EncryptedData("xpass1")).Error
73
74 if err == nil {
75 t.Errorf("There should be an error should happen when insert data")
76 } else if !strings.HasPrefix(err.Error(), "Should not start with") {
77 t.Errorf("The error should be returned from Valuer, but get %v", err)
78 }
79 }
00 package gorm
11
2 import "fmt"
2 import (
3 "fmt"
4 )
35
46 type search struct {
57 db *DB
79 orConditions []map[string]interface{}
810 notConditions []map[string]interface{}
911 havingConditions []map[string]interface{}
12 joinConditions []map[string]interface{}
1013 initAttrs []interface{}
1114 assignAttrs []interface{}
1215 selects map[string]interface{}
1316 omits []string
14 orders []string
15 joins string
17 orders []interface{}
1618 preload []searchPreload
17 offset string
18 limit string
19 offset interface{}
20 limit interface{}
1921 group string
2022 tableName string
2123 raw bool
2224 Unscoped bool
23 countingQuery bool
25 ignoreOrderQuery bool
2426 }
2527
2628 type searchPreload struct {
5860 return s
5961 }
6062
61 func (s *search) Order(value string, reorder ...bool) *search {
63 func (s *search) Order(value interface{}, reorder ...bool) *search {
6264 if len(reorder) > 0 && reorder[0] {
63 if value != "" {
64 s.orders = []string{value}
65 } else {
66 s.orders = []string{}
67 }
68 } else if value != "" {
65 s.orders = []interface{}{}
66 }
67
68 if value != nil && value != "" {
6969 s.orders = append(s.orders, value)
7070 }
7171 return s
8181 return s
8282 }
8383
84 func (s *search) Limit(value interface{}) *search {
85 s.limit = s.getInterfaceAsSql(value)
84 func (s *search) Limit(limit interface{}) *search {
85 s.limit = limit
8686 return s
8787 }
8888
89 func (s *search) Offset(value interface{}) *search {
90 s.offset = s.getInterfaceAsSql(value)
89 func (s *search) Offset(offset interface{}) *search {
90 s.offset = offset
9191 return s
9292 }
9393
9494 func (s *search) Group(query string) *search {
95 s.group = s.getInterfaceAsSql(query)
95 s.group = s.getInterfaceAsSQL(query)
9696 return s
9797 }
9898
99 func (s *search) Having(query string, values ...interface{}) *search {
100 s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values})
99 func (s *search) Having(query interface{}, values ...interface{}) *search {
100 if val, ok := query.(*expr); ok {
101 s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args})
102 } else {
103 s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values})
104 }
101105 return s
102106 }
103107
104 func (s *search) Joins(query string) *search {
105 s.joins = query
108 func (s *search) Joins(query string, values ...interface{}) *search {
109 s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values})
106110 return s
107111 }
108112
133137 return s
134138 }
135139
136 func (s *search) getInterfaceAsSql(value interface{}) (str string) {
140 func (s *search) getInterfaceAsSQL(value interface{}) (str string) {
137141 switch value.(type) {
138142 case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
139143 str = fmt.Sprintf("%v", value)
140144 default:
141 s.db.AddError(InvalidSql)
145 s.db.AddError(ErrInvalidSQL)
142146 }
143147
144148 if str == "-1" {
+0
-70
slice_test.go less more
0 package gorm_test
1
2 import (
3 "database/sql/driver"
4 "encoding/json"
5 "testing"
6 )
7
8 func TestScannableSlices(t *testing.T) {
9 if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil {
10 t.Errorf("Should create table with slice values correctly: %s", err)
11 }
12
13 r1 := RecordWithSlice{
14 Strings: ExampleStringSlice{"a", "b", "c"},
15 Structs: ExampleStructSlice{
16 {"name1", "value1"},
17 {"name2", "value2"},
18 },
19 }
20
21 if err := DB.Save(&r1).Error; err != nil {
22 t.Errorf("Should save record with slice values")
23 }
24
25 var r2 RecordWithSlice
26
27 if err := DB.Find(&r2).Error; err != nil {
28 t.Errorf("Should fetch record with slice values")
29 }
30
31 if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" {
32 t.Errorf("Should have serialised and deserialised a string array")
33 }
34
35 if len(r2.Structs) != 2 || r2.Structs[0].Name != "name1" || r2.Structs[0].Value != "value1" || r2.Structs[1].Name != "name2" || r2.Structs[1].Value != "value2" {
36 t.Errorf("Should have serialised and deserialised a struct array")
37 }
38 }
39
40 type RecordWithSlice struct {
41 ID uint64
42 Strings ExampleStringSlice `sql:"type:text"`
43 Structs ExampleStructSlice `sql:"type:text"`
44 }
45
46 type ExampleStringSlice []string
47
48 func (l ExampleStringSlice) Value() (driver.Value, error) {
49 return json.Marshal(l)
50 }
51
52 func (l *ExampleStringSlice) Scan(input interface{}) error {
53 return json.Unmarshal(input.([]byte), l)
54 }
55
56 type ExampleStruct struct {
57 Name string
58 Value string
59 }
60
61 type ExampleStructSlice []ExampleStruct
62
63 func (l ExampleStructSlice) Value() (driver.Value, error) {
64 return json.Marshal(l)
65 }
66
67 func (l *ExampleStructSlice) Scan(input interface{}) error {
68 return json.Unmarshal(input.([]byte), l)
69 }
+0
-84
sqlite3.go less more
0 package gorm
1
2 import (
3 "fmt"
4 "reflect"
5 "time"
6 )
7
8 type sqlite3 struct {
9 commonDialect
10 }
11
12 func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
13 switch value.Kind() {
14 case reflect.Bool:
15 return "bool"
16 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
17 if autoIncrease {
18 return "integer primary key autoincrement"
19 }
20 return "integer"
21 case reflect.Int64, reflect.Uint64:
22 if autoIncrease {
23 return "integer primary key autoincrement"
24 }
25 return "bigint"
26 case reflect.Float32, reflect.Float64:
27 return "real"
28 case reflect.String:
29 if size > 0 && size < 65532 {
30 return fmt.Sprintf("varchar(%d)", size)
31 }
32 return "text"
33 case reflect.Struct:
34 if _, ok := value.Interface().(time.Time); ok {
35 return "datetime"
36 }
37 default:
38 if _, ok := value.Interface().([]byte); ok {
39 return "blob"
40 }
41 }
42 panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
43 }
44
45 func (s sqlite3) HasTable(scope *Scope, tableName string) bool {
46 var count int
47 s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName)
48 return count > 0
49 }
50
51 func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
52 var count int
53 s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", columnName, columnName, columnName, columnName), tableName)
54 return count > 0
55 }
56
57 func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
58 var count int
59 s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
60 return count > 0
61 }
62
63 func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
64 scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
65 }
66
67 func (sqlite3) CurrentDatabase(scope *Scope) (name string) {
68 var (
69 ifaces = make([]interface{}, 3)
70 pointers = make([]*string, 3)
71 i int
72 )
73 for i = 0; i < 3; i++ {
74 ifaces[i] = &pointers[i]
75 }
76 if err := scope.NewDB().Raw("PRAGMA database_list").Row().Scan(ifaces...); scope.Err(err) != nil {
77 return
78 }
79 if pointers[1] != nil {
80 name = *pointers[1]
81 }
82 return
83 }
+0
-219
structs_test.go less more
0 package gorm_test
1
2 import (
3 "database/sql"
4 "database/sql/driver"
5 "errors"
6 "fmt"
7
8 "reflect"
9 "time"
10 )
11
12 type User struct {
13 Id int64
14 Age int64
15 UserNum Num
16 Name string `sql:"size:255"`
17 Birthday time.Time // Time
18 CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
19 UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
20 Emails []Email // Embedded structs
21 BillingAddress Address // Embedded struct
22 BillingAddressID sql.NullInt64 // Embedded struct's foreign key
23 ShippingAddress Address // Embedded struct
24 ShippingAddressId int64 // Embedded struct's foreign key
25 CreditCard CreditCard
26 Latitude float64
27 Languages []Language `gorm:"many2many:user_languages;"`
28 CompanyID int64
29 Company Company
30 Role
31 PasswordHash []byte
32 IgnoreMe int64 `sql:"-"`
33 IgnoreStringSlice []string `sql:"-"`
34 Ignored struct{ Name string } `sql:"-"`
35 IgnoredPointer *User `sql:"-"`
36 }
37
38 type CreditCard struct {
39 ID int8
40 Number string
41 UserId sql.NullInt64
42 CreatedAt time.Time
43 UpdatedAt time.Time
44 DeletedAt time.Time
45 }
46
47 type Email struct {
48 Id int16
49 UserId int
50 Email string `sql:"type:varchar(100);"`
51 CreatedAt time.Time
52 UpdatedAt time.Time
53 }
54
55 type Address struct {
56 ID int
57 Address1 string
58 Address2 string
59 Post string
60 CreatedAt time.Time
61 UpdatedAt time.Time
62 DeletedAt time.Time
63 }
64
65 type Language struct {
66 Id int
67 Name string
68 Users []User `gorm:"many2many:user_languages;"`
69 }
70
71 type Product struct {
72 Id int64
73 Code string
74 Price int64
75 CreatedAt time.Time
76 UpdatedAt time.Time
77 AfterFindCallTimes int64
78 BeforeCreateCallTimes int64
79 AfterCreateCallTimes int64
80 BeforeUpdateCallTimes int64
81 AfterUpdateCallTimes int64
82 BeforeSaveCallTimes int64
83 AfterSaveCallTimes int64
84 BeforeDeleteCallTimes int64
85 AfterDeleteCallTimes int64
86 }
87
88 type Company struct {
89 Id int64
90 Name string
91 Owner *User `sql:"-"`
92 }
93
94 type Role struct {
95 Name string
96 }
97
98 func (role *Role) Scan(value interface{}) error {
99 if b, ok := value.([]uint8); ok {
100 role.Name = string(b)
101 } else {
102 role.Name = value.(string)
103 }
104 return nil
105 }
106
107 func (role Role) Value() (driver.Value, error) {
108 return role.Name, nil
109 }
110
111 func (role Role) IsAdmin() bool {
112 return role.Name == "admin"
113 }
114
115 type Num int64
116
117 func (i *Num) Scan(src interface{}) error {
118 switch s := src.(type) {
119 case []byte:
120 case int64:
121 *i = Num(s)
122 default:
123 return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
124 }
125 return nil
126 }
127
128 type Animal struct {
129 Counter uint64 `gorm:"primary_key:yes"`
130 Name string `sql:"DEFAULT:'galeone'"`
131 From string //test reserved sql keyword as field name
132 Age time.Time `sql:"DEFAULT:current_timestamp"`
133 unexported string // unexported value
134 CreatedAt time.Time
135 UpdatedAt time.Time
136 }
137
138 type JoinTable struct {
139 From uint64
140 To uint64
141 Time time.Time `sql:"default: null"`
142 }
143
144 type Post struct {
145 Id int64
146 CategoryId sql.NullInt64
147 MainCategoryId int64
148 Title string
149 Body string
150 Comments []*Comment
151 Category Category
152 MainCategory Category
153 }
154
155 type Category struct {
156 Id int64
157 Name string
158 }
159
160 type Comment struct {
161 Id int64
162 PostId int64
163 Content string
164 Post Post
165 }
166
167 // Scanner
168 type NullValue struct {
169 Id int64
170 Name sql.NullString `sql:"not null"`
171 Age sql.NullInt64
172 Male sql.NullBool
173 Height sql.NullFloat64
174 AddedAt NullTime
175 }
176
177 type NullTime struct {
178 Time time.Time
179 Valid bool
180 }
181
182 func (nt *NullTime) Scan(value interface{}) error {
183 if value == nil {
184 nt.Valid = false
185 return nil
186 }
187 nt.Time, nt.Valid = value.(time.Time), true
188 return nil
189 }
190
191 func (nt NullTime) Value() (driver.Value, error) {
192 if !nt.Valid {
193 return nil, nil
194 }
195 return nt.Time, nil
196 }
197
198 func getPreparedUser(name string, role string) *User {
199 var company Company
200 DB.Where(Company{Name: role}).FirstOrCreate(&company)
201
202 return &User{
203 Name: name,
204 Age: 20,
205 Role: Role{role},
206 BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
207 ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
208 CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
209 Emails: []Email{
210 {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
211 },
212 Company: company,
213 Languages: []Language{
214 {Name: fmt.Sprintf("lang_1_%v", name)},
215 {Name: fmt.Sprintf("lang_2_%v", name)},
216 },
217 }
218 }
0 dialects=("postgres" "mysql" "sqlite")
0 dialects=("postgres" "mysql" "mssql" "sqlite")
11
22 for dialect in "${dialects[@]}" ; do
3 GORM_DIALECT=${dialect} go test
3 DEBUG=false GORM_DIALECT=${dialect} go test
44 done
1919 DB.First(&product1, product1.Id)
2020 DB.First(&product2, product2.Id)
2121 updatedAt1 := product1.UpdatedAt
22 updatedAt2 := product2.UpdatedAt
23
24 var product3 Product
25 DB.First(&product3, product2.Id).Update("code", "product2newcode")
26 if updatedAt2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
27 t.Errorf("updatedAt should not be updated if nothing changed")
28 }
2922
3023 if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() {
3124 t.Errorf("Product1 should not be updated")
7063 }
7164
7265 DB.First(&product4, product4.Id)
66 updatedAt4 := product4.UpdatedAt
7367 DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
7468 var product5 Product
7569 DB.First(&product5, product4.Id)
7670 if product5.Price != product4.Price+100-50 {
7771 t.Errorf("Update with expression")
7872 }
79 if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
73 if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
8074 t.Errorf("Update with expression should update UpdatedAt")
8175 }
8276 }
10296 DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched
10397 DB.First(&animal, animal.Counter)
10498 if animal.Name != "galeone" {
105 t.Errorf("Name fiels shouldn't be changed if untouched, but got %v", animal.Name)
99 t.Errorf("Name fields shouldn't be changed if untouched, but got %v", animal.Name)
106100 }
107101
108102 // When changing a field with a default value, the change must occur
133127
134128 DB.First(&product1, product1.Id)
135129 DB.First(&product2, product2.Id)
136 updatedAt1 := product1.UpdatedAt
137130 updatedAt2 := product2.UpdatedAt
138
139 var product3 Product
140 DB.First(&product3, product1.Id).Updates(Product{Code: "product1newcode", Price: 100})
141 if product3.Code != "product1newcode" || product3.Price != 100 {
142 t.Errorf("Record should be updated with struct")
143 }
144
145 if updatedAt1.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
146 t.Errorf("updatedAt should not be updated if nothing changed")
147 }
148131
149132 if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() {
150133 t.Errorf("Product2 should not be updated")
169152 t.Errorf("product2's code should be updated")
170153 }
171154
155 updatedAt4 := product4.UpdatedAt
172156 DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
173157 var product5 Product
174158 DB.First(&product5, product4.Id)
175159 if product5.Price != product4.Price+100 {
176160 t.Errorf("Updates with expression")
177161 }
178 if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
162 // product4's UpdatedAt will be reset when updating
163 if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
179164 t.Errorf("Updates with expression should update UpdatedAt")
180165 }
181166 }
314299 queryUser.ShippingAddressId == user.ShippingAddressId ||
315300 queryUser.CreditCard.ID != user.CreditCard.ID ||
316301 len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id {
317 t.Errorf("Should only update relationships that not omited")
302 t.Errorf("Should only update relationships that not omitted")
318303 }
319304 }
320305
350335 queryUser.ShippingAddressId == user.ShippingAddressId ||
351336 queryUser.CreditCard.ID != user.CreditCard.ID ||
352337 len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id {
353 t.Errorf("Should only update relationships not omited")
338 t.Errorf("Should only update relationships not omitted")
354339 }
355340 }
356341
418403 t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1)
419404 }
420405 }
406
407 func TestUpdatesWithBlankValues(t *testing.T) {
408 product := Product{Code: "product1", Price: 10}
409 DB.Save(&product)
410
411 DB.Model(&Product{Id: product.Id}).Updates(&Product{Price: 100})
412
413 var product1 Product
414 DB.First(&product1, product.Id)
415
416 if product1.Code != "product1" || product1.Price != 100 {
417 t.Errorf("product's code should not be updated")
418 }
419 }
420
421 type ElementWithIgnoredField struct {
422 Id int64
423 Value string
424 IgnoredField int64 `sql:"-"`
425 }
426
427 func (e ElementWithIgnoredField) TableName() string {
428 return "element_with_ignored_field"
429 }
430
431 func TestUpdatesTableWithIgnoredValues(t *testing.T) {
432 elem := ElementWithIgnoredField{Value: "foo", IgnoredField: 10}
433 DB.Save(&elem)
434
435 DB.Table(elem.TableName()).
436 Where("id = ?", elem.Id).
437 // DB.Model(&ElementWithIgnoredField{Id: elem.Id}).
438 Updates(&ElementWithIgnoredField{Value: "bar", IgnoredField: 100})
439
440 var elem1 ElementWithIgnoredField
441 err := DB.First(&elem1, elem.Id).Error
442 if err != nil {
443 t.Errorf("error getting an element from database: %s", err.Error())
444 }
445
446 if elem1.IgnoredField != 0 {
447 t.Errorf("element's ignored field should not be updated")
448 }
449 }
450
451 func TestUpdateDecodeVirtualAttributes(t *testing.T) {
452 var user = User{
453 Name: "jinzhu",
454 IgnoreMe: 88,
455 }
456
457 DB.Save(&user)
458
459 DB.Model(&user).Updates(User{Name: "jinzhu2", IgnoreMe: 100})
460
461 if user.IgnoreMe != 100 {
462 t.Errorf("should decode virtual attributes to struct, so it could be used in callbacks")
463 }
464 }
11
22 import (
33 "bytes"
4 "database/sql/driver"
5 "fmt"
6 "reflect"
7 "regexp"
8 "runtime"
49 "strings"
510 "sync"
11 "time"
612 )
713
14 // NowFunc returns current time, this function is exported in order to be able
15 // to give the flexibility to the developer to customize it according to their
16 // needs, e.g:
17 // gorm.NowFunc = func() time.Time {
18 // return time.Now().UTC()
19 // }
20 var NowFunc = func() time.Time {
21 return time.Now()
22 }
23
824 // Copied from golint
9 var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
25 var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
1026 var commonInitialismsReplacer *strings.Replacer
27
28 var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)
29 var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`)
1130
1231 func init() {
1332 var commonInitialismsForReplacer []string
4059
4160 var smap = newSafeMap()
4261
62 type strCase bool
63
64 const (
65 lower strCase = false
66 upper strCase = true
67 )
68
69 // ToDBName convert string to db name
4370 func ToDBName(name string) string {
4471 if v := smap.Get(name); v != "" {
4572 return v
4673 }
4774
48 value := commonInitialismsReplacer.Replace(name)
49 buf := bytes.NewBufferString("")
50 for i, v := range value {
51 if i > 0 && v >= 'A' && v <= 'Z' {
52 buf.WriteRune('_')
53 }
54 buf.WriteRune(v)
55 }
75 if name == "" {
76 return ""
77 }
78
79 var (
80 value = commonInitialismsReplacer.Replace(name)
81 buf = bytes.NewBufferString("")
82 lastCase, currCase, nextCase strCase
83 )
84
85 for i, v := range value[:len(value)-1] {
86 nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z')
87 if i > 0 {
88 if currCase == upper {
89 if lastCase == upper && nextCase == upper {
90 buf.WriteRune(v)
91 } else {
92 if value[i-1] != '_' && value[i+1] != '_' {
93 buf.WriteRune('_')
94 }
95 buf.WriteRune(v)
96 }
97 } else {
98 buf.WriteRune(v)
99 if i == len(value)-2 && nextCase == upper {
100 buf.WriteRune('_')
101 }
102 }
103 } else {
104 currCase = upper
105 buf.WriteRune(v)
106 }
107 lastCase = currCase
108 currCase = nextCase
109 }
110
111 buf.WriteByte(value[len(value)-1])
56112
57113 s := strings.ToLower(buf.String())
58114 smap.Set(name, s)
59115 return s
60116 }
61117
118 // SQL expression
62119 type expr struct {
63120 expr string
64121 args []interface{}
65122 }
66123
124 // Expr generate raw SQL expression, for example:
125 // DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100))
67126 func Expr(expression string, args ...interface{}) *expr {
68127 return &expr{expr: expression, args: args}
69128 }
129
130 func indirect(reflectValue reflect.Value) reflect.Value {
131 for reflectValue.Kind() == reflect.Ptr {
132 reflectValue = reflectValue.Elem()
133 }
134 return reflectValue
135 }
136
137 func toQueryMarks(primaryValues [][]interface{}) string {
138 var results []string
139
140 for _, primaryValue := range primaryValues {
141 var marks []string
142 for range primaryValue {
143 marks = append(marks, "?")
144 }
145
146 if len(marks) > 1 {
147 results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
148 } else {
149 results = append(results, strings.Join(marks, ""))
150 }
151 }
152 return strings.Join(results, ",")
153 }
154
155 func toQueryCondition(scope *Scope, columns []string) string {
156 var newColumns []string
157 for _, column := range columns {
158 newColumns = append(newColumns, scope.Quote(column))
159 }
160
161 if len(columns) > 1 {
162 return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
163 }
164 return strings.Join(newColumns, ",")
165 }
166
167 func toQueryValues(values [][]interface{}) (results []interface{}) {
168 for _, value := range values {
169 for _, v := range value {
170 results = append(results, v)
171 }
172 }
173 return
174 }
175
176 func fileWithLineNum() string {
177 for i := 2; i < 15; i++ {
178 _, file, line, ok := runtime.Caller(i)
179 if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) {
180 return fmt.Sprintf("%v:%v", file, line)
181 }
182 }
183 return ""
184 }
185
186 func isBlank(value reflect.Value) bool {
187 switch value.Kind() {
188 case reflect.String:
189 return value.Len() == 0
190 case reflect.Bool:
191 return !value.Bool()
192 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
193 return value.Int() == 0
194 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
195 return value.Uint() == 0
196 case reflect.Float32, reflect.Float64:
197 return value.Float() == 0
198 case reflect.Interface, reflect.Ptr:
199 return value.IsNil()
200 }
201
202 return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
203 }
204
205 func toSearchableMap(attrs ...interface{}) (result interface{}) {
206 if len(attrs) > 1 {
207 if str, ok := attrs[0].(string); ok {
208 result = map[string]interface{}{str: attrs[1]}
209 }
210 } else if len(attrs) == 1 {
211 if attr, ok := attrs[0].(map[string]interface{}); ok {
212 result = attr
213 }
214
215 if attr, ok := attrs[0].(interface{}); ok {
216 result = attr
217 }
218 }
219 return
220 }
221
222 func equalAsString(a interface{}, b interface{}) bool {
223 return toString(a) == toString(b)
224 }
225
226 func toString(str interface{}) string {
227 if values, ok := str.([]interface{}); ok {
228 var results []string
229 for _, value := range values {
230 results = append(results, toString(value))
231 }
232 return strings.Join(results, "_")
233 } else if bytes, ok := str.([]byte); ok {
234 return string(bytes)
235 } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() {
236 return fmt.Sprintf("%v", reflectValue.Interface())
237 }
238 return ""
239 }
240
241 func makeSlice(elemType reflect.Type) interface{} {
242 if elemType.Kind() == reflect.Slice {
243 elemType = elemType.Elem()
244 }
245 sliceType := reflect.SliceOf(elemType)
246 slice := reflect.New(sliceType)
247 slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
248 return slice.Interface()
249 }
250
251 func strInSlice(a string, list []string) bool {
252 for _, b := range list {
253 if b == a {
254 return true
255 }
256 }
257 return false
258 }
259
260 // getValueFromFields return given fields's value
261 func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) {
262 // If value is a nil pointer, Indirect returns a zero Value!
263 // Therefor we need to check for a zero value,
264 // as FieldByName could panic
265 if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
266 for _, fieldName := range fieldNames {
267 if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() {
268 result := fieldValue.Interface()
269 if r, ok := result.(driver.Valuer); ok {
270 result, _ = r.Value()
271 }
272 results = append(results, result)
273 }
274 }
275 }
276 return
277 }
278
279 func addExtraSpaceIfExist(str string) string {
280 if str != "" {
281 return " " + str
282 }
283 return ""
284 }
+0
-97
utils_private.go less more
0 package gorm
1
2 import (
3 "fmt"
4 "reflect"
5 "regexp"
6 "runtime"
7 "strings"
8 )
9
10 func fileWithLineNum() string {
11 for i := 2; i < 15; i++ {
12 _, file, line, ok := runtime.Caller(i)
13 if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
14 return fmt.Sprintf("%v:%v", file, line)
15 }
16 }
17 return ""
18 }
19
20 func isBlank(value reflect.Value) bool {
21 return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
22 }
23
24 func toSearchableMap(attrs ...interface{}) (result interface{}) {
25 if len(attrs) > 1 {
26 if str, ok := attrs[0].(string); ok {
27 result = map[string]interface{}{str: attrs[1]}
28 }
29 } else if len(attrs) == 1 {
30 if attr, ok := attrs[0].(map[string]interface{}); ok {
31 result = attr
32 }
33
34 if attr, ok := attrs[0].(interface{}); ok {
35 result = attr
36 }
37 }
38 return
39 }
40
41 func convertInterfaceToMap(values interface{}) map[string]interface{} {
42 attrs := map[string]interface{}{}
43
44 switch value := values.(type) {
45 case map[string]interface{}:
46 for k, v := range value {
47 attrs[ToDBName(k)] = v
48 }
49 case []interface{}:
50 for _, v := range value {
51 for key, value := range convertInterfaceToMap(v) {
52 attrs[key] = value
53 }
54 }
55 case interface{}:
56 reflectValue := reflect.ValueOf(values)
57
58 switch reflectValue.Kind() {
59 case reflect.Map:
60 for _, key := range reflectValue.MapKeys() {
61 attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
62 }
63 default:
64 scope := Scope{Value: values}
65 for _, field := range scope.Fields() {
66 if !field.IsBlank && !field.IsIgnored {
67 attrs[field.DBName] = field.Field.Interface()
68 }
69 }
70 }
71 }
72 return attrs
73 }
74
75 func toString(str interface{}) string {
76 if values, ok := str.([]interface{}); ok {
77 var results []string
78 for _, value := range values {
79 results = append(results, toString(value))
80 }
81 return strings.Join(results, "_")
82 } else if bytes, ok := str.([]byte); ok {
83 return string(bytes)
84 } else {
85 return fmt.Sprintf("%v", str)
86 }
87 }
88
89 func strInSlice(a string, list []string) bool {
90 for _, b := range list {
91 if b == a {
92 return true
93 }
94 }
95 return false
96 }
0 package gorm_test
1
2 import (
3 "testing"
4
5 "github.com/jinzhu/gorm"
6 )
7
8 func TestToDBNameGenerateFriendlyName(t *testing.T) {
9 var maps = map[string]string{
10 "": "",
11 "X": "x",
12 "ThisIsATest": "this_is_a_test",
13 "PFAndESI": "pf_and_esi",
14 "AbcAndJkl": "abc_and_jkl",
15 "EmployeeID": "employee_id",
16 "SKU_ID": "sku_id",
17 "FieldX": "field_x",
18 "HTTPAndSMTP": "http_and_smtp",
19 "HTTPServerHandlerForURLID": "http_server_handler_for_url_id",
20 "UUID": "uuid",
21 "HTTPURL": "http_url",
22 "HTTP_URL": "http_url",
23 "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
24 }
25
26 for key, value := range maps {
27 if gorm.ToDBName(key) != value {
28 t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key))
29 }
30 }
31 }
0 # use the default golang container from Docker Hub
1 box: golang
2
3 services:
4 - name: mariadb
5 id: mariadb:latest
6 env:
7 MYSQL_DATABASE: gorm
8 MYSQL_USER: gorm
9 MYSQL_PASSWORD: gorm
10 MYSQL_RANDOM_ROOT_PASSWORD: "yes"
11 - name: mysql
12 id: mysql:8
13 env:
14 MYSQL_DATABASE: gorm
15 MYSQL_USER: gorm
16 MYSQL_PASSWORD: gorm
17 MYSQL_RANDOM_ROOT_PASSWORD: "yes"
18 - name: mysql57
19 id: mysql:5.7
20 env:
21 MYSQL_DATABASE: gorm
22 MYSQL_USER: gorm
23 MYSQL_PASSWORD: gorm
24 MYSQL_RANDOM_ROOT_PASSWORD: "yes"
25 - name: mysql56
26 id: mysql:5.6
27 env:
28 MYSQL_DATABASE: gorm
29 MYSQL_USER: gorm
30 MYSQL_PASSWORD: gorm
31 MYSQL_RANDOM_ROOT_PASSWORD: "yes"
32 - name: mysql55
33 id: mysql:5.5
34 env:
35 MYSQL_DATABASE: gorm
36 MYSQL_USER: gorm
37 MYSQL_PASSWORD: gorm
38 MYSQL_RANDOM_ROOT_PASSWORD: "yes"
39 - name: postgres
40 id: postgres:latest
41 env:
42 POSTGRES_USER: gorm
43 POSTGRES_PASSWORD: gorm
44 POSTGRES_DB: gorm
45 - name: postgres96
46 id: postgres:9.6
47 env:
48 POSTGRES_USER: gorm
49 POSTGRES_PASSWORD: gorm
50 POSTGRES_DB: gorm
51 - name: postgres95
52 id: postgres:9.5
53 env:
54 POSTGRES_USER: gorm
55 POSTGRES_PASSWORD: gorm
56 POSTGRES_DB: gorm
57 - name: postgres94
58 id: postgres:9.4
59 env:
60 POSTGRES_USER: gorm
61 POSTGRES_PASSWORD: gorm
62 POSTGRES_DB: gorm
63 - name: postgres93
64 id: postgres:9.3
65 env:
66 POSTGRES_USER: gorm
67 POSTGRES_PASSWORD: gorm
68 POSTGRES_DB: gorm
69 - name: mssql
70 id: mcmoe/mssqldocker:latest
71 env:
72 ACCEPT_EULA: Y
73 SA_PASSWORD: LoremIpsum86
74 MSSQL_DB: gorm
75 MSSQL_USER: gorm
76 MSSQL_PASSWORD: LoremIpsum86
77
78 # The steps that will be executed in the build pipeline
79 build:
80 # The steps that will be executed on build
81 steps:
82 # Sets the go workspace and places you package
83 # at the right place in the workspace tree
84 - setup-go-workspace
85
86 # Gets the dependencies
87 - script:
88 name: go get
89 code: |
90 cd $WERCKER_SOURCE_DIR
91 go version
92 go get -t ./...
93
94 # Build the project
95 - script:
96 name: go build
97 code: |
98 go build ./...
99
100 # Test the project
101 - script:
102 name: test sqlite
103 code: |
104 go test ./...
105
106 - script:
107 name: test mariadb
108 code: |
109 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./...
110
111 - script:
112 name: test mysql
113 code: |
114 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test ./...
115
116 - script:
117 name: test mysql5.7
118 code: |
119 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test ./...
120
121 - script:
122 name: test mysql5.6
123 code: |
124 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test ./...
125
126 - script:
127 name: test mysql5.5
128 code: |
129 GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test ./...
130
131 - script:
132 name: test postgres
133 code: |
134 GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
135
136 - script:
137 name: test postgres96
138 code: |
139 GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
140
141 - script:
142 name: test postgres95
143 code: |
144 GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
145
146 - script:
147 name: test postgres94
148 code: |
149 GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
150
151 - script:
152 name: test postgres93
153 code: |
154 GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
155
156 - script:
157 name: test mssql
158 code: |
159 GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test ./...