Codebase list golang-github-jinzhu-gorm / bb86fc1
Imported Upstream version 0.0~git20151012.0.20e37a0 Tianon Gravi 8 years ago
57 changed file(s) with 10423 addition(s) and 0 deletion(s). Raw diff Collapse all Expand all
0 ---
1 engines:
2 gofmt:
3 enabled: true
4 govet:
5 enabled: true
6 golint:
7 enabled: true
8 ratings:
9 paths:
10 - "**.go"
0 The MIT License (MIT)
1
2 Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
3
4 Permission is hereby granted, free of charge, to any person obtaining a copy
5 of this software and associated documentation files (the "Software"), to deal
6 in the Software without restriction, including without limitation the rights
7 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 copies of the Software, and to permit persons to whom the Software is
9 furnished to do so, subject to the following conditions:
10
11 The above copyright notice and this permission notice shall be included in
12 all copies or substantial portions of the Software.
13
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 THE SOFTWARE.
0 # GORM
1
2 [![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
3
4 The fantastic ORM library for Golang, aims to be developer friendly.
5
6 [![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921)
7
8 ## Overview
9
10 * Full-Featured ORM (almost)
11 * Chainable API
12 * Auto Migrations
13 * Relations (Has One, Has Many, Belongs To, Many To Many, [Polymorphism](#polymorphism))
14 * Callbacks (Before/After Create/Save/Update/Delete/Find)
15 * Preloading (eager loading)
16 * Transactions
17 * Embed Anonymous Struct
18 * Soft Deletes
19 * Customizable Logger
20 * Iteration Support via [Rows](#row--rows)
21 * Every feature comes with tests
22 * Developer Friendly
23
24 # Getting Started
25
26 ## Install
27
28 ```
29 go get -u github.com/jinzhu/gorm
30 ```
31
32 ## Documentation
33
34 [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm)
35
36 `go doc` format documentation for this project can be viewed online without
37 installing the package by using the GoDoc page at:
38 http://godoc.org/github.com/jinzhu/gorm
39
40 ## Table of Contents
41
42 - [Define Models (Structs)](#define-models-structs)
43 - [Conventions](#conventions)
44 - [Initialize Database](#initialize-database)
45 - [Migration](#migration)
46 - [Basic CRUD](#basic-crud)
47 - [Create](#create-record)
48 - [Query](#query)
49 - [Query With Where (Plain SQL)](#query-with-where-plain-sql)
50 - [Query With Where (Struct & Map)](#query-with-where-struct--map)
51 - [Query With Not](#query-with-not)
52 - [Query With Inline Condition](#query-with-inline-condition)
53 - [Query With Or](#query-with-or)
54 - [Query Chains](#query-chains)
55 - [Preloading (Eager loading)](#preloading-eager-loading)
56 - [Update](#update)
57 - [Update Without Callbacks](#update-without-callbacks)
58 - [Batch Updates](#batch-updates)
59 - [Update with SQL Expression](#update-with-sql-expression)
60 - [Delete](#delete)
61 - [Batch Delete](#batch-delete)
62 - [Soft Delete](#soft-delete)
63 - [Associations](#associations)
64 - [Has One](#has-one)
65 - [Belongs To](#belongs-to)
66 - [Has Many](#has-many)
67 - [Many To Many](#many-to-many)
68 - [Polymorphism](#polymorphism)
69 - [Advanced Usage](#advanced-usage)
70 - [FirstOrInit](#firstorinit)
71 - [FirstOrCreate](#firstorcreate)
72 - [Select](#select)
73 - [Order](#order)
74 - [Limit](#limit)
75 - [Offset](#offset)
76 - [Count](#count)
77 - [Pluck](#pluck)
78 - [Raw SQL](#raw-sql)
79 - [Row & Rows](#row--rows)
80 - [Scan](#scan)
81 - [Group & Having](#group--having)
82 - [Joins](#joins)
83 - [Transactions](#transactions)
84 - [Scopes](#scopes)
85 - [Callbacks](#callbacks)
86 - [Specifying The Table Name](#specifying-the-table-name)
87 - [Error Handling](#error-handling)
88 - [Logger](#logger)
89 - [Existing Schema](#existing-schema)
90 - [Composite Primary Key](#composite-primary-key)
91 - [Database Indexes & Foreign Key](#database-indexes--foreign-key)
92 - [Default values](#default-values)
93 - [More examples with query chain](#more-examples-with-query-chain)
94
95 ## Define Models (Structs)
96
97 ```go
98 type User struct {
99 ID int
100 Birthday time.Time
101 Age int
102 Name string `sql:"size:255"` // Default size for string is 255, you could reset it with this tag
103 Num int `sql:"AUTO_INCREMENT"`
104 CreatedAt time.Time
105 UpdatedAt time.Time
106 DeletedAt *time.Time
107
108 Emails []Email // One-To-Many relationship (has many)
109 BillingAddress Address // One-To-One relationship (has one)
110 BillingAddressID sql.NullInt64 // Foreign key of BillingAddress
111 ShippingAddress Address // One-To-One relationship (has one)
112 ShippingAddressID int // Foreign key of ShippingAddress
113 IgnoreMe int `sql:"-"` // Ignore this field
114 Languages []Language `gorm:"many2many:user_languages;"` // Many-To-Many relationship, 'user_languages' is join table
115 }
116
117 type Email struct {
118 ID int
119 UserID int `sql:"index"` // Foreign key (belongs to), tag `index` will create index for this field when using AutoMigrate
120 Email string `sql:"type:varchar(100);unique_index"` // Set field's sql type, tag `unique_index` will create unique index
121 Subscribed bool
122 }
123
124 type Address struct {
125 ID int
126 Address1 string `sql:"not null;unique"` // Set field as not nullable and unique
127 Address2 string `sql:"type:varchar(100);unique"`
128 Post sql.NullString `sql:"not null"`
129 }
130
131 type Language struct {
132 ID int
133 Name string `sql:"index:idx_name_code"` // Create index with name, and will create combined index if find other fields defined same name
134 Code string `sql:"index:idx_name_code"` // `unique_index` also works
135 }
136 ```
137
138 ## Conventions
139
140 * Table name is the plural of struct name's snake case, you can disable pluralization with `db.SingularTable(true)`, or [Specifying The Table Name For A Struct Permanently With TableName](#specifying-the-table-name-for-a-struct-permanently-with-tablename)
141
142 ```go
143 type User struct{} // struct User's database table name is "users" by default, will be "user" if you disabled pluralisation
144 ```
145
146 * Column name is the snake case of field's name
147 * Use `ID` field as primary key
148 * Use `CreatedAt` to store record's created time if field exists
149 * Use `UpdatedAt` to store record's updated time if field exists
150 * Use `DeletedAt` to store record's deleted time if field exists [Soft Delete](#soft-delete)
151 * Gorm provide a default model struct, you could embed it in your struct
152
153 ```go
154 type Model struct {
155 ID uint `gorm:"primary_key"`
156 CreatedAt time.Time
157 UpdatedAt time.Time
158 DeletedAt *time.Time
159 }
160
161 type User struct {
162 gorm.Model
163 Name string
164 }
165 ```
166
167 ## Initialize Database
168
169 ```go
170 import (
171 "github.com/jinzhu/gorm"
172 _ "github.com/lib/pq"
173 _ "github.com/go-sql-driver/mysql"
174 _ "github.com/mattn/go-sqlite3"
175 )
176
177 db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
178 // db, err := gorm.Open("foundation", "dbname=gorm") // FoundationDB.
179 // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
180 // db, err := gorm.Open("sqlite3", "/tmp/gorm.db")
181
182 // You can also use an existing database connection handle
183 // dbSql, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
184 // db, _ := gorm.Open("postgres", dbSql)
185
186 // Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB)
187 db.DB()
188
189 // Then you could invoke `*sql.DB`'s functions with it
190 db.DB().Ping()
191 db.DB().SetMaxIdleConns(10)
192 db.DB().SetMaxOpenConns(100)
193
194 // Disable table name's pluralization
195 db.SingularTable(true)
196 ```
197
198 ## Migration
199
200 ```go
201 // Create table
202 db.CreateTable(&User{})
203 db.Set("gorm:table_options", "ENGINE=InnoDB").CreateTable(&User{})
204
205 // Drop table
206 db.DropTable(&User{})
207
208 // ModifyColumn
209 db.Model(&User{}).ModifyColumn("description", "text")
210
211 // DropColumn
212 db.Model(&User{}).DropColumn("description")
213
214 // Automating Migration
215 db.AutoMigrate(&User{})
216 db.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(&User{})
217 db.AutoMigrate(&User{}, &Product{}, &Order{})
218 // Feel free to change your struct, AutoMigrate will keep your database up-to-date.
219 // AutoMigrate will ONLY add *new columns* and *new indexes*,
220 // WON'T update current column's type or delete unused columns, to protect your data.
221 // If the table is not existing, AutoMigrate will create the table automatically.
222 ```
223
224 # Basic CRUD
225
226 ## Create Record
227
228 ```go
229 user := User{Name: "Jinzhu", Age: 18, Birthday: time.Now()}
230
231 db.NewRecord(user) // => returns `true` if primary key is blank
232
233 db.Create(&user)
234
235 db.NewRecord(user) // => return `false` after `user` created
236
237 // Associations will be inserted automatically when save the record
238 user := User{
239 Name: "jinzhu",
240 BillingAddress: Address{Address1: "Billing Address - Address 1"},
241 ShippingAddress: Address{Address1: "Shipping Address - Address 1"},
242 Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}},
243 Languages: []Language{{Name: "ZH"}, {Name: "EN"}},
244 }
245
246 db.Create(&user)
247 //// BEGIN TRANSACTION;
248 //// INSERT INTO "addresses" (address1) VALUES ("Billing Address - Address 1");
249 //// INSERT INTO "addresses" (address1) VALUES ("Shipping Address - Address 1");
250 //// INSERT INTO "users" (name,billing_address_id,shipping_address_id) VALUES ("jinzhu", 1, 2);
251 //// INSERT INTO "emails" (user_id,email) VALUES (111, "jinzhu@example.com");
252 //// INSERT INTO "emails" (user_id,email) VALUES (111, "jinzhu-2@example.com");
253 //// INSERT INTO "languages" ("name") VALUES ('ZH');
254 //// INSERT INTO user_languages ("user_id","language_id") VALUES (111, 1);
255 //// INSERT INTO "languages" ("name") VALUES ('EN');
256 //// INSERT INTO user_languages ("user_id","language_id") VALUES (111, 2);
257 //// COMMIT;
258 ```
259
260 Refer [Associations](#associations) for more details
261
262 ## Query
263
264 ```go
265 // Get the first record
266 db.First(&user)
267 //// SELECT * FROM users ORDER BY id LIMIT 1;
268
269 // Get the last record
270 db.Last(&user)
271 //// SELECT * FROM users ORDER BY id DESC LIMIT 1;
272
273 // Get all records
274 db.Find(&users)
275 //// SELECT * FROM users;
276
277 // Get record with primary key
278 db.First(&user, 10)
279 //// SELECT * FROM users WHERE id = 10;
280 ```
281
282 ### Query With Where (Plain SQL)
283
284 ```go
285 // Get the first matched record
286 db.Where("name = ?", "jinzhu").First(&user)
287 //// SELECT * FROM users WHERE name = 'jinzhu' limit 1;
288
289 // Get all matched records
290 db.Where("name = ?", "jinzhu").Find(&users)
291 //// SELECT * FROM users WHERE name = 'jinzhu';
292
293 db.Where("name <> ?", "jinzhu").Find(&users)
294
295 // IN
296 db.Where("name in (?)", []string{"jinzhu", "jinzhu 2"}).Find(&users)
297
298 // LIKE
299 db.Where("name LIKE ?", "%jin%").Find(&users)
300
301 // AND
302 db.Where("name = ? and age >= ?", "jinzhu", "22").Find(&users)
303
304 // Time
305 db.Where("updated_at > ?", lastWeek).Find(&users)
306
307 db.Where("created_at BETWEEN ? AND ?", lastWeek, today).Find(&users)
308 ```
309
310 ### Query With Where (Struct & Map)
311
312 ```go
313 // Struct
314 db.Where(&User{Name: "jinzhu", Age: 20}).First(&user)
315 //// SELECT * FROM users WHERE name = "jinzhu" AND age = 20 LIMIT 1;
316
317 // Map
318 db.Where(map[string]interface{}{"name": "jinzhu", "age": 20}).Find(&users)
319 //// SELECT * FROM users WHERE name = "jinzhu" AND age = 20;
320
321 // Slice of primary keys
322 db.Where([]int64{20, 21, 22}).Find(&users)
323 //// SELECT * FROM users WHERE id IN (20, 21, 22);
324 ```
325
326 ### Query With Not
327
328 ```go
329 db.Not("name", "jinzhu").First(&user)
330 //// SELECT * FROM users WHERE name <> "jinzhu" LIMIT 1;
331
332 // Not In
333 db.Not("name", []string{"jinzhu", "jinzhu 2"}).Find(&users)
334 //// SELECT * FROM users WHERE name NOT IN ("jinzhu", "jinzhu 2");
335
336 // Not In slice of primary keys
337 db.Not([]int64{1,2,3}).First(&user)
338 //// SELECT * FROM users WHERE id NOT IN (1,2,3);
339
340 db.Not([]int64{}).First(&user)
341 //// SELECT * FROM users;
342
343 // Plain SQL
344 db.Not("name = ?", "jinzhu").First(&user)
345 //// SELECT * FROM users WHERE NOT(name = "jinzhu");
346
347 // Struct
348 db.Not(User{Name: "jinzhu"}).First(&user)
349 //// SELECT * FROM users WHERE name <> "jinzhu";
350 ```
351
352 ### Query With Inline Condition
353
354 ```go
355 // Get by primary key
356 db.First(&user, 23)
357 //// SELECT * FROM users WHERE id = 23 LIMIT 1;
358
359 // Plain SQL
360 db.Find(&user, "name = ?", "jinzhu")
361 //// SELECT * FROM users WHERE name = "jinzhu";
362
363 db.Find(&users, "name <> ? AND age > ?", "jinzhu", 20)
364 //// SELECT * FROM users WHERE name <> "jinzhu" AND age > 20;
365
366 // Struct
367 db.Find(&users, User{Age: 20})
368 //// SELECT * FROM users WHERE age = 20;
369
370 // Map
371 db.Find(&users, map[string]interface{}{"age": 20})
372 //// SELECT * FROM users WHERE age = 20;
373 ```
374
375 ### Query With Or
376
377 ```go
378 db.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&users)
379 //// SELECT * FROM users WHERE role = 'admin' OR role = 'super_admin';
380
381 // Struct
382 db.Where("name = 'jinzhu'").Or(User{Name: "jinzhu 2"}).Find(&users)
383 //// SELECT * FROM users WHERE name = 'jinzhu' OR name = 'jinzhu 2';
384
385 // Map
386 db.Where("name = 'jinzhu'").Or(map[string]interface{}{"name": "jinzhu 2"}).Find(&users)
387 ```
388
389 ### Query Chains
390
391 Gorm has a chainable API, you could use it like this
392
393 ```go
394 db.Where("name <> ?","jinzhu").Where("age >= ? and role <> ?",20,"admin").Find(&users)
395 //// SELECT * FROM users WHERE name <> 'jinzhu' AND age >= 20 AND role <> 'admin';
396
397 db.Where("role = ?", "admin").Or("role = ?", "super_admin").Not("name = ?", "jinzhu").Find(&users)
398 ```
399
400 ### Preloading (Eager loading)
401
402 ```go
403 db.Preload("Orders").Find(&users)
404 //// SELECT * FROM users;
405 //// SELECT * FROM orders WHERE user_id IN (1,2,3,4);
406
407 db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
408 //// SELECT * FROM users;
409 //// SELECT * FROM orders WHERE user_id IN (1,2,3,4) AND state NOT IN ('cancelled');
410
411 db.Where("state = ?", "active").Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
412 //// SELECT * FROM users WHERE state = 'active';
413 //// SELECT * FROM orders WHERE user_id IN (1,2) AND state NOT IN ('cancelled');
414
415 db.Preload("Orders").Preload("Profile").Preload("Role").Find(&users)
416 //// SELECT * FROM users;
417 //// SELECT * FROM orders WHERE user_id IN (1,2,3,4); // has many
418 //// SELECT * FROM profiles WHERE user_id IN (1,2,3,4); // has one
419 //// SELECT * FROM roles WHERE id IN (4,5,6); // belongs to
420 ```
421
422 #### Nested Preloading
423
424 ```go
425 db.Preload("Orders.OrderItems").Find(&users)
426 db.Preload("Orders", "state = ?", "paid").Preload("Orders.OrderItems").Find(&users)
427 ```
428
429 ## Update
430
431 ```go
432 // Update an existing struct
433 db.First(&user)
434 user.Name = "jinzhu 2"
435 user.Age = 100
436 db.Save(&user)
437 //// UPDATE users SET name='jinzhu 2', age=100, updated_at = '2013-11-17 21:34:10' WHERE id=111;
438
439 db.Where("active = ?", true).Save(&user)
440 //// UPDATE users SET name='jinzhu 2', age=100, updated_at = '2013-11-17 21:34:10' WHERE id=111 AND active = true;
441
442 // Update an attribute if it is changed
443 db.Model(&user).Update("name", "hello")
444 //// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111;
445
446 db.Model(&user).Where("active = ?", true).Update("name", "hello")
447 //// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111 AND active = true;
448
449 db.First(&user, 111).Update("name", "hello")
450 //// SELECT * FROM users LIMIT 1;
451 //// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111;
452
453 // Update multiple attributes if they are changed
454 db.Model(&user).Updates(map[string]interface{}{"name": "hello", "age": 18, "actived": false})
455
456 // Update multiple attributes if they are changed (update with struct only works with none zero values)
457 db.Model(&user).Updates(User{Name: "hello", Age: 18})
458 //// UPDATE users SET name='hello', age=18, updated_at = '2013-11-17 21:34:10' WHERE id = 111;
459 ```
460
461 ### Update Without Callbacks
462
463 By default, update will call BeforeUpdate, AfterUpdate callbacks, if you want to update w/o callbacks and w/o saving associations:
464
465 ```go
466 db.Model(&user).UpdateColumn("name", "hello")
467 //// UPDATE users SET name='hello' WHERE id = 111;
468
469 // Update with struct only works with none zero values, or use map[string]interface{}
470 db.Model(&user).UpdateColumns(User{Name: "hello", Age: 18})
471 //// UPDATE users SET name='hello', age=18 WHERE id = 111;
472 ```
473
474 ### Batch Updates
475
476 ```go
477 db.Table("users").Where("id = ?", 10).Updates(map[string]interface{}{"name": "hello", "age": 18})
478 //// UPDATE users SET name='hello', age=18 WHERE id = 10;
479
480 // Update with struct only works with none zero values, or use map[string]interface{}
481 db.Model(User{}).Updates(User{Name: "hello", Age: 18})
482 //// UPDATE users SET name='hello', age=18;
483
484 // Callbacks won't run when do batch updates
485
486 // Use `RowsAffected` to get the count of affected records
487 db.Model(User{}).Updates(User{Name: "hello", Age: 18}).RowsAffected
488 ```
489
490 ### Update with SQL Expression
491
492 ```go
493 DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100))
494 //// UPDATE "products" SET "code" = 'L1212', "price" = price * '2' + '100', "updated_at" = '2013-11-17 21:34:10' WHERE "id" = '2';
495
496 DB.Model(&product).Updates(map[string]interface{}{"price": gorm.Expr("price * ? + ?", 2, 100)})
497 //// UPDATE "products" SET "code" = 'L1212', "price" = price * '2' + '100', "updated_at" = '2013-11-17 21:34:10' WHERE "id" = '2';
498
499 DB.Model(&product).UpdateColumn("quantity", gorm.Expr("quantity - ?", 1))
500 //// UPDATE "products" SET "quantity" = quantity - 1 WHERE "id" = '2';
501
502 DB.Model(&product).Where("quantity > 1").UpdateColumn("quantity", gorm.Expr("quantity - ?", 1))
503 //// UPDATE "products" SET "quantity" = quantity - 1 WHERE "id" = '2' AND quantity > 1;
504 ```
505
506 ## Delete
507
508 ```go
509 // Delete an existing record
510 db.Delete(&email)
511 //// DELETE from emails where id=10;
512 ```
513
514 ### Batch Delete
515
516 ```go
517 db.Where("email LIKE ?", "%jinzhu%").Delete(Email{})
518 //// DELETE from emails where email LIKE "%jinhu%";
519 ```
520
521 ### Soft Delete
522
523 If struct has `DeletedAt` field, it will get soft delete ability automatically!
524 Then it won't be deleted from database permanently when call `Delete`.
525
526 ```go
527 db.Delete(&user)
528 //// UPDATE users SET deleted_at="2013-10-29 10:23" WHERE id = 111;
529
530 // Batch Delete
531 db.Where("age = ?", 20).Delete(&User{})
532 //// UPDATE users SET deleted_at="2013-10-29 10:23" WHERE age = 20;
533
534 // Soft deleted records will be ignored when query them
535 db.Where("age = 20").Find(&user)
536 //// SELECT * FROM users WHERE age = 20 AND (deleted_at IS NULL OR deleted_at <= '0001-01-02');
537
538 // Find soft deleted records with Unscoped
539 db.Unscoped().Where("age = 20").Find(&users)
540 //// SELECT * FROM users WHERE age = 20;
541
542 // Delete record permanently with Unscoped
543 db.Unscoped().Delete(&order)
544 //// DELETE FROM orders WHERE id=10;
545 ```
546
547 ## Associations
548
549 ### Has One
550
551 ```go
552 // User has one address
553 db.Model(&user).Related(&address)
554 //// SELECT * FROM addresses WHERE id = 123; // 123 is user's foreign key AddressId
555
556 // Specify the foreign key
557 db.Model(&user).Related(&address1, "BillingAddressId")
558 //// SELECT * FROM addresses WHERE id = 123; // 123 is user's foreign key BillingAddressId
559 ```
560
561 ### Belongs To
562
563 ```go
564 // Email belongs to user
565 db.Model(&email).Related(&user)
566 //// SELECT * FROM users WHERE id = 111; // 111 is email's foreign key UserId
567
568 // Specify the foreign key
569 db.Model(&email).Related(&user, "ProfileId")
570 //// SELECT * FROM users WHERE id = 111; // 111 is email's foreign key ProfileId
571 ```
572
573 ### Has Many
574
575 ```go
576 // User has many emails
577 db.Model(&user).Related(&emails)
578 //// SELECT * FROM emails WHERE user_id = 111;
579 // user_id is the foreign key, 111 is user's primary key's value
580
581 // Specify the foreign key
582 db.Model(&user).Related(&emails, "ProfileId")
583 //// SELECT * FROM emails WHERE profile_id = 111;
584 // profile_id is the foreign key, 111 is user's primary key's value
585 ```
586
587 ### Many To Many
588
589 ```go
590 // User has many languages and belongs to many languages
591 db.Model(&user).Related(&languages, "Languages")
592 //// SELECT * FROM "languages" INNER JOIN "user_languages" ON "user_languages"."language_id" = "languages"."id" WHERE "user_languages"."user_id" = 111
593 // `Languages` is user's column name, this column's tag defined join table like this `gorm:"many2many:user_languages;"`
594 ```
595
596 There is also a mode used to handle many to many relations easily
597
598 ```go
599 // Query
600 db.Model(&user).Association("Languages").Find(&languages)
601 // same as `db.Model(&user).Related(&languages, "Languages")`
602
603 db.Where("name = ?", "ZH").First(&languageZH)
604 db.Where("name = ?", "EN").First(&languageEN)
605
606 // Append
607 db.Model(&user).Association("Languages").Append([]Language{languageZH, languageEN})
608 db.Model(&user).Association("Languages").Append([]Language{{Name: "DE"}})
609 db.Model(&user).Association("Languages").Append(Language{Name: "DE"})
610
611 // Delete
612 db.Model(&user).Association("Languages").Delete([]Language{languageZH, languageEN})
613 db.Model(&user).Association("Languages").Delete(languageZH, languageEN)
614
615 // Replace
616 db.Model(&user).Association("Languages").Replace([]Language{languageZH, languageEN})
617 db.Model(&user).Association("Languages").Replace(Language{Name: "DE"}, languageEN)
618
619 // Count
620 db.Model(&user).Association("Languages").Count()
621 // Return the count of languages the user has
622
623 // Clear
624 db.Model(&user).Association("Languages").Clear()
625 // Remove all relations between the user and languages
626 ```
627
628 ### Polymorphism
629
630 Supports polymorphic has-many and has-one associations.
631
632 ```go
633 type Cat struct {
634 Id int
635 Name string
636 Toy Toy `gorm:"polymorphic:Owner;"`
637 }
638
639 type Dog struct {
640 Id int
641 Name string
642 Toy Toy `gorm:"polymorphic:Owner;"`
643 }
644
645 type Toy struct {
646 Id int
647 Name string
648 OwnerId int
649 OwnerType string
650 }
651 ```
652 Note: polymorphic belongs-to and many-to-many are explicitly NOT supported, and will throw errors.
653
654 ## Advanced Usage
655
656 ## FirstOrInit
657
658 Get the first matched record, or initialize a record with search conditions.
659
660 ```go
661 // Unfound
662 db.FirstOrInit(&user, User{Name: "non_existing"})
663 //// user -> User{Name: "non_existing"}
664
665 // Found
666 db.Where(User{Name: "Jinzhu"}).FirstOrInit(&user)
667 //// user -> User{Id: 111, Name: "Jinzhu", Age: 20}
668 db.FirstOrInit(&user, map[string]interface{}{"name": "jinzhu"})
669 //// user -> User{Id: 111, Name: "Jinzhu", Age: 20}
670 ```
671
672 ### Attrs
673
674 Ignore some values when searching, but use them to initialize the struct if record is not found.
675
676 ```go
677 // Unfound
678 db.Where(User{Name: "non_existing"}).Attrs(User{Age: 20}).FirstOrInit(&user)
679 //// SELECT * FROM USERS WHERE name = 'non_existing';
680 //// user -> User{Name: "non_existing", Age: 20}
681
682 db.Where(User{Name: "noexisting_user"}).Attrs("age", 20).FirstOrInit(&user)
683 //// SELECT * FROM USERS WHERE name = 'non_existing';
684 //// user -> User{Name: "non_existing", Age: 20}
685
686 // Found
687 db.Where(User{Name: "Jinzhu"}).Attrs(User{Age: 30}).FirstOrInit(&user)
688 //// SELECT * FROM USERS WHERE name = jinzhu';
689 //// user -> User{Id: 111, Name: "Jinzhu", Age: 20}
690 ```
691
692 ### Assign
693
694 Ignore some values when searching, but assign it to the result regardless it is found or not.
695
696 ```go
697 // Unfound
698 db.Where(User{Name: "non_existing"}).Assign(User{Age: 20}).FirstOrInit(&user)
699 //// user -> User{Name: "non_existing", Age: 20}
700
701 // Found
702 db.Where(User{Name: "Jinzhu"}).Assign(User{Age: 30}).FirstOrInit(&user)
703 //// SELECT * FROM USERS WHERE name = jinzhu';
704 //// user -> User{Id: 111, Name: "Jinzhu", Age: 30}
705 ```
706
707 ## FirstOrCreate
708
709 Get the first matched record, or create with search conditions.
710
711 ```go
712 // Unfound
713 db.FirstOrCreate(&user, User{Name: "non_existing"})
714 //// INSERT INTO "users" (name) VALUES ("non_existing");
715 //// user -> User{Id: 112, Name: "non_existing"}
716
717 // Found
718 db.Where(User{Name: "Jinzhu"}).FirstOrCreate(&user)
719 //// user -> User{Id: 111, Name: "Jinzhu"}
720 ```
721
722 ### Attrs
723
724 Ignore some values when searching, but use them to create the struct if record is not found. like `FirstOrInit`
725
726 ```go
727 // Unfound
728 db.Where(User{Name: "non_existing"}).Attrs(User{Age: 20}).FirstOrCreate(&user)
729 //// SELECT * FROM users WHERE name = 'non_existing';
730 //// INSERT INTO "users" (name, age) VALUES ("non_existing", 20);
731 //// user -> User{Id: 112, Name: "non_existing", Age: 20}
732
733 // Found
734 db.Where(User{Name: "jinzhu"}).Attrs(User{Age: 30}).FirstOrCreate(&user)
735 //// SELECT * FROM users WHERE name = 'jinzhu';
736 //// user -> User{Id: 111, Name: "jinzhu", Age: 20}
737 ```
738
739 ### Assign
740
741 Ignore some values when searching, but assign it to the record regardless it is found or not, then save back to database. like `FirstOrInit`
742
743 ```go
744 // Unfound
745 db.Where(User{Name: "non_existing"}).Assign(User{Age: 20}).FirstOrCreate(&user)
746 //// SELECT * FROM users WHERE name = 'non_existing';
747 //// INSERT INTO "users" (name, age) VALUES ("non_existing", 20);
748 //// user -> User{Id: 112, Name: "non_existing", Age: 20}
749
750 // Found
751 db.Where(User{Name: "jinzhu"}).Assign(User{Age: 30}).FirstOrCreate(&user)
752 //// SELECT * FROM users WHERE name = 'jinzhu';
753 //// UPDATE users SET age=30 WHERE id = 111;
754 //// user -> User{Id: 111, Name: "jinzhu", Age: 30}
755 ```
756
757 ## Select
758
759 ```go
760 db.Select("name, age").Find(&users)
761 //// SELECT name, age FROM users;
762
763 db.Select([]string{"name", "age"}).Find(&users)
764 //// SELECT name, age FROM users;
765
766 db.Table("users").Select("COALESCE(age,?)", 42).Rows()
767 //// SELECT COALESCE(age,'42') FROM users;
768 ```
769
770 ## Order
771
772 ```go
773 db.Order("age desc, name").Find(&users)
774 //// SELECT * FROM users ORDER BY age desc, name;
775
776 // Multiple orders
777 db.Order("age desc").Order("name").Find(&users)
778 //// SELECT * FROM users ORDER BY age desc, name;
779
780 // ReOrder
781 db.Order("age desc").Find(&users1).Order("age", true).Find(&users2)
782 //// SELECT * FROM users ORDER BY age desc; (users1)
783 //// SELECT * FROM users ORDER BY age; (users2)
784 ```
785
786 ## Limit
787
788 ```go
789 db.Limit(3).Find(&users)
790 //// SELECT * FROM users LIMIT 3;
791
792 // Cancel limit condition with -1
793 db.Limit(10).Find(&users1).Limit(-1).Find(&users2)
794 //// SELECT * FROM users LIMIT 10; (users1)
795 //// SELECT * FROM users; (users2)
796 ```
797
798 ## Offset
799
800 ```go
801 db.Offset(3).Find(&users)
802 //// SELECT * FROM users OFFSET 3;
803
804 // Cancel offset condition with -1
805 db.Offset(10).Find(&users1).Offset(-1).Find(&users2)
806 //// SELECT * FROM users OFFSET 10; (users1)
807 //// SELECT * FROM users; (users2)
808 ```
809
810 ## Count
811
812 ```go
813 db.Where("name = ?", "jinzhu").Or("name = ?", "jinzhu 2").Find(&users).Count(&count)
814 //// SELECT * from USERS WHERE name = 'jinzhu' OR name = 'jinzhu 2'; (users)
815 //// SELECT count(*) FROM users WHERE name = 'jinzhu' OR name = 'jinzhu 2'; (count)
816
817 db.Model(User{}).Where("name = ?", "jinzhu").Count(&count)
818 //// SELECT count(*) FROM users WHERE name = 'jinzhu'; (count)
819
820 db.Table("deleted_users").Count(&count)
821 //// SELECT count(*) FROM deleted_users;
822 ```
823
824 ## Pluck
825
826 Get selected attributes as map
827
828 ```go
829 var ages []int64
830 db.Find(&users).Pluck("age", &ages)
831
832 var names []string
833 db.Model(&User{}).Pluck("name", &names)
834
835 db.Table("deleted_users").Pluck("name", &names)
836
837 // Requesting more than one column? Do it like this:
838 db.Select("name, age").Find(&users)
839 ```
840
841 ## Raw SQL
842
843 ```go
844 db.Exec("DROP TABLE users;")
845 db.Exec("UPDATE orders SET shipped_at=? WHERE id IN (?)", time.Now, []int64{11,22,33})
846 ```
847
848 ## Row & Rows
849
850 It is even possible to get query result as `*sql.Row` or `*sql.Rows`
851
852 ```go
853 row := db.Table("users").Where("name = ?", "jinzhu").Select("name, age").Row() // (*sql.Row)
854 row.Scan(&name, &age)
855
856 rows, err := db.Model(User{}).Where("name = ?", "jinzhu").Select("name, age, email").Rows() // (*sql.Rows, error)
857 defer rows.Close()
858 for rows.Next() {
859 ...
860 rows.Scan(&name, &age, &email)
861 ...
862 }
863
864 // Raw SQL
865 rows, err := db.Raw("select name, age, email from users where name = ?", "jinzhu").Rows() // (*sql.Rows, error)
866 defer rows.Close()
867 for rows.Next() {
868 ...
869 rows.Scan(&name, &age, &email)
870 ...
871 }
872 ```
873
874 ## Scan
875
876 Scan results into another struct.
877
878 ```go
879 type Result struct {
880 Name string
881 Age int
882 }
883
884 var result Result
885 db.Table("users").Select("name, age").Where("name = ?", 3).Scan(&result)
886
887 // Raw SQL
888 db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result)
889 ```
890
891 ## Group & Having
892
893 ```go
894 rows, err := db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Rows()
895 for rows.Next() {
896 ...
897 }
898
899 rows, err := db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Having("sum(amount) > ?", 100).Rows()
900 for rows.Next() {
901 ...
902 }
903
904 type Result struct {
905 Date time.Time
906 Total int64
907 }
908 db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Having("sum(amount) > ?", 100).Scan(&results)
909 ```
910
911 ## Joins
912
913 ```go
914 rows, err := db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Rows()
915 for rows.Next() {
916 ...
917 }
918
919 db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Scan(&results)
920
921 // find a user by email address
922 db.Joins("inner join emails on emails.user_id = users.id").Where("emails.email = ?", "x@example.org").Find(&user)
923
924 // find all email addresses for a user
925 db.Joins("left join users on users.id = emails.user_id").Where("users.name = ?", "jinzhu").Find(&emails)
926 ```
927
928 ## Transactions
929
930 To perform a set of operations within a transaction, the general flow is as below.
931 The database handle returned from ``` db.Begin() ``` should be used for all operations within the transaction.
932 (Note that all individual save and delete operations are run in a transaction by default.)
933
934 ```go
935 // begin
936 tx := db.Begin()
937
938 // do some database operations (use 'tx' from this point, not 'db')
939 tx.Create(...)
940 ...
941
942 // rollback in case of error
943 tx.Rollback()
944
945 // Or commit if all is ok
946 tx.Commit()
947 ```
948
949 ### A Specific Example
950 ```
951 func CreateAnimals(db *gorm.DB) err {
952 tx := db.Begin()
953 // Note the use of tx as the database handle once you are within a transaction
954
955 if err := tx.Create(&Animal{Name: "Giraffe"}).Error; err != nil {
956 tx.Rollback()
957 return err
958 }
959
960 if err := tx.Create(&Animal{Name: "Lion"}).Error; err != nil {
961 tx.Rollback()
962 return err
963 }
964
965 tx.Commit()
966 return nil
967 }
968 ```
969
970 ## Scopes
971
972 ```go
973 func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
974 return db.Where("amount > ?", 1000)
975 }
976
977 func PaidWithCreditCard(db *gorm.DB) *gorm.DB {
978 return db.Where("pay_mode_sign = ?", "C")
979 }
980
981 func PaidWithCod(db *gorm.DB) *gorm.DB {
982 return db.Where("pay_mode_sign = ?", "C")
983 }
984
985 func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
986 return func (db *gorm.DB) *gorm.DB {
987 return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
988 }
989 }
990
991 db.Scopes(AmountGreaterThan1000, PaidWithCreditCard).Find(&orders)
992 // Find all credit card orders and amount greater than 1000
993
994 db.Scopes(AmountGreaterThan1000, PaidWithCod).Find(&orders)
995 // Find all COD orders and amount greater than 1000
996
997 db.Scopes(OrderStatus([]string{"paid", "shipped"})).Find(&orders)
998 // Find all paid, shipped orders
999 ```
1000
1001 ## Callbacks
1002
1003 Callbacks are methods defined on the pointer of struct.
1004 If any callback returns an error, gorm will stop future operations and rollback all changes.
1005
1006 Here is the list of all available callbacks:
1007 (listed in the same order in which they will get called during the respective operations)
1008
1009 ### Creating An Object
1010
1011 ```go
1012 BeforeSave
1013 BeforeCreate
1014 // save before associations
1015 // save self
1016 // save after associations
1017 AfterCreate
1018 AfterSave
1019 ```
1020 ### Updating An Object
1021
1022 ```go
1023 BeforeSave
1024 BeforeUpdate
1025 // save before associations
1026 // save self
1027 // save after associations
1028 AfterUpdate
1029 AfterSave
1030 ```
1031
1032 ### Destroying An Object
1033
1034 ```go
1035 BeforeDelete
1036 // delete self
1037 AfterDelete
1038 ```
1039
1040 ### After Find
1041
1042 ```go
1043 // load data from database
1044 AfterFind
1045 ```
1046
1047 ### Example
1048
1049 ```go
1050 func (u *User) BeforeUpdate() (err error) {
1051 if u.readonly() {
1052 err = errors.New("read only user")
1053 }
1054 return
1055 }
1056
1057 // Rollback the insertion if user's id greater than 1000
1058 func (u *User) AfterCreate() (err error) {
1059 if (u.Id > 1000) {
1060 err = errors.New("user id is already greater than 1000")
1061 }
1062 return
1063 }
1064 ```
1065
1066 As you know, save/delete operations in gorm are running in a transaction,
1067 This is means if changes made in the transaction is not visiable unless it is commited,
1068 So if you want to use those changes in your callbacks, you need to run SQL in same transaction.
1069 Fortunately, gorm support pass transaction to callbacks as you needed, you could do it like this:
1070
1071 ```go
1072 func (u *User) AfterCreate(tx *gorm.DB) (err error) {
1073 tx.Model(u).Update("role", "admin")
1074 return
1075 }
1076 ```
1077
1078 ## Specifying The Table Name
1079
1080 ```go
1081 // Create `deleted_users` table with struct User's definition
1082 db.Table("deleted_users").CreateTable(&User{})
1083
1084 var deleted_users []User
1085 db.Table("deleted_users").Find(&deleted_users)
1086 //// SELECT * FROM deleted_users;
1087
1088 db.Table("deleted_users").Where("name = ?", "jinzhu").Delete()
1089 //// DELETE FROM deleted_users WHERE name = 'jinzhu';
1090 ```
1091
1092 ### Specifying The Table Name For A Struct Permanently with TableName
1093
1094 ```go
1095 type Cart struct {
1096 }
1097
1098 func (c Cart) TableName() string {
1099 return "shopping_cart"
1100 }
1101
1102 func (u User) TableName() string {
1103 if u.Role == "admin" {
1104 return "admin_users"
1105 } else {
1106 return "users"
1107 }
1108 }
1109 ```
1110
1111 ## Error Handling
1112
1113 ```go
1114 query := db.Where("name = ?", "jinzhu").First(&user)
1115 query := db.First(&user).Limit(10).Find(&users)
1116 // query.Error will return the last happened error
1117
1118 // So you could do error handing in your application like this:
1119 if err := db.Where("name = ?", "jinzhu").First(&user).Error; err != nil {
1120 // error handling...
1121 }
1122
1123 // RecordNotFound
1124 // If no record found when you query data, gorm will return RecordNotFound error, you could check it like this:
1125 db.Where("name = ?", "hello world").First(&User{}).Error == gorm.RecordNotFound
1126 // Or use the shortcut method
1127 db.Where("name = ?", "hello world").First(&user).RecordNotFound()
1128
1129 if db.Model(&user).Related(&credit_card).RecordNotFound() {
1130 // no credit card found error handling
1131 }
1132 ```
1133
1134 ## Logger
1135
1136 Gorm has built-in logger support
1137
1138 ```go
1139 // Enable Logger
1140 db.LogMode(true)
1141
1142 // Diable Logger
1143 db.LogMode(false)
1144
1145 // Debug a single operation
1146 db.Debug().Where("name = ?", "jinzhu").First(&User{})
1147 ```
1148
1149 ![logger](https://raw.github.com/jinzhu/gorm/master/images/logger.png)
1150
1151 ### Customize Logger
1152
1153 ```go
1154 // Refer gorm's default logger for how to: https://github.com/jinzhu/gorm/blob/master/logger.go#files
1155 db.SetLogger(gorm.Logger{revel.TRACE})
1156 db.SetLogger(log.New(os.Stdout, "\r\n", 0))
1157 ```
1158
1159 ## Existing Schema
1160
1161 If you have an existing database schema, and the primary key field is different from `id`, you can add a tag to the field structure to specify that this field is a primary key.
1162
1163 ```go
1164 type Animal struct {
1165 AnimalId int64 `gorm:"primary_key"`
1166 Birthday time.Time `sql:"DEFAULT:current_timestamp"`
1167 Name string `sql:"default:'galeone'"`
1168 Age int64
1169 }
1170 ```
1171
1172 If your column names differ from the struct fields, you can specify them like this:
1173
1174 ```go
1175 type Animal struct {
1176 AnimalId int64 `gorm:"column:beast_id;primary_key"`
1177 Birthday time.Time `gorm:"column:day_of_the_beast"`
1178 Age int64 `gorm:"column:age_of_the_beast"`
1179 }
1180 ```
1181
1182 ## Composite Primary Key
1183
1184 ```go
1185 type Product struct {
1186 ID string `gorm:"primary_key"`
1187 LanguageCode string `gorm:"primary_key"`
1188 }
1189 ```
1190
1191 ## Database Indexes & Foreign Key
1192
1193 ```go
1194 // Add foreign key
1195 // 1st param : foreignkey field
1196 // 2nd param : destination table(id)
1197 // 3rd param : ONDELETE
1198 // 4th param : ONUPDATE
1199 db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
1200
1201 // Add index
1202 db.Model(&User{}).AddIndex("idx_user_name", "name")
1203
1204 // Multiple column index
1205 db.Model(&User{}).AddIndex("idx_user_name_age", "name", "age")
1206
1207 // Add unique index
1208 db.Model(&User{}).AddUniqueIndex("idx_user_name", "name")
1209
1210 // Multiple column unique index
1211 db.Model(&User{}).AddUniqueIndex("idx_user_name_age", "name", "age")
1212
1213 // Remove index
1214 db.Model(&User{}).RemoveIndex("idx_user_name")
1215 ```
1216
1217 ## Default values
1218
1219 ```go
1220 type Animal struct {
1221 ID int64
1222 Name string `sql:"default:'galeone'"`
1223 Age int64
1224 }
1225 ```
1226
1227 If you have defined a default value in the `sql` tag, the generated create SQl will ignore these fields if it is blank.
1228
1229 Eg.
1230
1231 ```go
1232 db.Create(&Animal{Age: 99, Name: ""})
1233 ```
1234
1235 The generated SQL will be:
1236
1237 ```sql
1238 INSERT INTO animals("age") values('99');
1239 ```
1240
1241 The same thing occurs in update statements.
1242
1243 ## More examples with query chain
1244
1245 ```go
1246 db.First(&first_article).Count(&total_count).Limit(10).Find(&first_page_articles).Offset(10).Find(&second_page_articles)
1247 //// SELECT * FROM articles LIMIT 1; (first_article)
1248 //// SELECT count(*) FROM articles; (total_count)
1249 //// SELECT * FROM articles LIMIT 10; (first_page_articles)
1250 //// SELECT * FROM articles LIMIT 10 OFFSET 10; (second_page_articles)
1251
1252
1253 db.Where("created_at > ?", "2013-10-10").Find(&cancelled_orders, "state = ?", "cancelled").Find(&shipped_orders, "state = ?", "shipped")
1254 //// SELECT * FROM orders WHERE created_at > '2013/10/10' AND state = 'cancelled'; (cancelled_orders)
1255 //// SELECT * FROM orders WHERE created_at > '2013/10/10' AND state = 'shipped'; (shipped_orders)
1256
1257
1258 // Use variables to keep query chain
1259 todays_orders := db.Where("created_at > ?", "2013-10-29")
1260 cancelled_orders := todays_orders.Where("state = ?", "cancelled")
1261 shipped_orders := todays_orders.Where("state = ?", "shipped")
1262
1263
1264 // Search with shared conditions for different tables
1265 db.Where("product_name = ?", "fancy_product").Find(&orders).Find(&shopping_carts)
1266 //// SELECT * FROM orders WHERE product_name = 'fancy_product'; (orders)
1267 //// SELECT * FROM carts WHERE product_name = 'fancy_product'; (shopping_carts)
1268
1269
1270 // Search with shared conditions from different tables with specified table
1271 db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").Find(&users2)
1272 //// SELECT * FROM users WHERE mail_type = 'TEXT'; (users1)
1273 //// SELECT * FROM deleted_users WHERE mail_type = 'TEXT'; (users2)
1274
1275
1276 // FirstOrCreate example
1277 db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111"}).FirstOrCreate(&user)
1278 //// SELECT * FROM users WHERE email = 'x@example.org';
1279 //// INSERT INTO "users" (email,registered_ip) VALUES ("x@example.org", "111.111.111.111") // if record not found
1280 ```
1281
1282 ## TODO
1283 * Github Pages
1284
1285 # Author
1286
1287 **jinzhu**
1288
1289 * <http://github.com/jinzhu>
1290 * <wosmvp@gmail.com>
1291 * <http://twitter.com/zhangjinzhu>
1292
1293 ## License
1294
1295 Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License).
0 package gorm
1
2 import (
3 "errors"
4 "fmt"
5 "reflect"
6 "strings"
7 )
8
9 type Association struct {
10 Scope *Scope
11 Column string
12 Error error
13 Field *Field
14 }
15
16 func (association *Association) setErr(err error) *Association {
17 if err != nil {
18 association.Error = err
19 }
20 return association
21 }
22
23 func (association *Association) Find(value interface{}) *Association {
24 association.Scope.related(value, association.Column)
25 return association.setErr(association.Scope.db.Error)
26 }
27
28 func (association *Association) Append(values ...interface{}) *Association {
29 scope := association.Scope
30 field := association.Field
31
32 createJoinTable := func(reflectValue reflect.Value) {
33 var value = reflectValue.Interface()
34 if reflectValue.Kind() != reflect.Ptr {
35 reflectPtr := reflect.New(reflectValue.Type())
36 reflectPtr.Elem().Set(reflectValue)
37 value = reflectPtr.Interface()
38 }
39
40 if scope.New(value).PrimaryKeyZero() {
41 scope.NewDB().Save(value)
42 }
43
44 relationship := association.Field.Relationship
45 association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, value))
46
47 result := reflect.ValueOf(value)
48 fieldElemType := field.Field.Type().Elem()
49 if result.Type().AssignableTo(fieldElemType) {
50 field.Set(reflect.Append(field.Field, result))
51 } else if result.Type().Elem().AssignableTo(fieldElemType) {
52 field.Set(reflect.Append(field.Field, result.Elem()))
53 }
54 }
55
56 for _, value := range values {
57 reflectValue := reflect.Indirect(reflect.ValueOf(value))
58
59 if reflectValue.Kind() == reflect.Struct {
60 createJoinTable(reflectValue)
61 } else if reflectValue.Kind() == reflect.Slice {
62 for i := 0; i < reflectValue.Len(); i++ {
63 createJoinTable(reflectValue.Index(i))
64 }
65 } else {
66 association.setErr(errors.New("invalid association type"))
67 }
68 }
69 return association
70 }
71
72 func (association *Association) Delete(values ...interface{}) *Association {
73 scope := association.Scope
74 relationship := association.Field.Relationship
75
76 // many to many
77 if relationship.Kind == "many_to_many" {
78 query := scope.NewDB()
79 for idx, foreignKey := range relationship.ForeignDBNames {
80 if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
81 query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
82 }
83 }
84
85 primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...)
86 sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys))
87 query = query.Where(sql, toQueryValues(primaryKeys)...)
88
89 if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil {
90 leftValues := reflect.Zero(association.Field.Field.Type())
91 for i := 0; i < association.Field.Field.Len(); i++ {
92 reflectValue := association.Field.Field.Index(i)
93 primaryKey := association.getPrimaryKeys(relationship.ForeignFieldNames, reflectValue.Interface())[0]
94 var included = false
95 for _, pk := range primaryKeys {
96 if equalAsString(primaryKey, pk) {
97 included = true
98 }
99 }
100 if !included {
101 leftValues = reflect.Append(leftValues, reflectValue)
102 }
103 }
104 association.Field.Set(leftValues)
105 }
106 } else {
107 association.setErr(errors.New("delete only support many to many"))
108 }
109 return association
110 }
111
112 func (association *Association) Replace(values ...interface{}) *Association {
113 relationship := association.Field.Relationship
114 scope := association.Scope
115 if relationship.Kind == "many_to_many" {
116 field := association.Field.Field
117
118 oldPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface())
119 association.Field.Set(reflect.Zero(association.Field.Field.Type()))
120 association.Append(values...)
121 newPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface())
122
123 var addedPrimaryKeys = [][]interface{}{}
124 for _, newKey := range newPrimaryKeys {
125 hasEqual := false
126 for _, oldKey := range oldPrimaryKeys {
127 if equalAsString(newKey, oldKey) {
128 hasEqual = true
129 break
130 }
131 }
132 if !hasEqual {
133 addedPrimaryKeys = append(addedPrimaryKeys, newKey)
134 }
135 }
136
137 for _, primaryKey := range association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) {
138 addedPrimaryKeys = append(addedPrimaryKeys, primaryKey)
139 }
140
141 query := scope.NewDB()
142 for idx, foreignKey := range relationship.ForeignDBNames {
143 if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
144 query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
145 }
146 }
147
148 if len(addedPrimaryKeys) > 0 {
149 sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys))
150 query = query.Where(sql, toQueryValues(addedPrimaryKeys)...)
151 }
152 association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship))
153 } else {
154 association.setErr(errors.New("replace only support many to many"))
155 }
156 return association
157 }
158
159 func (association *Association) Clear() *Association {
160 relationship := association.Field.Relationship
161 scope := association.Scope
162 if relationship.Kind == "many_to_many" {
163 query := scope.NewDB()
164 for idx, foreignKey := range relationship.ForeignDBNames {
165 if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
166 query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
167 }
168 }
169
170 if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil {
171 association.Field.Set(reflect.Zero(association.Field.Field.Type()))
172 } else {
173 association.setErr(err)
174 }
175 } else {
176 association.setErr(errors.New("clear only support many to many"))
177 }
178 return association
179 }
180
181 func (association *Association) Count() int {
182 count := -1
183 relationship := association.Field.Relationship
184 scope := association.Scope
185 newScope := scope.New(association.Field.Field.Interface())
186
187 if relationship.Kind == "many_to_many" {
188 relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count)
189 } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
190 query := scope.DB()
191 for idx, foreignKey := range relationship.ForeignDBNames {
192 if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
193 query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)),
194 field.Field.Interface())
195 }
196 }
197
198 if relationship.PolymorphicType != "" {
199 query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName())
200 }
201 query.Table(newScope.TableName()).Count(&count)
202 } else if relationship.Kind == "belongs_to" {
203 query := scope.DB()
204 for idx, foreignKey := range relationship.ForeignDBNames {
205 if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
206 query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)),
207 field.Field.Interface())
208 }
209 }
210 query.Table(newScope.TableName()).Count(&count)
211 }
212
213 return count
214 }
215
216 func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) [][]interface{} {
217 results := [][]interface{}{}
218 scope := association.Scope
219
220 for _, value := range values {
221 reflectValue := reflect.Indirect(reflect.ValueOf(value))
222 if reflectValue.Kind() == reflect.Slice {
223 for i := 0; i < reflectValue.Len(); i++ {
224 primaryKeys := []interface{}{}
225 newScope := scope.New(reflectValue.Index(i).Interface())
226 for _, column := range columns {
227 if field, ok := newScope.FieldByName(column); ok {
228 primaryKeys = append(primaryKeys, field.Field.Interface())
229 } else {
230 primaryKeys = append(primaryKeys, "")
231 }
232 }
233 results = append(results, primaryKeys)
234 }
235 } else if reflectValue.Kind() == reflect.Struct {
236 newScope := scope.New(value)
237 var primaryKeys []interface{}
238 for _, column := range columns {
239 if field, ok := newScope.FieldByName(column); ok {
240 primaryKeys = append(primaryKeys, field.Field.Interface())
241 } else {
242 primaryKeys = append(primaryKeys, "")
243 }
244 }
245
246 results = append(results, primaryKeys)
247 }
248 }
249 return results
250 }
251
252 func toQueryMarks(primaryValues [][]interface{}) string {
253 var results []string
254
255 for _, primaryValue := range primaryValues {
256 var marks []string
257 for _, _ = range primaryValue {
258 marks = append(marks, "?")
259 }
260
261 if len(marks) > 1 {
262 results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
263 } else {
264 results = append(results, strings.Join(marks, ""))
265 }
266 }
267 return strings.Join(results, ",")
268 }
269
270 func toQueryCondition(scope *Scope, columns []string) string {
271 var newColumns []string
272 for _, column := range columns {
273 newColumns = append(newColumns, scope.Quote(column))
274 }
275
276 if len(columns) > 1 {
277 return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
278 } else {
279 return strings.Join(columns, ",")
280 }
281 }
282
283 func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
284 for _, primaryValue := range primaryValues {
285 for _, value := range primaryValue {
286 values = append(values, value)
287 }
288 }
289 return values
290 }
0 package gorm_test
1
2 import (
3 "fmt"
4 "testing"
5 )
6
7 func TestHasOneAndHasManyAssociation(t *testing.T) {
8 DB.DropTable(Category{}, Post{}, Comment{})
9 DB.CreateTable(Category{}, Post{}, Comment{})
10
11 post := Post{
12 Title: "post 1",
13 Body: "body 1",
14 Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}},
15 Category: Category{Name: "Category 1"},
16 MainCategory: Category{Name: "Main Category 1"},
17 }
18
19 if err := DB.Save(&post).Error; err != nil {
20 t.Errorf("Got errors when save post", err.Error())
21 }
22
23 if err := DB.First(&Category{}, "name = ?", "Category 1").Error; err != nil {
24 t.Errorf("Category should be saved", err.Error())
25 }
26
27 var p Post
28 DB.First(&p, post.Id)
29
30 if post.CategoryId.Int64 == 0 || p.CategoryId.Int64 == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 {
31 t.Errorf("Category Id should exist")
32 }
33
34 if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil {
35 t.Errorf("Comment 1 should be saved")
36 }
37 if post.Comments[0].PostId == 0 {
38 t.Errorf("Comment Should have post id")
39 }
40
41 var comment Comment
42 if DB.First(&comment, "content = ?", "Comment 2").Error != nil {
43 t.Errorf("Comment 2 should be saved")
44 }
45
46 if comment.PostId == 0 {
47 t.Errorf("Comment 2 Should have post id")
48 }
49
50 comment3 := Comment{Content: "Comment 3", Post: Post{Title: "Title 3", Body: "Body 3"}}
51 DB.Save(&comment3)
52 }
53
54 func TestRelated(t *testing.T) {
55 user := User{
56 Name: "jinzhu",
57 BillingAddress: Address{Address1: "Billing Address - Address 1"},
58 ShippingAddress: Address{Address1: "Shipping Address - Address 1"},
59 Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}},
60 CreditCard: CreditCard{Number: "1234567890"},
61 Company: Company{Name: "company1"},
62 }
63
64 DB.Save(&user)
65
66 if user.CreditCard.ID == 0 {
67 t.Errorf("After user save, credit card should have id")
68 }
69
70 if user.BillingAddress.ID == 0 {
71 t.Errorf("After user save, billing address should have id")
72 }
73
74 if user.Emails[0].Id == 0 {
75 t.Errorf("After user save, billing address should have id")
76 }
77
78 var emails []Email
79 DB.Model(&user).Related(&emails)
80 if len(emails) != 2 {
81 t.Errorf("Should have two emails")
82 }
83
84 var emails2 []Email
85 DB.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2)
86 if len(emails2) != 1 {
87 t.Errorf("Should have two emails")
88 }
89
90 var user1 User
91 DB.Model(&user).Related(&user1.Emails)
92 if len(user1.Emails) != 2 {
93 t.Errorf("Should have only one email match related condition")
94 }
95
96 var address1 Address
97 DB.Model(&user).Related(&address1, "BillingAddressId")
98 if address1.Address1 != "Billing Address - Address 1" {
99 t.Errorf("Should get billing address from user correctly")
100 }
101
102 user1 = User{}
103 DB.Model(&address1).Related(&user1, "BillingAddressId")
104 if DB.NewRecord(user1) {
105 t.Errorf("Should get user from address correctly")
106 }
107
108 var user2 User
109 DB.Model(&emails[0]).Related(&user2)
110 if user2.Id != user.Id || user2.Name != user.Name {
111 t.Errorf("Should get user from email correctly")
112 }
113
114 var creditcard CreditCard
115 var user3 User
116 DB.First(&creditcard, "number = ?", "1234567890")
117 DB.Model(&creditcard).Related(&user3)
118 if user3.Id != user.Id || user3.Name != user.Name {
119 t.Errorf("Should get user from credit card correctly")
120 }
121
122 if !DB.Model(&CreditCard{}).Related(&User{}).RecordNotFound() {
123 t.Errorf("RecordNotFound for Related")
124 }
125
126 var company Company
127 if DB.Model(&user).Related(&company, "Company").RecordNotFound() || company.Name != "company1" {
128 t.Errorf("RecordNotFound for Related")
129 }
130 }
131
132 func TestManyToMany(t *testing.T) {
133 DB.Raw("delete from languages")
134 var languages = []Language{{Name: "ZH"}, {Name: "EN"}}
135 user := User{Name: "Many2Many", Languages: languages}
136 DB.Save(&user)
137
138 // Query
139 var newLanguages []Language
140 DB.Model(&user).Related(&newLanguages, "Languages")
141 if len(newLanguages) != len([]string{"ZH", "EN"}) {
142 t.Errorf("Query many to many relations")
143 }
144
145 DB.Model(&user).Association("Languages").Find(&newLanguages)
146 if len(newLanguages) != len([]string{"ZH", "EN"}) {
147 t.Errorf("Should be able to find many to many relations")
148 }
149
150 if DB.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) {
151 t.Errorf("Count should return correct result")
152 }
153
154 // Append
155 DB.Model(&user).Association("Languages").Append(&Language{Name: "DE"})
156 if DB.Where("name = ?", "DE").First(&Language{}).RecordNotFound() {
157 t.Errorf("New record should be saved when append")
158 }
159
160 languageA := Language{Name: "AA"}
161 DB.Save(&languageA)
162 DB.Model(&User{Id: user.Id}).Association("Languages").Append(&languageA)
163
164 languageC := Language{Name: "CC"}
165 DB.Save(&languageC)
166 DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC})
167
168 DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}})
169
170 totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"}
171
172 if DB.Model(&user).Association("Languages").Count() != len(totalLanguages) {
173 t.Errorf("All appended languages should be saved")
174 }
175
176 // Delete
177 user.Languages = []Language{}
178 DB.Model(&user).Association("Languages").Find(&user.Languages)
179
180 var language Language
181 DB.Where("name = ?", "EE").First(&language)
182 DB.Model(&user).Association("Languages").Delete(language, &language)
183
184 if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 {
185 t.Errorf("Relations should be deleted with Delete")
186 }
187 if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() {
188 t.Errorf("Language EE should not be deleted")
189 }
190
191 DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages)
192
193 user2 := User{Name: "Many2Many_User2", Languages: languages}
194 DB.Save(&user2)
195
196 DB.Model(&user).Association("Languages").Delete(languages, &languages)
197 if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 || len(user.Languages) != len(totalLanguages)-3 {
198 t.Errorf("Relations should be deleted with Delete")
199 }
200
201 if DB.Model(&user2).Association("Languages").Count() == 0 {
202 t.Errorf("Other user's relations should not be deleted")
203 }
204
205 // Replace
206 var languageB Language
207 DB.Where("name = ?", "BB").First(&languageB)
208 DB.Model(&user).Association("Languages").Replace(languageB)
209 if len(user.Languages) != 1 || DB.Model(&user).Association("Languages").Count() != 1 {
210 t.Errorf("Relations should be replaced")
211 }
212
213 DB.Model(&user).Association("Languages").Replace()
214 if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 {
215 t.Errorf("Relations should be replaced with empty")
216 }
217
218 DB.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}})
219 if len(user.Languages) != 2 || DB.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) {
220 t.Errorf("Relations should be replaced")
221 }
222
223 // Clear
224 DB.Model(&user).Association("Languages").Clear()
225 if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 {
226 t.Errorf("Relations should be cleared")
227 }
228 }
229
230 func TestForeignKey(t *testing.T) {
231 for _, structField := range DB.NewScope(&User{}).GetStructFields() {
232 for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} {
233 if structField.Name == foreignKey && !structField.IsForeignKey {
234 t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
235 }
236 }
237 }
238
239 for _, structField := range DB.NewScope(&Email{}).GetStructFields() {
240 for _, foreignKey := range []string{"UserId"} {
241 if structField.Name == foreignKey && !structField.IsForeignKey {
242 t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
243 }
244 }
245 }
246
247 for _, structField := range DB.NewScope(&Post{}).GetStructFields() {
248 for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} {
249 if structField.Name == foreignKey && !structField.IsForeignKey {
250 t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
251 }
252 }
253 }
254
255 for _, structField := range DB.NewScope(&Comment{}).GetStructFields() {
256 for _, foreignKey := range []string{"PostId"} {
257 if structField.Name == foreignKey && !structField.IsForeignKey {
258 t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
259 }
260 }
261 }
262 }
0 package gorm
1
2 import (
3 "fmt"
4 )
5
6 type callback struct {
7 creates []*func(scope *Scope)
8 updates []*func(scope *Scope)
9 deletes []*func(scope *Scope)
10 queries []*func(scope *Scope)
11 rowQueries []*func(scope *Scope)
12 processors []*callbackProcessor
13 }
14
15 type callbackProcessor struct {
16 name string
17 before string
18 after string
19 replace bool
20 remove bool
21 typ string
22 processor *func(scope *Scope)
23 callback *callback
24 }
25
26 func (c *callback) addProcessor(typ string) *callbackProcessor {
27 cp := &callbackProcessor{typ: typ, callback: c}
28 c.processors = append(c.processors, cp)
29 return cp
30 }
31
32 func (c *callback) clone() *callback {
33 return &callback{
34 creates: c.creates,
35 updates: c.updates,
36 deletes: c.deletes,
37 queries: c.queries,
38 processors: c.processors,
39 }
40 }
41
42 func (c *callback) Create() *callbackProcessor {
43 return c.addProcessor("create")
44 }
45
46 func (c *callback) Update() *callbackProcessor {
47 return c.addProcessor("update")
48 }
49
50 func (c *callback) Delete() *callbackProcessor {
51 return c.addProcessor("delete")
52 }
53
54 func (c *callback) Query() *callbackProcessor {
55 return c.addProcessor("query")
56 }
57
58 func (c *callback) RowQuery() *callbackProcessor {
59 return c.addProcessor("row_query")
60 }
61
62 func (cp *callbackProcessor) Before(name string) *callbackProcessor {
63 cp.before = name
64 return cp
65 }
66
67 func (cp *callbackProcessor) After(name string) *callbackProcessor {
68 cp.after = name
69 return cp
70 }
71
72 func (cp *callbackProcessor) Register(name string, fc func(scope *Scope)) {
73 cp.name = name
74 cp.processor = &fc
75 cp.callback.sort()
76 }
77
78 func (cp *callbackProcessor) Remove(name string) {
79 fmt.Printf("[info] removing callback `%v` from %v\n", name, fileWithLineNum())
80 cp.name = name
81 cp.remove = true
82 cp.callback.sort()
83 }
84
85 func (cp *callbackProcessor) Replace(name string, fc func(scope *Scope)) {
86 fmt.Printf("[info] replacing callback `%v` from %v\n", name, fileWithLineNum())
87 cp.name = name
88 cp.processor = &fc
89 cp.replace = true
90 cp.callback.sort()
91 }
92
93 func getRIndex(strs []string, str string) int {
94 for i := len(strs) - 1; i >= 0; i-- {
95 if strs[i] == str {
96 return i
97 }
98 }
99 return -1
100 }
101
102 func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
103 var sortCallbackProcessor func(c *callbackProcessor)
104 var names, sortedNames = []string{}, []string{}
105
106 for _, cp := range cps {
107 if index := getRIndex(names, cp.name); index > -1 {
108 if !cp.replace && !cp.remove {
109 fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
110 }
111 }
112 names = append(names, cp.name)
113 }
114
115 sortCallbackProcessor = func(c *callbackProcessor) {
116 if getRIndex(sortedNames, c.name) > -1 {
117 return
118 }
119
120 if len(c.before) > 0 {
121 if index := getRIndex(sortedNames, c.before); index > -1 {
122 sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
123 } else if index := getRIndex(names, c.before); index > -1 {
124 sortedNames = append(sortedNames, c.name)
125 sortCallbackProcessor(cps[index])
126 } else {
127 sortedNames = append(sortedNames, c.name)
128 }
129 }
130
131 if len(c.after) > 0 {
132 if index := getRIndex(sortedNames, c.after); index > -1 {
133 sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
134 } else if index := getRIndex(names, c.after); index > -1 {
135 cp := cps[index]
136 if len(cp.before) == 0 {
137 cp.before = c.name
138 }
139 sortCallbackProcessor(cp)
140 } else {
141 sortedNames = append(sortedNames, c.name)
142 }
143 }
144
145 if getRIndex(sortedNames, c.name) == -1 {
146 sortedNames = append(sortedNames, c.name)
147 }
148 }
149
150 for _, cp := range cps {
151 sortCallbackProcessor(cp)
152 }
153
154 var funcs = []*func(scope *Scope){}
155 var sortedFuncs = []*func(scope *Scope){}
156 for _, name := range sortedNames {
157 index := getRIndex(names, name)
158 if !cps[index].remove {
159 sortedFuncs = append(sortedFuncs, cps[index].processor)
160 }
161 }
162
163 for _, cp := range cps {
164 if sindex := getRIndex(sortedNames, cp.name); sindex == -1 {
165 if !cp.remove {
166 funcs = append(funcs, cp.processor)
167 }
168 }
169 }
170
171 return append(sortedFuncs, funcs...)
172 }
173
174 func (c *callback) sort() {
175 var creates, updates, deletes, queries, rowQueries []*callbackProcessor
176
177 for _, processor := range c.processors {
178 switch processor.typ {
179 case "create":
180 creates = append(creates, processor)
181 case "update":
182 updates = append(updates, processor)
183 case "delete":
184 deletes = append(deletes, processor)
185 case "query":
186 queries = append(queries, processor)
187 case "row_query":
188 rowQueries = append(rowQueries, processor)
189 }
190 }
191
192 c.creates = sortProcessors(creates)
193 c.updates = sortProcessors(updates)
194 c.deletes = sortProcessors(deletes)
195 c.queries = sortProcessors(queries)
196 c.rowQueries = sortProcessors(rowQueries)
197 }
198
199 var DefaultCallback = &callback{processors: []*callbackProcessor{}}
0 package gorm
1
2 import (
3 "fmt"
4 "strings"
5 )
6
7 func BeforeCreate(scope *Scope) {
8 scope.CallMethodWithErrorCheck("BeforeSave")
9 scope.CallMethodWithErrorCheck("BeforeCreate")
10 }
11
12 func UpdateTimeStampWhenCreate(scope *Scope) {
13 if !scope.HasError() {
14 now := NowFunc()
15 scope.SetColumn("CreatedAt", now)
16 scope.SetColumn("UpdatedAt", now)
17 }
18 }
19
20 func Create(scope *Scope) {
21 defer scope.Trace(NowFunc())
22
23 if !scope.HasError() {
24 // set create sql
25 var sqls, columns []string
26 fields := scope.Fields()
27 for _, field := range fields {
28 if scope.changeableField(field) {
29 if field.IsNormal {
30 if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
31 if !field.IsBlank || !field.HasDefaultValue {
32 columns = append(columns, scope.Quote(field.DBName))
33 sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
34 } else if field.HasDefaultValue {
35 scope.InstanceSet("gorm:force_reload_after_create", true)
36 }
37 }
38 } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
39 for _, dbName := range relationship.ForeignDBNames {
40 if relationField := fields[dbName]; !scope.changeableField(relationField) {
41 columns = append(columns, scope.Quote(relationField.DBName))
42 sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
43 }
44 }
45 }
46 }
47 }
48
49 returningKey := "*"
50 primaryField := scope.PrimaryField()
51 if primaryField != nil {
52 returningKey = scope.Quote(primaryField.DBName)
53 }
54
55 if len(columns) == 0 {
56 scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v",
57 scope.QuotedTableName(),
58 scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
59 ))
60 } else {
61 scope.Raw(fmt.Sprintf(
62 "INSERT INTO %v (%v) VALUES (%v) %v",
63 scope.QuotedTableName(),
64 strings.Join(columns, ","),
65 strings.Join(sqls, ","),
66 scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
67 ))
68 }
69
70 // execute create sql
71 if scope.Dialect().SupportLastInsertId() {
72 if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
73 id, err := result.LastInsertId()
74 if scope.Err(err) == nil {
75 scope.db.RowsAffected, _ = result.RowsAffected()
76 if primaryField != nil && primaryField.IsBlank {
77 scope.Err(scope.SetColumn(primaryField, id))
78 }
79 }
80 }
81 } else {
82 if primaryField == nil {
83 if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
84 scope.db.RowsAffected, _ = results.RowsAffected()
85 } else {
86 scope.Err(err)
87 }
88 } else {
89 if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
90 scope.db.RowsAffected = 1
91 } else {
92 scope.Err(err)
93 }
94 }
95 }
96 }
97 }
98
99 func ForceReloadAfterCreate(scope *Scope) {
100 if _, ok := scope.InstanceGet("gorm:force_reload_after_create"); ok {
101 scope.DB().New().First(scope.Value)
102 }
103 }
104
105 func AfterCreate(scope *Scope) {
106 scope.CallMethodWithErrorCheck("AfterCreate")
107 scope.CallMethodWithErrorCheck("AfterSave")
108 }
109
110 func init() {
111 DefaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction)
112 DefaultCallback.Create().Register("gorm:before_create", BeforeCreate)
113 DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
114 DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate)
115 DefaultCallback.Create().Register("gorm:create", Create)
116 DefaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate)
117 DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
118 DefaultCallback.Create().Register("gorm:after_create", AfterCreate)
119 DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
120 }
0 package gorm
1
2 import "fmt"
3
4 func BeforeDelete(scope *Scope) {
5 scope.CallMethodWithErrorCheck("BeforeDelete")
6 }
7
8 func Delete(scope *Scope) {
9 if !scope.HasError() {
10 if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") {
11 scope.Raw(
12 fmt.Sprintf("UPDATE %v SET deleted_at=%v %v",
13 scope.QuotedTableName(),
14 scope.AddToVars(NowFunc()),
15 scope.CombinedConditionSql(),
16 ))
17 } else {
18 scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.QuotedTableName(), scope.CombinedConditionSql()))
19 }
20
21 scope.Exec()
22 }
23 }
24
25 func AfterDelete(scope *Scope) {
26 scope.CallMethodWithErrorCheck("AfterDelete")
27 }
28
29 func init() {
30 DefaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction)
31 DefaultCallback.Delete().Register("gorm:before_delete", BeforeDelete)
32 DefaultCallback.Delete().Register("gorm:delete", Delete)
33 DefaultCallback.Delete().Register("gorm:after_delete", AfterDelete)
34 DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
35 }
0 package gorm
1
2 import (
3 "errors"
4 "fmt"
5 "reflect"
6 )
7
8 func Query(scope *Scope) {
9 defer scope.Trace(NowFunc())
10
11 var (
12 isSlice bool
13 isPtr bool
14 anyRecordFound bool
15 destType reflect.Type
16 )
17
18 if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
19 if primaryKey := scope.PrimaryKey(); primaryKey != "" {
20 scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryKey), orderBy))
21 }
22 }
23
24 var dest = scope.IndirectValue()
25 if value, ok := scope.Get("gorm:query_destination"); ok {
26 dest = reflect.Indirect(reflect.ValueOf(value))
27 }
28
29 if kind := dest.Kind(); kind == reflect.Slice {
30 isSlice = true
31 destType = dest.Type().Elem()
32 dest.Set(reflect.MakeSlice(dest.Type(), 0, 0))
33
34 if destType.Kind() == reflect.Ptr {
35 isPtr = true
36 destType = destType.Elem()
37 }
38 } else if kind != reflect.Struct {
39 scope.Err(errors.New("unsupported destination, should be slice or struct"))
40 return
41 }
42
43 scope.prepareQuerySql()
44
45 if !scope.HasError() {
46 rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
47 scope.db.RowsAffected = 0
48
49 if scope.Err(err) != nil {
50 return
51 }
52 defer rows.Close()
53
54 columns, _ := rows.Columns()
55 for rows.Next() {
56 scope.db.RowsAffected++
57
58 anyRecordFound = true
59 elem := dest
60 if isSlice {
61 elem = reflect.New(destType).Elem()
62 }
63
64 var values = make([]interface{}, len(columns))
65
66 fields := scope.New(elem.Addr().Interface()).Fields()
67
68 for index, column := range columns {
69 if field, ok := fields[column]; ok {
70 if field.Field.Kind() == reflect.Ptr {
71 values[index] = field.Field.Addr().Interface()
72 } else {
73 values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
74 }
75 } else {
76 var value interface{}
77 values[index] = &value
78 }
79 }
80
81 scope.Err(rows.Scan(values...))
82
83 for index, column := range columns {
84 value := values[index]
85 if field, ok := fields[column]; ok {
86 if field.Field.Kind() == reflect.Ptr {
87 field.Field.Set(reflect.ValueOf(value).Elem())
88 } else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
89 field.Field.Set(v)
90 }
91 }
92 }
93
94 if isSlice {
95 if isPtr {
96 dest.Set(reflect.Append(dest, elem.Addr()))
97 } else {
98 dest.Set(reflect.Append(dest, elem))
99 }
100 }
101 }
102
103 if !anyRecordFound && !isSlice {
104 scope.Err(RecordNotFound)
105 }
106 }
107 }
108
109 func AfterQuery(scope *Scope) {
110 scope.CallMethodWithErrorCheck("AfterFind")
111 }
112
113 func init() {
114 DefaultCallback.Query().Register("gorm:query", Query)
115 DefaultCallback.Query().Register("gorm:after_query", AfterQuery)
116 DefaultCallback.Query().Register("gorm:preload", Preload)
117 }
0 package gorm
1
2 import "reflect"
3
4 func BeginTransaction(scope *Scope) {
5 scope.Begin()
6 }
7
8 func CommitOrRollbackTransaction(scope *Scope) {
9 scope.CommitOrRollback()
10 }
11
12 func SaveBeforeAssociations(scope *Scope) {
13 if !scope.shouldSaveAssociations() {
14 return
15 }
16 for _, field := range scope.Fields() {
17 if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
18 if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
19 value := field.Field
20 scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
21 if len(relationship.ForeignFieldNames) != 0 {
22 for idx, fieldName := range relationship.ForeignFieldNames {
23 associationForeignName := relationship.AssociationForeignDBNames[idx]
24 if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok {
25 scope.Err(scope.SetColumn(fieldName, f.Field.Interface()))
26 }
27 }
28 }
29 }
30 }
31 }
32 }
33
34 func SaveAfterAssociations(scope *Scope) {
35 if !scope.shouldSaveAssociations() {
36 return
37 }
38 for _, field := range scope.Fields() {
39 if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
40 if relationship := field.Relationship; relationship != nil &&
41 (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
42 value := field.Field
43
44 switch value.Kind() {
45 case reflect.Slice:
46 for i := 0; i < value.Len(); i++ {
47 newDB := scope.NewDB()
48 elem := value.Index(i).Addr().Interface()
49 newScope := newDB.NewScope(elem)
50
51 if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
52 for idx, fieldName := range relationship.ForeignFieldNames {
53 associationForeignName := relationship.AssociationForeignDBNames[idx]
54 if f, ok := scope.FieldByName(associationForeignName); ok {
55 scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
56 }
57 }
58 }
59
60 if relationship.PolymorphicType != "" {
61 scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName()))
62 }
63
64 scope.Err(newDB.Save(elem).Error)
65
66 if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
67 scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value))
68 }
69 }
70 default:
71 elem := value.Addr().Interface()
72 newScope := scope.New(elem)
73 if len(relationship.ForeignFieldNames) != 0 {
74 for idx, fieldName := range relationship.ForeignFieldNames {
75 associationForeignName := relationship.AssociationForeignDBNames[idx]
76 if f, ok := scope.FieldByName(associationForeignName); ok {
77 scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
78 }
79 }
80 }
81
82 if relationship.PolymorphicType != "" {
83 scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName()))
84 }
85 scope.Err(scope.NewDB().Save(elem).Error)
86 }
87 }
88 }
89 }
90 }
0 package gorm
1
2 import (
3 "reflect"
4 "runtime"
5 "strings"
6 "testing"
7 )
8
9 func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
10 var names []string
11 for _, f := range funcs {
12 fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
13 names = append(names, fnames[len(fnames)-1])
14 }
15 return reflect.DeepEqual(names, fnames)
16 }
17
18 func create(s *Scope) {}
19 func beforeCreate1(s *Scope) {}
20 func beforeCreate2(s *Scope) {}
21 func afterCreate1(s *Scope) {}
22 func afterCreate2(s *Scope) {}
23
24 func TestRegisterCallback(t *testing.T) {
25 var callback = &callback{processors: []*callbackProcessor{}}
26
27 callback.Create().Register("before_create1", beforeCreate1)
28 callback.Create().Register("before_create2", beforeCreate2)
29 callback.Create().Register("create", create)
30 callback.Create().Register("after_create1", afterCreate1)
31 callback.Create().Register("after_create2", afterCreate2)
32
33 if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
34 t.Errorf("register callback")
35 }
36 }
37
38 func TestRegisterCallbackWithOrder(t *testing.T) {
39 var callback1 = &callback{processors: []*callbackProcessor{}}
40 callback1.Create().Register("before_create1", beforeCreate1)
41 callback1.Create().Register("create", create)
42 callback1.Create().Register("after_create1", afterCreate1)
43 callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
44 if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
45 t.Errorf("register callback with order")
46 }
47
48 var callback2 = &callback{processors: []*callbackProcessor{}}
49
50 callback2.Update().Register("create", create)
51 callback2.Update().Before("create").Register("before_create1", beforeCreate1)
52 callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
53 callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
54 callback2.Update().Register("after_create2", afterCreate2)
55
56 if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
57 t.Errorf("register callback with order")
58 }
59 }
60
61 func TestRegisterCallbackWithComplexOrder(t *testing.T) {
62 var callback1 = &callback{processors: []*callbackProcessor{}}
63
64 callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
65 callback1.Query().Register("before_create1", beforeCreate1)
66 callback1.Query().Register("after_create1", afterCreate1)
67
68 if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
69 t.Errorf("register callback with order")
70 }
71
72 var callback2 = &callback{processors: []*callbackProcessor{}}
73
74 callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
75 callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
76 callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
77 callback2.Delete().Register("after_create1", afterCreate1)
78 callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
79
80 if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
81 t.Errorf("register callback with order")
82 }
83 }
84
85 func replaceCreate(s *Scope) {}
86
87 func TestReplaceCallback(t *testing.T) {
88 var callback = &callback{processors: []*callbackProcessor{}}
89
90 callback.Create().Before("after_create1").After("before_create1").Register("create", create)
91 callback.Create().Register("before_create1", beforeCreate1)
92 callback.Create().Register("after_create1", afterCreate1)
93 callback.Create().Replace("create", replaceCreate)
94
95 if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
96 t.Errorf("replace callback")
97 }
98 }
99
100 func TestRemoveCallback(t *testing.T) {
101 var callback = &callback{processors: []*callbackProcessor{}}
102
103 callback.Create().Before("after_create1").After("before_create1").Register("create", create)
104 callback.Create().Register("before_create1", beforeCreate1)
105 callback.Create().Register("after_create1", afterCreate1)
106 callback.Create().Remove("create")
107
108 if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
109 t.Errorf("remove callback")
110 }
111 }
0 package gorm
1
2 import (
3 "fmt"
4 "strings"
5 )
6
7 func AssignUpdateAttributes(scope *Scope) {
8 if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
9 if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
10 protected, ok := scope.Get("gorm:ignore_protected_attrs")
11 _, updateColumn := scope.Get("gorm:update_column")
12 updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool))
13
14 if updateColumn {
15 scope.InstanceSet("gorm:update_attrs", maps)
16 } else if len(updateAttrs) > 0 {
17 scope.InstanceSet("gorm:update_attrs", updateAttrs)
18 } else if !hasUpdate {
19 scope.SkipLeft()
20 return
21 }
22 }
23 }
24 }
25
26 func BeforeUpdate(scope *Scope) {
27 if _, ok := scope.Get("gorm:update_column"); !ok {
28 scope.CallMethodWithErrorCheck("BeforeSave")
29 scope.CallMethodWithErrorCheck("BeforeUpdate")
30 }
31 }
32
33 func UpdateTimeStampWhenUpdate(scope *Scope) {
34 if _, ok := scope.Get("gorm:update_column"); !ok {
35 scope.SetColumn("UpdatedAt", NowFunc())
36 }
37 }
38
39 func Update(scope *Scope) {
40 if !scope.HasError() {
41 var sqls []string
42
43 if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
44 for key, value := range updateAttrs.(map[string]interface{}) {
45 if scope.changeableDBColumn(key) {
46 sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
47 }
48 }
49 } else {
50 fields := scope.Fields()
51 for _, field := range fields {
52 if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal {
53 sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
54 } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
55 for _, dbName := range relationship.ForeignDBNames {
56 if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank {
57 sql := fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface()))
58 sqls = append(sqls, sql)
59 }
60 }
61 }
62 }
63 }
64
65 if len(sqls) > 0 {
66 scope.Raw(fmt.Sprintf(
67 "UPDATE %v SET %v %v",
68 scope.QuotedTableName(),
69 strings.Join(sqls, ", "),
70 scope.CombinedConditionSql(),
71 ))
72 scope.Exec()
73 }
74 }
75 }
76
77 func AfterUpdate(scope *Scope) {
78 if _, ok := scope.Get("gorm:update_column"); !ok {
79 scope.CallMethodWithErrorCheck("AfterUpdate")
80 scope.CallMethodWithErrorCheck("AfterSave")
81 }
82 }
83
84 func init() {
85 DefaultCallback.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes)
86 DefaultCallback.Update().Register("gorm:begin_transaction", BeginTransaction)
87 DefaultCallback.Update().Register("gorm:before_update", BeforeUpdate)
88 DefaultCallback.Update().Register("gorm:save_before_associations", SaveBeforeAssociations)
89 DefaultCallback.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate)
90 DefaultCallback.Update().Register("gorm:update", Update)
91 DefaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations)
92 DefaultCallback.Update().Register("gorm:after_update", AfterUpdate)
93 DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
94 }
0 package gorm_test
1
2 import (
3 "errors"
4
5 "github.com/jinzhu/gorm"
6
7 "reflect"
8 "testing"
9 )
10
11 func (s *Product) BeforeCreate() (err error) {
12 if s.Code == "Invalid" {
13 err = errors.New("invalid product")
14 }
15 s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
16 return
17 }
18
19 func (s *Product) BeforeUpdate() (err error) {
20 if s.Code == "dont_update" {
21 err = errors.New("can't update")
22 }
23 s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
24 return
25 }
26
27 func (s *Product) BeforeSave() (err error) {
28 if s.Code == "dont_save" {
29 err = errors.New("can't save")
30 }
31 s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
32 return
33 }
34
35 func (s *Product) AfterFind() {
36 s.AfterFindCallTimes = s.AfterFindCallTimes + 1
37 }
38
39 func (s *Product) AfterCreate(tx *gorm.DB) {
40 tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1})
41 }
42
43 func (s *Product) AfterUpdate() {
44 s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
45 }
46
47 func (s *Product) AfterSave() (err error) {
48 if s.Code == "after_save_error" {
49 err = errors.New("can't save")
50 }
51 s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
52 return
53 }
54
55 func (s *Product) BeforeDelete() (err error) {
56 if s.Code == "dont_delete" {
57 err = errors.New("can't delete")
58 }
59 s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
60 return
61 }
62
63 func (s *Product) AfterDelete() (err error) {
64 if s.Code == "after_delete_error" {
65 err = errors.New("can't delete")
66 }
67 s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
68 return
69 }
70
71 func (s *Product) GetCallTimes() []int64 {
72 return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
73 }
74
75 func TestRunCallbacks(t *testing.T) {
76 p := Product{Code: "unique_code", Price: 100}
77 DB.Save(&p)
78
79 if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
80 t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
81 }
82
83 DB.Where("Code = ?", "unique_code").First(&p)
84 if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
85 t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes())
86 }
87
88 p.Price = 200
89 DB.Save(&p)
90 if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
91 t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
92 }
93
94 var products []Product
95 DB.Find(&products, "code = ?", "unique_code")
96 if products[0].AfterFindCallTimes != 2 {
97 t.Errorf("AfterFind callbacks should work with slice")
98 }
99
100 DB.Where("Code = ?", "unique_code").First(&p)
101 if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
102 t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes())
103 }
104
105 DB.Delete(&p)
106 if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
107 t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
108 }
109
110 if DB.Where("Code = ?", "unique_code").First(&p).Error == nil {
111 t.Errorf("Can't find a deleted record")
112 }
113 }
114
115 func TestCallbacksWithErrors(t *testing.T) {
116 p := Product{Code: "Invalid", Price: 100}
117 if DB.Save(&p).Error == nil {
118 t.Errorf("An error from before create callbacks happened when create with invalid value")
119 }
120
121 if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
122 t.Errorf("Should not save record that have errors")
123 }
124
125 if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
126 t.Errorf("An error from after create callbacks happened when create with invalid value")
127 }
128
129 p2 := Product{Code: "update_callback", Price: 100}
130 DB.Save(&p2)
131
132 p2.Code = "dont_update"
133 if DB.Save(&p2).Error == nil {
134 t.Errorf("An error from before update callbacks happened when update with invalid value")
135 }
136
137 if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
138 t.Errorf("Record Should not be updated due to errors happened in before update callback")
139 }
140
141 if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
142 t.Errorf("Record Should not be updated due to errors happened in before update callback")
143 }
144
145 p2.Code = "dont_save"
146 if DB.Save(&p2).Error == nil {
147 t.Errorf("An error from before save callbacks happened when update with invalid value")
148 }
149
150 p3 := Product{Code: "dont_delete", Price: 100}
151 DB.Save(&p3)
152 if DB.Delete(&p3).Error == nil {
153 t.Errorf("An error from before delete callbacks happened when delete")
154 }
155
156 if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
157 t.Errorf("An error from before delete callbacks happened")
158 }
159
160 p4 := Product{Code: "after_save_error", Price: 100}
161 DB.Save(&p4)
162 if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
163 t.Errorf("Record should be reverted if get an error in after save callback")
164 }
165
166 p5 := Product{Code: "after_delete_error", Price: 100}
167 DB.Save(&p5)
168 if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
169 t.Errorf("Record should be found")
170 }
171
172 DB.Delete(&p5)
173 if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
174 t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
175 }
176 }
0 package gorm
1
2 import (
3 "fmt"
4 "reflect"
5 "time"
6 )
7
8 type commonDialect struct{}
9
10 func (commonDialect) BinVar(i int) string {
11 return "$$" // ?
12 }
13
14 func (commonDialect) SupportLastInsertId() bool {
15 return true
16 }
17
18 func (commonDialect) HasTop() bool {
19 return false
20 }
21
22 func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
23 switch value.Kind() {
24 case reflect.Bool:
25 return "BOOLEAN"
26 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
27 if autoIncrease {
28 return "INTEGER AUTO_INCREMENT"
29 }
30 return "INTEGER"
31 case reflect.Int64, reflect.Uint64:
32 if autoIncrease {
33 return "BIGINT AUTO_INCREMENT"
34 }
35 return "BIGINT"
36 case reflect.Float32, reflect.Float64:
37 return "FLOAT"
38 case reflect.String:
39 if size > 0 && size < 65532 {
40 return fmt.Sprintf("VARCHAR(%d)", size)
41 }
42 return "VARCHAR(65532)"
43 case reflect.Struct:
44 if _, ok := value.Interface().(time.Time); ok {
45 return "TIMESTAMP"
46 }
47 default:
48 if _, ok := value.Interface().([]byte); ok {
49 if size > 0 && size < 65532 {
50 return fmt.Sprintf("BINARY(%d)", size)
51 }
52 return "BINARY(65532)"
53 }
54 }
55 panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
56 }
57
58 func (commonDialect) ReturningStr(tableName, key string) string {
59 return ""
60 }
61
62 func (commonDialect) SelectFromDummyTable() string {
63 return ""
64 }
65
66 func (commonDialect) Quote(key string) string {
67 return fmt.Sprintf(`"%s"`, key)
68 }
69
70 func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
71 var (
72 count int
73 databaseName = c.CurrentDatabase(scope)
74 )
75 c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", databaseName, tableName)
76 return count > 0
77 }
78
79 func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
80 var (
81 count int
82 databaseName = c.CurrentDatabase(scope)
83 )
84 c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
85 return count > 0
86 }
87
88 func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
89 var (
90 count int
91 databaseName = c.CurrentDatabase(scope)
92 )
93 c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName)
94 return count > 0
95 }
96
97 func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
98 scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error)
99 }
100
101 // RawScanInt scans the first column of the first row into the `scan' int pointer.
102 // This function captures raw query errors and propagates them to the original scope.
103 func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) {
104 scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
105 }
106
107 // RawScanString scans the first column of the first row into the `scan' string pointer.
108 // This function captures raw query errors and propagates them to the original scope.
109 func (commonDialect) RawScanString(scope *Scope, scanPtr *string, query string, args ...interface{}) {
110 scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
111 }
112
113 func (commonDialect) CurrentDatabase(scope *Scope) (name string) {
114 scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
115 return
116 }
0 package gorm_test
1
2 import (
3 "reflect"
4 "testing"
5 "time"
6 )
7
8 func TestCreate(t *testing.T) {
9 float := 35.03554004971999
10 user := User{Name: "CreateUser", Age: 18, Birthday: time.Now(), UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float}
11
12 if !DB.NewRecord(user) || !DB.NewRecord(&user) {
13 t.Error("User should be new record before create")
14 }
15
16 if count := DB.Save(&user).RowsAffected; count != 1 {
17 t.Error("There should be one record be affected when create record")
18 }
19
20 if DB.NewRecord(user) || DB.NewRecord(&user) {
21 t.Error("User should not new record after save")
22 }
23
24 var newUser User
25 DB.First(&newUser, user.Id)
26
27 if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) {
28 t.Errorf("User's PasswordHash should be saved ([]byte)")
29 }
30
31 if newUser.Age != 18 {
32 t.Errorf("User's Age should be saved (int)")
33 }
34
35 if newUser.UserNum != Num(111) {
36 t.Errorf("User's UserNum should be saved (custom type)")
37 }
38
39 if newUser.Latitude != float {
40 t.Errorf("Float64 should not be changed after save")
41 }
42
43 if user.CreatedAt.IsZero() {
44 t.Errorf("Should have created_at after create")
45 }
46
47 if newUser.CreatedAt.IsZero() {
48 t.Errorf("Should have created_at after create")
49 }
50
51 DB.Model(user).Update("name", "create_user_new_name")
52 DB.First(&user, user.Id)
53 if user.CreatedAt != newUser.CreatedAt {
54 t.Errorf("CreatedAt should not be changed after update")
55 }
56 }
57
58 func TestCreateWithNoGORMPrimayKey(t *testing.T) {
59 jt := JoinTable{From: 1, To: 2}
60 err := DB.Create(&jt).Error
61 if err != nil {
62 t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err)
63 }
64 }
65
66 func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
67 animal := Animal{Name: "Ferdinand"}
68 if DB.Save(&animal).Error != nil {
69 t.Errorf("No error should happen when create a record without std primary key")
70 }
71
72 if animal.Counter == 0 {
73 t.Errorf("No std primary key should be filled value after create")
74 }
75
76 if animal.Name != "Ferdinand" {
77 t.Errorf("Default value should be overrided")
78 }
79
80 // Test create with default value not overrided
81 an := Animal{From: "nerdz"}
82
83 if DB.Save(&an).Error != nil {
84 t.Errorf("No error should happen when create an record without std primary key")
85 }
86
87 // We must fetch the value again, to have the default fields updated
88 // (We can't do this in the update statements, since sql default can be expressions
89 // And be different from the fields' type (eg. a time.Time fiels has a default value of "now()"
90 DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an)
91
92 if an.Name != "galeone" {
93 t.Errorf("Default value should fill the field. But got %v", an.Name)
94 }
95 }
96
97 func TestAnonymousScanner(t *testing.T) {
98 user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}}
99 DB.Save(&user)
100
101 var user2 User
102 DB.First(&user2, "name = ?", "anonymous_scanner")
103 if user2.Role.Name != "admin" {
104 t.Errorf("Should be able to get anonymous scanner")
105 }
106
107 if !user2.IsAdmin() {
108 t.Errorf("Should be able to get anonymous scanner")
109 }
110 }
111
112 func TestAnonymousField(t *testing.T) {
113 user := User{Name: "anonymous_field", Company: Company{Name: "company"}}
114 DB.Save(&user)
115
116 var user2 User
117 DB.First(&user2, "name = ?", "anonymous_field")
118 DB.Model(&user2).Related(&user2.Company)
119 if user2.Company.Name != "company" {
120 t.Errorf("Should be able to get anonymous field")
121 }
122 }
123
124 func TestSelectWithCreate(t *testing.T) {
125 user := getPreparedUser("select_user", "select_with_create")
126 DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user)
127
128 var queryuser User
129 DB.Preload("BillingAddress").Preload("ShippingAddress").
130 Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id)
131
132 if queryuser.Name != user.Name || queryuser.Age == user.Age {
133 t.Errorf("Should only create users with name column")
134 }
135
136 if queryuser.BillingAddressID.Int64 == 0 || queryuser.ShippingAddressId != 0 ||
137 queryuser.CreditCard.ID == 0 || len(queryuser.Emails) == 0 {
138 t.Errorf("Should only create selected relationships")
139 }
140 }
141
142 func TestOmitWithCreate(t *testing.T) {
143 user := getPreparedUser("omit_user", "omit_with_create")
144 DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user)
145
146 var queryuser User
147 DB.Preload("BillingAddress").Preload("ShippingAddress").
148 Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id)
149
150 if queryuser.Name == user.Name || queryuser.Age != user.Age {
151 t.Errorf("Should only create users with age column")
152 }
153
154 if queryuser.BillingAddressID.Int64 != 0 || queryuser.ShippingAddressId == 0 ||
155 queryuser.CreditCard.ID != 0 || len(queryuser.Emails) != 0 {
156 t.Errorf("Should not create omited relationships")
157 }
158 }
0 package gorm_test
1
2 import (
3 "testing"
4 "time"
5 )
6
7 type CustomizeColumn struct {
8 ID int64 `gorm:"column:mapped_id; primary_key:yes"`
9 Name string `gorm:"column:mapped_name"`
10 Date time.Time `gorm:"column:mapped_time"`
11 }
12
13 // Make sure an ignored field does not interfere with another field's custom
14 // column name that matches the ignored field.
15 type CustomColumnAndIgnoredFieldClash struct {
16 Body string `sql:"-"`
17 RawBody string `gorm:"column:body"`
18 }
19
20 func TestCustomizeColumn(t *testing.T) {
21 col := "mapped_name"
22 DB.DropTable(&CustomizeColumn{})
23 DB.AutoMigrate(&CustomizeColumn{})
24
25 scope := DB.NewScope(&CustomizeColumn{})
26 if !scope.Dialect().HasColumn(scope, scope.TableName(), col) {
27 t.Errorf("CustomizeColumn should have column %s", col)
28 }
29
30 col = "mapped_id"
31 if scope.PrimaryKey() != col {
32 t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey())
33 }
34
35 expected := "foo"
36 cc := CustomizeColumn{ID: 666, Name: expected, Date: time.Now()}
37
38 if count := DB.Create(&cc).RowsAffected; count != 1 {
39 t.Error("There should be one record be affected when create record")
40 }
41
42 var cc1 CustomizeColumn
43 DB.First(&cc1, 666)
44
45 if cc1.Name != expected {
46 t.Errorf("Failed to query CustomizeColumn")
47 }
48
49 cc.Name = "bar"
50 DB.Save(&cc)
51
52 var cc2 CustomizeColumn
53 DB.First(&cc2, 666)
54 if cc2.Name != "bar" {
55 t.Errorf("Failed to query CustomizeColumn")
56 }
57 }
58
59 func TestCustomColumnAndIgnoredFieldClash(t *testing.T) {
60 DB.DropTable(&CustomColumnAndIgnoredFieldClash{})
61 if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil {
62 t.Errorf("Should not raise error: %s", err)
63 }
64 }
0 package gorm_test
1
2 import (
3 "testing"
4 )
5
6 func TestDdlErrors(t *testing.T) {
7 var err error
8
9 if err = DB.Close(); err != nil {
10 t.Errorf("Closing DDL test db connection err=%s", err)
11 }
12 defer func() {
13 // Reopen DB connection.
14 if DB, err = OpenTestConnection(); err != nil {
15 t.Fatalf("Failed re-opening db connection: %s", err)
16 }
17 }()
18
19 DB.HasTable("foobarbaz")
20 if DB.Error == nil {
21 t.Errorf("Expected operation on closed db to produce an error, but err was nil")
22 }
23 }
0 package gorm_test
1
2 import (
3 "testing"
4 "time"
5 )
6
7 func TestDelete(t *testing.T) {
8 user1, user2 := User{Name: "delete1"}, User{Name: "delete2"}
9 DB.Save(&user1)
10 DB.Save(&user2)
11
12 if err := DB.Delete(&user1).Error; err != nil {
13 t.Errorf("No error should happen when delete a record, err=%s", err)
14 }
15
16 if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
17 t.Errorf("User can't be found after delete")
18 }
19
20 if DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() {
21 t.Errorf("Other users that not deleted should be found-able")
22 }
23 }
24
25 func TestInlineDelete(t *testing.T) {
26 user1, user2 := User{Name: "inline_delete1"}, User{Name: "inline_delete2"}
27 DB.Save(&user1)
28 DB.Save(&user2)
29
30 if DB.Delete(&User{}, user1.Id).Error != nil {
31 t.Errorf("No error should happen when delete a record")
32 } else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
33 t.Errorf("User can't be found after delete")
34 }
35
36 if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil {
37 t.Errorf("No error should happen when delete a record, err=%s", err)
38 } else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() {
39 t.Errorf("User can't be found after delete")
40 }
41 }
42
43 func TestSoftDelete(t *testing.T) {
44 type User struct {
45 Id int64
46 Name string
47 DeletedAt time.Time
48 }
49 DB.AutoMigrate(&User{})
50
51 user := User{Name: "soft_delete"}
52 DB.Save(&user)
53 DB.Delete(&user)
54
55 if DB.First(&User{}, "name = ?", user.Name).Error == nil {
56 t.Errorf("Can't find a soft deleted record")
57 }
58
59 if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil {
60 t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err)
61 }
62
63 DB.Unscoped().Delete(&user)
64 if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() {
65 t.Errorf("Can't find permanently deleted record")
66 }
67 }
0 package gorm
1
2 import (
3 "fmt"
4 "reflect"
5 )
6
7 type Dialect interface {
8 BinVar(i int) string
9 SupportLastInsertId() bool
10 HasTop() bool
11 SqlTag(value reflect.Value, size int, autoIncrease bool) string
12 ReturningStr(tableName, key string) string
13 SelectFromDummyTable() string
14 Quote(key string) string
15 HasTable(scope *Scope, tableName string) bool
16 HasColumn(scope *Scope, tableName string, columnName string) bool
17 HasIndex(scope *Scope, tableName string, indexName string) bool
18 RemoveIndex(scope *Scope, indexName string)
19 CurrentDatabase(scope *Scope) string
20 }
21
22 func NewDialect(driver string) Dialect {
23 var d Dialect
24 switch driver {
25 case "postgres":
26 d = &postgres{}
27 case "foundation":
28 d = &foundation{}
29 case "mysql":
30 d = &mysql{}
31 case "sqlite3":
32 d = &sqlite3{}
33 case "mssql":
34 d = &mssql{}
35 default:
36 fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver)
37 d = &commonDialect{}
38 }
39 return d
40 }
0 # Gorm Development
1
2 ## Architecture
3
4 The most notable component of Gorm is`gorm.DB`, which hold database connection. It could be initialized like this:
5
6 db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
7
8 Gorm has chainable API, `gorm.DB` is the bridge of chains, it save related information and pass it to the next chain.
9
10 Lets use below code to explain how it works:
11
12 db.Where("name = ?", "jinzhu").Find(&users)
13
14 // equivalent code
15 newdb := db.Where("name =?", "jinzhu")
16 newdb.Find(&user)
17
18 `newdb` is `db`'s clone, in addition, it contains search conditions from the `Where` method.
19 `Find` is a query method, it creates a `Scope` instance, and pass it as argument to query callbacks.
20
21 There are four kinds of callbacks corresponds to sql's CURD: create callbacks, update callbacks, query callbacks, delete callbacks.
22
23 ## Callbacks
24
25 ### Register a new callback
26
27 func updateCreated(scope *Scope) {
28 if scope.HasColumn("Created") {
29 scope.SetColumn("Created", NowFunc())
30 }
31 }
32
33 db.Callback().Create().Register("update_created_at", updateCreated)
34 // register a callback for Create process
35
36 ### Delete an existing callback
37
38 db.Callback().Create().Remove("gorm:create")
39 // delete callback `gorm:create` from Create callbacks
40
41 ### Replace an existing callback
42
43 db.Callback().Create().Replace("gorm:create", newCreateFunction)
44 // replace callback `gorm:create` with new function `newCreateFunction` for Create process
45
46 ### Register callback orders
47
48 db.Callback().Create().Before("gorm:create").Register("update_created_at", updateCreated)
49 db.Callback().Create().After("gorm:create").Register("update_created_at", updateCreated)
50 db.Callback().Query().After("gorm:query").Register("my_plugin:after_query", afterQuery)
51 db.Callback().Delete().After("gorm:delete").Register("my_plugin:after_delete", afterDelete)
52 db.Callback().Update().Before("gorm:update").Register("my_plugin:before_update", beforeUpdate)
53 db.Callback().Create().Before("gorm:create").After("gorm:before_create").Register("my_plugin:before_create", beforeCreate)
54
55 ### Callback API
56
57 Gorm is powered by callbacks, so you could refer below links to learn how to write callbacks
58
59 [Create callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go)
60
61 [Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go)
62
63 [Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go)
64
65 [Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go)
66
67 View [https://github.com/jinzhu/gorm/blob/master/scope.go](https://github.com/jinzhu/gorm/blob/master/scope.go) for all available API
0 package gorm_test
1
2 import "testing"
3
4 type BasePost struct {
5 Id int64
6 Title string
7 URL string
8 }
9
10 type HNPost struct {
11 BasePost
12 Upvotes int32
13 }
14
15 type EngadgetPost struct {
16 BasePost BasePost `gorm:"embedded"`
17 ImageUrl string
18 }
19
20 func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
21 DB.Save(&HNPost{BasePost: BasePost{Title: "news"}})
22 DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
23 var news HNPost
24 if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
25 t.Errorf("no error should happen when query with embedded struct, but got %v", err)
26 } else if news.Title != "hn_news" {
27 t.Errorf("embedded struct's value should be scanned correctly")
28 }
29
30 DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
31 var egNews EngadgetPost
32 if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
33 t.Errorf("no error should happen when query with embedded struct, but got %v", err)
34 } else if egNews.BasePost.Title != "engadget_news" {
35 t.Errorf("embedded struct's value should be scanned correctly")
36 }
37
38 if DB.NewScope(&HNPost{}).PrimaryField() == nil {
39 t.Errorf("primary key with embedded struct should works")
40 }
41
42 for _, field := range DB.NewScope(&HNPost{}).Fields() {
43 if field.Name == "BasePost" {
44 t.Errorf("scope Fields should not contain embedded struct")
45 }
46 }
47 }
0 package gorm
1
2 import (
3 "errors"
4 "strings"
5 )
6
7 var (
8 RecordNotFound = errors.New("record not found")
9 InvalidSql = errors.New("invalid sql")
10 NoNewAttrs = errors.New("no new attributes")
11 NoValidTransaction = errors.New("no valid transaction")
12 CantStartTransaction = errors.New("can't start transaction")
13 )
14
15 type errorsInterface interface {
16 GetErrors() []error
17 }
18
19 type Errors struct {
20 errors []error
21 }
22
23 func (errs Errors) GetErrors() []error {
24 return errs.errors
25 }
26
27 func (errs *Errors) Add(err error) {
28 if errors, ok := err.(errorsInterface); ok {
29 for _, err := range errors.GetErrors() {
30 errs.Add(err)
31 }
32 } else {
33 for _, e := range errs.errors {
34 if err == e {
35 return
36 }
37 }
38 errs.errors = append(errs.errors, err)
39 }
40 }
41
42 func (errs Errors) Error() string {
43 var errors = []string{}
44 for _, e := range errs.errors {
45 errors = append(errors, e.Error())
46 }
47 return strings.Join(errors, "; ")
48 }
0 package gorm
1
2 import (
3 "database/sql"
4 "errors"
5 "reflect"
6 )
7
8 type Field struct {
9 *StructField
10 IsBlank bool
11 Field reflect.Value
12 }
13
14 func (field *Field) Set(value interface{}) error {
15 if !field.Field.IsValid() {
16 return errors.New("field value not valid")
17 }
18
19 if !field.Field.CanAddr() {
20 return errors.New("unaddressable value")
21 }
22
23 if rvalue, ok := value.(reflect.Value); ok {
24 value = rvalue.Interface()
25 }
26
27 if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
28 if v, ok := value.(reflect.Value); ok {
29 if err := scanner.Scan(v.Interface()); err != nil {
30 return err
31 }
32 } else {
33 if err := scanner.Scan(value); err != nil {
34 return err
35 }
36 }
37 } else {
38 reflectValue, ok := value.(reflect.Value)
39 if !ok {
40 reflectValue = reflect.ValueOf(value)
41 }
42
43 if reflectValue.Type().ConvertibleTo(field.Field.Type()) {
44 field.Field.Set(reflectValue.Convert(field.Field.Type()))
45 } else {
46 return errors.New("could not convert argument")
47 }
48 }
49
50 field.IsBlank = isBlank(field.Field)
51 return nil
52 }
53
54 // Fields get value's fields
55 func (scope *Scope) Fields() map[string]*Field {
56 if scope.fields == nil {
57 fields := map[string]*Field{}
58 modelStruct := scope.GetModelStruct()
59
60 indirectValue := scope.IndirectValue()
61 isStruct := indirectValue.Kind() == reflect.Struct
62 for _, structField := range modelStruct.StructFields {
63 if isStruct {
64 fields[structField.DBName] = getField(indirectValue, structField)
65 } else {
66 fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
67 }
68 }
69
70 if modelStruct.cached {
71 scope.fields = fields
72 }
73 return fields
74 }
75 return scope.fields
76 }
77
78 func getField(indirectValue reflect.Value, structField *StructField) *Field {
79 field := &Field{StructField: structField}
80 for _, name := range structField.Names {
81 indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
82 }
83 field.Field = indirectValue
84 field.IsBlank = isBlank(indirectValue)
85 return field
86 }
0 package gorm_test
1
2 import (
3 "testing"
4
5 "github.com/jinzhu/gorm"
6 )
7
8 type CalculateField struct {
9 gorm.Model
10 Name string
11 Children []CalculateFieldChild
12 Category CalculateFieldCategory
13 }
14
15 type CalculateFieldChild struct {
16 gorm.Model
17 CalculateFieldID uint
18 Name string
19 }
20
21 type CalculateFieldCategory struct {
22 gorm.Model
23 CalculateFieldID uint
24 Name string
25 }
26
27 func TestCalculateField(t *testing.T) {
28 var field CalculateField
29 fields := DB.NewScope(&field).Fields()
30 if fields["children"].Relationship == nil || fields["category"].Relationship == nil {
31 t.Errorf("Should calculate fields correctly for the first time")
32 }
33 }
0 package gorm
1
2 import (
3 "fmt"
4 "reflect"
5 "time"
6 )
7
8 type foundation struct {
9 commonDialect
10 }
11
12 func (foundation) BinVar(i int) string {
13 return fmt.Sprintf("$%v", i)
14 }
15
16 func (foundation) SupportLastInsertId() bool {
17 return false
18 }
19
20 func (foundation) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
21 switch value.Kind() {
22 case reflect.Bool:
23 return "boolean"
24 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
25 if autoIncrease {
26 return "serial"
27 }
28 return "int"
29 case reflect.Int64, reflect.Uint64:
30 if autoIncrease {
31 return "bigserial"
32 }
33 return "bigint"
34 case reflect.Float32, reflect.Float64:
35 return "double"
36 case reflect.String:
37 if size > 0 && size < 65532 {
38 return fmt.Sprintf("varchar(%d)", size)
39 }
40 return "clob"
41 case reflect.Struct:
42 if _, ok := value.Interface().(time.Time); ok {
43 return "datetime"
44 }
45 default:
46 if _, ok := value.Interface().([]byte); ok {
47 return "blob"
48 }
49 }
50 panic(fmt.Sprintf("invalid sql type %s (%s) for foundation", value.Type().Name(), value.Kind().String()))
51 }
52
53 func (s foundation) ReturningStr(tableName, key string) string {
54 return fmt.Sprintf("RETURNING %v.%v", tableName, key)
55 }
56
57 func (s foundation) HasTable(scope *Scope, tableName string) bool {
58 var count int
59 s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName)
60 return count > 0
61 }
62
63 func (s foundation) HasColumn(scope *Scope, tableName string, columnName string) bool {
64 var count int
65 s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName)
66 return count > 0
67 }
68
69 func (s foundation) RemoveIndex(scope *Scope, indexName string) {
70 scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", s.Quote(indexName)))
71 }
72
73 func (s foundation) HasIndex(scope *Scope, tableName string, indexName string) bool {
74 var count int
75 s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName)
76 return count > 0
77 }
78
79 func (s foundation) CurrentDatabase(scope *Scope) (name string) {
80 s.RawScanString(scope, &name, "SELECT CURRENT_SCHEMA")
81 return
82 }
Binary diff not shown
0 package gorm
1
2 import "database/sql"
3
4 type sqlCommon interface {
5 Exec(query string, args ...interface{}) (sql.Result, error)
6 Prepare(query string) (*sql.Stmt, error)
7 Query(query string, args ...interface{}) (*sql.Rows, error)
8 QueryRow(query string, args ...interface{}) *sql.Row
9 }
10
11 type sqlDb interface {
12 Begin() (*sql.Tx, error)
13 }
14
15 type sqlTx interface {
16 Commit() error
17 Rollback() error
18 }
0 package gorm
1
2 import (
3 "errors"
4 "fmt"
5 "reflect"
6 "strings"
7 )
8
9 type JoinTableHandlerInterface interface {
10 Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
11 Table(db *DB) string
12 Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
13 Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
14 JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
15 SourceForeignKeys() []JoinTableForeignKey
16 DestinationForeignKeys() []JoinTableForeignKey
17 }
18
19 type JoinTableForeignKey struct {
20 DBName string
21 AssociationDBName string
22 }
23
24 type JoinTableSource struct {
25 ModelType reflect.Type
26 ForeignKeys []JoinTableForeignKey
27 }
28
29 type JoinTableHandler struct {
30 TableName string `sql:"-"`
31 Source JoinTableSource `sql:"-"`
32 Destination JoinTableSource `sql:"-"`
33 }
34
35 func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
36 return s.Source.ForeignKeys
37 }
38
39 func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
40 return s.Destination.ForeignKeys
41 }
42
43 func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
44 s.TableName = tableName
45
46 s.Source = JoinTableSource{ModelType: source}
47 for idx, dbName := range relationship.ForeignFieldNames {
48 s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
49 DBName: relationship.ForeignDBNames[idx],
50 AssociationDBName: dbName,
51 })
52 }
53
54 s.Destination = JoinTableSource{ModelType: destination}
55 for idx, dbName := range relationship.AssociationForeignFieldNames {
56 s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
57 DBName: relationship.AssociationForeignDBNames[idx],
58 AssociationDBName: dbName,
59 })
60 }
61 }
62
63 func (s JoinTableHandler) Table(db *DB) string {
64 return s.TableName
65 }
66
67 func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
68 values := map[string]interface{}{}
69
70 for _, source := range sources {
71 scope := db.NewScope(source)
72 modelType := scope.GetModelStruct().ModelType
73
74 if s.Source.ModelType == modelType {
75 for _, foreignKey := range s.Source.ForeignKeys {
76 values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
77 }
78 } else if s.Destination.ModelType == modelType {
79 for _, foreignKey := range s.Destination.ForeignKeys {
80 values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
81 }
82 }
83 }
84 return values
85 }
86
87 func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error {
88 scope := db.NewScope("")
89 searchMap := s.GetSearchMap(db, source1, source2)
90
91 var assignColumns, binVars, conditions []string
92 var values []interface{}
93 for key, value := range searchMap {
94 assignColumns = append(assignColumns, key)
95 binVars = append(binVars, `?`)
96 conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
97 values = append(values, value)
98 }
99
100 for _, value := range values {
101 values = append(values, value)
102 }
103
104 quotedTable := handler.Table(db)
105 sql := fmt.Sprintf(
106 "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
107 quotedTable,
108 strings.Join(assignColumns, ","),
109 strings.Join(binVars, ","),
110 scope.Dialect().SelectFromDummyTable(),
111 quotedTable,
112 strings.Join(conditions, " AND "),
113 )
114
115 return db.Exec(sql, values...).Error
116 }
117
118 func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
119 var conditions []string
120 var values []interface{}
121
122 for key, value := range s.GetSearchMap(db, sources...) {
123 conditions = append(conditions, fmt.Sprintf("%v = ?", key))
124 values = append(values, value)
125 }
126
127 return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
128 }
129
130 func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
131 quotedTable := handler.Table(db)
132
133 scope := db.NewScope(source)
134 modelType := scope.GetModelStruct().ModelType
135 var joinConditions []string
136 var queryConditions []string
137 var values []interface{}
138 if s.Source.ModelType == modelType {
139 destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
140 for _, foreignKey := range s.Destination.ForeignKeys {
141 joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
142 }
143
144 var foreignDBNames []string
145 var foreignFieldNames []string
146
147 for _, foreignKey := range s.Source.ForeignKeys {
148 foreignDBNames = append(foreignDBNames, foreignKey.DBName)
149 foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name)
150 }
151
152 foreignFieldValues := scope.getColumnAsArray(foreignFieldNames)
153
154 condString := fmt.Sprintf("%v in (%v)", toQueryCondition(scope, foreignDBNames), toQueryMarks(foreignFieldValues))
155
156 keys := scope.getColumnAsArray(foreignFieldNames)
157 values = append(values, toQueryValues(keys))
158
159 queryConditions = append(queryConditions, condString)
160
161 return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))).
162 Where(condString, toQueryValues(foreignFieldValues)...)
163 } else {
164 db.Error = errors.New("wrong source type for join table handler")
165 return db
166 }
167 }
0 package gorm_test
1
2 import (
3 "fmt"
4 "testing"
5 "time"
6
7 "github.com/jinzhu/gorm"
8 )
9
10 type Person struct {
11 Id int
12 Name string
13 Addresses []*Address `gorm:"many2many:person_addresses;"`
14 }
15
16 type PersonAddress struct {
17 gorm.JoinTableHandler
18 PersonID int
19 AddressID int
20 DeletedAt time.Time
21 CreatedAt time.Time
22 }
23
24 func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error {
25 return db.Where(map[string]interface{}{
26 "person_id": db.NewScope(foreignValue).PrimaryKeyValue(),
27 "address_id": db.NewScope(associationValue).PrimaryKeyValue(),
28 }).Assign(map[string]interface{}{
29 "person_id": foreignValue,
30 "address_id": associationValue,
31 "deleted_at": gorm.Expr("NULL"),
32 }).FirstOrCreate(&PersonAddress{}).Error
33 }
34
35 func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error {
36 return db.Delete(&PersonAddress{}).Error
37 }
38
39 func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB {
40 table := pa.Table(db)
41 return db.Table(table).Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
42 }
43
44 func TestJoinTable(t *testing.T) {
45 DB.Exec("drop table person_addresses;")
46 DB.AutoMigrate(&Person{})
47 DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{})
48
49 address1 := &Address{Address1: "address 1"}
50 address2 := &Address{Address1: "address 2"}
51 person := &Person{Name: "person", Addresses: []*Address{address1, address2}}
52 DB.Save(person)
53
54 DB.Model(person).Association("Addresses").Delete(address1)
55
56 if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 1 {
57 t.Errorf("Should found one address")
58 }
59
60 if DB.Model(person).Association("Addresses").Count() != 1 {
61 t.Errorf("Should found one address")
62 }
63
64 if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 2 {
65 t.Errorf("Found two addresses with Unscoped")
66 }
67
68 if DB.Model(person).Association("Addresses").Clear(); DB.Model(person).Association("Addresses").Count() != 0 {
69 t.Errorf("Should deleted all addresses")
70 }
71 }
0 package gorm
1
2 import (
3 "database/sql/driver"
4 "fmt"
5 "log"
6 "os"
7 "reflect"
8 "regexp"
9 "time"
10 )
11
12 type logger interface {
13 Print(v ...interface{})
14 }
15
16 type LogWriter interface {
17 Println(v ...interface{})
18 }
19
20 type Logger struct {
21 LogWriter
22 }
23
24 var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
25
26 // Format log
27 var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
28
29 func (logger Logger) Print(values ...interface{}) {
30 if len(values) > 1 {
31 level := values[0]
32 currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m"
33 source := fmt.Sprintf("\033[35m(%v)\033[0m", values[1])
34 messages := []interface{}{source, currentTime}
35
36 if level == "sql" {
37 // duration
38 messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
39 // sql
40 var formatedValues []interface{}
41 for _, value := range values[4].([]interface{}) {
42 indirectValue := reflect.Indirect(reflect.ValueOf(value))
43 if indirectValue.IsValid() {
44 value = indirectValue.Interface()
45 if t, ok := value.(time.Time); ok {
46 formatedValues = append(formatedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339)))
47 } else if b, ok := value.([]byte); ok {
48 formatedValues = append(formatedValues, fmt.Sprintf("'%v'", string(b)))
49 } else if r, ok := value.(driver.Valuer); ok {
50 if value, err := r.Value(); err == nil && value != nil {
51 formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
52 } else {
53 formatedValues = append(formatedValues, "NULL")
54 }
55 } else {
56 formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
57 }
58 } else {
59 formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
60 }
61 }
62 messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(values[3].(string), "%v"), formatedValues...))
63 } else {
64 messages = append(messages, "\033[31;1m")
65 messages = append(messages, values[2:]...)
66 messages = append(messages, "\033[0m")
67 }
68 logger.Println(messages...)
69 }
70 }
0 package gorm
1
2 import (
3 "database/sql"
4 "errors"
5 "fmt"
6 "reflect"
7 "strings"
8 "time"
9 )
10
11 // NowFunc returns current time, this function is exported in order to be able
12 // to give the flexibility to the developer to customize it according to their
13 // needs
14 //
15 // e.g: return time.Now().UTC()
16 //
17 var NowFunc = func() time.Time {
18 return time.Now()
19 }
20
21 type DB struct {
22 Value interface{}
23 Error error
24 RowsAffected int64
25 callback *callback
26 db sqlCommon
27 parent *DB
28 search *search
29 logMode int
30 logger logger
31 dialect Dialect
32 singularTable bool
33 source string
34 values map[string]interface{}
35 joinTableHandlers map[string]JoinTableHandler
36 }
37
38 func Open(dialect string, args ...interface{}) (DB, error) {
39 var db DB
40 var err error
41
42 if len(args) == 0 {
43 err = errors.New("invalid database source")
44 } else {
45 var source string
46 var dbSql sqlCommon
47
48 switch value := args[0].(type) {
49 case string:
50 var driver = dialect
51 if len(args) == 1 {
52 source = value
53 } else if len(args) >= 2 {
54 driver = value
55 source = args[1].(string)
56 }
57 if driver == "foundation" {
58 driver = "postgres" // FoundationDB speaks a postgres-compatible protocol.
59 }
60 dbSql, err = sql.Open(driver, source)
61 case sqlCommon:
62 source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
63 dbSql = value
64 }
65
66 db = DB{
67 dialect: NewDialect(dialect),
68 logger: defaultLogger,
69 callback: DefaultCallback,
70 source: source,
71 values: map[string]interface{}{},
72 db: dbSql,
73 }
74 db.parent = &db
75
76 if err == nil {
77 err = db.DB().Ping() // Send a ping to make sure the database connection is alive.
78 }
79 }
80
81 return db, err
82 }
83
84 func (s *DB) Close() error {
85 return s.parent.db.(*sql.DB).Close()
86 }
87
88 func (s *DB) DB() *sql.DB {
89 return s.db.(*sql.DB)
90 }
91
92 func (s *DB) New() *DB {
93 clone := s.clone()
94 clone.search = nil
95 clone.Value = nil
96 return clone
97 }
98
99 // NewScope create scope for callbacks, including DB's search information
100 func (db *DB) NewScope(value interface{}) *Scope {
101 dbClone := db.clone()
102 dbClone.Value = value
103 return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
104 }
105
106 // CommonDB Return the underlying sql.DB or sql.Tx instance.
107 // Use of this method is discouraged. It's mainly intended to allow
108 // coexistence with legacy non-GORM code.
109 func (s *DB) CommonDB() sqlCommon {
110 return s.db
111 }
112
113 func (s *DB) Callback() *callback {
114 s.parent.callback = s.parent.callback.clone()
115 return s.parent.callback
116 }
117
118 func (s *DB) SetLogger(l logger) {
119 s.logger = l
120 }
121
122 func (s *DB) LogMode(enable bool) *DB {
123 if enable {
124 s.logMode = 2
125 } else {
126 s.logMode = 1
127 }
128 return s
129 }
130
131 func (s *DB) SingularTable(enable bool) {
132 modelStructsMap = newModelStructsMap()
133 s.parent.singularTable = enable
134 }
135
136 func (s *DB) Where(query interface{}, args ...interface{}) *DB {
137 return s.clone().search.Where(query, args...).db
138 }
139
140 func (s *DB) Or(query interface{}, args ...interface{}) *DB {
141 return s.clone().search.Or(query, args...).db
142 }
143
144 func (s *DB) Not(query interface{}, args ...interface{}) *DB {
145 return s.clone().search.Not(query, args...).db
146 }
147
148 func (s *DB) Limit(value interface{}) *DB {
149 return s.clone().search.Limit(value).db
150 }
151
152 func (s *DB) Offset(value interface{}) *DB {
153 return s.clone().search.Offset(value).db
154 }
155
156 func (s *DB) Order(value string, reorder ...bool) *DB {
157 return s.clone().search.Order(value, reorder...).db
158 }
159
160 func (s *DB) Select(query interface{}, args ...interface{}) *DB {
161 return s.clone().search.Select(query, args...).db
162 }
163
164 func (s *DB) Omit(columns ...string) *DB {
165 return s.clone().search.Omit(columns...).db
166 }
167
168 func (s *DB) Group(query string) *DB {
169 return s.clone().search.Group(query).db
170 }
171
172 func (s *DB) Having(query string, values ...interface{}) *DB {
173 return s.clone().search.Having(query, values...).db
174 }
175
176 func (s *DB) Joins(query string) *DB {
177 return s.clone().search.Joins(query).db
178 }
179
180 func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
181 for _, f := range funcs {
182 s = f(s)
183 }
184 return s
185 }
186
187 func (s *DB) Unscoped() *DB {
188 return s.clone().search.unscoped().db
189 }
190
191 func (s *DB) Attrs(attrs ...interface{}) *DB {
192 return s.clone().search.Attrs(attrs...).db
193 }
194
195 func (s *DB) Assign(attrs ...interface{}) *DB {
196 return s.clone().search.Assign(attrs...).db
197 }
198
199 func (s *DB) First(out interface{}, where ...interface{}) *DB {
200 newScope := s.clone().NewScope(out)
201 newScope.Search.Limit(1)
202 return newScope.Set("gorm:order_by_primary_key", "ASC").
203 inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
204 }
205
206 func (s *DB) Last(out interface{}, where ...interface{}) *DB {
207 newScope := s.clone().NewScope(out)
208 newScope.Search.Limit(1)
209 return newScope.Set("gorm:order_by_primary_key", "DESC").
210 inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
211 }
212
213 func (s *DB) Find(out interface{}, where ...interface{}) *DB {
214 return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
215 }
216
217 func (s *DB) Scan(dest interface{}) *DB {
218 return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db
219 }
220
221 func (s *DB) Row() *sql.Row {
222 return s.NewScope(s.Value).row()
223 }
224
225 func (s *DB) Rows() (*sql.Rows, error) {
226 return s.NewScope(s.Value).rows()
227 }
228
229 func (s *DB) Pluck(column string, value interface{}) *DB {
230 return s.NewScope(s.Value).pluck(column, value).db
231 }
232
233 func (s *DB) Count(value interface{}) *DB {
234 return s.NewScope(s.Value).count(value).db
235 }
236
237 func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
238 return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
239 }
240
241 func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
242 c := s.clone()
243 if result := c.First(out, where...); result.Error != nil {
244 if !result.RecordNotFound() {
245 return result
246 }
247 c.NewScope(out).inlineCondition(where...).initialize()
248 } else {
249 c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs), false)
250 }
251 return c
252 }
253
254 func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
255 c := s.clone()
256 if result := c.First(out, where...); result.Error != nil {
257 if !result.RecordNotFound() {
258 return result
259 }
260 c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callback.creates).db.Error)
261 } else if len(c.search.assignAttrs) > 0 {
262 c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callback.updates).db.Error)
263 }
264 return c
265 }
266
267 func (s *DB) Update(attrs ...interface{}) *DB {
268 return s.Updates(toSearchableMap(attrs...), true)
269 }
270
271 func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
272 return s.clone().NewScope(s.Value).
273 Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
274 InstanceSet("gorm:update_interface", values).
275 callCallbacks(s.parent.callback.updates).db
276 }
277
278 func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
279 return s.UpdateColumns(toSearchableMap(attrs...))
280 }
281
282 func (s *DB) UpdateColumns(values interface{}) *DB {
283 return s.clone().NewScope(s.Value).
284 Set("gorm:update_column", true).
285 Set("gorm:save_associations", false).
286 InstanceSet("gorm:update_interface", values).
287 callCallbacks(s.parent.callback.updates).db
288 }
289
290 func (s *DB) Save(value interface{}) *DB {
291 scope := s.clone().NewScope(value)
292 if scope.PrimaryKeyZero() {
293 return scope.callCallbacks(s.parent.callback.creates).db
294 }
295 return scope.callCallbacks(s.parent.callback.updates).db
296 }
297
298 func (s *DB) Create(value interface{}) *DB {
299 scope := s.clone().NewScope(value)
300 return scope.callCallbacks(s.parent.callback.creates).db
301 }
302
303 func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
304 return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callback.deletes).db
305 }
306
307 func (s *DB) Raw(sql string, values ...interface{}) *DB {
308 return s.clone().search.Raw(true).Where(sql, values...).db
309 }
310
311 func (s *DB) Exec(sql string, values ...interface{}) *DB {
312 scope := s.clone().NewScope(nil)
313 generatedSql := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
314 generatedSql = strings.TrimSuffix(strings.TrimPrefix(generatedSql, "("), ")")
315 scope.Raw(generatedSql)
316 return scope.Exec().db
317 }
318
319 func (s *DB) Model(value interface{}) *DB {
320 c := s.clone()
321 c.Value = value
322 return c
323 }
324
325 func (s *DB) Table(name string) *DB {
326 clone := s.clone()
327 clone.search.Table(name)
328 clone.Value = nil
329 return clone
330 }
331
332 func (s *DB) Debug() *DB {
333 return s.clone().LogMode(true)
334 }
335
336 func (s *DB) Begin() *DB {
337 c := s.clone()
338 if db, ok := c.db.(sqlDb); ok {
339 tx, err := db.Begin()
340 c.db = interface{}(tx).(sqlCommon)
341 c.AddError(err)
342 } else {
343 c.AddError(CantStartTransaction)
344 }
345 return c
346 }
347
348 func (s *DB) Commit() *DB {
349 if db, ok := s.db.(sqlTx); ok {
350 s.AddError(db.Commit())
351 } else {
352 s.AddError(NoValidTransaction)
353 }
354 return s
355 }
356
357 func (s *DB) Rollback() *DB {
358 if db, ok := s.db.(sqlTx); ok {
359 s.AddError(db.Rollback())
360 } else {
361 s.AddError(NoValidTransaction)
362 }
363 return s
364 }
365
366 func (s *DB) NewRecord(value interface{}) bool {
367 return s.clone().NewScope(value).PrimaryKeyZero()
368 }
369
370 func (s *DB) RecordNotFound() bool {
371 return s.Error == RecordNotFound
372 }
373
374 // Migrations
375 func (s *DB) CreateTable(values ...interface{}) *DB {
376 db := s.clone()
377 for _, value := range values {
378 db = db.NewScope(value).createTable().db
379 }
380 return db
381 }
382
383 func (s *DB) DropTable(values ...interface{}) *DB {
384 db := s.clone()
385 for _, value := range values {
386 db = db.NewScope(value).dropTable().db
387 }
388 return db
389 }
390
391 func (s *DB) DropTableIfExists(values ...interface{}) *DB {
392 db := s.clone()
393 for _, value := range values {
394 db = db.NewScope(value).dropTableIfExists().db
395 }
396 return db
397 }
398
399 func (s *DB) HasTable(value interface{}) bool {
400 scope := s.clone().NewScope(value)
401 tableName := scope.TableName()
402 has := scope.Dialect().HasTable(scope, tableName)
403 s.AddError(scope.db.Error)
404 return has
405 }
406
407 func (s *DB) AutoMigrate(values ...interface{}) *DB {
408 db := s.clone()
409 for _, value := range values {
410 db = db.NewScope(value).NeedPtr().autoMigrate().db
411 }
412 return db
413 }
414
415 func (s *DB) ModifyColumn(column string, typ string) *DB {
416 scope := s.clone().NewScope(s.Value)
417 scope.modifyColumn(column, typ)
418 return scope.db
419 }
420
421 func (s *DB) DropColumn(column string) *DB {
422 scope := s.clone().NewScope(s.Value)
423 scope.dropColumn(column)
424 return scope.db
425 }
426
427 func (s *DB) AddIndex(indexName string, column ...string) *DB {
428 scope := s.clone().NewScope(s.Value)
429 scope.addIndex(false, indexName, column...)
430 return scope.db
431 }
432
433 func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB {
434 scope := s.clone().NewScope(s.Value)
435 scope.addIndex(true, indexName, column...)
436 return scope.db
437 }
438
439 func (s *DB) RemoveIndex(indexName string) *DB {
440 scope := s.clone().NewScope(s.Value)
441 scope.removeIndex(indexName)
442 return scope.db
443 }
444
445 func (s *DB) CurrentDatabase() string {
446 var (
447 scope = s.clone().NewScope(s.Value)
448 name = s.dialect.CurrentDatabase(scope)
449 )
450 return name
451 }
452
453 /*
454 Add foreign key to the given scope
455
456 Example:
457 db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
458 */
459 func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
460 scope := s.clone().NewScope(s.Value)
461 scope.addForeignKey(field, dest, onDelete, onUpdate)
462 return scope.db
463 }
464
465 func (s *DB) Association(column string) *Association {
466 var err error
467 scope := s.clone().NewScope(s.Value)
468
469 if primaryField := scope.PrimaryField(); primaryField.IsBlank {
470 err = errors.New("primary key can't be nil")
471 } else {
472 if field, ok := scope.FieldByName(column); ok {
473 if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
474 err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
475 } else {
476 return &Association{Scope: scope, Column: column, Field: field}
477 }
478 } else {
479 err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
480 }
481 }
482
483 return &Association{Error: err}
484 }
485
486 func (s *DB) Preload(column string, conditions ...interface{}) *DB {
487 return s.clone().search.Preload(column, conditions...).db
488 }
489
490 // Set set value by name
491 func (s *DB) Set(name string, value interface{}) *DB {
492 return s.clone().InstantSet(name, value)
493 }
494
495 func (s *DB) InstantSet(name string, value interface{}) *DB {
496 s.values[name] = value
497 return s
498 }
499
500 // Get get value by name
501 func (s *DB) Get(name string) (value interface{}, ok bool) {
502 value, ok = s.values[name]
503 return
504 }
505
506 func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
507 scope := s.NewScope(source)
508 for _, field := range scope.GetModelStruct().StructFields {
509 if field.Name == column || field.DBName == column {
510 if many2many := parseTagSetting(field.Tag.Get("gorm"))["MANY2MANY"]; many2many != "" {
511 source := (&Scope{Value: source}).GetModelStruct().ModelType
512 destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
513 handler.Setup(field.Relationship, many2many, source, destination)
514 field.Relationship.JoinTableHandler = handler
515 if table := handler.Table(s); scope.Dialect().HasTable(scope, table) {
516 s.Table(table).AutoMigrate(handler)
517 }
518 }
519 }
520 }
521 }
522
523 func (s *DB) AddError(err error) error {
524 if err != nil {
525 if err != RecordNotFound {
526 if s.logMode == 0 {
527 go s.print(fileWithLineNum(), err)
528 } else {
529 s.log(err)
530 }
531
532 errors := Errors{errors: s.GetErrors()}
533 errors.Add(err)
534 if len(errors.GetErrors()) > 1 {
535 err = errors
536 }
537 }
538
539 s.Error = err
540 }
541 return err
542 }
543
544 func (s *DB) GetErrors() (errors []error) {
545 if errs, ok := s.Error.(errorsInterface); ok {
546 return errs.GetErrors()
547 } else if s.Error != nil {
548 return []error{s.Error}
549 }
550 return
551 }
0 package gorm
1
2 import "time"
3
4 func (s *DB) clone() *DB {
5 db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error}
6
7 for key, value := range s.values {
8 db.values[key] = value
9 }
10
11 if s.search == nil {
12 db.search = &search{}
13 } else {
14 db.search = s.search.clone()
15 }
16
17 db.search.db = &db
18 return &db
19 }
20
21 func (s *DB) print(v ...interface{}) {
22 s.logger.(logger).Print(v...)
23 }
24
25 func (s *DB) log(v ...interface{}) {
26 if s != nil && s.logMode == 2 {
27 s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
28 }
29 }
30
31 func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
32 if s.logMode == 2 {
33 s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars)
34 }
35 }
0 package gorm_test
1
2 import (
3 "database/sql"
4 "database/sql/driver"
5 "fmt"
6 "strconv"
7
8 _ "github.com/denisenkom/go-mssqldb"
9 testdb "github.com/erikstmartin/go-testdb"
10 _ "github.com/go-sql-driver/mysql"
11 "github.com/jinzhu/gorm"
12 "github.com/jinzhu/now"
13 _ "github.com/lib/pq"
14 _ "github.com/mattn/go-sqlite3"
15
16 "os"
17 "testing"
18 "time"
19 )
20
21 var (
22 DB gorm.DB
23 t1, t2, t3, t4, t5 time.Time
24 )
25
26 func init() {
27 var err error
28
29 if DB, err = OpenTestConnection(); err != nil {
30 panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err))
31 }
32
33 // DB.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)})
34 // DB.SetLogger(log.New(os.Stdout, "\r\n", 0))
35 // DB.LogMode(true)
36 DB.LogMode(false)
37
38 DB.DB().SetMaxIdleConns(10)
39
40 runMigration()
41 }
42
43 func OpenTestConnection() (db gorm.DB, err error) {
44 switch os.Getenv("GORM_DIALECT") {
45 case "mysql":
46 // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm';
47 // CREATE DATABASE gorm;
48 // GRANT ALL ON gorm.* TO 'gorm'@'localhost';
49 fmt.Println("testing mysql...")
50 db, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True")
51 case "postgres":
52 fmt.Println("testing postgres...")
53 db, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable")
54 case "foundation":
55 fmt.Println("testing foundation...")
56 db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable")
57 case "mssql":
58 fmt.Println("testing mssql...")
59 db, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433")
60 default:
61 fmt.Println("testing sqlite3...")
62 db, err = gorm.Open("sqlite3", "/tmp/gorm.db")
63 }
64 return
65 }
66
67 func TestStringPrimaryKey(t *testing.T) {
68 type UUIDStruct struct {
69 ID string `gorm:"primary_key"`
70 Name string
71 }
72 DB.AutoMigrate(&UUIDStruct{})
73
74 data := UUIDStruct{ID: "uuid", Name: "hello"}
75 if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" {
76 t.Errorf("string primary key should not be populated")
77 }
78 }
79
80 func TestExceptionsWithInvalidSql(t *testing.T) {
81 var columns []string
82 if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
83 t.Errorf("Should got error with invalid SQL")
84 }
85
86 if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
87 t.Errorf("Should got error with invalid SQL")
88 }
89
90 if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil {
91 t.Errorf("Should got error with invalid SQL")
92 }
93
94 var count1, count2 int64
95 DB.Model(&User{}).Count(&count1)
96 if count1 <= 0 {
97 t.Errorf("Should find some users")
98 }
99
100 if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil {
101 t.Errorf("Should got error with invalid SQL")
102 }
103
104 DB.Model(&User{}).Count(&count2)
105 if count1 != count2 {
106 t.Errorf("No user should not be deleted by invalid SQL")
107 }
108 }
109
110 func TestSetTable(t *testing.T) {
111 DB.Create(getPreparedUser("pluck_user1", "pluck_user"))
112 DB.Create(getPreparedUser("pluck_user2", "pluck_user"))
113 DB.Create(getPreparedUser("pluck_user3", "pluck_user"))
114
115 if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil {
116 t.Errorf("No errors should happen if set table for pluck", err.Error())
117 }
118
119 var users []User
120 if DB.Table("users").Find(&[]User{}).Error != nil {
121 t.Errorf("No errors should happen if set table for find")
122 }
123
124 if DB.Table("invalid_table").Find(&users).Error == nil {
125 t.Errorf("Should got error when table is set to an invalid table")
126 }
127
128 DB.Exec("drop table deleted_users;")
129 if DB.Table("deleted_users").CreateTable(&User{}).Error != nil {
130 t.Errorf("Create table with specified table")
131 }
132
133 DB.Table("deleted_users").Save(&User{Name: "DeletedUser"})
134
135 var deletedUsers []User
136 DB.Table("deleted_users").Find(&deletedUsers)
137 if len(deletedUsers) != 1 {
138 t.Errorf("Query from specified table")
139 }
140
141 DB.Save(getPreparedUser("normal_user", "reset_table"))
142 DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table"))
143 var user1, user2, user3 User
144 DB.Where("role = ?", "reset_table").First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3)
145 if (user1.Name != "normal_user") || (user2.Name != "deleted_user") || (user3.Name != "normal_user") {
146 t.Errorf("unset specified table with blank string")
147 }
148 }
149
150 type Order struct {
151 }
152
153 type Cart struct {
154 }
155
156 func (c Cart) TableName() string {
157 return "shopping_cart"
158 }
159
160 func TestHasTable(t *testing.T) {
161 type Foo struct {
162 Id int
163 Stuff string
164 }
165 DB.DropTable(&Foo{})
166 if ok := DB.HasTable(&Foo{}); ok {
167 t.Errorf("Table should not exist, but does")
168 }
169 if err := DB.CreateTable(&Foo{}).Error; err != nil {
170 t.Errorf("Table should be created")
171 }
172 if ok := DB.HasTable(&Foo{}); !ok {
173 t.Errorf("Table should exist, but HasTable informs it does not")
174 }
175 }
176
177 func TestTableName(t *testing.T) {
178 DB := DB.Model("")
179 if DB.NewScope(Order{}).TableName() != "orders" {
180 t.Errorf("Order's table name should be orders")
181 }
182
183 if DB.NewScope(&Order{}).TableName() != "orders" {
184 t.Errorf("&Order's table name should be orders")
185 }
186
187 if DB.NewScope([]Order{}).TableName() != "orders" {
188 t.Errorf("[]Order's table name should be orders")
189 }
190
191 if DB.NewScope(&[]Order{}).TableName() != "orders" {
192 t.Errorf("&[]Order's table name should be orders")
193 }
194
195 DB.SingularTable(true)
196 if DB.NewScope(Order{}).TableName() != "order" {
197 t.Errorf("Order's singular table name should be order")
198 }
199
200 if DB.NewScope(&Order{}).TableName() != "order" {
201 t.Errorf("&Order's singular table name should be order")
202 }
203
204 if DB.NewScope([]Order{}).TableName() != "order" {
205 t.Errorf("[]Order's singular table name should be order")
206 }
207
208 if DB.NewScope(&[]Order{}).TableName() != "order" {
209 t.Errorf("&[]Order's singular table name should be order")
210 }
211
212 if DB.NewScope(&Cart{}).TableName() != "shopping_cart" {
213 t.Errorf("&Cart's singular table name should be shopping_cart")
214 }
215
216 if DB.NewScope(Cart{}).TableName() != "shopping_cart" {
217 t.Errorf("Cart's singular table name should be shopping_cart")
218 }
219
220 if DB.NewScope(&[]Cart{}).TableName() != "shopping_cart" {
221 t.Errorf("&[]Cart's singular table name should be shopping_cart")
222 }
223
224 if DB.NewScope([]Cart{}).TableName() != "shopping_cart" {
225 t.Errorf("[]Cart's singular table name should be shopping_cart")
226 }
227 DB.SingularTable(false)
228 }
229
230 func TestSqlNullValue(t *testing.T) {
231 DB.DropTable(&NullValue{})
232 DB.AutoMigrate(&NullValue{})
233
234 if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello", Valid: true},
235 Age: sql.NullInt64{Int64: 18, Valid: true},
236 Male: sql.NullBool{Bool: true, Valid: true},
237 Height: sql.NullFloat64{Float64: 100.11, Valid: true},
238 AddedAt: NullTime{Time: time.Now(), Valid: true},
239 }).Error; err != nil {
240 t.Errorf("Not error should raise when test null value")
241 }
242
243 var nv NullValue
244 DB.First(&nv, "name = ?", "hello")
245
246 if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
247 t.Errorf("Should be able to fetch null value")
248 }
249
250 if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello-2", Valid: true},
251 Age: sql.NullInt64{Int64: 18, Valid: false},
252 Male: sql.NullBool{Bool: true, Valid: true},
253 Height: sql.NullFloat64{Float64: 100.11, Valid: true},
254 AddedAt: NullTime{Time: time.Now(), Valid: false},
255 }).Error; err != nil {
256 t.Errorf("Not error should raise when test null value")
257 }
258
259 var nv2 NullValue
260 DB.First(&nv2, "name = ?", "hello-2")
261 if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
262 t.Errorf("Should be able to fetch null value")
263 }
264
265 if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello-3", Valid: false},
266 Age: sql.NullInt64{Int64: 18, Valid: false},
267 Male: sql.NullBool{Bool: true, Valid: true},
268 Height: sql.NullFloat64{Float64: 100.11, Valid: true},
269 AddedAt: NullTime{Time: time.Now(), Valid: false},
270 }).Error; err == nil {
271 t.Errorf("Can't save because of name can't be null")
272 }
273 }
274
275 func TestTransaction(t *testing.T) {
276 tx := DB.Begin()
277 u := User{Name: "transcation"}
278 if err := tx.Save(&u).Error; err != nil {
279 t.Errorf("No error should raise")
280 }
281
282 if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
283 t.Errorf("Should find saved record")
284 }
285
286 if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil {
287 t.Errorf("Should return the underlying sql.Tx")
288 }
289
290 tx.Rollback()
291
292 if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {
293 t.Errorf("Should not find record after rollback")
294 }
295
296 tx2 := DB.Begin()
297 u2 := User{Name: "transcation-2"}
298 if err := tx2.Save(&u2).Error; err != nil {
299 t.Errorf("No error should raise")
300 }
301
302 if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
303 t.Errorf("Should find saved record")
304 }
305
306 tx2.Commit()
307
308 if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
309 t.Errorf("Should be able to find committed record")
310 }
311 }
312
313 func TestRow(t *testing.T) {
314 user1 := User{Name: "RowUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
315 user2 := User{Name: "RowUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
316 user3 := User{Name: "RowUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
317 DB.Save(&user1).Save(&user2).Save(&user3)
318
319 row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row()
320 var age int64
321 row.Scan(&age)
322 if age != 10 {
323 t.Errorf("Scan with Row")
324 }
325 }
326
327 func TestRows(t *testing.T) {
328 user1 := User{Name: "RowsUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
329 user2 := User{Name: "RowsUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
330 user3 := User{Name: "RowsUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
331 DB.Save(&user1).Save(&user2).Save(&user3)
332
333 rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
334 if err != nil {
335 t.Errorf("Not error should happen, but got")
336 }
337
338 count := 0
339 for rows.Next() {
340 var name string
341 var age int64
342 rows.Scan(&name, &age)
343 count++
344 }
345 if count != 2 {
346 t.Errorf("Should found two records with name 3")
347 }
348 }
349
350 func TestScan(t *testing.T) {
351 user1 := User{Name: "ScanUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
352 user2 := User{Name: "ScanUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
353 user3 := User{Name: "ScanUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
354 DB.Save(&user1).Save(&user2).Save(&user3)
355
356 type result struct {
357 Name string
358 Age int
359 }
360
361 var res result
362 DB.Table("users").Select("name, age").Where("name = ?", user3.Name).Scan(&res)
363 if res.Name != user3.Name {
364 t.Errorf("Scan into struct should work")
365 }
366
367 var doubleAgeRes result
368 DB.Table("users").Select("age + age as age").Where("name = ?", user3.Name).Scan(&doubleAgeRes)
369 if doubleAgeRes.Age != res.Age*2 {
370 t.Errorf("Scan double age as age")
371 }
372
373 var ress []result
374 DB.Table("users").Select("name, age").Where("name in (?)", []string{user2.Name, user3.Name}).Scan(&ress)
375 if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name {
376 t.Errorf("Scan into struct map")
377 }
378 }
379
380 func TestRaw(t *testing.T) {
381 user1 := User{Name: "ExecRawSqlUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
382 user2 := User{Name: "ExecRawSqlUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
383 user3 := User{Name: "ExecRawSqlUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
384 DB.Save(&user1).Save(&user2).Save(&user3)
385
386 type result struct {
387 Name string
388 Email string
389 }
390
391 var ress []result
392 DB.Raw("SELECT name, age FROM users WHERE name = ? or name = ?", user2.Name, user3.Name).Scan(&ress)
393 if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name {
394 t.Errorf("Raw with scan")
395 }
396
397 rows, _ := DB.Raw("select name, age from users where name = ?", user3.Name).Rows()
398 count := 0
399 for rows.Next() {
400 count++
401 }
402 if count != 1 {
403 t.Errorf("Raw with Rows should find one record with name 3")
404 }
405
406 DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
407 if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.RecordNotFound {
408 t.Error("Raw sql to update records")
409 }
410 }
411
412 func TestGroup(t *testing.T) {
413 rows, err := DB.Select("name").Table("users").Group("name").Rows()
414
415 if err == nil {
416 defer rows.Close()
417 for rows.Next() {
418 var name string
419 rows.Scan(&name)
420 }
421 } else {
422 t.Errorf("Should not raise any error")
423 }
424 }
425
426 func TestJoins(t *testing.T) {
427 var user = User{
428 Name: "joins",
429 Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
430 }
431 DB.Save(&user)
432
433 var result User
434 DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").First(&result)
435 if result.Name != "joins" || result.Id != user.Id {
436 t.Errorf("Should find all two emails with Join")
437 }
438 }
439
440 func TestJoinsWithSelect(t *testing.T) {
441 type result struct {
442 Name string
443 Email string
444 }
445
446 user := User{
447 Name: "joins_with_select",
448 Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
449 }
450 DB.Save(&user)
451
452 var results []result
453 DB.Table("users").Select("name, email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results)
454 if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" {
455 t.Errorf("Should find all two emails with Join select")
456 }
457 }
458
459 func TestHaving(t *testing.T) {
460 rows, err := DB.Select("name, count(*) as total").Table("users").Group("name").Having("name IN (?)", []string{"2", "3"}).Rows()
461
462 if err == nil {
463 defer rows.Close()
464 for rows.Next() {
465 var name string
466 var total int64
467 rows.Scan(&name, &total)
468
469 if name == "2" && total != 1 {
470 t.Errorf("Should have one user having name 2")
471 }
472 if name == "3" && total != 2 {
473 t.Errorf("Should have two users having name 3")
474 }
475 }
476 } else {
477 t.Errorf("Should not raise any error")
478 }
479 }
480
481 func DialectHasTzSupport() bool {
482 // NB: mssql and FoundationDB do not support time zones.
483 if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" {
484 return false
485 }
486 return true
487 }
488
489 func TestTimeWithZone(t *testing.T) {
490 var format = "2006-01-02 15:04:05 -0700"
491 var times []time.Time
492 GMT8, _ := time.LoadLocation("Asia/Shanghai")
493 times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8))
494 times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC))
495
496 for index, vtime := range times {
497 name := "time_with_zone_" + strconv.Itoa(index)
498 user := User{Name: name, Birthday: vtime}
499
500 if !DialectHasTzSupport() {
501 // If our driver dialect doesn't support TZ's, just use UTC for everything here.
502 user.Birthday = vtime.UTC()
503 }
504
505 DB.Save(&user)
506 expectedBirthday := "2013-02-18 17:51:49 +0000"
507 foundBirthday := user.Birthday.UTC().Format(format)
508 if foundBirthday != expectedBirthday {
509 t.Errorf("User's birthday should not be changed after save for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday)
510 }
511
512 var findUser, findUser2, findUser3 User
513 DB.First(&findUser, "name = ?", name)
514 foundBirthday = findUser.Birthday.UTC().Format(format)
515 if foundBirthday != expectedBirthday {
516 t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v or %+v", name, expectedBirthday, foundBirthday)
517 }
518
519 if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
520 t.Errorf("User should be found")
521 }
522
523 if !DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(time.Minute)).First(&findUser3).RecordNotFound() {
524 t.Errorf("User should not be found")
525 }
526 }
527 }
528
529 func TestHstore(t *testing.T) {
530 type Details struct {
531 Id int64
532 Bulk gorm.Hstore
533 }
534
535 if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
536 t.Skip()
537 }
538
539 if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").Error; err != nil {
540 fmt.Println("\033[31mHINT: Must be superuser to create hstore extension (ALTER USER gorm WITH SUPERUSER;)\033[0m")
541 panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err))
542 }
543
544 DB.Exec("drop table details")
545
546 if err := DB.CreateTable(&Details{}).Error; err != nil {
547 panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
548 }
549
550 bankAccountId, phoneNumber, opinion := "123456", "14151321232", "sharkbait"
551 bulk := map[string]*string{
552 "bankAccountId": &bankAccountId,
553 "phoneNumber": &phoneNumber,
554 "opinion": &opinion,
555 }
556 d := Details{Bulk: bulk}
557 DB.Save(&d)
558
559 var d2 Details
560 if err := DB.First(&d2).Error; err != nil {
561 t.Errorf("Got error when tried to fetch details: %+v", err)
562 }
563
564 for k := range bulk {
565 if r, ok := d2.Bulk[k]; ok {
566 if res, _ := bulk[k]; *res != *r {
567 t.Errorf("Details should be equal")
568 }
569 } else {
570 t.Errorf("Details should be existed")
571 }
572 }
573 }
574
575 func TestSetAndGet(t *testing.T) {
576 if value, ok := DB.Set("hello", "world").Get("hello"); !ok {
577 t.Errorf("Should be able to get setting after set")
578 } else {
579 if value.(string) != "world" {
580 t.Errorf("Setted value should not be changed")
581 }
582 }
583
584 if _, ok := DB.Get("non_existing"); ok {
585 t.Errorf("Get non existing key should return error")
586 }
587 }
588
589 func TestCompatibilityMode(t *testing.T) {
590 DB, _ := gorm.Open("testdb", "")
591 testdb.SetQueryFunc(func(query string) (driver.Rows, error) {
592 columns := []string{"id", "name", "age"}
593 result := `
594 1,Tim,20
595 2,Joe,25
596 3,Bob,30
597 `
598 return testdb.RowsFromCSVString(columns, result), nil
599 })
600
601 var users []User
602 DB.Find(&users)
603 if (users[0].Name != "Tim") || len(users) != 3 {
604 t.Errorf("Unexcepted result returned")
605 }
606 }
607
608 func TestOpenExistingDB(t *testing.T) {
609 DB.Save(&User{Name: "jnfeinstein"})
610 dialect := os.Getenv("GORM_DIALECT")
611
612 db, err := gorm.Open(dialect, DB.DB())
613 if err != nil {
614 t.Errorf("Should have wrapped the existing DB connection")
615 }
616
617 var user User
618 if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound {
619 t.Errorf("Should have found existing record")
620 }
621 }
622
623 func BenchmarkGorm(b *testing.B) {
624 b.N = 2000
625 for x := 0; x < b.N; x++ {
626 e := strconv.Itoa(x) + "benchmark@example.org"
627 email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
628 // Insert
629 DB.Save(&email)
630 // Query
631 DB.First(&BigEmail{}, "email = ?", e)
632 // Update
633 DB.Model(&email).UpdateColumn("email", "new-"+e)
634 // Delete
635 DB.Delete(&email)
636 }
637 }
638
639 func BenchmarkRawSql(b *testing.B) {
640 DB, _ := sql.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable")
641 DB.SetMaxIdleConns(10)
642 insertSql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id"
643 querySql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1"
644 updateSql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
645 deleteSql := "DELETE FROM orders WHERE id = $1"
646
647 b.N = 2000
648 for x := 0; x < b.N; x++ {
649 var id int64
650 e := strconv.Itoa(x) + "benchmark@example.org"
651 email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
652 // Insert
653 DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
654 // Query
655 rows, _ := DB.Query(querySql, email.Email)
656 rows.Close()
657 // Update
658 DB.Exec(updateSql, "new-"+e, time.Now(), id)
659 // Delete
660 DB.Exec(deleteSql, id)
661 }
662 }
0 package gorm_test
1
2 import (
3 "fmt"
4 "testing"
5 "time"
6 )
7
8 func runMigration() {
9 if err := DB.DropTableIfExists(&User{}).Error; err != nil {
10 fmt.Printf("Got error when try to delete table users, %+v\n", err)
11 }
12
13 for _, table := range []string{"animals", "user_languages"} {
14 DB.Exec(fmt.Sprintf("drop table %v;", table))
15 }
16
17 values := []interface{}{&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}}
18 for _, value := range values {
19 DB.DropTable(value)
20 }
21
22 if err := DB.AutoMigrate(values...).Error; err != nil {
23 panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
24 }
25 }
26
27 func TestIndexes(t *testing.T) {
28 if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil {
29 t.Errorf("Got error when tried to create index: %+v", err)
30 }
31
32 scope := DB.NewScope(&Email{})
33 if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
34 t.Errorf("Email should have index idx_email_email")
35 }
36
37 if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil {
38 t.Errorf("Got error when tried to remove index: %+v", err)
39 }
40
41 if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
42 t.Errorf("Email's index idx_email_email should be deleted")
43 }
44
45 if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
46 t.Errorf("Got error when tried to create index: %+v", err)
47 }
48
49 if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
50 t.Errorf("Email should have index idx_email_email_and_user_id")
51 }
52
53 if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
54 t.Errorf("Got error when tried to remove index: %+v", err)
55 }
56
57 if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
58 t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
59 }
60
61 if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
62 t.Errorf("Got error when tried to create index: %+v", err)
63 }
64
65 if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
66 t.Errorf("Email should have index idx_email_email_and_user_id")
67 }
68
69 if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil {
70 t.Errorf("Should get to create duplicate record when having unique index")
71 }
72
73 if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
74 t.Errorf("Got error when tried to remove index: %+v", err)
75 }
76
77 if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
78 t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
79 }
80
81 if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil {
82 t.Errorf("Should be able to create duplicated emails after remove unique index")
83 }
84 }
85
86 type BigEmail struct {
87 Id int64
88 UserId int64
89 Email string `sql:"index:idx_email_agent"`
90 UserAgent string `sql:"index:idx_email_agent"`
91 RegisteredAt time.Time `sql:"unique_index"`
92 CreatedAt time.Time
93 UpdatedAt time.Time
94 }
95
96 func (b BigEmail) TableName() string {
97 return "emails"
98 }
99
100 func TestAutoMigration(t *testing.T) {
101 DB.AutoMigrate(&Address{})
102 if err := DB.Table("emails").AutoMigrate(&BigEmail{}).Error; err != nil {
103 t.Errorf("Auto Migrate should not raise any error")
104 }
105
106 DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()})
107
108 scope := DB.NewScope(&BigEmail{})
109 if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_agent") {
110 t.Errorf("Failed to create index")
111 }
112
113 if !scope.Dialect().HasIndex(scope, scope.TableName(), "uix_emails_registered_at") {
114 t.Errorf("Failed to create index")
115 }
116
117 var bigemail BigEmail
118 DB.First(&bigemail, "user_agent = ?", "pc")
119 if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() {
120 t.Error("Big Emails should be saved and fetched correctly")
121 }
122 }
0 package gorm
1
2 import "time"
3
4 type Model struct {
5 ID uint `gorm:"primary_key"`
6 CreatedAt time.Time
7 UpdatedAt time.Time
8 DeletedAt *time.Time `sql:"index"`
9 }
0 package gorm
1
2 import (
3 "database/sql"
4 "fmt"
5 "go/ast"
6 "reflect"
7 "strconv"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/qor/inflection"
13 )
14
15 var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
16 return defaultTableName
17 }
18
19 type safeModelStructsMap struct {
20 m map[reflect.Type]*ModelStruct
21 l *sync.RWMutex
22 }
23
24 func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) {
25 s.l.Lock()
26 defer s.l.Unlock()
27 s.m[key] = value
28 }
29
30 func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct {
31 s.l.RLock()
32 defer s.l.RUnlock()
33 return s.m[key]
34 }
35
36 func newModelStructsMap() *safeModelStructsMap {
37 return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)}
38 }
39
40 var modelStructsMap = newModelStructsMap()
41
42 type ModelStruct struct {
43 PrimaryFields []*StructField
44 StructFields []*StructField
45 ModelType reflect.Type
46 defaultTableName string
47 cached bool
48 }
49
50 func (s ModelStruct) TableName(db *DB) string {
51 return DefaultTableNameHandler(db, s.defaultTableName)
52 }
53
54 type StructField struct {
55 DBName string
56 Name string
57 Names []string
58 IsPrimaryKey bool
59 IsNormal bool
60 IsIgnored bool
61 IsScanner bool
62 HasDefaultValue bool
63 Tag reflect.StructTag
64 Struct reflect.StructField
65 IsForeignKey bool
66 Relationship *Relationship
67 }
68
69 func (structField *StructField) clone() *StructField {
70 return &StructField{
71 DBName: structField.DBName,
72 Name: structField.Name,
73 Names: structField.Names,
74 IsPrimaryKey: structField.IsPrimaryKey,
75 IsNormal: structField.IsNormal,
76 IsIgnored: structField.IsIgnored,
77 IsScanner: structField.IsScanner,
78 HasDefaultValue: structField.HasDefaultValue,
79 Tag: structField.Tag,
80 Struct: structField.Struct,
81 IsForeignKey: structField.IsForeignKey,
82 Relationship: structField.Relationship,
83 }
84 }
85
86 type Relationship struct {
87 Kind string
88 PolymorphicType string
89 PolymorphicDBName string
90 ForeignFieldNames []string
91 ForeignDBNames []string
92 AssociationForeignFieldNames []string
93 AssociationForeignStructFieldNames []string
94 AssociationForeignDBNames []string
95 JoinTableHandler JoinTableHandlerInterface
96 }
97
98 func (scope *Scope) GetModelStruct() *ModelStruct {
99 var modelStruct ModelStruct
100
101 reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value))
102 if !reflectValue.IsValid() {
103 return &modelStruct
104 }
105
106 if reflectValue.Kind() == reflect.Slice {
107 reflectValue = reflect.Indirect(reflect.New(reflectValue.Type().Elem()))
108 }
109
110 scopeType := reflectValue.Type()
111
112 if scopeType.Kind() == reflect.Ptr {
113 scopeType = scopeType.Elem()
114 }
115
116 if value := modelStructsMap.Get(scopeType); value != nil {
117 return value
118 }
119
120 modelStruct.ModelType = scopeType
121 if scopeType.Kind() != reflect.Struct {
122 return &modelStruct
123 }
124
125 if tabler, ok := reflect.New(scopeType).Interface().(interface {
126 TableName() string
127 }); ok {
128 modelStruct.defaultTableName = tabler.TableName()
129 } else {
130 name := ToDBName(scopeType.Name())
131 if scope.db == nil || !scope.db.parent.singularTable {
132 name = inflection.Plural(name)
133 }
134
135 modelStruct.defaultTableName = name
136 }
137
138 // Get all fields
139 fields := []*StructField{}
140 for i := 0; i < scopeType.NumField(); i++ {
141 if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) {
142 field := &StructField{
143 Struct: fieldStruct,
144 Name: fieldStruct.Name,
145 Names: []string{fieldStruct.Name},
146 Tag: fieldStruct.Tag,
147 }
148
149 if fieldStruct.Tag.Get("sql") == "-" {
150 field.IsIgnored = true
151 } else {
152 sqlSettings := parseTagSetting(field.Tag.Get("sql"))
153 gormSettings := parseTagSetting(field.Tag.Get("gorm"))
154 if _, ok := gormSettings["PRIMARY_KEY"]; ok {
155 field.IsPrimaryKey = true
156 modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
157 }
158
159 if _, ok := sqlSettings["DEFAULT"]; ok {
160 field.HasDefaultValue = true
161 }
162
163 if value, ok := gormSettings["COLUMN"]; ok {
164 field.DBName = value
165 } else {
166 field.DBName = ToDBName(fieldStruct.Name)
167 }
168 }
169 fields = append(fields, field)
170 }
171 }
172
173 var finished = make(chan bool)
174 go func(finished chan bool) {
175 for _, field := range fields {
176 if !field.IsIgnored {
177 fieldStruct := field.Struct
178 indirectType := fieldStruct.Type
179 if indirectType.Kind() == reflect.Ptr {
180 indirectType = indirectType.Elem()
181 }
182
183 if _, isScanner := reflect.New(indirectType).Interface().(sql.Scanner); isScanner {
184 field.IsScanner, field.IsNormal = true, true
185 }
186
187 if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime {
188 field.IsNormal = true
189 }
190
191 if !field.IsNormal {
192 gormSettings := parseTagSetting(field.Tag.Get("gorm"))
193 toScope := scope.New(reflect.New(fieldStruct.Type).Interface())
194
195 getForeignField := func(column string, fields []*StructField) *StructField {
196 for _, field := range fields {
197 if field.Name == column || field.DBName == ToDBName(column) {
198 return field
199 }
200 }
201 return nil
202 }
203
204 var relationship = &Relationship{}
205
206 if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" {
207 if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil {
208 if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil {
209 relationship.ForeignFieldNames = []string{polymorphicField.Name}
210 relationship.ForeignDBNames = []string{polymorphicField.DBName}
211 relationship.AssociationForeignFieldNames = []string{scope.PrimaryField().Name}
212 relationship.AssociationForeignDBNames = []string{scope.PrimaryField().DBName}
213 relationship.PolymorphicType = polymorphicType.Name
214 relationship.PolymorphicDBName = polymorphicType.DBName
215 polymorphicType.IsForeignKey = true
216 polymorphicField.IsForeignKey = true
217 }
218 }
219 }
220
221 var foreignKeys []string
222 if foreignKey, ok := gormSettings["FOREIGNKEY"]; ok {
223 foreignKeys = append(foreignKeys, foreignKey)
224 }
225 switch indirectType.Kind() {
226 case reflect.Slice:
227 elemType := indirectType.Elem()
228 if elemType.Kind() == reflect.Ptr {
229 elemType = elemType.Elem()
230 }
231
232 if elemType.Kind() == reflect.Struct {
233 if many2many := gormSettings["MANY2MANY"]; many2many != "" {
234 relationship.Kind = "many_to_many"
235
236 // foreign keys
237 if len(foreignKeys) == 0 {
238 for _, field := range scope.PrimaryFields() {
239 foreignKeys = append(foreignKeys, field.DBName)
240 }
241 }
242
243 for _, foreignKey := range foreignKeys {
244 if field, ok := scope.FieldByName(foreignKey); ok {
245 relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName)
246 joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName
247 relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
248 }
249 }
250
251 // association foreign keys
252 var associationForeignKeys []string
253 if foreignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
254 associationForeignKeys = []string{gormSettings["ASSOCIATIONFOREIGNKEY"]}
255 } else {
256 for _, field := range toScope.PrimaryFields() {
257 associationForeignKeys = append(associationForeignKeys, field.DBName)
258 }
259 }
260
261 for _, name := range associationForeignKeys {
262 if field, ok := toScope.FieldByName(name); ok {
263 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
264 relationship.AssociationForeignStructFieldNames = append(relationship.AssociationForeignFieldNames, field.Name)
265 joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
266 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
267 }
268 }
269
270 joinTableHandler := JoinTableHandler{}
271 joinTableHandler.Setup(relationship, many2many, scopeType, elemType)
272 relationship.JoinTableHandler = &joinTableHandler
273 field.Relationship = relationship
274 } else {
275 relationship.Kind = "has_many"
276
277 if len(foreignKeys) == 0 {
278 for _, field := range scope.PrimaryFields() {
279 if foreignField := getForeignField(scopeType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil {
280 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name)
281 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName)
282 relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
283 relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
284 foreignField.IsForeignKey = true
285 }
286 }
287 } else {
288 for _, foreignKey := range foreignKeys {
289 if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
290 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name)
291 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName)
292 relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
293 relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
294 foreignField.IsForeignKey = true
295 }
296 }
297 }
298
299 if len(relationship.ForeignFieldNames) != 0 {
300 field.Relationship = relationship
301 }
302 }
303 } else {
304 field.IsNormal = true
305 }
306 case reflect.Struct:
307 if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
308 for _, toField := range toScope.GetStructFields() {
309 toField = toField.clone()
310 toField.Names = append([]string{fieldStruct.Name}, toField.Names...)
311 modelStruct.StructFields = append(modelStruct.StructFields, toField)
312 if toField.IsPrimaryKey {
313 modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField)
314 }
315 }
316 continue
317 } else {
318 if len(foreignKeys) == 0 {
319 for _, f := range scope.PrimaryFields() {
320 if foreignField := getForeignField(modelStruct.ModelType.Name()+f.Name, toScope.GetStructFields()); foreignField != nil {
321 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name)
322 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName)
323 relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
324 relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
325 foreignField.IsForeignKey = true
326 }
327 }
328 } else {
329 for _, foreignKey := range foreignKeys {
330 if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
331 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name)
332 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName)
333 relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
334 relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
335 foreignField.IsForeignKey = true
336 }
337 }
338 }
339
340 if len(relationship.ForeignFieldNames) != 0 {
341 relationship.Kind = "has_one"
342 field.Relationship = relationship
343 } else {
344 if len(foreignKeys) == 0 {
345 for _, f := range toScope.PrimaryFields() {
346 if foreignField := getForeignField(field.Name+f.Name, fields); foreignField != nil {
347 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name)
348 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName)
349 relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
350 relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
351 foreignField.IsForeignKey = true
352 }
353 }
354 } else {
355 for _, foreignKey := range foreignKeys {
356 if foreignField := getForeignField(foreignKey, fields); foreignField != nil {
357 relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, toScope.PrimaryField().Name)
358 relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, toScope.PrimaryField().DBName)
359 relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
360 relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
361 foreignField.IsForeignKey = true
362 }
363 }
364 }
365
366 if len(relationship.ForeignFieldNames) != 0 {
367 relationship.Kind = "belongs_to"
368 field.Relationship = relationship
369 }
370 }
371 }
372 default:
373 field.IsNormal = true
374 }
375 }
376
377 if field.IsNormal {
378 if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" {
379 field.IsPrimaryKey = true
380 modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
381 }
382 }
383 }
384 modelStruct.StructFields = append(modelStruct.StructFields, field)
385 }
386 finished <- true
387 }(finished)
388
389 modelStructsMap.Set(scopeType, &modelStruct)
390
391 <-finished
392 modelStruct.cached = true
393
394 return &modelStruct
395 }
396
397 func (scope *Scope) GetStructFields() (fields []*StructField) {
398 return scope.GetModelStruct().StructFields
399 }
400
401 func (scope *Scope) generateSqlTag(field *StructField) string {
402 var sqlType string
403 structType := field.Struct.Type
404 if structType.Kind() == reflect.Ptr {
405 structType = structType.Elem()
406 }
407 reflectValue := reflect.Indirect(reflect.New(structType))
408 sqlSettings := parseTagSetting(field.Tag.Get("sql"))
409
410 if value, ok := sqlSettings["TYPE"]; ok {
411 sqlType = value
412 }
413
414 additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"]
415 if value, ok := sqlSettings["DEFAULT"]; ok {
416 additionalType = additionalType + " DEFAULT " + value
417 }
418
419 if field.IsScanner {
420 var getScannerValue func(reflect.Value)
421 getScannerValue = func(value reflect.Value) {
422 reflectValue = value
423 if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct {
424 getScannerValue(reflectValue.Field(0))
425 }
426 }
427 getScannerValue(reflectValue)
428 }
429
430 if sqlType == "" {
431 var size = 255
432
433 if value, ok := sqlSettings["SIZE"]; ok {
434 size, _ = strconv.Atoi(value)
435 }
436
437 _, autoIncrease := sqlSettings["AUTO_INCREMENT"]
438 if field.IsPrimaryKey {
439 autoIncrease = true
440 }
441
442 sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
443 }
444
445 if strings.TrimSpace(additionalType) == "" {
446 return sqlType
447 } else {
448 return fmt.Sprintf("%v %v", sqlType, additionalType)
449 }
450 }
451
452 func parseTagSetting(str string) map[string]string {
453 tags := strings.Split(str, ";")
454 setting := map[string]string{}
455 for _, value := range tags {
456 v := strings.Split(value, ":")
457 k := strings.TrimSpace(strings.ToUpper(v[0]))
458 if len(v) >= 2 {
459 setting[k] = strings.Join(v[1:], ":")
460 } else {
461 setting[k] = k
462 }
463 }
464 return setting
465 }
0 package gorm
1
2 import (
3 "fmt"
4 "reflect"
5 "time"
6 )
7
8 type mssql struct {
9 commonDialect
10 }
11
12 func (mssql) HasTop() bool {
13 return true
14 }
15
16 func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
17 switch value.Kind() {
18 case reflect.Bool:
19 return "bit"
20 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
21 if autoIncrease {
22 return "int IDENTITY(1,1)"
23 }
24 return "int"
25 case reflect.Int64, reflect.Uint64:
26 if autoIncrease {
27 return "bigint IDENTITY(1,1)"
28 }
29 return "bigint"
30 case reflect.Float32, reflect.Float64:
31 return "float"
32 case reflect.String:
33 if size > 0 && size < 65532 {
34 return fmt.Sprintf("nvarchar(%d)", size)
35 }
36 return "text"
37 case reflect.Struct:
38 if _, ok := value.Interface().(time.Time); ok {
39 return "datetime2"
40 }
41 default:
42 if _, ok := value.Interface().([]byte); ok {
43 if size > 0 && size < 65532 {
44 return fmt.Sprintf("varchar(%d)", size)
45 }
46 return "text"
47 }
48 }
49 panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
50 }
51
52 func (s mssql) HasTable(scope *Scope, tableName string) bool {
53 var (
54 count int
55 databaseName = s.CurrentDatabase(scope)
56 )
57 s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName)
58 return count > 0
59 }
60
61 func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
62 var (
63 count int
64 databaseName = s.CurrentDatabase(scope)
65 )
66 s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
67 return count > 0
68 }
69
70 func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
71 var count int
72 s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName)
73 return count > 0
74 }
75
76 func (s mssql) CurrentDatabase(scope *Scope) (name string) {
77 s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]")
78 return
79 }
0 package gorm_test
1
2 import (
3 "fmt"
4 "os"
5 "testing"
6 )
7
8 type Blog struct {
9 ID uint `gorm:"primary_key"`
10 Locale string `gorm:"primary_key"`
11 Subject string
12 Body string
13 Tags []Tag `gorm:"many2many:blog_tags;"`
14 }
15
16 type Tag struct {
17 ID uint `gorm:"primary_key"`
18 Locale string `gorm:"primary_key"`
19 Value string
20 }
21
22 func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
23 if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" {
24 DB.Exec(fmt.Sprintf("drop table blog_tags;"))
25 DB.AutoMigrate(&Blog{}, &Tag{})
26 blog := Blog{
27 Locale: "ZH",
28 Subject: "subject",
29 Body: "body",
30 Tags: []Tag{
31 {Locale: "ZH", Value: "tag1"},
32 {Locale: "ZH", Value: "tag2"},
33 },
34 }
35
36 DB.Save(&blog)
37 DB.Model(&blog).Association("Tags").Append([]Tag{{Locale: "ZH", Value: "tag3"}})
38
39 var tags []Tag
40 DB.Model(&blog).Related(&tags, "Tags")
41 if len(tags) != 3 {
42 t.Errorf("should found 3 tags with blog")
43 }
44 }
45 }
0 package gorm
1
2 import (
3 "fmt"
4 "reflect"
5 "time"
6 )
7
8 type mysql struct {
9 commonDialect
10 }
11
12 func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
13 switch value.Kind() {
14 case reflect.Bool:
15 return "boolean"
16 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
17 if autoIncrease {
18 return "int AUTO_INCREMENT"
19 }
20 return "int"
21 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
22 if autoIncrease {
23 return "int unsigned AUTO_INCREMENT"
24 }
25 return "int unsigned"
26 case reflect.Int64:
27 if autoIncrease {
28 return "bigint AUTO_INCREMENT"
29 }
30 return "bigint"
31 case reflect.Uint64:
32 if autoIncrease {
33 return "bigint unsigned AUTO_INCREMENT"
34 }
35 return "bigint unsigned"
36 case reflect.Float32, reflect.Float64:
37 return "double"
38 case reflect.String:
39 if size > 0 && size < 65532 {
40 return fmt.Sprintf("varchar(%d)", size)
41 }
42 return "longtext"
43 case reflect.Struct:
44 if _, ok := value.Interface().(time.Time); ok {
45 return "timestamp NULL"
46 }
47 default:
48 if _, ok := value.Interface().([]byte); ok {
49 if size > 0 && size < 65532 {
50 return fmt.Sprintf("varbinary(%d)", size)
51 }
52 return "longblob"
53 }
54 }
55 panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
56 }
57
58 func (mysql) Quote(key string) string {
59 return fmt.Sprintf("`%s`", key)
60 }
61
62 func (mysql) SelectFromDummyTable() string {
63 return "FROM DUAL"
64 }
65
66 func (s mysql) CurrentDatabase(scope *Scope) (name string) {
67 s.RawScanString(scope, &name, "SELECT DATABASE()")
68 return
69 }
0 package gorm_test
1
2 import "testing"
3
4 type PointerStruct struct {
5 ID int64
6 Name *string
7 Num *int
8 }
9
10 type NormalStruct struct {
11 ID int64
12 Name string
13 Num int
14 }
15
16 func TestPointerFields(t *testing.T) {
17 DB.DropTable(&PointerStruct{})
18 DB.AutoMigrate(&PointerStruct{})
19 var name = "pointer struct 1"
20 var num = 100
21 pointerStruct := PointerStruct{Name: &name, Num: &num}
22 if DB.Create(&pointerStruct).Error != nil {
23 t.Errorf("Failed to save pointer struct")
24 }
25
26 var pointerStructResult PointerStruct
27 if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).Error; err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num {
28 t.Errorf("Failed to query saved pointer struct")
29 }
30
31 var tableName = DB.NewScope(&PointerStruct{}).TableName()
32
33 var normalStruct NormalStruct
34 DB.Table(tableName).First(&normalStruct)
35 if normalStruct.Name != name || normalStruct.Num != num {
36 t.Errorf("Failed to query saved Normal struct")
37 }
38
39 var nilPointerStruct = PointerStruct{}
40 if err := DB.Create(&nilPointerStruct).Error; err != nil {
41 t.Errorf("Failed to save nil pointer struct", err)
42 }
43
44 var pointerStruct2 PointerStruct
45 if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
46 t.Errorf("Failed to query saved nil pointer struct", err)
47 }
48
49 var normalStruct2 NormalStruct
50 if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
51 t.Errorf("Failed to query saved nil pointer struct", err)
52 }
53
54 var partialNilPointerStruct1 = PointerStruct{Num: &num}
55 if err := DB.Create(&partialNilPointerStruct1).Error; err != nil {
56 t.Errorf("Failed to save partial nil pointer struct", err)
57 }
58
59 var pointerStruct3 PointerStruct
60 if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num {
61 t.Errorf("Failed to query saved partial nil pointer struct", err)
62 }
63
64 var normalStruct3 NormalStruct
65 if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num {
66 t.Errorf("Failed to query saved partial pointer struct", err)
67 }
68
69 var partialNilPointerStruct2 = PointerStruct{Name: &name}
70 if err := DB.Create(&partialNilPointerStruct2).Error; err != nil {
71 t.Errorf("Failed to save partial nil pointer struct", err)
72 }
73
74 var pointerStruct4 PointerStruct
75 if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name {
76 t.Errorf("Failed to query saved partial nil pointer struct", err)
77 }
78
79 var normalStruct4 NormalStruct
80 if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name {
81 t.Errorf("Failed to query saved partial pointer struct", err)
82 }
83 }
0 package gorm_test
1
2 import "testing"
3
4 type Cat struct {
5 Id int
6 Name string
7 Toy Toy `gorm:"polymorphic:Owner;"`
8 }
9
10 type Dog struct {
11 Id int
12 Name string
13 Toys []Toy `gorm:"polymorphic:Owner;"`
14 }
15
16 type Toy struct {
17 Id int
18 Name string
19 OwnerId int
20 OwnerType string
21 }
22
23 func TestPolymorphic(t *testing.T) {
24 DB.AutoMigrate(&Cat{})
25 DB.AutoMigrate(&Dog{})
26 DB.AutoMigrate(&Toy{})
27
28 cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat nip"}}
29 dog := Dog{Name: "Pluto", Toys: []Toy{Toy{Name: "orange ball"}, Toy{Name: "yellow ball"}}}
30 DB.Save(&cat).Save(&dog)
31
32 var catToys []Toy
33 if DB.Model(&cat).Related(&catToys, "Toy").RecordNotFound() {
34 t.Errorf("Did not find any has one polymorphic association")
35 } else if len(catToys) != 1 {
36 t.Errorf("Should have found only one polymorphic has one association")
37 } else if catToys[0].Name != cat.Toy.Name {
38 t.Errorf("Should have found the proper has one polymorphic association")
39 }
40
41 var dogToys []Toy
42 if DB.Model(&dog).Related(&dogToys, "Toys").RecordNotFound() {
43 t.Errorf("Did not find any polymorphic has many associations")
44 } else if len(dogToys) != len(dog.Toys) {
45 t.Errorf("Should have found all polymorphic has many associations")
46 }
47
48 if DB.Model(&cat).Association("Toy").Count() != 1 {
49 t.Errorf("Should return one polymorphic has one association")
50 }
51
52 if DB.Model(&dog).Association("Toys").Count() != 2 {
53 t.Errorf("Should return two polymorphic has many associations")
54 }
55 }
0 package gorm
1
2 import (
3 "database/sql"
4 "database/sql/driver"
5 "fmt"
6 "reflect"
7 "time"
8
9 "github.com/lib/pq/hstore"
10 )
11
12 type postgres struct {
13 commonDialect
14 }
15
16 func (postgres) BinVar(i int) string {
17 return fmt.Sprintf("$%v", i)
18 }
19
20 func (postgres) SupportLastInsertId() bool {
21 return false
22 }
23
24 func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
25 switch value.Kind() {
26 case reflect.Bool:
27 return "boolean"
28 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
29 if autoIncrease {
30 return "serial"
31 }
32 return "integer"
33 case reflect.Int64, reflect.Uint64:
34 if autoIncrease {
35 return "bigserial"
36 }
37 return "bigint"
38 case reflect.Float32, reflect.Float64:
39 return "numeric"
40 case reflect.String:
41 if size > 0 && size < 65532 {
42 return fmt.Sprintf("varchar(%d)", size)
43 }
44 return "text"
45 case reflect.Struct:
46 if _, ok := value.Interface().(time.Time); ok {
47 return "timestamp with time zone"
48 }
49 case reflect.Map:
50 if value.Type() == hstoreType {
51 return "hstore"
52 }
53 default:
54 if _, ok := value.Interface().([]byte); ok {
55 return "bytea"
56 }
57 }
58 panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
59 }
60
61 func (s postgres) ReturningStr(tableName, key string) string {
62 return fmt.Sprintf("RETURNING %v.%v", tableName, key)
63 }
64
65 func (s postgres) HasTable(scope *Scope, tableName string) bool {
66 var count int
67 s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName)
68 return count > 0
69 }
70
71 func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) bool {
72 var count int
73 s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName)
74 return count > 0
75 }
76
77 func (postgres) RemoveIndex(scope *Scope, indexName string) {
78 scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
79 }
80
81 func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
82 var count int
83 s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName)
84 return count > 0
85 }
86
87 func (s postgres) CurrentDatabase(scope *Scope) (name string) {
88 s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()")
89 return
90 }
91
92 var hstoreType = reflect.TypeOf(Hstore{})
93
94 type Hstore map[string]*string
95
96 func (h Hstore) Value() (driver.Value, error) {
97 hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
98 if len(h) == 0 {
99 return nil, nil
100 }
101
102 for key, value := range h {
103 var s sql.NullString
104 if value != nil {
105 s.String = *value
106 s.Valid = true
107 }
108 hstore.Map[key] = s
109 }
110 return hstore.Value()
111 }
112
113 func (h *Hstore) Scan(value interface{}) error {
114 hstore := hstore.Hstore{}
115
116 if err := hstore.Scan(value); err != nil {
117 return err
118 }
119
120 if len(hstore.Map) == 0 {
121 return nil
122 }
123
124 *h = Hstore{}
125 for k := range hstore.Map {
126 if hstore.Map[k].Valid {
127 s := hstore.Map[k].String
128 (*h)[k] = &s
129 } else {
130 (*h)[k] = nil
131 }
132 }
133
134 return nil
135 }
0 package gorm
1
2 import (
3 "database/sql/driver"
4 "errors"
5 "fmt"
6 "reflect"
7 "strings"
8 )
9
10 func getRealValue(value reflect.Value, columns []string) (results []interface{}) {
11 for _, column := range columns {
12 if reflect.Indirect(value).FieldByName(column).IsValid() {
13 result := reflect.Indirect(value).FieldByName(column).Interface()
14 if r, ok := result.(driver.Valuer); ok {
15 result, _ = r.Value()
16 }
17 results = append(results, result)
18 }
19 }
20 return
21 }
22
23 func equalAsString(a interface{}, b interface{}) bool {
24 return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b)
25 }
26
27 func Preload(scope *Scope) {
28 if scope.Search.preload == nil {
29 return
30 }
31
32 preloadMap := map[string]bool{}
33 fields := scope.Fields()
34 for _, preload := range scope.Search.preload {
35 schema, conditions := preload.schema, preload.conditions
36 keys := strings.Split(schema, ".")
37 currentScope := scope
38 currentFields := fields
39 originalConditions := conditions
40 conditions = []interface{}{}
41 for i, key := range keys {
42 var found bool
43 if preloadMap[strings.Join(keys[:i+1], ".")] {
44 goto nextLoop
45 }
46
47 if i == len(keys)-1 {
48 conditions = originalConditions
49 }
50
51 for _, field := range currentFields {
52 if field.Name != key || field.Relationship == nil {
53 continue
54 }
55
56 found = true
57 switch field.Relationship.Kind {
58 case "has_one":
59 currentScope.handleHasOnePreload(field, conditions)
60 case "has_many":
61 currentScope.handleHasManyPreload(field, conditions)
62 case "belongs_to":
63 currentScope.handleBelongsToPreload(field, conditions)
64 case "many_to_many":
65 currentScope.handleHasManyToManyPreload(field, conditions)
66 default:
67 currentScope.Err(errors.New("not supported relation"))
68 }
69 break
70 }
71
72 if !found {
73 value := reflect.ValueOf(currentScope.Value)
74 if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface {
75 value = value.Index(0).Elem()
76 }
77 scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type()))
78 return
79 }
80
81 preloadMap[strings.Join(keys[:i+1], ".")] = true
82
83 nextLoop:
84 if i < len(keys)-1 {
85 currentScope = currentScope.getColumnsAsScope(key)
86 currentFields = currentScope.Fields()
87 }
88 }
89 }
90
91 }
92
93 func makeSlice(typ reflect.Type) interface{} {
94 if typ.Kind() == reflect.Slice {
95 typ = typ.Elem()
96 }
97 sliceType := reflect.SliceOf(typ)
98 slice := reflect.New(sliceType)
99 slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
100 return slice.Interface()
101 }
102
103 func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
104 relation := field.Relationship
105
106 primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
107 if len(primaryKeys) == 0 {
108 return
109 }
110
111 results := makeSlice(field.Struct.Type)
112 scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
113 resultValues := reflect.Indirect(reflect.ValueOf(results))
114
115 for i := 0; i < resultValues.Len(); i++ {
116 result := resultValues.Index(i)
117 if scope.IndirectValue().Kind() == reflect.Slice {
118 value := getRealValue(result, relation.ForeignFieldNames)
119 objects := scope.IndirectValue()
120 for j := 0; j < objects.Len(); j++ {
121 if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) {
122 reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
123 break
124 }
125 }
126 } else {
127 if err := scope.SetColumn(field, result); err != nil {
128 scope.Err(err)
129 return
130 }
131 }
132 }
133 }
134
135 func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
136 relation := field.Relationship
137 primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
138 if len(primaryKeys) == 0 {
139 return
140 }
141
142 results := makeSlice(field.Struct.Type)
143 scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
144 resultValues := reflect.Indirect(reflect.ValueOf(results))
145
146 if scope.IndirectValue().Kind() == reflect.Slice {
147 for i := 0; i < resultValues.Len(); i++ {
148 result := resultValues.Index(i)
149 value := getRealValue(result, relation.ForeignFieldNames)
150 objects := scope.IndirectValue()
151 for j := 0; j < objects.Len(); j++ {
152 object := reflect.Indirect(objects.Index(j))
153 if equalAsString(getRealValue(object, relation.AssociationForeignFieldNames), value) {
154 f := object.FieldByName(field.Name)
155 f.Set(reflect.Append(f, result))
156 break
157 }
158 }
159 }
160 } else {
161 scope.SetColumn(field, resultValues)
162 }
163 }
164
165 func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
166 relation := field.Relationship
167 primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames)
168 if len(primaryKeys) == 0 {
169 return
170 }
171
172 results := makeSlice(field.Struct.Type)
173 scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
174 resultValues := reflect.Indirect(reflect.ValueOf(results))
175
176 for i := 0; i < resultValues.Len(); i++ {
177 result := resultValues.Index(i)
178 if scope.IndirectValue().Kind() == reflect.Slice {
179 value := getRealValue(result, relation.AssociationForeignFieldNames)
180 objects := scope.IndirectValue()
181 for j := 0; j < objects.Len(); j++ {
182 object := reflect.Indirect(objects.Index(j))
183 if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) {
184 object.FieldByName(field.Name).Set(result)
185 }
186 }
187 } else {
188 scope.SetColumn(field, result)
189 }
190 }
191 }
192
193 func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interface{}) {
194 relation := field.Relationship
195
196 joinTableHandler := relation.JoinTableHandler
197 destType := field.StructField.Struct.Type.Elem()
198 var isPtr bool
199 if destType.Kind() == reflect.Ptr {
200 isPtr = true
201 destType = destType.Elem()
202 }
203
204 var sourceKeys []string
205 var linkHash = make(map[string][]reflect.Value)
206
207 for _, key := range joinTableHandler.SourceForeignKeys() {
208 sourceKeys = append(sourceKeys, key.DBName)
209 }
210
211 db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*")
212 preloadJoinDB := joinTableHandler.JoinWith(joinTableHandler, db, scope.Value)
213 if len(conditions) > 0 {
214 preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...)
215 }
216 rows, err := preloadJoinDB.Rows()
217
218 if scope.Err(err) != nil {
219 return
220 }
221 defer rows.Close()
222
223 columns, _ := rows.Columns()
224 for rows.Next() {
225 elem := reflect.New(destType).Elem()
226 var values = make([]interface{}, len(columns))
227
228 fields := scope.New(elem.Addr().Interface()).Fields()
229
230 for index, column := range columns {
231 if field, ok := fields[column]; ok {
232 if field.Field.Kind() == reflect.Ptr {
233 values[index] = field.Field.Addr().Interface()
234 } else {
235 values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
236 }
237 } else {
238 var i interface{}
239 values[index] = &i
240 }
241 }
242
243 scope.Err(rows.Scan(values...))
244
245 var sourceKey []interface{}
246
247 for index, column := range columns {
248 value := values[index]
249 if field, ok := fields[column]; ok {
250 if field.Field.Kind() == reflect.Ptr {
251 field.Field.Set(reflect.ValueOf(value).Elem())
252 } else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
253 field.Field.Set(v)
254 }
255 } else if strInSlice(column, sourceKeys) {
256 sourceKey = append(sourceKey, *(value.(*interface{})))
257 }
258 }
259
260 if len(sourceKey) != 0 {
261 if isPtr {
262 linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem.Addr())
263 } else {
264 linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem)
265 }
266 }
267 }
268
269 var associationForeignStructFieldNames []string
270 for _, dbName := range relation.AssociationForeignFieldNames {
271 if field, ok := scope.FieldByName(dbName); ok {
272 associationForeignStructFieldNames = append(associationForeignStructFieldNames, field.Name)
273 }
274 }
275
276 if scope.IndirectValue().Kind() == reflect.Slice {
277 objects := scope.IndirectValue()
278 for j := 0; j < objects.Len(); j++ {
279 object := reflect.Indirect(objects.Index(j))
280 source := getRealValue(object, associationForeignStructFieldNames)
281 field := object.FieldByName(field.Name)
282 for _, link := range linkHash[toString(source)] {
283 field.Set(reflect.Append(field, link))
284 }
285 }
286 } else {
287 object := scope.IndirectValue()
288 source := getRealValue(object, associationForeignStructFieldNames)
289 field := object.FieldByName(field.Name)
290 for _, link := range linkHash[toString(source)] {
291 field.Set(reflect.Append(field, link))
292 }
293 }
294 }
295
296 func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) {
297 values := scope.IndirectValue()
298 switch values.Kind() {
299 case reflect.Slice:
300 for i := 0; i < values.Len(); i++ {
301 var result []interface{}
302 for _, column := range columns {
303 result = append(result, reflect.Indirect(values.Index(i)).FieldByName(column).Interface())
304 }
305 results = append(results, result)
306 }
307 case reflect.Struct:
308 var result []interface{}
309 for _, column := range columns {
310 result = append(result, values.FieldByName(column).Interface())
311 }
312 return [][]interface{}{result}
313 }
314 return
315 }
316
317 func (scope *Scope) getColumnsAsScope(column string) *Scope {
318 values := scope.IndirectValue()
319 switch values.Kind() {
320 case reflect.Slice:
321 modelType := values.Type().Elem()
322 if modelType.Kind() == reflect.Ptr {
323 modelType = modelType.Elem()
324 }
325 fieldStruct, _ := modelType.FieldByName(column)
326 var columns reflect.Value
327 if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr {
328 columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem()
329 } else {
330 columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem()
331 }
332 for i := 0; i < values.Len(); i++ {
333 column := reflect.Indirect(values.Index(i)).FieldByName(column)
334 if column.Kind() == reflect.Ptr {
335 column = column.Elem()
336 }
337 if column.Kind() == reflect.Slice {
338 for i := 0; i < column.Len(); i++ {
339 elem := column.Index(i)
340 if elem.CanAddr() {
341 columns = reflect.Append(columns, elem.Addr())
342 }
343 }
344 } else {
345 if column.CanAddr() {
346 columns = reflect.Append(columns, column.Addr())
347 }
348 }
349 }
350 return scope.New(columns.Interface())
351 case reflect.Struct:
352 field := values.FieldByName(column)
353 if !field.CanAddr() {
354 return nil
355 }
356 return scope.New(field.Addr().Interface())
357 }
358 return nil
359 }
0 package gorm_test
1
2 import (
3 "encoding/json"
4 "os"
5 "reflect"
6 "testing"
7 )
8
9 func getPreloadUser(name string) *User {
10 return getPreparedUser(name, "Preload")
11 }
12
13 func checkUserHasPreloadData(user User, t *testing.T) {
14 u := getPreloadUser(user.Name)
15 if user.BillingAddress.Address1 != u.BillingAddress.Address1 {
16 t.Error("Failed to preload user's BillingAddress")
17 }
18
19 if user.ShippingAddress.Address1 != u.ShippingAddress.Address1 {
20 t.Error("Failed to preload user's ShippingAddress")
21 }
22
23 if user.CreditCard.Number != u.CreditCard.Number {
24 t.Error("Failed to preload user's CreditCard")
25 }
26
27 if user.Company.Name != u.Company.Name {
28 t.Error("Failed to preload user's Company")
29 }
30
31 if len(user.Emails) != len(u.Emails) {
32 t.Error("Failed to preload user's Emails")
33 } else {
34 var found int
35 for _, e1 := range u.Emails {
36 for _, e2 := range user.Emails {
37 if e1.Email == e2.Email {
38 found++
39 break
40 }
41 }
42 }
43 if found != len(u.Emails) {
44 t.Error("Failed to preload user's email details")
45 }
46 }
47 }
48
49 func TestPreload(t *testing.T) {
50 user1 := getPreloadUser("user1")
51 DB.Save(user1)
52
53 preloadDB := DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress").
54 Preload("CreditCard").Preload("Emails").Preload("Company")
55 var user User
56 preloadDB.Find(&user)
57 checkUserHasPreloadData(user, t)
58
59 user2 := getPreloadUser("user2")
60 DB.Save(user2)
61
62 user3 := getPreloadUser("user3")
63 DB.Save(user3)
64
65 var users []User
66 preloadDB.Find(&users)
67
68 for _, user := range users {
69 checkUserHasPreloadData(user, t)
70 }
71
72 var users2 []*User
73 preloadDB.Find(&users2)
74
75 for _, user := range users2 {
76 checkUserHasPreloadData(*user, t)
77 }
78
79 var users3 []*User
80 preloadDB.Preload("Emails", "email = ?", user3.Emails[0].Email).Find(&users3)
81
82 for _, user := range users3 {
83 if user.Name == user3.Name {
84 if len(user.Emails) != 1 {
85 t.Errorf("should only preload one emails for user3 when with condition")
86 }
87 } else if len(user.Emails) != 0 {
88 t.Errorf("should not preload any emails for other users when with condition")
89 }
90 }
91 }
92
93 func TestNestedPreload1(t *testing.T) {
94 type (
95 Level1 struct {
96 ID uint
97 Value string
98 Level2ID uint
99 }
100 Level2 struct {
101 ID uint
102 Level1 Level1
103 Level3ID uint
104 }
105 Level3 struct {
106 ID uint
107 Name string
108 Level2 Level2
109 }
110 )
111 DB.DropTableIfExists(&Level3{})
112 DB.DropTableIfExists(&Level2{})
113 DB.DropTableIfExists(&Level1{})
114 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
115 panic(err)
116 }
117
118 want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}}
119 if err := DB.Create(&want).Error; err != nil {
120 panic(err)
121 }
122
123 var got Level3
124 if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil {
125 panic(err)
126 }
127
128 if !reflect.DeepEqual(got, want) {
129 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
130 }
131 }
132
133 func TestNestedPreload2(t *testing.T) {
134 type (
135 Level1 struct {
136 ID uint
137 Value string
138 Level2ID uint
139 }
140 Level2 struct {
141 ID uint
142 Level1s []*Level1
143 Level3ID uint
144 }
145 Level3 struct {
146 ID uint
147 Name string
148 Level2s []Level2
149 }
150 )
151 DB.DropTableIfExists(&Level3{})
152 DB.DropTableIfExists(&Level2{})
153 DB.DropTableIfExists(&Level1{})
154 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
155 panic(err)
156 }
157
158 want := Level3{
159 Level2s: []Level2{
160 {
161 Level1s: []*Level1{
162 &Level1{Value: "value1"},
163 &Level1{Value: "value2"},
164 },
165 },
166 {
167 Level1s: []*Level1{
168 &Level1{Value: "value3"},
169 },
170 },
171 },
172 }
173 if err := DB.Create(&want).Error; err != nil {
174 panic(err)
175 }
176
177 var got Level3
178 if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil {
179 panic(err)
180 }
181
182 if !reflect.DeepEqual(got, want) {
183 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
184 }
185 }
186
187 func TestNestedPreload3(t *testing.T) {
188 type (
189 Level1 struct {
190 ID uint
191 Value string
192 Level2ID uint
193 }
194 Level2 struct {
195 ID uint
196 Level1 Level1
197 Level3ID uint
198 }
199 Level3 struct {
200 Name string
201 ID uint
202 Level2s []Level2
203 }
204 )
205 DB.DropTableIfExists(&Level3{})
206 DB.DropTableIfExists(&Level2{})
207 DB.DropTableIfExists(&Level1{})
208 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
209 panic(err)
210 }
211
212 want := Level3{
213 Level2s: []Level2{
214 {Level1: Level1{Value: "value1"}},
215 {Level1: Level1{Value: "value2"}},
216 },
217 }
218 if err := DB.Create(&want).Error; err != nil {
219 panic(err)
220 }
221
222 var got Level3
223 if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil {
224 panic(err)
225 }
226
227 if !reflect.DeepEqual(got, want) {
228 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
229 }
230 }
231
232 func TestNestedPreload4(t *testing.T) {
233 type (
234 Level1 struct {
235 ID uint
236 Value string
237 Level2ID uint
238 }
239 Level2 struct {
240 ID uint
241 Level1s []Level1
242 Level3ID uint
243 }
244 Level3 struct {
245 ID uint
246 Name string
247 Level2 Level2
248 }
249 )
250 DB.DropTableIfExists(&Level3{})
251 DB.DropTableIfExists(&Level2{})
252 DB.DropTableIfExists(&Level1{})
253 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
254 panic(err)
255 }
256
257 want := Level3{
258 Level2: Level2{
259 Level1s: []Level1{
260 Level1{Value: "value1"},
261 Level1{Value: "value2"},
262 },
263 },
264 }
265 if err := DB.Create(&want).Error; err != nil {
266 panic(err)
267 }
268
269 var got Level3
270 if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil {
271 panic(err)
272 }
273
274 if !reflect.DeepEqual(got, want) {
275 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
276 }
277 }
278
279 // Slice: []Level3
280 func TestNestedPreload5(t *testing.T) {
281 type (
282 Level1 struct {
283 ID uint
284 Value string
285 Level2ID uint
286 }
287 Level2 struct {
288 ID uint
289 Level1 Level1
290 Level3ID uint
291 }
292 Level3 struct {
293 ID uint
294 Name string
295 Level2 Level2
296 }
297 )
298 DB.DropTableIfExists(&Level3{})
299 DB.DropTableIfExists(&Level2{})
300 DB.DropTableIfExists(&Level1{})
301 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
302 panic(err)
303 }
304
305 want := make([]Level3, 2)
306 want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}}
307 if err := DB.Create(&want[0]).Error; err != nil {
308 panic(err)
309 }
310 want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}}
311 if err := DB.Create(&want[1]).Error; err != nil {
312 panic(err)
313 }
314
315 var got []Level3
316 if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil {
317 panic(err)
318 }
319
320 if !reflect.DeepEqual(got, want) {
321 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
322 }
323 }
324
325 func TestNestedPreload6(t *testing.T) {
326 type (
327 Level1 struct {
328 ID uint
329 Value string
330 Level2ID uint
331 }
332 Level2 struct {
333 ID uint
334 Level1s []Level1
335 Level3ID uint
336 }
337 Level3 struct {
338 ID uint
339 Name string
340 Level2s []Level2
341 }
342 )
343 DB.DropTableIfExists(&Level3{})
344 DB.DropTableIfExists(&Level2{})
345 DB.DropTableIfExists(&Level1{})
346 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
347 panic(err)
348 }
349
350 want := make([]Level3, 2)
351 want[0] = Level3{
352 Level2s: []Level2{
353 {
354 Level1s: []Level1{
355 {Value: "value1"},
356 {Value: "value2"},
357 },
358 },
359 {
360 Level1s: []Level1{
361 {Value: "value3"},
362 },
363 },
364 },
365 }
366 if err := DB.Create(&want[0]).Error; err != nil {
367 panic(err)
368 }
369
370 want[1] = Level3{
371 Level2s: []Level2{
372 {
373 Level1s: []Level1{
374 {Value: "value3"},
375 {Value: "value4"},
376 },
377 },
378 {
379 Level1s: []Level1{
380 {Value: "value5"},
381 },
382 },
383 },
384 }
385 if err := DB.Create(&want[1]).Error; err != nil {
386 panic(err)
387 }
388
389 var got []Level3
390 if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil {
391 panic(err)
392 }
393
394 if !reflect.DeepEqual(got, want) {
395 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
396 }
397 }
398
399 func TestNestedPreload7(t *testing.T) {
400 type (
401 Level1 struct {
402 ID uint
403 Value string
404 Level2ID uint
405 }
406 Level2 struct {
407 ID uint
408 Level1 Level1
409 Level3ID uint
410 }
411 Level3 struct {
412 ID uint
413 Name string
414 Level2s []Level2
415 }
416 )
417 DB.DropTableIfExists(&Level3{})
418 DB.DropTableIfExists(&Level2{})
419 DB.DropTableIfExists(&Level1{})
420 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
421 panic(err)
422 }
423
424 want := make([]Level3, 2)
425 want[0] = Level3{
426 Level2s: []Level2{
427 {Level1: Level1{Value: "value1"}},
428 {Level1: Level1{Value: "value2"}},
429 },
430 }
431 if err := DB.Create(&want[0]).Error; err != nil {
432 panic(err)
433 }
434
435 want[1] = Level3{
436 Level2s: []Level2{
437 {Level1: Level1{Value: "value3"}},
438 {Level1: Level1{Value: "value4"}},
439 },
440 }
441 if err := DB.Create(&want[1]).Error; err != nil {
442 panic(err)
443 }
444
445 var got []Level3
446 if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil {
447 panic(err)
448 }
449
450 if !reflect.DeepEqual(got, want) {
451 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
452 }
453 }
454
455 func TestNestedPreload8(t *testing.T) {
456 type (
457 Level1 struct {
458 ID uint
459 Value string
460 Level2ID uint
461 }
462 Level2 struct {
463 ID uint
464 Level1s []Level1
465 Level3ID uint
466 }
467 Level3 struct {
468 ID uint
469 Name string
470 Level2 Level2
471 }
472 )
473 DB.DropTableIfExists(&Level3{})
474 DB.DropTableIfExists(&Level2{})
475 DB.DropTableIfExists(&Level1{})
476 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
477 panic(err)
478 }
479
480 want := make([]Level3, 2)
481 want[0] = Level3{
482 Level2: Level2{
483 Level1s: []Level1{
484 Level1{Value: "value1"},
485 Level1{Value: "value2"},
486 },
487 },
488 }
489 if err := DB.Create(&want[0]).Error; err != nil {
490 panic(err)
491 }
492 want[1] = Level3{
493 Level2: Level2{
494 Level1s: []Level1{
495 Level1{Value: "value3"},
496 Level1{Value: "value4"},
497 },
498 },
499 }
500 if err := DB.Create(&want[1]).Error; err != nil {
501 panic(err)
502 }
503
504 var got []Level3
505 if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil {
506 panic(err)
507 }
508
509 if !reflect.DeepEqual(got, want) {
510 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
511 }
512 }
513
514 func TestNestedPreload9(t *testing.T) {
515 type (
516 Level0 struct {
517 ID uint
518 Value string
519 Level1ID uint
520 }
521 Level1 struct {
522 ID uint
523 Value string
524 Level2ID uint
525 Level2_1ID uint
526 Level0s []Level0
527 }
528 Level2 struct {
529 ID uint
530 Level1s []Level1
531 Level3ID uint
532 }
533 Level2_1 struct {
534 ID uint
535 Level1s []Level1
536 Level3ID uint
537 }
538 Level3 struct {
539 ID uint
540 Name string
541 Level2 Level2
542 Level2_1 Level2_1
543 }
544 )
545 DB.DropTableIfExists(&Level3{})
546 DB.DropTableIfExists(&Level2{})
547 DB.DropTableIfExists(&Level2_1{})
548 DB.DropTableIfExists(&Level1{})
549 DB.DropTableIfExists(&Level0{})
550 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil {
551 panic(err)
552 }
553
554 want := make([]Level3, 2)
555 want[0] = Level3{
556 Level2: Level2{
557 Level1s: []Level1{
558 Level1{Value: "value1"},
559 Level1{Value: "value2"},
560 },
561 },
562 Level2_1: Level2_1{
563 Level1s: []Level1{
564 Level1{
565 Value: "value1-1",
566 Level0s: []Level0{{Value: "Level0-1"}},
567 },
568 Level1{
569 Value: "value2-2",
570 Level0s: []Level0{{Value: "Level0-2"}},
571 },
572 },
573 },
574 }
575 if err := DB.Create(&want[0]).Error; err != nil {
576 panic(err)
577 }
578 want[1] = Level3{
579 Level2: Level2{
580 Level1s: []Level1{
581 Level1{Value: "value3"},
582 Level1{Value: "value4"},
583 },
584 },
585 Level2_1: Level2_1{
586 Level1s: []Level1{
587 Level1{Value: "value3-3"},
588 Level1{Value: "value4-4"},
589 },
590 },
591 }
592 if err := DB.Create(&want[1]).Error; err != nil {
593 panic(err)
594 }
595
596 var got []Level3
597 if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil {
598 panic(err)
599 }
600
601 if !reflect.DeepEqual(got, want) {
602 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
603 }
604 }
605
606 func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
607 if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" {
608 return
609 }
610
611 type (
612 Level1 struct {
613 ID uint `gorm:"primary_key;"`
614 LanguageCode string `gorm:"primary_key"`
615 Value string
616 }
617 Level2 struct {
618 ID uint `gorm:"primary_key;"`
619 LanguageCode string `gorm:"primary_key"`
620 Value string
621 Level1s []Level1 `gorm:"many2many:levels;"`
622 }
623 )
624
625 DB.DropTableIfExists(&Level2{})
626 DB.DropTableIfExists(&Level1{})
627 DB.Table("levels").DropTableIfExists("levels")
628
629 if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
630 panic(err)
631 }
632
633 want := Level2{Value: "Bob", LanguageCode: "ru", Level1s: []Level1{
634 {Value: "ru", LanguageCode: "ru"},
635 {Value: "en", LanguageCode: "en"},
636 }}
637 if err := DB.Save(&want).Error; err != nil {
638 panic(err)
639 }
640
641 want2 := Level2{Value: "Tom", LanguageCode: "zh", Level1s: []Level1{
642 {Value: "zh", LanguageCode: "zh"},
643 {Value: "de", LanguageCode: "de"},
644 }}
645 if err := DB.Save(&want2).Error; err != nil {
646 panic(err)
647 }
648
649 var got Level2
650 if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
651 panic(err)
652 }
653
654 if !reflect.DeepEqual(got, want) {
655 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
656 }
657
658 var got2 Level2
659 if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
660 panic(err)
661 }
662
663 if !reflect.DeepEqual(got2, want2) {
664 t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2))
665 }
666
667 var got3 []Level2
668 if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
669 panic(err)
670 }
671
672 if !reflect.DeepEqual(got3, []Level2{got, got2}) {
673 t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2}))
674 }
675
676 var got4 []Level2
677 if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
678 panic(err)
679 }
680
681 var ruLevel1 Level1
682 var zhLevel1 Level1
683 DB.First(&ruLevel1, "value = ?", "ru")
684 DB.First(&zhLevel1, "value = ?", "zh")
685
686 got.Level1s = []Level1{ruLevel1}
687 got2.Level1s = []Level1{zhLevel1}
688 if !reflect.DeepEqual(got4, []Level2{got, got2}) {
689 t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2}))
690 }
691 }
692
693 func TestManyToManyPreloadForPointer(t *testing.T) {
694 type (
695 Level1 struct {
696 ID uint `gorm:"primary_key;"`
697 Value string
698 }
699 Level2 struct {
700 ID uint `gorm:"primary_key;"`
701 Value string
702 Level1s []*Level1 `gorm:"many2many:levels;"`
703 }
704 )
705
706 DB.DropTableIfExists(&Level2{})
707 DB.DropTableIfExists(&Level1{})
708 DB.Table("levels").DropTableIfExists("levels")
709
710 if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
711 panic(err)
712 }
713
714 want := Level2{Value: "Bob", Level1s: []*Level1{
715 {Value: "ru"},
716 {Value: "en"},
717 }}
718 if err := DB.Save(&want).Error; err != nil {
719 panic(err)
720 }
721
722 want2 := Level2{Value: "Tom", Level1s: []*Level1{
723 {Value: "zh"},
724 {Value: "de"},
725 }}
726 if err := DB.Save(&want2).Error; err != nil {
727 panic(err)
728 }
729
730 var got Level2
731 if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
732 panic(err)
733 }
734
735 if !reflect.DeepEqual(got, want) {
736 t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
737 }
738
739 var got2 Level2
740 if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
741 panic(err)
742 }
743
744 if !reflect.DeepEqual(got2, want2) {
745 t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2))
746 }
747
748 var got3 []Level2
749 if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
750 panic(err)
751 }
752
753 if !reflect.DeepEqual(got3, []Level2{got, got2}) {
754 t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2}))
755 }
756
757 var got4 []Level2
758 if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
759 panic(err)
760 }
761
762 var ruLevel1 Level1
763 var zhLevel1 Level1
764 DB.First(&ruLevel1, "value = ?", "ru")
765 DB.First(&zhLevel1, "value = ?", "zh")
766
767 got.Level1s = []*Level1{&ruLevel1}
768 got2.Level1s = []*Level1{&zhLevel1}
769 if !reflect.DeepEqual(got4, []Level2{got, got2}) {
770 t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2}))
771 }
772 }
773
774 func TestNilPointerSlice(t *testing.T) {
775 type (
776 Level3 struct {
777 ID uint `gorm:"primary_key;"`
778 Value string
779 }
780 Level2 struct {
781 ID uint `gorm:"primary_key;"`
782 Value string
783 Level3ID uint
784 Level3 *Level3
785 }
786 Level1 struct {
787 ID uint `gorm:"primary_key;"`
788 Value string
789 Level2ID uint
790 Level2 *Level2
791 }
792 )
793
794 DB.DropTableIfExists(&Level3{})
795 DB.DropTableIfExists(&Level2{})
796 DB.DropTableIfExists(&Level1{})
797
798 if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
799 panic(err)
800 }
801
802 want := Level1{Value: "Bob", Level2: &Level2{
803 Value: "en",
804 Level3: &Level3{
805 Value: "native",
806 },
807 }}
808 if err := DB.Save(&want).Error; err != nil {
809 panic(err)
810 }
811
812 want2 := Level1{Value: "Tom", Level2: nil}
813 if err := DB.Save(&want2).Error; err != nil {
814 panic(err)
815 }
816
817 var got []Level1
818 if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).Error; err != nil {
819 panic(err)
820 }
821
822 if len(got) != 2 {
823 t.Fatalf("got %v items, expected 2", len(got))
824 }
825
826 if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) {
827 t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want))
828 }
829
830 if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) {
831 t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want2))
832 }
833 }
834
835 func toJSONString(v interface{}) []byte {
836 r, _ := json.MarshalIndent(v, "", " ")
837 return r
838 }
0 package gorm_test
1
2 import (
3 "fmt"
4 "reflect"
5
6 "github.com/jinzhu/gorm"
7 "github.com/jinzhu/now"
8
9 "testing"
10 "time"
11 )
12
13 func TestFirstAndLast(t *testing.T) {
14 DB.Save(&User{Name: "user1", Emails: []Email{{Email: "user1@example.com"}}})
15 DB.Save(&User{Name: "user2", Emails: []Email{{Email: "user2@example.com"}}})
16
17 var user1, user2, user3, user4 User
18 DB.First(&user1)
19 DB.Order("id").Limit(1).Find(&user2)
20
21 DB.Last(&user3)
22 DB.Order("id desc").Limit(1).Find(&user4)
23 if user1.Id != user2.Id || user3.Id != user4.Id {
24 t.Errorf("First and Last should by order by primary key")
25 }
26
27 var users []User
28 DB.First(&users)
29 if len(users) != 1 {
30 t.Errorf("Find first record as slice")
31 }
32
33 if DB.Joins("left join emails on emails.user_id = users.id").First(&User{}).Error != nil {
34 t.Errorf("Should not raise any error when order with Join table")
35 }
36 }
37
38 func TestFirstAndLastWithNoStdPrimaryKey(t *testing.T) {
39 DB.Save(&Animal{Name: "animal1"})
40 DB.Save(&Animal{Name: "animal2"})
41
42 var animal1, animal2, animal3, animal4 Animal
43 DB.First(&animal1)
44 DB.Order("counter").Limit(1).Find(&animal2)
45
46 DB.Last(&animal3)
47 DB.Order("counter desc").Limit(1).Find(&animal4)
48 if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter {
49 t.Errorf("First and Last should work correctly")
50 }
51 }
52
53 func TestUIntPrimaryKey(t *testing.T) {
54 var animal Animal
55 DB.First(&animal, uint64(1))
56 if animal.Counter != 1 {
57 t.Errorf("Fetch a record from with a non-int primary key should work, but failed")
58 }
59
60 DB.Model(Animal{}).Where(Animal{Counter: uint64(2)}).Scan(&animal)
61 if animal.Counter != 2 {
62 t.Errorf("Fetch a record from with a non-int primary key should work, but failed")
63 }
64 }
65
66 func TestFindAsSliceOfPointers(t *testing.T) {
67 DB.Save(&User{Name: "user"})
68
69 var users []User
70 DB.Find(&users)
71
72 var userPointers []*User
73 DB.Find(&userPointers)
74
75 if len(users) == 0 || len(users) != len(userPointers) {
76 t.Errorf("Find slice of pointers")
77 }
78 }
79
80 func TestSearchWithPlainSQL(t *testing.T) {
81 user1 := User{Name: "PlainSqlUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
82 user2 := User{Name: "PlainSqlUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
83 user3 := User{Name: "PlainSqlUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
84 DB.Save(&user1).Save(&user2).Save(&user3)
85 scopedb := DB.Where("name LIKE ?", "%PlainSqlUser%")
86
87 if DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
88 t.Errorf("Search with plain SQL")
89 }
90
91 if DB.Where("name LIKE ?", "%"+user1.Name+"%").First(&User{}).RecordNotFound() {
92 t.Errorf("Search with plan SQL (regexp)")
93 }
94
95 var users []User
96 DB.Find(&users, "name LIKE ? and age > ?", "%PlainSqlUser%", 1)
97 if len(users) != 2 {
98 t.Errorf("Should found 2 users that age > 1, but got %v", len(users))
99 }
100
101 DB.Where("name LIKE ?", "%PlainSqlUser%").Where("age >= ?", 1).Find(&users)
102 if len(users) != 3 {
103 t.Errorf("Should found 3 users that age >= 1, but got %v", len(users))
104 }
105
106 scopedb.Where("age <> ?", 20).Find(&users)
107 if len(users) != 2 {
108 t.Errorf("Should found 2 users age != 20, but got %v", len(users))
109 }
110
111 scopedb.Where("birthday > ?", now.MustParse("2000-1-1")).Find(&users)
112 if len(users) != 2 {
113 t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users))
114 }
115
116 scopedb.Where("birthday > ?", "2002-10-10").Find(&users)
117 if len(users) != 2 {
118 t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users))
119 }
120
121 scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users)
122 if len(users) != 1 {
123 t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
124 }
125
126 DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users)
127 if len(users) != 2 {
128 t.Errorf("Should found 2 users, but got %v", len(users))
129 }
130
131 DB.Where("id in (?)", []int64{user1.Id, user2.Id, user3.Id}).Find(&users)
132 if len(users) != 3 {
133 t.Errorf("Should found 3 users, but got %v", len(users))
134 }
135
136 DB.Where("id in (?)", user1.Id).Find(&users)
137 if len(users) != 1 {
138 t.Errorf("Should found 1 users, but got %v", len(users))
139 }
140
141 if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() {
142 t.Errorf("Should not get RecordNotFound error when looking for none existing records")
143 }
144 }
145
146 func TestSearchWithStruct(t *testing.T) {
147 user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
148 user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
149 user3 := User{Name: "StructSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
150 DB.Save(&user1).Save(&user2).Save(&user3)
151
152 if DB.Where(user1.Id).First(&User{}).RecordNotFound() {
153 t.Errorf("Search with primary key")
154 }
155
156 if DB.First(&User{}, user1.Id).RecordNotFound() {
157 t.Errorf("Search with primary key as inline condition")
158 }
159
160 if DB.First(&User{}, fmt.Sprintf("%v", user1.Id)).RecordNotFound() {
161 t.Errorf("Search with primary key as inline condition")
162 }
163
164 var users []User
165 DB.Where([]int64{user1.Id, user2.Id, user3.Id}).Find(&users)
166 if len(users) != 3 {
167 t.Errorf("Should found 3 users when search with primary keys, but got %v", len(users))
168 }
169
170 var user User
171 DB.First(&user, &User{Name: user1.Name})
172 if user.Id == 0 || user.Name != user1.Name {
173 t.Errorf("Search first record with inline pointer of struct")
174 }
175
176 DB.First(&user, User{Name: user1.Name})
177 if user.Id == 0 || user.Name != user.Name {
178 t.Errorf("Search first record with inline struct")
179 }
180
181 DB.Where(&User{Name: user1.Name}).First(&user)
182 if user.Id == 0 || user.Name != user1.Name {
183 t.Errorf("Search first record with where struct")
184 }
185
186 DB.Find(&users, &User{Name: user2.Name})
187 if len(users) != 1 {
188 t.Errorf("Search all records with inline struct")
189 }
190 }
191
192 func TestSearchWithMap(t *testing.T) {
193 user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
194 user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
195 user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
196 DB.Save(&user1).Save(&user2).Save(&user3)
197
198 var user User
199 DB.First(&user, map[string]interface{}{"name": user1.Name})
200 if user.Id == 0 || user.Name != user1.Name {
201 t.Errorf("Search first record with inline map")
202 }
203
204 user = User{}
205 DB.Where(map[string]interface{}{"name": user2.Name}).First(&user)
206 if user.Id == 0 || user.Name != user2.Name {
207 t.Errorf("Search first record with where map")
208 }
209
210 var users []User
211 DB.Where(map[string]interface{}{"name": user3.Name}).Find(&users)
212 if len(users) != 1 {
213 t.Errorf("Search all records with inline map")
214 }
215
216 DB.Find(&users, map[string]interface{}{"name": user3.Name})
217 if len(users) != 1 {
218 t.Errorf("Search all records with inline map")
219 }
220 }
221
222 func TestSearchWithEmptyChain(t *testing.T) {
223 user1 := User{Name: "ChainSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
224 user2 := User{Name: "ChainearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
225 user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
226 DB.Save(&user1).Save(&user2).Save(&user3)
227
228 if DB.Where("").Where("").First(&User{}).Error != nil {
229 t.Errorf("Should not raise any error if searching with empty strings")
230 }
231
232 if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).Error != nil {
233 t.Errorf("Should not raise any error if searching with empty struct")
234 }
235
236 if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).Error != nil {
237 t.Errorf("Should not raise any error if searching with empty map")
238 }
239 }
240
241 func TestSelect(t *testing.T) {
242 user1 := User{Name: "SelectUser1"}
243 DB.Save(&user1)
244
245 var user User
246 DB.Where("name = ?", user1.Name).Select("name").Find(&user)
247 if user.Id != 0 {
248 t.Errorf("Should not have ID because only selected name, %+v", user.Id)
249 }
250
251 if user.Name != user1.Name {
252 t.Errorf("Should have user Name when selected it")
253 }
254 }
255
256 func TestOrderAndPluck(t *testing.T) {
257 user1 := User{Name: "OrderPluckUser1", Age: 1}
258 user2 := User{Name: "OrderPluckUser2", Age: 10}
259 user3 := User{Name: "OrderPluckUser3", Age: 20}
260 DB.Save(&user1).Save(&user2).Save(&user3)
261 scopedb := DB.Model(&User{}).Where("name like ?", "%OrderPluckUser%")
262
263 var ages []int64
264 scopedb.Order("age desc").Pluck("age", &ages)
265 if ages[0] != 20 {
266 t.Errorf("The first age should be 20 when order with age desc")
267 }
268
269 var ages1, ages2 []int64
270 scopedb.Order("age desc").Pluck("age", &ages1).Pluck("age", &ages2)
271 if !reflect.DeepEqual(ages1, ages2) {
272 t.Errorf("The first order is the primary order")
273 }
274
275 var ages3, ages4 []int64
276 scopedb.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4)
277 if reflect.DeepEqual(ages3, ages4) {
278 t.Errorf("Reorder should work")
279 }
280
281 var names []string
282 var ages5 []int64
283 scopedb.Model(User{}).Order("name").Order("age desc").Pluck("age", &ages5).Pluck("name", &names)
284 if names != nil && ages5 != nil {
285 if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) {
286 t.Errorf("Order with multiple orders")
287 }
288 } else {
289 t.Errorf("Order with multiple orders")
290 }
291
292 DB.Model(User{}).Select("name, age").Find(&[]User{})
293 }
294
295 func TestLimit(t *testing.T) {
296 user1 := User{Name: "LimitUser1", Age: 1}
297 user2 := User{Name: "LimitUser2", Age: 10}
298 user3 := User{Name: "LimitUser3", Age: 20}
299 user4 := User{Name: "LimitUser4", Age: 10}
300 user5 := User{Name: "LimitUser5", Age: 20}
301 DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5)
302
303 var users1, users2, users3 []User
304 DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3)
305
306 if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 {
307 t.Errorf("Limit should works")
308 }
309 }
310
311 func TestOffset(t *testing.T) {
312 for i := 0; i < 20; i++ {
313 DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)})
314 }
315 var users1, users2, users3, users4 []User
316 DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
317
318 if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
319 t.Errorf("Offset should work")
320 }
321 }
322
323 func TestOr(t *testing.T) {
324 user1 := User{Name: "OrUser1", Age: 1}
325 user2 := User{Name: "OrUser2", Age: 10}
326 user3 := User{Name: "OrUser3", Age: 20}
327 DB.Save(&user1).Save(&user2).Save(&user3)
328
329 var users []User
330 DB.Where("name = ?", user1.Name).Or("name = ?", user2.Name).Find(&users)
331 if len(users) != 2 {
332 t.Errorf("Find users with or")
333 }
334 }
335
336 func TestCount(t *testing.T) {
337 user1 := User{Name: "CountUser1", Age: 1}
338 user2 := User{Name: "CountUser2", Age: 10}
339 user3 := User{Name: "CountUser3", Age: 20}
340
341 DB.Save(&user1).Save(&user2).Save(&user3)
342 var count, count1, count2 int64
343 var users []User
344
345 if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil {
346 t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
347 }
348
349 if count != int64(len(users)) {
350 t.Errorf("Count() method should get correct value")
351 }
352
353 DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in (?)", []string{user2.Name, user3.Name}).Count(&count2)
354 if count1 != 1 || count2 != 3 {
355 t.Errorf("Multiple count in chain")
356 }
357 }
358
359 func TestNot(t *testing.T) {
360 DB.Create(getPreparedUser("user1", "not"))
361 DB.Create(getPreparedUser("user2", "not"))
362 DB.Create(getPreparedUser("user3", "not"))
363 DB.Create(getPreparedUser("user4", "not"))
364 DB := DB.Where("role = ?", "not")
365
366 var users1, users2, users3, users4, users5, users6, users7, users8 []User
367 if DB.Find(&users1).RowsAffected != 4 {
368 t.Errorf("should find 4 not users")
369 }
370 DB.Not(users1[0].Id).Find(&users2)
371
372 if len(users1)-len(users2) != 1 {
373 t.Errorf("Should ignore the first users with Not")
374 }
375
376 DB.Not([]int{}).Find(&users3)
377 if len(users1)-len(users3) != 0 {
378 t.Errorf("Should find all users with a blank condition")
379 }
380
381 var name3Count int64
382 DB.Table("users").Where("name = ?", "user3").Count(&name3Count)
383 DB.Not("name", "user3").Find(&users4)
384 if len(users1)-len(users4) != int(name3Count) {
385 t.Errorf("Should find all users's name not equal 3")
386 }
387
388 DB.Not("name = ?", "user3").Find(&users4)
389 if len(users1)-len(users4) != int(name3Count) {
390 t.Errorf("Should find all users's name not equal 3")
391 }
392
393 DB.Not("name <> ?", "user3").Find(&users4)
394 if len(users4) != int(name3Count) {
395 t.Errorf("Should find all users's name not equal 3")
396 }
397
398 DB.Not(User{Name: "user3"}).Find(&users5)
399
400 if len(users1)-len(users5) != int(name3Count) {
401 t.Errorf("Should find all users's name not equal 3")
402 }
403
404 DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6)
405 if len(users1)-len(users6) != int(name3Count) {
406 t.Errorf("Should find all users's name not equal 3")
407 }
408
409 DB.Not("name", []string{"user3"}).Find(&users7)
410 if len(users1)-len(users7) != int(name3Count) {
411 t.Errorf("Should find all users's name not equal 3")
412 }
413
414 var name2Count int64
415 DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
416 DB.Not("name", []string{"user3", "user2"}).Find(&users8)
417 if len(users1)-len(users8) != (int(name3Count) + int(name2Count)) {
418 t.Errorf("Should find all users's name not equal 3")
419 }
420 }
421
422 func TestFillSmallerStruct(t *testing.T) {
423 user1 := User{Name: "SmallerUser", Age: 100}
424 DB.Save(&user1)
425 type SimpleUser struct {
426 Name string
427 Id int64
428 UpdatedAt time.Time
429 CreatedAt time.Time
430 }
431
432 var simpleUser SimpleUser
433 DB.Table("users").Where("name = ?", user1.Name).First(&simpleUser)
434
435 if simpleUser.Id == 0 || simpleUser.Name == "" {
436 t.Errorf("Should fill data correctly into smaller struct")
437 }
438 }
439
440 func TestFindOrInitialize(t *testing.T) {
441 var user1, user2, user3, user4, user5, user6 User
442 DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1)
443 if user1.Name != "find or init" || user1.Id != 0 || user1.Age != 33 {
444 t.Errorf("user should be initialized with search value")
445 }
446
447 DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2)
448 if user2.Name != "find or init" || user2.Id != 0 || user2.Age != 33 {
449 t.Errorf("user should be initialized with search value")
450 }
451
452 DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"})
453 if user3.Name != "find or init 2" || user3.Id != 0 {
454 t.Errorf("user should be initialized with inline search value")
455 }
456
457 DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4)
458 if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 {
459 t.Errorf("user should be initialized with search value and attrs")
460 }
461
462 DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4)
463 if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 {
464 t.Errorf("user should be initialized with search value and assign attrs")
465 }
466
467 DB.Save(&User{Name: "find or init", Age: 33})
468 DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5)
469 if user5.Name != "find or init" || user5.Id == 0 || user5.Age != 33 {
470 t.Errorf("user should be found and not initialized by Attrs")
471 }
472
473 DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6)
474 if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 33 {
475 t.Errorf("user should be found with FirstOrInit")
476 }
477
478 DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6)
479 if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 44 {
480 t.Errorf("user should be found and updated with assigned attrs")
481 }
482 }
483
484 func TestFindOrCreate(t *testing.T) {
485 var user1, user2, user3, user4, user5, user6, user7, user8 User
486 DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1)
487 if user1.Name != "find or create" || user1.Id == 0 || user1.Age != 33 {
488 t.Errorf("user should be created with search value")
489 }
490
491 DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2)
492 if user1.Id != user2.Id || user2.Name != "find or create" || user2.Id == 0 || user2.Age != 33 {
493 t.Errorf("user should be created with search value")
494 }
495
496 DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"})
497 if user3.Name != "find or create 2" || user3.Id == 0 {
498 t.Errorf("user should be created with inline search value")
499 }
500
501 DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4)
502 if user4.Name != "find or create 3" || user4.Id == 0 || user4.Age != 44 {
503 t.Errorf("user should be created with search value and attrs")
504 }
505
506 updatedAt1 := user4.UpdatedAt
507 DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4)
508 if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) {
509 t.Errorf("UpdateAt should be changed when update values with assign")
510 }
511
512 DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4)
513 if user4.Name != "find or create 4" || user4.Id == 0 || user4.Age != 44 {
514 t.Errorf("user should be created with search value and assigned attrs")
515 }
516
517 DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5)
518 if user5.Name != "find or create" || user5.Id == 0 || user5.Age != 33 {
519 t.Errorf("user should be found and not initialized by Attrs")
520 }
521
522 DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6)
523 if user6.Name != "find or create" || user6.Id == 0 || user6.Age != 44 {
524 t.Errorf("user should be found and updated with assigned attrs")
525 }
526
527 DB.Where(&User{Name: "find or create"}).Find(&user7)
528 if user7.Name != "find or create" || user7.Id == 0 || user7.Age != 44 {
529 t.Errorf("user should be found and updated with assigned attrs")
530 }
531
532 DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, CreditCard: CreditCard{Number: "1231231231"}, Emails: []Email{{Email: "jinzhu@assign_embedded_struct.com"}, {Email: "jinzhu-2@assign_embedded_struct.com"}}}).FirstOrCreate(&user8)
533 if DB.Where("email = ?", "jinzhu-2@assign_embedded_struct.com").First(&Email{}).RecordNotFound() {
534 t.Errorf("embedded struct email should be saved")
535 }
536
537 if DB.Where("email = ?", "1231231231").First(&CreditCard{}).RecordNotFound() {
538 t.Errorf("embedded struct credit card should be saved")
539 }
540 }
541
542 func TestSelectWithEscapedFieldName(t *testing.T) {
543 user1 := User{Name: "EscapedFieldNameUser", Age: 1}
544 user2 := User{Name: "EscapedFieldNameUser", Age: 10}
545 user3 := User{Name: "EscapedFieldNameUser", Age: 20}
546 DB.Save(&user1).Save(&user2).Save(&user3)
547
548 var names []string
549 DB.Model(User{}).Where(&User{Name: "EscapedFieldNameUser"}).Pluck("\"name\"", &names)
550
551 if len(names) != 3 {
552 t.Errorf("Expected 3 name, but got: %d", len(names))
553 }
554 }
555
556 func TestSelectWithVariables(t *testing.T) {
557 DB.Save(&User{Name: "jinzhu"})
558
559 rows, _ := DB.Table("users").Select("? as fake", gorm.Expr("name")).Rows()
560
561 if !rows.Next() {
562 t.Errorf("Should have returned at least one row")
563 } else {
564 columns, _ := rows.Columns()
565 if !reflect.DeepEqual(columns, []string{"fake"}) {
566 t.Errorf("Should only contains one column")
567 }
568 }
569 }
570
571 func TestSelectWithArrayInput(t *testing.T) {
572 DB.Save(&User{Name: "jinzhu", Age: 42})
573
574 var user User
575 DB.Select([]string{"name", "age"}).Where("age = 42 AND name = 'jinzhu'").First(&user)
576
577 if user.Name != "jinzhu" || user.Age != 42 {
578 t.Errorf("Should have selected both age and name")
579 }
580 }
581
582 func TestCurrentDatabase(t *testing.T) {
583 databaseName := DB.CurrentDatabase()
584 if err := DB.Error; err != nil {
585 t.Errorf("Problem getting current db name: %s", err)
586 }
587 if databaseName == "" {
588 t.Errorf("Current db name returned empty; this should never happen!")
589 }
590 t.Logf("Got current db name: %v", databaseName)
591 }
0 package gorm
1
2 import (
3 "errors"
4 "fmt"
5 "regexp"
6 "strings"
7 "time"
8
9 "reflect"
10 )
11
12 type Scope struct {
13 Search *search
14 Value interface{}
15 Sql string
16 SqlVars []interface{}
17 db *DB
18 indirectValue *reflect.Value
19 instanceId string
20 primaryKeyField *Field
21 skipLeft bool
22 fields map[string]*Field
23 selectAttrs *[]string
24 }
25
26 func (scope *Scope) IndirectValue() reflect.Value {
27 if scope.indirectValue == nil {
28 value := reflect.Indirect(reflect.ValueOf(scope.Value))
29 if value.Kind() == reflect.Ptr {
30 value = value.Elem()
31 }
32 scope.indirectValue = &value
33 }
34 return *scope.indirectValue
35 }
36
37 func (scope *Scope) NeedPtr() *Scope {
38 reflectKind := reflect.ValueOf(scope.Value).Kind()
39 if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) {
40 err := fmt.Errorf("%v %v\n", fileWithLineNum(), "using unaddressable value")
41 scope.Err(err)
42 fmt.Printf(err.Error())
43 }
44 return scope
45 }
46
47 // New create a new Scope without search information
48 func (scope *Scope) New(value interface{}) *Scope {
49 return &Scope{db: scope.NewDB(), Search: &search{}, Value: value}
50 }
51
52 // NewDB create a new DB without search information
53 func (scope *Scope) NewDB() *DB {
54 if scope.db != nil {
55 db := scope.db.clone()
56 db.search = nil
57 db.Value = nil
58 return db
59 }
60 return nil
61 }
62
63 func (scope *Scope) DB() *DB {
64 return scope.db
65 }
66
67 // SqlDB return *sql.DB
68 func (scope *Scope) SqlDB() sqlCommon {
69 return scope.db.db
70 }
71
72 // SkipLeft skip remaining callbacks
73 func (scope *Scope) SkipLeft() {
74 scope.skipLeft = true
75 }
76
77 // Quote used to quote database column name according to database dialect
78 func (scope *Scope) Quote(str string) string {
79 if strings.Index(str, ".") != -1 {
80 newStrs := []string{}
81 for _, str := range strings.Split(str, ".") {
82 newStrs = append(newStrs, scope.Dialect().Quote(str))
83 }
84 return strings.Join(newStrs, ".")
85 } else {
86 return scope.Dialect().Quote(str)
87 }
88 }
89
90 func (scope *Scope) QuoteIfPossible(str string) string {
91 if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) {
92 return scope.Quote(str)
93 }
94 return str
95 }
96
97 // Dialect get dialect
98 func (scope *Scope) Dialect() Dialect {
99 return scope.db.parent.dialect
100 }
101
102 // Err write error
103 func (scope *Scope) Err(err error) error {
104 if err != nil {
105 scope.db.AddError(err)
106 }
107 return err
108 }
109
110 // Log print log message
111 func (scope *Scope) Log(v ...interface{}) {
112 scope.db.log(v...)
113 }
114
115 // HasError check if there are any error
116 func (scope *Scope) HasError() bool {
117 return scope.db.Error != nil
118 }
119
120 func (scope *Scope) PrimaryFields() []*Field {
121 var fields = []*Field{}
122 for _, field := range scope.GetModelStruct().PrimaryFields {
123 fields = append(fields, scope.Fields()[field.DBName])
124 }
125 return fields
126 }
127
128 func (scope *Scope) PrimaryField() *Field {
129 if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
130 if len(primaryFields) > 1 {
131 if field, ok := scope.Fields()["id"]; ok {
132 return field
133 }
134 }
135 return scope.Fields()[primaryFields[0].DBName]
136 }
137 return nil
138 }
139
140 // PrimaryKey get the primary key's column name
141 func (scope *Scope) PrimaryKey() string {
142 if field := scope.PrimaryField(); field != nil {
143 return field.DBName
144 }
145 return ""
146 }
147
148 // PrimaryKeyZero check the primary key is blank or not
149 func (scope *Scope) PrimaryKeyZero() bool {
150 field := scope.PrimaryField()
151 return field == nil || field.IsBlank
152 }
153
154 // PrimaryKeyValue get the primary key's value
155 func (scope *Scope) PrimaryKeyValue() interface{} {
156 if field := scope.PrimaryField(); field != nil && field.Field.IsValid() {
157 return field.Field.Interface()
158 }
159 return 0
160 }
161
162 // HasColumn to check if has column
163 func (scope *Scope) HasColumn(column string) bool {
164 for _, field := range scope.GetStructFields() {
165 if field.IsNormal && (field.Name == column || field.DBName == column) {
166 return true
167 }
168 }
169 return false
170 }
171
172 // SetColumn to set the column's value
173 func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
174 if field, ok := column.(*Field); ok {
175 return field.Set(value)
176 } else if name, ok := column.(string); ok {
177
178 if field, ok := scope.Fields()[name]; ok {
179 return field.Set(value)
180 }
181
182 dbName := ToDBName(name)
183 if field, ok := scope.Fields()[dbName]; ok {
184 return field.Set(value)
185 }
186
187 if field, ok := scope.FieldByName(name); ok {
188 return field.Set(value)
189 }
190 }
191 return errors.New("could not convert column to field")
192 }
193
194 func (scope *Scope) CallMethod(name string, checkError bool) {
195 if scope.Value == nil || (checkError && scope.HasError()) {
196 return
197 }
198
199 call := func(value interface{}) {
200 if fm := reflect.ValueOf(value).MethodByName(name); fm.IsValid() {
201 switch f := fm.Interface().(type) {
202 case func():
203 f()
204 case func(s *Scope):
205 f(scope)
206 case func(s *DB):
207 newDB := scope.NewDB()
208 f(newDB)
209 scope.Err(newDB.Error)
210 case func() error:
211 scope.Err(f())
212 case func(s *Scope) error:
213 scope.Err(f(scope))
214 case func(s *DB) error:
215 newDB := scope.NewDB()
216 scope.Err(f(newDB))
217 scope.Err(newDB.Error)
218 default:
219 scope.Err(fmt.Errorf("unsupported function %v", name))
220 }
221 }
222 }
223
224 if values := scope.IndirectValue(); values.Kind() == reflect.Slice {
225 for i := 0; i < values.Len(); i++ {
226 value := values.Index(i).Addr().Interface()
227 if values.Index(i).Kind() == reflect.Ptr {
228 value = values.Index(i).Interface()
229 }
230 call(value)
231 }
232 } else {
233 if scope.IndirectValue().CanAddr() {
234 call(scope.IndirectValue().Addr().Interface())
235 } else {
236 call(scope.IndirectValue().Interface())
237 }
238 }
239 }
240
241 func (scope *Scope) CallMethodWithErrorCheck(name string) {
242 scope.CallMethod(name, true)
243 }
244
245 // AddToVars add value as sql's vars, gorm will escape them
246 func (scope *Scope) AddToVars(value interface{}) string {
247 if expr, ok := value.(*expr); ok {
248 exp := expr.expr
249 for _, arg := range expr.args {
250 exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
251 }
252 return exp
253 } else {
254 scope.SqlVars = append(scope.SqlVars, value)
255 return scope.Dialect().BinVar(len(scope.SqlVars))
256 }
257 }
258
259 type tabler interface {
260 TableName() string
261 }
262
263 type dbTabler interface {
264 TableName(*DB) string
265 }
266
267 // TableName get table name
268 func (scope *Scope) TableName() string {
269 if scope.Search != nil && len(scope.Search.tableName) > 0 {
270 return scope.Search.tableName
271 }
272
273 if tabler, ok := scope.Value.(tabler); ok {
274 return tabler.TableName()
275 }
276
277 if tabler, ok := scope.Value.(dbTabler); ok {
278 return tabler.TableName(scope.db)
279 }
280
281 return scope.GetModelStruct().TableName(scope.db.Model(scope.Value))
282 }
283
284 func (scope *Scope) QuotedTableName() (name string) {
285 if scope.Search != nil && len(scope.Search.tableName) > 0 {
286 if strings.Index(scope.Search.tableName, " ") != -1 {
287 return scope.Search.tableName
288 }
289 return scope.Quote(scope.Search.tableName)
290 } else {
291 return scope.Quote(scope.TableName())
292 }
293 }
294
295 // CombinedConditionSql get combined condition sql
296 func (scope *Scope) CombinedConditionSql() string {
297 return scope.joinsSql() + scope.whereSql() + scope.groupSql() +
298 scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql()
299 }
300
301 func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
302 for _, field := range scope.Fields() {
303 if field.Name == name || field.DBName == name {
304 return field, true
305 }
306 }
307 return nil, false
308 }
309
310 // Raw set sql
311 func (scope *Scope) Raw(sql string) *Scope {
312 scope.Sql = strings.Replace(sql, "$$", "?", -1)
313 return scope
314 }
315
316 // Exec invoke sql
317 func (scope *Scope) Exec() *Scope {
318 defer scope.Trace(NowFunc())
319
320 if !scope.HasError() {
321 if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
322 if count, err := result.RowsAffected(); scope.Err(err) == nil {
323 scope.db.RowsAffected = count
324 }
325 }
326 }
327 return scope
328 }
329
330 // Set set value by name
331 func (scope *Scope) Set(name string, value interface{}) *Scope {
332 scope.db.InstantSet(name, value)
333 return scope
334 }
335
336 // Get get value by name
337 func (scope *Scope) Get(name string) (interface{}, bool) {
338 return scope.db.Get(name)
339 }
340
341 // InstanceId get InstanceId for scope
342 func (scope *Scope) InstanceId() string {
343 if scope.instanceId == "" {
344 scope.instanceId = fmt.Sprintf("%v%v", &scope, &scope.db)
345 }
346 return scope.instanceId
347 }
348
349 func (scope *Scope) InstanceSet(name string, value interface{}) *Scope {
350 return scope.Set(name+scope.InstanceId(), value)
351 }
352
353 func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
354 return scope.Get(name + scope.InstanceId())
355 }
356
357 // Trace print sql log
358 func (scope *Scope) Trace(t time.Time) {
359 if len(scope.Sql) > 0 {
360 scope.db.slog(scope.Sql, t, scope.SqlVars...)
361 }
362 }
363
364 // Begin start a transaction
365 func (scope *Scope) Begin() *Scope {
366 if db, ok := scope.SqlDB().(sqlDb); ok {
367 if tx, err := db.Begin(); err == nil {
368 scope.db.db = interface{}(tx).(sqlCommon)
369 scope.InstanceSet("gorm:started_transaction", true)
370 }
371 }
372 return scope
373 }
374
375 // CommitOrRollback commit current transaction if there is no error, otherwise rollback it
376 func (scope *Scope) CommitOrRollback() *Scope {
377 if _, ok := scope.InstanceGet("gorm:started_transaction"); ok {
378 if db, ok := scope.db.db.(sqlTx); ok {
379 if scope.HasError() {
380 db.Rollback()
381 } else {
382 db.Commit()
383 }
384 scope.db.db = scope.db.parent.db
385 }
386 }
387 return scope
388 }
389
390 func (scope *Scope) SelectAttrs() []string {
391 if scope.selectAttrs == nil {
392 attrs := []string{}
393 for _, value := range scope.Search.selects {
394 if str, ok := value.(string); ok {
395 attrs = append(attrs, str)
396 } else if strs, ok := value.([]string); ok {
397 attrs = append(attrs, strs...)
398 } else if strs, ok := value.([]interface{}); ok {
399 for _, str := range strs {
400 attrs = append(attrs, fmt.Sprintf("%v", str))
401 }
402 }
403 }
404 scope.selectAttrs = &attrs
405 }
406 return *scope.selectAttrs
407 }
408
409 func (scope *Scope) OmitAttrs() []string {
410 return scope.Search.omits
411 }
412
413 func (scope *Scope) changeableDBColumn(column string) bool {
414 selectAttrs := scope.SelectAttrs()
415 omitAttrs := scope.OmitAttrs()
416
417 if len(selectAttrs) > 0 {
418 for _, attr := range selectAttrs {
419 if column == ToDBName(attr) {
420 return true
421 }
422 }
423 return false
424 }
425
426 for _, attr := range omitAttrs {
427 if column == ToDBName(attr) {
428 return false
429 }
430 }
431 return true
432 }
433
434 func (scope *Scope) changeableField(field *Field) bool {
435 selectAttrs := scope.SelectAttrs()
436 omitAttrs := scope.OmitAttrs()
437
438 if len(selectAttrs) > 0 {
439 for _, attr := range selectAttrs {
440 if field.Name == attr || field.DBName == attr {
441 return true
442 }
443 }
444 return false
445 }
446
447 for _, attr := range omitAttrs {
448 if field.Name == attr || field.DBName == attr {
449 return false
450 }
451 }
452
453 return !field.IsIgnored
454 }
455
456 func (scope *Scope) shouldSaveAssociations() bool {
457 saveAssociations, ok := scope.Get("gorm:save_associations")
458 if ok && !saveAssociations.(bool) {
459 return false
460 }
461 return true && !scope.HasError()
462 }
0 package gorm
1
2 import (
3 "database/sql"
4 "database/sql/driver"
5 "fmt"
6 "reflect"
7 "regexp"
8 "strconv"
9 "strings"
10 )
11
12 func (scope *Scope) primaryCondition(value interface{}) string {
13 return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value)
14 }
15
16 func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) {
17 switch value := clause["query"].(type) {
18 case string:
19 // if string is number
20 if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
21 id, _ := strconv.Atoi(value)
22 return scope.primaryCondition(scope.AddToVars(id))
23 } else if value != "" {
24 str = fmt.Sprintf("(%v)", value)
25 }
26 case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
27 return scope.primaryCondition(scope.AddToVars(value))
28 case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
29 str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey()))
30 clause["args"] = []interface{}{value}
31 case map[string]interface{}:
32 var sqls []string
33 for key, value := range value {
34 sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value)))
35 }
36 return strings.Join(sqls, " AND ")
37 case interface{}:
38 var sqls []string
39 for _, field := range scope.New(value).Fields() {
40 if !field.IsIgnored && !field.IsBlank {
41 sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
42 }
43 }
44 return strings.Join(sqls, " AND ")
45 }
46
47 args := clause["args"].([]interface{})
48 for _, arg := range args {
49 switch reflect.ValueOf(arg).Kind() {
50 case reflect.Slice: // For where("id in (?)", []int64{1,2})
51 values := reflect.ValueOf(arg)
52 var tempMarks []string
53 for i := 0; i < values.Len(); i++ {
54 tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
55 }
56 str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
57 default:
58 if valuer, ok := interface{}(arg).(driver.Valuer); ok {
59 arg, _ = valuer.Value()
60 }
61
62 str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
63 }
64 }
65 return
66 }
67
68 func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
69 var notEqualSql string
70 var primaryKey = scope.PrimaryKey()
71
72 switch value := clause["query"].(type) {
73 case string:
74 // is number
75 if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
76 id, _ := strconv.Atoi(value)
77 return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
78 } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) {
79 str = fmt.Sprintf(" NOT (%v) ", value)
80 notEqualSql = fmt.Sprintf("NOT (%v)", value)
81 } else {
82 str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
83 notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
84 }
85 case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
86 return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
87 case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
88 if reflect.ValueOf(value).Len() > 0 {
89 str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey))
90 clause["args"] = []interface{}{value}
91 }
92 return ""
93 case map[string]interface{}:
94 var sqls []string
95 for key, value := range value {
96 sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value)))
97 }
98 return strings.Join(sqls, " AND ")
99 case interface{}:
100 var sqls []string
101 for _, field := range scope.New(value).Fields() {
102 if !field.IsBlank {
103 sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
104 }
105 }
106 return strings.Join(sqls, " AND ")
107 }
108
109 args := clause["args"].([]interface{})
110 for _, arg := range args {
111 switch reflect.ValueOf(arg).Kind() {
112 case reflect.Slice: // For where("id in (?)", []int64{1,2})
113 values := reflect.ValueOf(arg)
114 var tempMarks []string
115 for i := 0; i < values.Len(); i++ {
116 tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
117 }
118 str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
119 default:
120 if scanner, ok := interface{}(arg).(driver.Valuer); ok {
121 arg, _ = scanner.Value()
122 }
123 str = strings.Replace(notEqualSql, "?", scope.AddToVars(arg), 1)
124 }
125 }
126 return
127 }
128
129 func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) {
130 switch value := clause["query"].(type) {
131 case string:
132 str = value
133 case []string:
134 str = strings.Join(value, ", ")
135 }
136
137 args := clause["args"].([]interface{})
138 for _, arg := range args {
139 switch reflect.ValueOf(arg).Kind() {
140 case reflect.Slice:
141 values := reflect.ValueOf(arg)
142 var tempMarks []string
143 for i := 0; i < values.Len(); i++ {
144 tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
145 }
146 str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
147 default:
148 if valuer, ok := interface{}(arg).(driver.Valuer); ok {
149 arg, _ = valuer.Value()
150 }
151 str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
152 }
153 }
154 return
155 }
156
157 func (scope *Scope) whereSql() (sql string) {
158 var primaryConditions, andConditions, orConditions []string
159
160 if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil {
161 sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
162 primaryConditions = append(primaryConditions, sql)
163 }
164
165 if !scope.PrimaryKeyZero() {
166 for _, field := range scope.PrimaryFields() {
167 sql := fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
168 primaryConditions = append(primaryConditions, sql)
169 }
170 }
171
172 for _, clause := range scope.Search.whereConditions {
173 if sql := scope.buildWhereCondition(clause); sql != "" {
174 andConditions = append(andConditions, sql)
175 }
176 }
177
178 for _, clause := range scope.Search.orConditions {
179 if sql := scope.buildWhereCondition(clause); sql != "" {
180 orConditions = append(orConditions, sql)
181 }
182 }
183
184 for _, clause := range scope.Search.notConditions {
185 if sql := scope.buildNotCondition(clause); sql != "" {
186 andConditions = append(andConditions, sql)
187 }
188 }
189
190 orSql := strings.Join(orConditions, " OR ")
191 combinedSql := strings.Join(andConditions, " AND ")
192 if len(combinedSql) > 0 {
193 if len(orSql) > 0 {
194 combinedSql = combinedSql + " OR " + orSql
195 }
196 } else {
197 combinedSql = orSql
198 }
199
200 if len(primaryConditions) > 0 {
201 sql = "WHERE " + strings.Join(primaryConditions, " AND ")
202 if len(combinedSql) > 0 {
203 sql = sql + " AND (" + combinedSql + ")"
204 }
205 } else if len(combinedSql) > 0 {
206 sql = "WHERE " + combinedSql
207 }
208 return
209 }
210
211 var hasCountRegexp = regexp.MustCompile(`(?i)count(.+)`)
212
213 func (scope *Scope) selectSql() string {
214 if len(scope.Search.selects) == 0 {
215 if scope.Search.joins != "" {
216 return fmt.Sprintf("%v.*", scope.QuotedTableName())
217 }
218 return "*"
219 }
220 sql := scope.buildSelectQuery(scope.Search.selects)
221 scope.Search.countingQuery = (len(scope.Search.group) == 0) && hasCountRegexp.MatchString(sql)
222 return sql
223 }
224
225 func (scope *Scope) orderSql() string {
226 if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
227 return ""
228 }
229 return " ORDER BY " + strings.Join(scope.Search.orders, ",")
230 }
231
232 func (scope *Scope) limitSql() string {
233 if !scope.Dialect().HasTop() {
234 if len(scope.Search.limit) == 0 {
235 return ""
236 }
237 return " LIMIT " + scope.Search.limit
238 }
239
240 return ""
241 }
242
243 func (scope *Scope) topSql() string {
244 if scope.Dialect().HasTop() && len(scope.Search.offset) == 0 {
245 if len(scope.Search.limit) == 0 {
246 return ""
247 }
248 return " TOP(" + scope.Search.limit + ")"
249 }
250
251 return ""
252 }
253
254 func (scope *Scope) offsetSql() string {
255 if len(scope.Search.offset) == 0 {
256 return ""
257 }
258
259 if scope.Dialect().HasTop() {
260 sql := " OFFSET " + scope.Search.offset + " ROW "
261 if len(scope.Search.limit) > 0 {
262 sql += "FETCH NEXT " + scope.Search.limit + " ROWS ONLY"
263 }
264 return sql
265 }
266 return " OFFSET " + scope.Search.offset
267 }
268
269 func (scope *Scope) groupSql() string {
270 if len(scope.Search.group) == 0 {
271 return ""
272 }
273 return " GROUP BY " + scope.Search.group
274 }
275
276 func (scope *Scope) havingSql() string {
277 if scope.Search.havingConditions == nil {
278 return ""
279 }
280
281 var andConditions []string
282
283 for _, clause := range scope.Search.havingConditions {
284 if sql := scope.buildWhereCondition(clause); sql != "" {
285 andConditions = append(andConditions, sql)
286 }
287 }
288
289 combinedSql := strings.Join(andConditions, " AND ")
290 if len(combinedSql) == 0 {
291 return ""
292 }
293
294 return " HAVING " + combinedSql
295 }
296
297 func (scope *Scope) joinsSql() string {
298 return scope.Search.joins + " "
299 }
300
301 func (scope *Scope) prepareQuerySql() {
302 if scope.Search.raw {
303 scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
304 } else {
305 scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql()))
306 }
307 return
308 }
309
310 func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
311 if len(values) > 0 {
312 scope.Search.Where(values[0], values[1:]...)
313 }
314 return scope
315 }
316
317 func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
318 for _, f := range funcs {
319 (*f)(scope)
320 if scope.skipLeft {
321 break
322 }
323 }
324 return scope
325 }
326
327 func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) {
328 if !scope.IndirectValue().CanAddr() {
329 return values, true
330 }
331
332 var hasExpr bool
333 fields := scope.Fields()
334 for key, value := range values {
335 if field, ok := fields[ToDBName(key)]; ok && field.Field.IsValid() {
336 if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
337 if _, ok := value.(*expr); ok {
338 hasExpr = true
339 } else if !equalAsString(field.Field.Interface(), value) {
340 hasUpdate = true
341 field.Set(value)
342 }
343 }
344 }
345 }
346 if hasExpr {
347 var updateMap = map[string]interface{}{}
348 for key, value := range fields {
349 if v, ok := values[key]; ok {
350 updateMap[key] = v
351 } else {
352 updateMap[key] = value.Field.Interface()
353 }
354 }
355 return updateMap, true
356 }
357 return
358 }
359
360 func (scope *Scope) row() *sql.Row {
361 defer scope.Trace(NowFunc())
362 scope.callCallbacks(scope.db.parent.callback.rowQueries)
363 scope.prepareQuerySql()
364 return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...)
365 }
366
367 func (scope *Scope) rows() (*sql.Rows, error) {
368 defer scope.Trace(NowFunc())
369 scope.callCallbacks(scope.db.parent.callback.rowQueries)
370 scope.prepareQuerySql()
371 return scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
372 }
373
374 func (scope *Scope) initialize() *Scope {
375 for _, clause := range scope.Search.whereConditions {
376 scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false)
377 }
378 scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false)
379 scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false)
380 return scope
381 }
382
383 func (scope *Scope) pluck(column string, value interface{}) *Scope {
384 dest := reflect.Indirect(reflect.ValueOf(value))
385 scope.Search.Select(column)
386 if dest.Kind() != reflect.Slice {
387 scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
388 return scope
389 }
390
391 rows, err := scope.rows()
392 if scope.Err(err) == nil {
393 defer rows.Close()
394 for rows.Next() {
395 elem := reflect.New(dest.Type().Elem()).Interface()
396 scope.Err(rows.Scan(elem))
397 dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem()))
398 }
399 }
400 return scope
401 }
402
403 func (scope *Scope) count(value interface{}) *Scope {
404 scope.Search.Select("count(*)")
405 scope.Err(scope.row().Scan(value))
406 return scope
407 }
408
409 func (scope *Scope) typeName() string {
410 value := scope.IndirectValue()
411 if value.Kind() == reflect.Slice {
412 return value.Type().Elem().Name()
413 }
414
415 return value.Type().Name()
416 }
417
418 func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
419 toScope := scope.db.NewScope(value)
420 fromFields := scope.Fields()
421 toFields := toScope.Fields()
422 for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
423 var fromField, toField *Field
424 if field, ok := scope.FieldByName(foreignKey); ok {
425 fromField = field
426 } else {
427 fromField = fromFields[ToDBName(foreignKey)]
428 }
429 if field, ok := toScope.FieldByName(foreignKey); ok {
430 toField = field
431 } else {
432 toField = toFields[ToDBName(foreignKey)]
433 }
434
435 if fromField != nil {
436 if relationship := fromField.Relationship; relationship != nil {
437 if relationship.Kind == "many_to_many" {
438 joinTableHandler := relationship.JoinTableHandler
439 scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error)
440 } else if relationship.Kind == "belongs_to" {
441 query := toScope.db
442 for idx, foreignKey := range relationship.ForeignDBNames {
443 if field, ok := scope.FieldByName(foreignKey); ok {
444 query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface())
445 }
446 }
447 scope.Err(query.Find(value).Error)
448 } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
449 query := toScope.db
450 for idx, foreignKey := range relationship.ForeignDBNames {
451 if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
452 query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
453 }
454 }
455
456 if relationship.PolymorphicType != "" {
457 query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
458 }
459 scope.Err(query.Find(value).Error)
460 }
461 } else {
462 sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
463 scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error)
464 }
465 return scope
466 } else if toField != nil {
467 sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
468 scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error)
469 return scope
470 }
471 }
472
473 scope.Err(fmt.Errorf("invalid association %v", foreignKeys))
474 return scope
475 }
476
477 /**
478 Return the table options string or an empty string if the table options does not exist
479 */
480 func (scope *Scope) getTableOptions() string {
481 tableOptions, ok := scope.Get("gorm:table_options")
482 if !ok {
483 return ""
484 }
485 return tableOptions.(string)
486 }
487
488 func (scope *Scope) createJoinTable(field *StructField) {
489 if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
490 joinTableHandler := relationship.JoinTableHandler
491 joinTable := joinTableHandler.Table(scope.db)
492 if !scope.Dialect().HasTable(scope, joinTable) {
493 toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
494
495 var sqlTypes []string
496 for idx, fieldName := range relationship.ForeignFieldNames {
497 if field, ok := scope.Fields()[fieldName]; ok {
498 value := reflect.Indirect(reflect.New(field.Struct.Type))
499 primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false)
500 sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType)
501 }
502 }
503
504 for idx, fieldName := range relationship.AssociationForeignFieldNames {
505 if field, ok := toScope.Fields()[fieldName]; ok {
506 value := reflect.Indirect(reflect.New(field.Struct.Type))
507 primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false)
508 sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType)
509 }
510 }
511 scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), scope.getTableOptions())).Error)
512 }
513 scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
514 }
515 }
516
517 func (scope *Scope) createTable() *Scope {
518 var tags []string
519 var primaryKeys []string
520 var primaryKeyInColumnType bool = false
521 for _, field := range scope.GetStructFields() {
522 if field.IsNormal {
523 sqlTag := scope.generateSqlTag(field)
524
525 // Check if the primary key constraint was specified as
526 // part of the column type. If so, we can only support
527 // one column as the primary key.
528 if strings.Contains(strings.ToLower(sqlTag), "primary key") {
529 primaryKeyInColumnType = true
530 }
531
532 tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
533 }
534
535 if field.IsPrimaryKey {
536 primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
537 }
538 scope.createJoinTable(field)
539 }
540
541 var primaryKeyStr string
542 if len(primaryKeys) > 0 && !primaryKeyInColumnType {
543 primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
544 }
545 scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
546 return scope
547 }
548
549 func (scope *Scope) dropTable() *Scope {
550 scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
551 return scope
552 }
553
554 func (scope *Scope) dropTableIfExists() *Scope {
555 if scope.Dialect().HasTable(scope, scope.TableName()) {
556 scope.dropTable()
557 }
558 return scope
559 }
560
561 func (scope *Scope) modifyColumn(column string, typ string) {
562 scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
563 }
564
565 func (scope *Scope) dropColumn(column string) {
566 scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec()
567 }
568
569 func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
570 if scope.Dialect().HasIndex(scope, scope.TableName(), indexName) {
571 return
572 }
573
574 var columns []string
575 for _, name := range column {
576 columns = append(columns, scope.QuoteIfPossible(name))
577 }
578
579 sqlCreate := "CREATE INDEX"
580 if unique {
581 sqlCreate = "CREATE UNIQUE INDEX"
582 }
583
584 scope.Search.Unscoped = true
585 scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSql())).Exec()
586 scope.Search.Unscoped = false
587 }
588
589 func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
590 var table = scope.TableName()
591 var keyName = fmt.Sprintf("%s_%s_%s_foreign", table, field, regexp.MustCompile("[^a-zA-Z]").ReplaceAllString(dest, "_"))
592 keyName = regexp.MustCompile("_+").ReplaceAllString(keyName, "_")
593 var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
594 scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.QuoteIfPossible(keyName), scope.QuoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
595 }
596
597 func (scope *Scope) removeIndex(indexName string) {
598 scope.Dialect().RemoveIndex(scope, indexName)
599 }
600
601 func (scope *Scope) autoMigrate() *Scope {
602 tableName := scope.TableName()
603 quotedTableName := scope.QuotedTableName()
604
605 if !scope.Dialect().HasTable(scope, tableName) {
606 scope.createTable()
607 } else {
608 for _, field := range scope.GetStructFields() {
609 if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
610 if field.IsNormal {
611 sqlTag := scope.generateSqlTag(field)
612 scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
613 }
614 }
615 scope.createJoinTable(field)
616 }
617 }
618
619 scope.autoIndex()
620 return scope
621 }
622
623 func (scope *Scope) autoIndex() *Scope {
624 var indexes = map[string][]string{}
625 var uniqueIndexes = map[string][]string{}
626
627 for _, field := range scope.GetStructFields() {
628 sqlSettings := parseTagSetting(field.Tag.Get("sql"))
629 if name, ok := sqlSettings["INDEX"]; ok {
630 if name == "INDEX" {
631 name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName)
632 }
633 indexes[name] = append(indexes[name], field.DBName)
634 }
635
636 if name, ok := sqlSettings["UNIQUE_INDEX"]; ok {
637 if name == "UNIQUE_INDEX" {
638 name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName)
639 }
640 uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
641 }
642 }
643
644 for name, columns := range indexes {
645 scope.addIndex(false, name, columns...)
646 }
647
648 for name, columns := range uniqueIndexes {
649 scope.addIndex(true, name, columns...)
650 }
651
652 return scope
653 }
0 package gorm_test
1
2 import (
3 "github.com/jinzhu/gorm"
4 "testing"
5 )
6
7 func NameIn1And2(d *gorm.DB) *gorm.DB {
8 return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"})
9 }
10
11 func NameIn2And3(d *gorm.DB) *gorm.DB {
12 return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"})
13 }
14
15 func NameIn(names []string) func(d *gorm.DB) *gorm.DB {
16 return func(d *gorm.DB) *gorm.DB {
17 return d.Where("name in (?)", names)
18 }
19 }
20
21 func TestScopes(t *testing.T) {
22 user1 := User{Name: "ScopeUser1", Age: 1}
23 user2 := User{Name: "ScopeUser2", Age: 1}
24 user3 := User{Name: "ScopeUser3", Age: 2}
25 DB.Save(&user1).Save(&user2).Save(&user3)
26
27 var users1, users2, users3 []User
28 DB.Scopes(NameIn1And2).Find(&users1)
29 if len(users1) != 2 {
30 t.Errorf("Should found two users's name in 1, 2")
31 }
32
33 DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2)
34 if len(users2) != 1 {
35 t.Errorf("Should found one user's name is 2")
36 }
37
38 DB.Scopes(NameIn([]string{user1.Name, user3.Name})).Find(&users3)
39 if len(users3) != 2 {
40 t.Errorf("Should found two users's name in 1, 3")
41 }
42 }
0 package gorm
1
2 import "fmt"
3
4 type search struct {
5 db *DB
6 whereConditions []map[string]interface{}
7 orConditions []map[string]interface{}
8 notConditions []map[string]interface{}
9 havingConditions []map[string]interface{}
10 initAttrs []interface{}
11 assignAttrs []interface{}
12 selects map[string]interface{}
13 omits []string
14 orders []string
15 joins string
16 preload []searchPreload
17 offset string
18 limit string
19 group string
20 tableName string
21 raw bool
22 Unscoped bool
23 countingQuery bool
24 }
25
26 type searchPreload struct {
27 schema string
28 conditions []interface{}
29 }
30
31 func (s *search) clone() *search {
32 clone := *s
33 return &clone
34 }
35
36 func (s *search) Where(query interface{}, values ...interface{}) *search {
37 s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values})
38 return s
39 }
40
41 func (s *search) Not(query interface{}, values ...interface{}) *search {
42 s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values})
43 return s
44 }
45
46 func (s *search) Or(query interface{}, values ...interface{}) *search {
47 s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values})
48 return s
49 }
50
51 func (s *search) Attrs(attrs ...interface{}) *search {
52 s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...))
53 return s
54 }
55
56 func (s *search) Assign(attrs ...interface{}) *search {
57 s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...))
58 return s
59 }
60
61 func (s *search) Order(value string, reorder ...bool) *search {
62 if len(reorder) > 0 && reorder[0] {
63 if value != "" {
64 s.orders = []string{value}
65 } else {
66 s.orders = []string{}
67 }
68 } else if value != "" {
69 s.orders = append(s.orders, value)
70 }
71 return s
72 }
73
74 func (s *search) Select(query interface{}, args ...interface{}) *search {
75 s.selects = map[string]interface{}{"query": query, "args": args}
76 return s
77 }
78
79 func (s *search) Omit(columns ...string) *search {
80 s.omits = columns
81 return s
82 }
83
84 func (s *search) Limit(value interface{}) *search {
85 s.limit = s.getInterfaceAsSql(value)
86 return s
87 }
88
89 func (s *search) Offset(value interface{}) *search {
90 s.offset = s.getInterfaceAsSql(value)
91 return s
92 }
93
94 func (s *search) Group(query string) *search {
95 s.group = s.getInterfaceAsSql(query)
96 return s
97 }
98
99 func (s *search) Having(query string, values ...interface{}) *search {
100 s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values})
101 return s
102 }
103
104 func (s *search) Joins(query string) *search {
105 s.joins = query
106 return s
107 }
108
109 func (s *search) Preload(schema string, values ...interface{}) *search {
110 var preloads []searchPreload
111 for _, preload := range s.preload {
112 if preload.schema != schema {
113 preloads = append(preloads, preload)
114 }
115 }
116 preloads = append(preloads, searchPreload{schema, values})
117 s.preload = preloads
118 return s
119 }
120
121 func (s *search) Raw(b bool) *search {
122 s.raw = b
123 return s
124 }
125
126 func (s *search) unscoped() *search {
127 s.Unscoped = true
128 return s
129 }
130
131 func (s *search) Table(name string) *search {
132 s.tableName = name
133 return s
134 }
135
136 func (s *search) getInterfaceAsSql(value interface{}) (str string) {
137 switch value.(type) {
138 case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
139 str = fmt.Sprintf("%v", value)
140 default:
141 s.db.AddError(InvalidSql)
142 }
143
144 if str == "-1" {
145 return ""
146 }
147 return
148 }
0 package gorm
1
2 import (
3 "reflect"
4 "testing"
5 )
6
7 func TestCloneSearch(t *testing.T) {
8 s := new(search)
9 s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Select("name, age")
10
11 s1 := s.clone()
12 s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Select("email")
13
14 if reflect.DeepEqual(s.whereConditions, s1.whereConditions) {
15 t.Errorf("Where should be copied")
16 }
17
18 if reflect.DeepEqual(s.orders, s1.orders) {
19 t.Errorf("Order should be copied")
20 }
21
22 if reflect.DeepEqual(s.initAttrs, s1.initAttrs) {
23 t.Errorf("InitAttrs should be copied")
24 }
25
26 if reflect.DeepEqual(s.Select, s1.Select) {
27 t.Errorf("selectStr should be copied")
28 }
29 }
0 package gorm_test
1
2 import (
3 "database/sql/driver"
4 "encoding/json"
5 "testing"
6 )
7
8 func TestScannableSlices(t *testing.T) {
9 if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil {
10 t.Errorf("Should create table with slice values correctly: %s", err)
11 }
12
13 r1 := RecordWithSlice{
14 Strings: ExampleStringSlice{"a", "b", "c"},
15 Structs: ExampleStructSlice{
16 {"name1", "value1"},
17 {"name2", "value2"},
18 },
19 }
20
21 if err := DB.Save(&r1).Error; err != nil {
22 t.Errorf("Should save record with slice values")
23 }
24
25 var r2 RecordWithSlice
26
27 if err := DB.Find(&r2).Error; err != nil {
28 t.Errorf("Should fetch record with slice values")
29 }
30
31 if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" {
32 t.Errorf("Should have serialised and deserialised a string array")
33 }
34
35 if len(r2.Structs) != 2 || r2.Structs[0].Name != "name1" || r2.Structs[0].Value != "value1" || r2.Structs[1].Name != "name2" || r2.Structs[1].Value != "value2" {
36 t.Errorf("Should have serialised and deserialised a struct array")
37 }
38 }
39
40 type RecordWithSlice struct {
41 ID uint64
42 Strings ExampleStringSlice `sql:"type:text"`
43 Structs ExampleStructSlice `sql:"type:text"`
44 }
45
46 type ExampleStringSlice []string
47
48 func (l ExampleStringSlice) Value() (driver.Value, error) {
49 return json.Marshal(l)
50 }
51
52 func (l *ExampleStringSlice) Scan(input interface{}) error {
53 return json.Unmarshal(input.([]byte), l)
54 }
55
56 type ExampleStruct struct {
57 Name string
58 Value string
59 }
60
61 type ExampleStructSlice []ExampleStruct
62
63 func (l ExampleStructSlice) Value() (driver.Value, error) {
64 return json.Marshal(l)
65 }
66
67 func (l *ExampleStructSlice) Scan(input interface{}) error {
68 return json.Unmarshal(input.([]byte), l)
69 }
0 package gorm
1
2 import (
3 "fmt"
4 "reflect"
5 "time"
6 )
7
8 type sqlite3 struct {
9 commonDialect
10 }
11
12 func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
13 switch value.Kind() {
14 case reflect.Bool:
15 return "bool"
16 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
17 if autoIncrease {
18 return "integer primary key autoincrement"
19 }
20 return "integer"
21 case reflect.Int64, reflect.Uint64:
22 if autoIncrease {
23 return "integer primary key autoincrement"
24 }
25 return "bigint"
26 case reflect.Float32, reflect.Float64:
27 return "real"
28 case reflect.String:
29 if size > 0 && size < 65532 {
30 return fmt.Sprintf("varchar(%d)", size)
31 }
32 return "text"
33 case reflect.Struct:
34 if _, ok := value.Interface().(time.Time); ok {
35 return "datetime"
36 }
37 default:
38 if _, ok := value.Interface().([]byte); ok {
39 return "blob"
40 }
41 }
42 panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
43 }
44
45 func (s sqlite3) HasTable(scope *Scope, tableName string) bool {
46 var count int
47 s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName)
48 return count > 0
49 }
50
51 func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
52 var count int
53 s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", columnName, columnName, columnName, columnName), tableName)
54 return count > 0
55 }
56
57 func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
58 var count int
59 s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
60 return count > 0
61 }
62
63 func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
64 scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
65 }
66
67 func (sqlite3) CurrentDatabase(scope *Scope) (name string) {
68 var (
69 ifaces = make([]interface{}, 3)
70 pointers = make([]*string, 3)
71 i int
72 )
73 for i = 0; i < 3; i++ {
74 ifaces[i] = &pointers[i]
75 }
76 if err := scope.NewDB().Raw("PRAGMA database_list").Row().Scan(ifaces...); scope.Err(err) != nil {
77 return
78 }
79 if pointers[1] != nil {
80 name = *pointers[1]
81 }
82 return
83 }
0 package gorm_test
1
2 import (
3 "database/sql"
4 "database/sql/driver"
5 "errors"
6 "fmt"
7
8 "reflect"
9 "time"
10 )
11
12 type User struct {
13 Id int64
14 Age int64
15 UserNum Num
16 Name string `sql:"size:255"`
17 Birthday time.Time // Time
18 CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
19 UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
20 Emails []Email // Embedded structs
21 BillingAddress Address // Embedded struct
22 BillingAddressID sql.NullInt64 // Embedded struct's foreign key
23 ShippingAddress Address // Embedded struct
24 ShippingAddressId int64 // Embedded struct's foreign key
25 CreditCard CreditCard
26 Latitude float64
27 Languages []Language `gorm:"many2many:user_languages;"`
28 CompanyID int64
29 Company Company
30 Role
31 PasswordHash []byte
32 IgnoreMe int64 `sql:"-"`
33 IgnoreStringSlice []string `sql:"-"`
34 Ignored struct{ Name string } `sql:"-"`
35 IgnoredPointer *User `sql:"-"`
36 }
37
38 type CreditCard struct {
39 ID int8
40 Number string
41 UserId sql.NullInt64
42 CreatedAt time.Time
43 UpdatedAt time.Time
44 DeletedAt time.Time
45 }
46
47 type Email struct {
48 Id int16
49 UserId int
50 Email string `sql:"type:varchar(100);"`
51 CreatedAt time.Time
52 UpdatedAt time.Time
53 }
54
55 type Address struct {
56 ID int
57 Address1 string
58 Address2 string
59 Post string
60 CreatedAt time.Time
61 UpdatedAt time.Time
62 DeletedAt time.Time
63 }
64
65 type Language struct {
66 Id int
67 Name string
68 Users []User `gorm:"many2many:user_languages;"`
69 }
70
71 type Product struct {
72 Id int64
73 Code string
74 Price int64
75 CreatedAt time.Time
76 UpdatedAt time.Time
77 AfterFindCallTimes int64
78 BeforeCreateCallTimes int64
79 AfterCreateCallTimes int64
80 BeforeUpdateCallTimes int64
81 AfterUpdateCallTimes int64
82 BeforeSaveCallTimes int64
83 AfterSaveCallTimes int64
84 BeforeDeleteCallTimes int64
85 AfterDeleteCallTimes int64
86 }
87
88 type Company struct {
89 Id int64
90 Name string
91 Owner *User `sql:"-"`
92 }
93
94 type Role struct {
95 Name string
96 }
97
98 func (role *Role) Scan(value interface{}) error {
99 if b, ok := value.([]uint8); ok {
100 role.Name = string(b)
101 } else {
102 role.Name = value.(string)
103 }
104 return nil
105 }
106
107 func (role Role) Value() (driver.Value, error) {
108 return role.Name, nil
109 }
110
111 func (role Role) IsAdmin() bool {
112 return role.Name == "admin"
113 }
114
115 type Num int64
116
117 func (i *Num) Scan(src interface{}) error {
118 switch s := src.(type) {
119 case []byte:
120 case int64:
121 *i = Num(s)
122 default:
123 return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
124 }
125 return nil
126 }
127
128 type Animal struct {
129 Counter uint64 `gorm:"primary_key:yes"`
130 Name string `sql:"DEFAULT:'galeone'"`
131 From string //test reserved sql keyword as field name
132 Age time.Time `sql:"DEFAULT:current_timestamp"`
133 unexported string // unexported value
134 CreatedAt time.Time
135 UpdatedAt time.Time
136 }
137
138 type JoinTable struct {
139 From uint64
140 To uint64
141 Time time.Time `sql:"default: null"`
142 }
143
144 type Post struct {
145 Id int64
146 CategoryId sql.NullInt64
147 MainCategoryId int64
148 Title string
149 Body string
150 Comments []*Comment
151 Category Category
152 MainCategory Category
153 }
154
155 type Category struct {
156 Id int64
157 Name string
158 }
159
160 type Comment struct {
161 Id int64
162 PostId int64
163 Content string
164 Post Post
165 }
166
167 // Scanner
168 type NullValue struct {
169 Id int64
170 Name sql.NullString `sql:"not null"`
171 Age sql.NullInt64
172 Male sql.NullBool
173 Height sql.NullFloat64
174 AddedAt NullTime
175 }
176
177 type NullTime struct {
178 Time time.Time
179 Valid bool
180 }
181
182 func (nt *NullTime) Scan(value interface{}) error {
183 if value == nil {
184 nt.Valid = false
185 return nil
186 }
187 nt.Time, nt.Valid = value.(time.Time), true
188 return nil
189 }
190
191 func (nt NullTime) Value() (driver.Value, error) {
192 if !nt.Valid {
193 return nil, nil
194 }
195 return nt.Time, nil
196 }
197
198 func getPreparedUser(name string, role string) *User {
199 var company Company
200 DB.Where(Company{Name: role}).FirstOrCreate(&company)
201
202 return &User{
203 Name: name,
204 Age: 20,
205 Role: Role{role},
206 BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
207 ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
208 CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
209 Emails: []Email{
210 {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
211 },
212 Company: company,
213 Languages: []Language{
214 {Name: fmt.Sprintf("lang_1_%v", name)},
215 {Name: fmt.Sprintf("lang_2_%v", name)},
216 },
217 }
218 }
0 dialects=("postgres" "mysql" "sqlite")
1
2 for dialect in "${dialects[@]}" ; do
3 GORM_DIALECT=${dialect} go test
4 done
0 package gorm_test
1
2 import (
3 "testing"
4 "time"
5
6 "github.com/jinzhu/gorm"
7 )
8
9 func TestUpdate(t *testing.T) {
10 product1 := Product{Code: "product1code"}
11 product2 := Product{Code: "product2code"}
12
13 DB.Save(&product1).Save(&product2).Update("code", "product2newcode")
14
15 if product2.Code != "product2newcode" {
16 t.Errorf("Record should be updated")
17 }
18
19 DB.First(&product1, product1.Id)
20 DB.First(&product2, product2.Id)
21 updatedAt1 := product1.UpdatedAt
22 updatedAt2 := product2.UpdatedAt
23
24 var product3 Product
25 DB.First(&product3, product2.Id).Update("code", "product2newcode")
26 if updatedAt2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
27 t.Errorf("updatedAt should not be updated if nothing changed")
28 }
29
30 if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() {
31 t.Errorf("Product1 should not be updated")
32 }
33
34 if !DB.First(&Product{}, "code = ?", "product2code").RecordNotFound() {
35 t.Errorf("Product2's code should be updated")
36 }
37
38 if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() {
39 t.Errorf("Product2's code should be updated")
40 }
41
42 DB.Table("products").Where("code in (?)", []string{"product1code"}).Update("code", "product1newcode")
43
44 var product4 Product
45 DB.First(&product4, product1.Id)
46 if updatedAt1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
47 t.Errorf("updatedAt should be updated if something changed")
48 }
49
50 if !DB.First(&Product{}, "code = 'product1code'").RecordNotFound() {
51 t.Errorf("Product1's code should be updated")
52 }
53
54 if DB.First(&Product{}, "code = 'product1newcode'").RecordNotFound() {
55 t.Errorf("Product should not be changed to 789")
56 }
57
58 if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
59 t.Error("No error should raise when update with CamelCase")
60 }
61
62 if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
63 t.Error("No error should raise when update_column with CamelCase")
64 }
65
66 var products []Product
67 DB.Find(&products)
68 if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) {
69 t.Error("RowsAffected should be correct when do batch update")
70 }
71
72 DB.First(&product4, product4.Id)
73 DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
74 var product5 Product
75 DB.First(&product5, product4.Id)
76 if product5.Price != product4.Price+100-50 {
77 t.Errorf("Update with expression")
78 }
79 if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
80 t.Errorf("Update with expression should update UpdatedAt")
81 }
82 }
83
84 func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
85 animal := Animal{Name: "Ferdinand"}
86 DB.Save(&animal)
87 updatedAt1 := animal.UpdatedAt
88
89 DB.Save(&animal).Update("name", "Francis")
90
91 if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) {
92 t.Errorf("updatedAt should not be updated if nothing changed")
93 }
94
95 var animals []Animal
96 DB.Find(&animals)
97 if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) {
98 t.Error("RowsAffected should be correct when do batch update")
99 }
100
101 animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone)
102 DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched
103 DB.First(&animal, animal.Counter)
104 if animal.Name != "galeone" {
105 t.Errorf("Name fiels shouldn't be changed if untouched, but got %v", animal.Name)
106 }
107
108 // When changing a field with a default value, the change must occur
109 animal.Name = "amazing horse"
110 DB.Save(&animal)
111 DB.First(&animal, animal.Counter)
112 if animal.Name != "amazing horse" {
113 t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name)
114 }
115
116 // When changing a field with a default value with blank value
117 animal.Name = ""
118 DB.Save(&animal)
119 DB.First(&animal, animal.Counter)
120 if animal.Name != "" {
121 t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name)
122 }
123 }
124
125 func TestUpdates(t *testing.T) {
126 product1 := Product{Code: "product1code", Price: 10}
127 product2 := Product{Code: "product2code", Price: 10}
128 DB.Save(&product1).Save(&product2)
129 DB.Model(&product1).Updates(map[string]interface{}{"code": "product1newcode", "price": 100})
130 if product1.Code != "product1newcode" || product1.Price != 100 {
131 t.Errorf("Record should be updated also with map")
132 }
133
134 DB.First(&product1, product1.Id)
135 DB.First(&product2, product2.Id)
136 updatedAt1 := product1.UpdatedAt
137 updatedAt2 := product2.UpdatedAt
138
139 var product3 Product
140 DB.First(&product3, product1.Id).Updates(Product{Code: "product1newcode", Price: 100})
141 if product3.Code != "product1newcode" || product3.Price != 100 {
142 t.Errorf("Record should be updated with struct")
143 }
144
145 if updatedAt1.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
146 t.Errorf("updatedAt should not be updated if nothing changed")
147 }
148
149 if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() {
150 t.Errorf("Product2 should not be updated")
151 }
152
153 if DB.First(&Product{}, "code = ?", "product1newcode").RecordNotFound() {
154 t.Errorf("Product1 should be updated")
155 }
156
157 DB.Table("products").Where("code in (?)", []string{"product2code"}).Updates(Product{Code: "product2newcode"})
158 if !DB.First(&Product{}, "code = 'product2code'").RecordNotFound() {
159 t.Errorf("Product2's code should be updated")
160 }
161
162 var product4 Product
163 DB.First(&product4, product2.Id)
164 if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
165 t.Errorf("updatedAt should be updated if something changed")
166 }
167
168 if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() {
169 t.Errorf("product2's code should be updated")
170 }
171
172 DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
173 var product5 Product
174 DB.First(&product5, product4.Id)
175 if product5.Price != product4.Price+100 {
176 t.Errorf("Updates with expression")
177 }
178 if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
179 t.Errorf("Updates with expression should update UpdatedAt")
180 }
181 }
182
183 func TestUpdateColumn(t *testing.T) {
184 product1 := Product{Code: "product1code", Price: 10}
185 product2 := Product{Code: "product2code", Price: 20}
186 DB.Save(&product1).Save(&product2).UpdateColumn(map[string]interface{}{"code": "product2newcode", "price": 100})
187 if product2.Code != "product2newcode" || product2.Price != 100 {
188 t.Errorf("product 2 should be updated with update column")
189 }
190
191 var product3 Product
192 DB.First(&product3, product1.Id)
193 if product3.Code != "product1code" || product3.Price != 10 {
194 t.Errorf("product 1 should not be updated")
195 }
196
197 DB.First(&product2, product2.Id)
198 updatedAt2 := product2.UpdatedAt
199 DB.Model(product2).UpdateColumn("code", "update_column_new")
200 var product4 Product
201 DB.First(&product4, product2.Id)
202 if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
203 t.Errorf("updatedAt should not be updated with update column")
204 }
205
206 DB.Model(&product4).UpdateColumn("price", gorm.Expr("price + 100 - 50"))
207 var product5 Product
208 DB.First(&product5, product4.Id)
209 if product5.Price != product4.Price+100-50 {
210 t.Errorf("UpdateColumn with expression")
211 }
212 if product5.UpdatedAt.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
213 t.Errorf("UpdateColumn with expression should not update UpdatedAt")
214 }
215 }
216
217 func TestSelectWithUpdate(t *testing.T) {
218 user := getPreparedUser("select_user", "select_with_update")
219 DB.Create(user)
220
221 var reloadUser User
222 DB.First(&reloadUser, user.Id)
223 reloadUser.Name = "new_name"
224 reloadUser.Age = 50
225 reloadUser.BillingAddress = Address{Address1: "New Billing Address"}
226 reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"}
227 reloadUser.CreditCard = CreditCard{Number: "987654321"}
228 reloadUser.Emails = []Email{
229 {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"},
230 }
231 reloadUser.Company = Company{Name: "new company"}
232
233 DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser)
234
235 var queryUser User
236 DB.Preload("BillingAddress").Preload("ShippingAddress").
237 Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id)
238
239 if queryUser.Name == user.Name || queryUser.Age != user.Age {
240 t.Errorf("Should only update users with name column")
241 }
242
243 if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 ||
244 queryUser.ShippingAddressId != user.ShippingAddressId ||
245 queryUser.CreditCard.ID == user.CreditCard.ID ||
246 len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id {
247 t.Errorf("Should only update selected relationships")
248 }
249 }
250
251 func TestSelectWithUpdateWithMap(t *testing.T) {
252 user := getPreparedUser("select_user", "select_with_update_map")
253 DB.Create(user)
254
255 updateValues := map[string]interface{}{
256 "Name": "new_name",
257 "Age": 50,
258 "BillingAddress": Address{Address1: "New Billing Address"},
259 "ShippingAddress": Address{Address1: "New ShippingAddress Address"},
260 "CreditCard": CreditCard{Number: "987654321"},
261 "Emails": []Email{
262 {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"},
263 },
264 "Company": Company{Name: "new company"},
265 }
266
267 var reloadUser User
268 DB.First(&reloadUser, user.Id)
269 DB.Model(&reloadUser).Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues)
270
271 var queryUser User
272 DB.Preload("BillingAddress").Preload("ShippingAddress").
273 Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id)
274
275 if queryUser.Name == user.Name || queryUser.Age != user.Age {
276 t.Errorf("Should only update users with name column")
277 }
278
279 if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 ||
280 queryUser.ShippingAddressId != user.ShippingAddressId ||
281 queryUser.CreditCard.ID == user.CreditCard.ID ||
282 len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id {
283 t.Errorf("Should only update selected relationships")
284 }
285 }
286
287 func TestOmitWithUpdate(t *testing.T) {
288 user := getPreparedUser("omit_user", "omit_with_update")
289 DB.Create(user)
290
291 var reloadUser User
292 DB.First(&reloadUser, user.Id)
293 reloadUser.Name = "new_name"
294 reloadUser.Age = 50
295 reloadUser.BillingAddress = Address{Address1: "New Billing Address"}
296 reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"}
297 reloadUser.CreditCard = CreditCard{Number: "987654321"}
298 reloadUser.Emails = []Email{
299 {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"},
300 }
301 reloadUser.Company = Company{Name: "new company"}
302
303 DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser)
304
305 var queryUser User
306 DB.Preload("BillingAddress").Preload("ShippingAddress").
307 Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id)
308
309 if queryUser.Name != user.Name || queryUser.Age == user.Age {
310 t.Errorf("Should only update users with name column")
311 }
312
313 if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 ||
314 queryUser.ShippingAddressId == user.ShippingAddressId ||
315 queryUser.CreditCard.ID != user.CreditCard.ID ||
316 len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id {
317 t.Errorf("Should only update relationships that not omited")
318 }
319 }
320
321 func TestOmitWithUpdateWithMap(t *testing.T) {
322 user := getPreparedUser("select_user", "select_with_update_map")
323 DB.Create(user)
324
325 updateValues := map[string]interface{}{
326 "Name": "new_name",
327 "Age": 50,
328 "BillingAddress": Address{Address1: "New Billing Address"},
329 "ShippingAddress": Address{Address1: "New ShippingAddress Address"},
330 "CreditCard": CreditCard{Number: "987654321"},
331 "Emails": []Email{
332 {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"},
333 },
334 "Company": Company{Name: "new company"},
335 }
336
337 var reloadUser User
338 DB.First(&reloadUser, user.Id)
339 DB.Model(&reloadUser).Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues)
340
341 var queryUser User
342 DB.Preload("BillingAddress").Preload("ShippingAddress").
343 Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id)
344
345 if queryUser.Name != user.Name || queryUser.Age == user.Age {
346 t.Errorf("Should only update users with name column")
347 }
348
349 if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 ||
350 queryUser.ShippingAddressId == user.ShippingAddressId ||
351 queryUser.CreditCard.ID != user.CreditCard.ID ||
352 len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id {
353 t.Errorf("Should only update relationships not omited")
354 }
355 }
356
357 func TestSelectWithUpdateColumn(t *testing.T) {
358 user := getPreparedUser("select_user", "select_with_update_map")
359 DB.Create(user)
360
361 updateValues := map[string]interface{}{"Name": "new_name", "Age": 50}
362
363 var reloadUser User
364 DB.First(&reloadUser, user.Id)
365 DB.Model(&reloadUser).Select("Name").UpdateColumn(updateValues)
366
367 var queryUser User
368 DB.First(&queryUser, user.Id)
369
370 if queryUser.Name == user.Name || queryUser.Age != user.Age {
371 t.Errorf("Should only update users with name column")
372 }
373 }
374
375 func TestOmitWithUpdateColumn(t *testing.T) {
376 user := getPreparedUser("select_user", "select_with_update_map")
377 DB.Create(user)
378
379 updateValues := map[string]interface{}{"Name": "new_name", "Age": 50}
380
381 var reloadUser User
382 DB.First(&reloadUser, user.Id)
383 DB.Model(&reloadUser).Omit("Name").UpdateColumn(updateValues)
384
385 var queryUser User
386 DB.First(&queryUser, user.Id)
387
388 if queryUser.Name != user.Name || queryUser.Age == user.Age {
389 t.Errorf("Should omit name column when update user")
390 }
391 }
392
393 func TestUpdateColumnsSkipsAssociations(t *testing.T) {
394 user := getPreparedUser("update_columns_user", "special_role")
395 user.Age = 99
396 address1 := "first street"
397 user.BillingAddress = Address{Address1: address1}
398 DB.Save(user)
399
400 // Update a single field of the user and verify that the changed address is not stored.
401 newAge := int64(100)
402 user.BillingAddress.Address1 = "second street"
403 db := DB.Model(user).UpdateColumns(User{Age: newAge})
404 if db.RowsAffected != 1 {
405 t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected)
406 }
407
408 // Verify that Age now=`newAge`.
409 freshUser := &User{Id: user.Id}
410 DB.First(freshUser)
411 if freshUser.Age != newAge {
412 t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, freshUser.Age)
413 }
414
415 // Verify that user's BillingAddress.Address1 is not changed and is still "first street".
416 DB.First(&freshUser.BillingAddress, freshUser.BillingAddressID)
417 if freshUser.BillingAddress.Address1 != address1 {
418 t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1)
419 }
420 }
0 package gorm
1
2 import (
3 "bytes"
4 "strings"
5 "sync"
6 )
7
8 // Copied from golint
9 var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
10 var commonInitialismsReplacer *strings.Replacer
11
12 func init() {
13 var commonInitialismsForReplacer []string
14 for _, initialism := range commonInitialisms {
15 commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
16 }
17 commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
18 }
19
20 type safeMap struct {
21 m map[string]string
22 l *sync.RWMutex
23 }
24
25 func (s *safeMap) Set(key string, value string) {
26 s.l.Lock()
27 defer s.l.Unlock()
28 s.m[key] = value
29 }
30
31 func (s *safeMap) Get(key string) string {
32 s.l.RLock()
33 defer s.l.RUnlock()
34 return s.m[key]
35 }
36
37 func newSafeMap() *safeMap {
38 return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
39 }
40
41 var smap = newSafeMap()
42
43 func ToDBName(name string) string {
44 if v := smap.Get(name); v != "" {
45 return v
46 }
47
48 value := commonInitialismsReplacer.Replace(name)
49 buf := bytes.NewBufferString("")
50 for i, v := range value {
51 if i > 0 && v >= 'A' && v <= 'Z' {
52 buf.WriteRune('_')
53 }
54 buf.WriteRune(v)
55 }
56
57 s := strings.ToLower(buf.String())
58 smap.Set(name, s)
59 return s
60 }
61
62 type expr struct {
63 expr string
64 args []interface{}
65 }
66
67 func Expr(expression string, args ...interface{}) *expr {
68 return &expr{expr: expression, args: args}
69 }
0 package gorm
1
2 import (
3 "fmt"
4 "reflect"
5 "regexp"
6 "runtime"
7 "strings"
8 )
9
10 func fileWithLineNum() string {
11 for i := 2; i < 15; i++ {
12 _, file, line, ok := runtime.Caller(i)
13 if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
14 return fmt.Sprintf("%v:%v", file, line)
15 }
16 }
17 return ""
18 }
19
20 func isBlank(value reflect.Value) bool {
21 return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
22 }
23
24 func toSearchableMap(attrs ...interface{}) (result interface{}) {
25 if len(attrs) > 1 {
26 if str, ok := attrs[0].(string); ok {
27 result = map[string]interface{}{str: attrs[1]}
28 }
29 } else if len(attrs) == 1 {
30 if attr, ok := attrs[0].(map[string]interface{}); ok {
31 result = attr
32 }
33
34 if attr, ok := attrs[0].(interface{}); ok {
35 result = attr
36 }
37 }
38 return
39 }
40
41 func convertInterfaceToMap(values interface{}) map[string]interface{} {
42 attrs := map[string]interface{}{}
43
44 switch value := values.(type) {
45 case map[string]interface{}:
46 for k, v := range value {
47 attrs[ToDBName(k)] = v
48 }
49 case []interface{}:
50 for _, v := range value {
51 for key, value := range convertInterfaceToMap(v) {
52 attrs[key] = value
53 }
54 }
55 case interface{}:
56 reflectValue := reflect.ValueOf(values)
57
58 switch reflectValue.Kind() {
59 case reflect.Map:
60 for _, key := range reflectValue.MapKeys() {
61 attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
62 }
63 default:
64 scope := Scope{Value: values}
65 for _, field := range scope.Fields() {
66 if !field.IsBlank && !field.IsIgnored {
67 attrs[field.DBName] = field.Field.Interface()
68 }
69 }
70 }
71 }
72 return attrs
73 }
74
75 func toString(str interface{}) string {
76 if values, ok := str.([]interface{}); ok {
77 var results []string
78 for _, value := range values {
79 results = append(results, toString(value))
80 }
81 return strings.Join(results, "_")
82 } else if bytes, ok := str.([]byte); ok {
83 return string(bytes)
84 } else {
85 return fmt.Sprintf("%v", str)
86 }
87 }
88
89 func strInSlice(a string, list []string) bool {
90 for _, b := range list {
91 if b == a {
92 return true
93 }
94 }
95 return false
96 }