Imported Upstream version 0.0~git20151012.0.20e37a0
Tianon Gravi
8 years ago
0 | --- | |
1 | engines: | |
2 | gofmt: | |
3 | enabled: true | |
4 | govet: | |
5 | enabled: true | |
6 | golint: | |
7 | enabled: true | |
8 | ratings: | |
9 | paths: | |
10 | - "**.go" |
0 | The MIT License (MIT) | |
1 | ||
2 | Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com> | |
3 | ||
4 | Permission is hereby granted, free of charge, to any person obtaining a copy | |
5 | of this software and associated documentation files (the "Software"), to deal | |
6 | in the Software without restriction, including without limitation the rights | |
7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
8 | copies of the Software, and to permit persons to whom the Software is | |
9 | furnished to do so, subject to the following conditions: | |
10 | ||
11 | The above copyright notice and this permission notice shall be included in | |
12 | all copies or substantial portions of the Software. | |
13 | ||
14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | |
20 | THE SOFTWARE. |
0 | # GORM | |
1 | ||
2 | [![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) | |
3 | ||
4 | The fantastic ORM library for Golang, aims to be developer friendly. | |
5 | ||
6 | [![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921) | |
7 | ||
8 | ## Overview | |
9 | ||
10 | * Full-Featured ORM (almost) | |
11 | * Chainable API | |
12 | * Auto Migrations | |
13 | * Relations (Has One, Has Many, Belongs To, Many To Many, [Polymorphism](#polymorphism)) | |
14 | * Callbacks (Before/After Create/Save/Update/Delete/Find) | |
15 | * Preloading (eager loading) | |
16 | * Transactions | |
17 | * Embed Anonymous Struct | |
18 | * Soft Deletes | |
19 | * Customizable Logger | |
20 | * Iteration Support via [Rows](#row--rows) | |
21 | * Every feature comes with tests | |
22 | * Developer Friendly | |
23 | ||
24 | # Getting Started | |
25 | ||
26 | ## Install | |
27 | ||
28 | ``` | |
29 | go get -u github.com/jinzhu/gorm | |
30 | ``` | |
31 | ||
32 | ## Documentation | |
33 | ||
34 | [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) | |
35 | ||
36 | `go doc` format documentation for this project can be viewed online without | |
37 | installing the package by using the GoDoc page at: | |
38 | http://godoc.org/github.com/jinzhu/gorm | |
39 | ||
40 | ## Table of Contents | |
41 | ||
42 | - [Define Models (Structs)](#define-models-structs) | |
43 | - [Conventions](#conventions) | |
44 | - [Initialize Database](#initialize-database) | |
45 | - [Migration](#migration) | |
46 | - [Basic CRUD](#basic-crud) | |
47 | - [Create](#create-record) | |
48 | - [Query](#query) | |
49 | - [Query With Where (Plain SQL)](#query-with-where-plain-sql) | |
50 | - [Query With Where (Struct & Map)](#query-with-where-struct--map) | |
51 | - [Query With Not](#query-with-not) | |
52 | - [Query With Inline Condition](#query-with-inline-condition) | |
53 | - [Query With Or](#query-with-or) | |
54 | - [Query Chains](#query-chains) | |
55 | - [Preloading (Eager loading)](#preloading-eager-loading) | |
56 | - [Update](#update) | |
57 | - [Update Without Callbacks](#update-without-callbacks) | |
58 | - [Batch Updates](#batch-updates) | |
59 | - [Update with SQL Expression](#update-with-sql-expression) | |
60 | - [Delete](#delete) | |
61 | - [Batch Delete](#batch-delete) | |
62 | - [Soft Delete](#soft-delete) | |
63 | - [Associations](#associations) | |
64 | - [Has One](#has-one) | |
65 | - [Belongs To](#belongs-to) | |
66 | - [Has Many](#has-many) | |
67 | - [Many To Many](#many-to-many) | |
68 | - [Polymorphism](#polymorphism) | |
69 | - [Advanced Usage](#advanced-usage) | |
70 | - [FirstOrInit](#firstorinit) | |
71 | - [FirstOrCreate](#firstorcreate) | |
72 | - [Select](#select) | |
73 | - [Order](#order) | |
74 | - [Limit](#limit) | |
75 | - [Offset](#offset) | |
76 | - [Count](#count) | |
77 | - [Pluck](#pluck) | |
78 | - [Raw SQL](#raw-sql) | |
79 | - [Row & Rows](#row--rows) | |
80 | - [Scan](#scan) | |
81 | - [Group & Having](#group--having) | |
82 | - [Joins](#joins) | |
83 | - [Transactions](#transactions) | |
84 | - [Scopes](#scopes) | |
85 | - [Callbacks](#callbacks) | |
86 | - [Specifying The Table Name](#specifying-the-table-name) | |
87 | - [Error Handling](#error-handling) | |
88 | - [Logger](#logger) | |
89 | - [Existing Schema](#existing-schema) | |
90 | - [Composite Primary Key](#composite-primary-key) | |
91 | - [Database Indexes & Foreign Key](#database-indexes--foreign-key) | |
92 | - [Default values](#default-values) | |
93 | - [More examples with query chain](#more-examples-with-query-chain) | |
94 | ||
95 | ## Define Models (Structs) | |
96 | ||
97 | ```go | |
98 | type User struct { | |
99 | ID int | |
100 | Birthday time.Time | |
101 | Age int | |
102 | Name string `sql:"size:255"` // Default size for string is 255, you could reset it with this tag | |
103 | Num int `sql:"AUTO_INCREMENT"` | |
104 | CreatedAt time.Time | |
105 | UpdatedAt time.Time | |
106 | DeletedAt *time.Time | |
107 | ||
108 | Emails []Email // One-To-Many relationship (has many) | |
109 | BillingAddress Address // One-To-One relationship (has one) | |
110 | BillingAddressID sql.NullInt64 // Foreign key of BillingAddress | |
111 | ShippingAddress Address // One-To-One relationship (has one) | |
112 | ShippingAddressID int // Foreign key of ShippingAddress | |
113 | IgnoreMe int `sql:"-"` // Ignore this field | |
114 | Languages []Language `gorm:"many2many:user_languages;"` // Many-To-Many relationship, 'user_languages' is join table | |
115 | } | |
116 | ||
117 | type Email struct { | |
118 | ID int | |
119 | UserID int `sql:"index"` // Foreign key (belongs to), tag `index` will create index for this field when using AutoMigrate | |
120 | Email string `sql:"type:varchar(100);unique_index"` // Set field's sql type, tag `unique_index` will create unique index | |
121 | Subscribed bool | |
122 | } | |
123 | ||
124 | type Address struct { | |
125 | ID int | |
126 | Address1 string `sql:"not null;unique"` // Set field as not nullable and unique | |
127 | Address2 string `sql:"type:varchar(100);unique"` | |
128 | Post sql.NullString `sql:"not null"` | |
129 | } | |
130 | ||
131 | type Language struct { | |
132 | ID int | |
133 | Name string `sql:"index:idx_name_code"` // Create index with name, and will create combined index if find other fields defined same name | |
134 | Code string `sql:"index:idx_name_code"` // `unique_index` also works | |
135 | } | |
136 | ``` | |
137 | ||
138 | ## Conventions | |
139 | ||
140 | * Table name is the plural of struct name's snake case, you can disable pluralization with `db.SingularTable(true)`, or [Specifying The Table Name For A Struct Permanently With TableName](#specifying-the-table-name-for-a-struct-permanently-with-tablename) | |
141 | ||
142 | ```go | |
143 | type User struct{} // struct User's database table name is "users" by default, will be "user" if you disabled pluralisation | |
144 | ``` | |
145 | ||
146 | * Column name is the snake case of field's name | |
147 | * Use `ID` field as primary key | |
148 | * Use `CreatedAt` to store record's created time if field exists | |
149 | * Use `UpdatedAt` to store record's updated time if field exists | |
150 | * Use `DeletedAt` to store record's deleted time if field exists [Soft Delete](#soft-delete) | |
151 | * Gorm provide a default model struct, you could embed it in your struct | |
152 | ||
153 | ```go | |
154 | type Model struct { | |
155 | ID uint `gorm:"primary_key"` | |
156 | CreatedAt time.Time | |
157 | UpdatedAt time.Time | |
158 | DeletedAt *time.Time | |
159 | } | |
160 | ||
161 | type User struct { | |
162 | gorm.Model | |
163 | Name string | |
164 | } | |
165 | ``` | |
166 | ||
167 | ## Initialize Database | |
168 | ||
169 | ```go | |
170 | import ( | |
171 | "github.com/jinzhu/gorm" | |
172 | _ "github.com/lib/pq" | |
173 | _ "github.com/go-sql-driver/mysql" | |
174 | _ "github.com/mattn/go-sqlite3" | |
175 | ) | |
176 | ||
177 | db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") | |
178 | // db, err := gorm.Open("foundation", "dbname=gorm") // FoundationDB. | |
179 | // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") | |
180 | // db, err := gorm.Open("sqlite3", "/tmp/gorm.db") | |
181 | ||
182 | // You can also use an existing database connection handle | |
183 | // dbSql, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable") | |
184 | // db, _ := gorm.Open("postgres", dbSql) | |
185 | ||
186 | // Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB) | |
187 | db.DB() | |
188 | ||
189 | // Then you could invoke `*sql.DB`'s functions with it | |
190 | db.DB().Ping() | |
191 | db.DB().SetMaxIdleConns(10) | |
192 | db.DB().SetMaxOpenConns(100) | |
193 | ||
194 | // Disable table name's pluralization | |
195 | db.SingularTable(true) | |
196 | ``` | |
197 | ||
198 | ## Migration | |
199 | ||
200 | ```go | |
201 | // Create table | |
202 | db.CreateTable(&User{}) | |
203 | db.Set("gorm:table_options", "ENGINE=InnoDB").CreateTable(&User{}) | |
204 | ||
205 | // Drop table | |
206 | db.DropTable(&User{}) | |
207 | ||
208 | // ModifyColumn | |
209 | db.Model(&User{}).ModifyColumn("description", "text") | |
210 | ||
211 | // DropColumn | |
212 | db.Model(&User{}).DropColumn("description") | |
213 | ||
214 | // Automating Migration | |
215 | db.AutoMigrate(&User{}) | |
216 | db.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(&User{}) | |
217 | db.AutoMigrate(&User{}, &Product{}, &Order{}) | |
218 | // Feel free to change your struct, AutoMigrate will keep your database up-to-date. | |
219 | // AutoMigrate will ONLY add *new columns* and *new indexes*, | |
220 | // WON'T update current column's type or delete unused columns, to protect your data. | |
221 | // If the table is not existing, AutoMigrate will create the table automatically. | |
222 | ``` | |
223 | ||
224 | # Basic CRUD | |
225 | ||
226 | ## Create Record | |
227 | ||
228 | ```go | |
229 | user := User{Name: "Jinzhu", Age: 18, Birthday: time.Now()} | |
230 | ||
231 | db.NewRecord(user) // => returns `true` if primary key is blank | |
232 | ||
233 | db.Create(&user) | |
234 | ||
235 | db.NewRecord(user) // => return `false` after `user` created | |
236 | ||
237 | // Associations will be inserted automatically when save the record | |
238 | user := User{ | |
239 | Name: "jinzhu", | |
240 | BillingAddress: Address{Address1: "Billing Address - Address 1"}, | |
241 | ShippingAddress: Address{Address1: "Shipping Address - Address 1"}, | |
242 | Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}}, | |
243 | Languages: []Language{{Name: "ZH"}, {Name: "EN"}}, | |
244 | } | |
245 | ||
246 | db.Create(&user) | |
247 | //// BEGIN TRANSACTION; | |
248 | //// INSERT INTO "addresses" (address1) VALUES ("Billing Address - Address 1"); | |
249 | //// INSERT INTO "addresses" (address1) VALUES ("Shipping Address - Address 1"); | |
250 | //// INSERT INTO "users" (name,billing_address_id,shipping_address_id) VALUES ("jinzhu", 1, 2); | |
251 | //// INSERT INTO "emails" (user_id,email) VALUES (111, "jinzhu@example.com"); | |
252 | //// INSERT INTO "emails" (user_id,email) VALUES (111, "jinzhu-2@example.com"); | |
253 | //// INSERT INTO "languages" ("name") VALUES ('ZH'); | |
254 | //// INSERT INTO user_languages ("user_id","language_id") VALUES (111, 1); | |
255 | //// INSERT INTO "languages" ("name") VALUES ('EN'); | |
256 | //// INSERT INTO user_languages ("user_id","language_id") VALUES (111, 2); | |
257 | //// COMMIT; | |
258 | ``` | |
259 | ||
260 | Refer [Associations](#associations) for more details | |
261 | ||
262 | ## Query | |
263 | ||
264 | ```go | |
265 | // Get the first record | |
266 | db.First(&user) | |
267 | //// SELECT * FROM users ORDER BY id LIMIT 1; | |
268 | ||
269 | // Get the last record | |
270 | db.Last(&user) | |
271 | //// SELECT * FROM users ORDER BY id DESC LIMIT 1; | |
272 | ||
273 | // Get all records | |
274 | db.Find(&users) | |
275 | //// SELECT * FROM users; | |
276 | ||
277 | // Get record with primary key | |
278 | db.First(&user, 10) | |
279 | //// SELECT * FROM users WHERE id = 10; | |
280 | ``` | |
281 | ||
282 | ### Query With Where (Plain SQL) | |
283 | ||
284 | ```go | |
285 | // Get the first matched record | |
286 | db.Where("name = ?", "jinzhu").First(&user) | |
287 | //// SELECT * FROM users WHERE name = 'jinzhu' limit 1; | |
288 | ||
289 | // Get all matched records | |
290 | db.Where("name = ?", "jinzhu").Find(&users) | |
291 | //// SELECT * FROM users WHERE name = 'jinzhu'; | |
292 | ||
293 | db.Where("name <> ?", "jinzhu").Find(&users) | |
294 | ||
295 | // IN | |
296 | db.Where("name in (?)", []string{"jinzhu", "jinzhu 2"}).Find(&users) | |
297 | ||
298 | // LIKE | |
299 | db.Where("name LIKE ?", "%jin%").Find(&users) | |
300 | ||
301 | // AND | |
302 | db.Where("name = ? and age >= ?", "jinzhu", "22").Find(&users) | |
303 | ||
304 | // Time | |
305 | db.Where("updated_at > ?", lastWeek).Find(&users) | |
306 | ||
307 | db.Where("created_at BETWEEN ? AND ?", lastWeek, today).Find(&users) | |
308 | ``` | |
309 | ||
310 | ### Query With Where (Struct & Map) | |
311 | ||
312 | ```go | |
313 | // Struct | |
314 | db.Where(&User{Name: "jinzhu", Age: 20}).First(&user) | |
315 | //// SELECT * FROM users WHERE name = "jinzhu" AND age = 20 LIMIT 1; | |
316 | ||
317 | // Map | |
318 | db.Where(map[string]interface{}{"name": "jinzhu", "age": 20}).Find(&users) | |
319 | //// SELECT * FROM users WHERE name = "jinzhu" AND age = 20; | |
320 | ||
321 | // Slice of primary keys | |
322 | db.Where([]int64{20, 21, 22}).Find(&users) | |
323 | //// SELECT * FROM users WHERE id IN (20, 21, 22); | |
324 | ``` | |
325 | ||
326 | ### Query With Not | |
327 | ||
328 | ```go | |
329 | db.Not("name", "jinzhu").First(&user) | |
330 | //// SELECT * FROM users WHERE name <> "jinzhu" LIMIT 1; | |
331 | ||
332 | // Not In | |
333 | db.Not("name", []string{"jinzhu", "jinzhu 2"}).Find(&users) | |
334 | //// SELECT * FROM users WHERE name NOT IN ("jinzhu", "jinzhu 2"); | |
335 | ||
336 | // Not In slice of primary keys | |
337 | db.Not([]int64{1,2,3}).First(&user) | |
338 | //// SELECT * FROM users WHERE id NOT IN (1,2,3); | |
339 | ||
340 | db.Not([]int64{}).First(&user) | |
341 | //// SELECT * FROM users; | |
342 | ||
343 | // Plain SQL | |
344 | db.Not("name = ?", "jinzhu").First(&user) | |
345 | //// SELECT * FROM users WHERE NOT(name = "jinzhu"); | |
346 | ||
347 | // Struct | |
348 | db.Not(User{Name: "jinzhu"}).First(&user) | |
349 | //// SELECT * FROM users WHERE name <> "jinzhu"; | |
350 | ``` | |
351 | ||
352 | ### Query With Inline Condition | |
353 | ||
354 | ```go | |
355 | // Get by primary key | |
356 | db.First(&user, 23) | |
357 | //// SELECT * FROM users WHERE id = 23 LIMIT 1; | |
358 | ||
359 | // Plain SQL | |
360 | db.Find(&user, "name = ?", "jinzhu") | |
361 | //// SELECT * FROM users WHERE name = "jinzhu"; | |
362 | ||
363 | db.Find(&users, "name <> ? AND age > ?", "jinzhu", 20) | |
364 | //// SELECT * FROM users WHERE name <> "jinzhu" AND age > 20; | |
365 | ||
366 | // Struct | |
367 | db.Find(&users, User{Age: 20}) | |
368 | //// SELECT * FROM users WHERE age = 20; | |
369 | ||
370 | // Map | |
371 | db.Find(&users, map[string]interface{}{"age": 20}) | |
372 | //// SELECT * FROM users WHERE age = 20; | |
373 | ``` | |
374 | ||
375 | ### Query With Or | |
376 | ||
377 | ```go | |
378 | db.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&users) | |
379 | //// SELECT * FROM users WHERE role = 'admin' OR role = 'super_admin'; | |
380 | ||
381 | // Struct | |
382 | db.Where("name = 'jinzhu'").Or(User{Name: "jinzhu 2"}).Find(&users) | |
383 | //// SELECT * FROM users WHERE name = 'jinzhu' OR name = 'jinzhu 2'; | |
384 | ||
385 | // Map | |
386 | db.Where("name = 'jinzhu'").Or(map[string]interface{}{"name": "jinzhu 2"}).Find(&users) | |
387 | ``` | |
388 | ||
389 | ### Query Chains | |
390 | ||
391 | Gorm has a chainable API, you could use it like this | |
392 | ||
393 | ```go | |
394 | db.Where("name <> ?","jinzhu").Where("age >= ? and role <> ?",20,"admin").Find(&users) | |
395 | //// SELECT * FROM users WHERE name <> 'jinzhu' AND age >= 20 AND role <> 'admin'; | |
396 | ||
397 | db.Where("role = ?", "admin").Or("role = ?", "super_admin").Not("name = ?", "jinzhu").Find(&users) | |
398 | ``` | |
399 | ||
400 | ### Preloading (Eager loading) | |
401 | ||
402 | ```go | |
403 | db.Preload("Orders").Find(&users) | |
404 | //// SELECT * FROM users; | |
405 | //// SELECT * FROM orders WHERE user_id IN (1,2,3,4); | |
406 | ||
407 | db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) | |
408 | //// SELECT * FROM users; | |
409 | //// SELECT * FROM orders WHERE user_id IN (1,2,3,4) AND state NOT IN ('cancelled'); | |
410 | ||
411 | db.Where("state = ?", "active").Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) | |
412 | //// SELECT * FROM users WHERE state = 'active'; | |
413 | //// SELECT * FROM orders WHERE user_id IN (1,2) AND state NOT IN ('cancelled'); | |
414 | ||
415 | db.Preload("Orders").Preload("Profile").Preload("Role").Find(&users) | |
416 | //// SELECT * FROM users; | |
417 | //// SELECT * FROM orders WHERE user_id IN (1,2,3,4); // has many | |
418 | //// SELECT * FROM profiles WHERE user_id IN (1,2,3,4); // has one | |
419 | //// SELECT * FROM roles WHERE id IN (4,5,6); // belongs to | |
420 | ``` | |
421 | ||
422 | #### Nested Preloading | |
423 | ||
424 | ```go | |
425 | db.Preload("Orders.OrderItems").Find(&users) | |
426 | db.Preload("Orders", "state = ?", "paid").Preload("Orders.OrderItems").Find(&users) | |
427 | ``` | |
428 | ||
429 | ## Update | |
430 | ||
431 | ```go | |
432 | // Update an existing struct | |
433 | db.First(&user) | |
434 | user.Name = "jinzhu 2" | |
435 | user.Age = 100 | |
436 | db.Save(&user) | |
437 | //// UPDATE users SET name='jinzhu 2', age=100, updated_at = '2013-11-17 21:34:10' WHERE id=111; | |
438 | ||
439 | db.Where("active = ?", true).Save(&user) | |
440 | //// UPDATE users SET name='jinzhu 2', age=100, updated_at = '2013-11-17 21:34:10' WHERE id=111 AND active = true; | |
441 | ||
442 | // Update an attribute if it is changed | |
443 | db.Model(&user).Update("name", "hello") | |
444 | //// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111; | |
445 | ||
446 | db.Model(&user).Where("active = ?", true).Update("name", "hello") | |
447 | //// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111 AND active = true; | |
448 | ||
449 | db.First(&user, 111).Update("name", "hello") | |
450 | //// SELECT * FROM users LIMIT 1; | |
451 | //// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111; | |
452 | ||
453 | // Update multiple attributes if they are changed | |
454 | db.Model(&user).Updates(map[string]interface{}{"name": "hello", "age": 18, "actived": false}) | |
455 | ||
456 | // Update multiple attributes if they are changed (update with struct only works with none zero values) | |
457 | db.Model(&user).Updates(User{Name: "hello", Age: 18}) | |
458 | //// UPDATE users SET name='hello', age=18, updated_at = '2013-11-17 21:34:10' WHERE id = 111; | |
459 | ``` | |
460 | ||
461 | ### Update Without Callbacks | |
462 | ||
463 | By default, update will call BeforeUpdate, AfterUpdate callbacks, if you want to update w/o callbacks and w/o saving associations: | |
464 | ||
465 | ```go | |
466 | db.Model(&user).UpdateColumn("name", "hello") | |
467 | //// UPDATE users SET name='hello' WHERE id = 111; | |
468 | ||
469 | // Update with struct only works with none zero values, or use map[string]interface{} | |
470 | db.Model(&user).UpdateColumns(User{Name: "hello", Age: 18}) | |
471 | //// UPDATE users SET name='hello', age=18 WHERE id = 111; | |
472 | ``` | |
473 | ||
474 | ### Batch Updates | |
475 | ||
476 | ```go | |
477 | db.Table("users").Where("id = ?", 10).Updates(map[string]interface{}{"name": "hello", "age": 18}) | |
478 | //// UPDATE users SET name='hello', age=18 WHERE id = 10; | |
479 | ||
480 | // Update with struct only works with none zero values, or use map[string]interface{} | |
481 | db.Model(User{}).Updates(User{Name: "hello", Age: 18}) | |
482 | //// UPDATE users SET name='hello', age=18; | |
483 | ||
484 | // Callbacks won't run when do batch updates | |
485 | ||
486 | // Use `RowsAffected` to get the count of affected records | |
487 | db.Model(User{}).Updates(User{Name: "hello", Age: 18}).RowsAffected | |
488 | ``` | |
489 | ||
490 | ### Update with SQL Expression | |
491 | ||
492 | ```go | |
493 | DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) | |
494 | //// UPDATE "products" SET "code" = 'L1212', "price" = price * '2' + '100', "updated_at" = '2013-11-17 21:34:10' WHERE "id" = '2'; | |
495 | ||
496 | DB.Model(&product).Updates(map[string]interface{}{"price": gorm.Expr("price * ? + ?", 2, 100)}) | |
497 | //// UPDATE "products" SET "code" = 'L1212', "price" = price * '2' + '100', "updated_at" = '2013-11-17 21:34:10' WHERE "id" = '2'; | |
498 | ||
499 | DB.Model(&product).UpdateColumn("quantity", gorm.Expr("quantity - ?", 1)) | |
500 | //// UPDATE "products" SET "quantity" = quantity - 1 WHERE "id" = '2'; | |
501 | ||
502 | DB.Model(&product).Where("quantity > 1").UpdateColumn("quantity", gorm.Expr("quantity - ?", 1)) | |
503 | //// UPDATE "products" SET "quantity" = quantity - 1 WHERE "id" = '2' AND quantity > 1; | |
504 | ``` | |
505 | ||
506 | ## Delete | |
507 | ||
508 | ```go | |
509 | // Delete an existing record | |
510 | db.Delete(&email) | |
511 | //// DELETE from emails where id=10; | |
512 | ``` | |
513 | ||
514 | ### Batch Delete | |
515 | ||
516 | ```go | |
517 | db.Where("email LIKE ?", "%jinzhu%").Delete(Email{}) | |
518 | //// DELETE from emails where email LIKE "%jinhu%"; | |
519 | ``` | |
520 | ||
521 | ### Soft Delete | |
522 | ||
523 | If struct has `DeletedAt` field, it will get soft delete ability automatically! | |
524 | Then it won't be deleted from database permanently when call `Delete`. | |
525 | ||
526 | ```go | |
527 | db.Delete(&user) | |
528 | //// UPDATE users SET deleted_at="2013-10-29 10:23" WHERE id = 111; | |
529 | ||
530 | // Batch Delete | |
531 | db.Where("age = ?", 20).Delete(&User{}) | |
532 | //// UPDATE users SET deleted_at="2013-10-29 10:23" WHERE age = 20; | |
533 | ||
534 | // Soft deleted records will be ignored when query them | |
535 | db.Where("age = 20").Find(&user) | |
536 | //// SELECT * FROM users WHERE age = 20 AND (deleted_at IS NULL OR deleted_at <= '0001-01-02'); | |
537 | ||
538 | // Find soft deleted records with Unscoped | |
539 | db.Unscoped().Where("age = 20").Find(&users) | |
540 | //// SELECT * FROM users WHERE age = 20; | |
541 | ||
542 | // Delete record permanently with Unscoped | |
543 | db.Unscoped().Delete(&order) | |
544 | //// DELETE FROM orders WHERE id=10; | |
545 | ``` | |
546 | ||
547 | ## Associations | |
548 | ||
549 | ### Has One | |
550 | ||
551 | ```go | |
552 | // User has one address | |
553 | db.Model(&user).Related(&address) | |
554 | //// SELECT * FROM addresses WHERE id = 123; // 123 is user's foreign key AddressId | |
555 | ||
556 | // Specify the foreign key | |
557 | db.Model(&user).Related(&address1, "BillingAddressId") | |
558 | //// SELECT * FROM addresses WHERE id = 123; // 123 is user's foreign key BillingAddressId | |
559 | ``` | |
560 | ||
561 | ### Belongs To | |
562 | ||
563 | ```go | |
564 | // Email belongs to user | |
565 | db.Model(&email).Related(&user) | |
566 | //// SELECT * FROM users WHERE id = 111; // 111 is email's foreign key UserId | |
567 | ||
568 | // Specify the foreign key | |
569 | db.Model(&email).Related(&user, "ProfileId") | |
570 | //// SELECT * FROM users WHERE id = 111; // 111 is email's foreign key ProfileId | |
571 | ``` | |
572 | ||
573 | ### Has Many | |
574 | ||
575 | ```go | |
576 | // User has many emails | |
577 | db.Model(&user).Related(&emails) | |
578 | //// SELECT * FROM emails WHERE user_id = 111; | |
579 | // user_id is the foreign key, 111 is user's primary key's value | |
580 | ||
581 | // Specify the foreign key | |
582 | db.Model(&user).Related(&emails, "ProfileId") | |
583 | //// SELECT * FROM emails WHERE profile_id = 111; | |
584 | // profile_id is the foreign key, 111 is user's primary key's value | |
585 | ``` | |
586 | ||
587 | ### Many To Many | |
588 | ||
589 | ```go | |
590 | // User has many languages and belongs to many languages | |
591 | db.Model(&user).Related(&languages, "Languages") | |
592 | //// SELECT * FROM "languages" INNER JOIN "user_languages" ON "user_languages"."language_id" = "languages"."id" WHERE "user_languages"."user_id" = 111 | |
593 | // `Languages` is user's column name, this column's tag defined join table like this `gorm:"many2many:user_languages;"` | |
594 | ``` | |
595 | ||
596 | There is also a mode used to handle many to many relations easily | |
597 | ||
598 | ```go | |
599 | // Query | |
600 | db.Model(&user).Association("Languages").Find(&languages) | |
601 | // same as `db.Model(&user).Related(&languages, "Languages")` | |
602 | ||
603 | db.Where("name = ?", "ZH").First(&languageZH) | |
604 | db.Where("name = ?", "EN").First(&languageEN) | |
605 | ||
606 | // Append | |
607 | db.Model(&user).Association("Languages").Append([]Language{languageZH, languageEN}) | |
608 | db.Model(&user).Association("Languages").Append([]Language{{Name: "DE"}}) | |
609 | db.Model(&user).Association("Languages").Append(Language{Name: "DE"}) | |
610 | ||
611 | // Delete | |
612 | db.Model(&user).Association("Languages").Delete([]Language{languageZH, languageEN}) | |
613 | db.Model(&user).Association("Languages").Delete(languageZH, languageEN) | |
614 | ||
615 | // Replace | |
616 | db.Model(&user).Association("Languages").Replace([]Language{languageZH, languageEN}) | |
617 | db.Model(&user).Association("Languages").Replace(Language{Name: "DE"}, languageEN) | |
618 | ||
619 | // Count | |
620 | db.Model(&user).Association("Languages").Count() | |
621 | // Return the count of languages the user has | |
622 | ||
623 | // Clear | |
624 | db.Model(&user).Association("Languages").Clear() | |
625 | // Remove all relations between the user and languages | |
626 | ``` | |
627 | ||
628 | ### Polymorphism | |
629 | ||
630 | Supports polymorphic has-many and has-one associations. | |
631 | ||
632 | ```go | |
633 | type Cat struct { | |
634 | Id int | |
635 | Name string | |
636 | Toy Toy `gorm:"polymorphic:Owner;"` | |
637 | } | |
638 | ||
639 | type Dog struct { | |
640 | Id int | |
641 | Name string | |
642 | Toy Toy `gorm:"polymorphic:Owner;"` | |
643 | } | |
644 | ||
645 | type Toy struct { | |
646 | Id int | |
647 | Name string | |
648 | OwnerId int | |
649 | OwnerType string | |
650 | } | |
651 | ``` | |
652 | Note: polymorphic belongs-to and many-to-many are explicitly NOT supported, and will throw errors. | |
653 | ||
654 | ## Advanced Usage | |
655 | ||
656 | ## FirstOrInit | |
657 | ||
658 | Get the first matched record, or initialize a record with search conditions. | |
659 | ||
660 | ```go | |
661 | // Unfound | |
662 | db.FirstOrInit(&user, User{Name: "non_existing"}) | |
663 | //// user -> User{Name: "non_existing"} | |
664 | ||
665 | // Found | |
666 | db.Where(User{Name: "Jinzhu"}).FirstOrInit(&user) | |
667 | //// user -> User{Id: 111, Name: "Jinzhu", Age: 20} | |
668 | db.FirstOrInit(&user, map[string]interface{}{"name": "jinzhu"}) | |
669 | //// user -> User{Id: 111, Name: "Jinzhu", Age: 20} | |
670 | ``` | |
671 | ||
672 | ### Attrs | |
673 | ||
674 | Ignore some values when searching, but use them to initialize the struct if record is not found. | |
675 | ||
676 | ```go | |
677 | // Unfound | |
678 | db.Where(User{Name: "non_existing"}).Attrs(User{Age: 20}).FirstOrInit(&user) | |
679 | //// SELECT * FROM USERS WHERE name = 'non_existing'; | |
680 | //// user -> User{Name: "non_existing", Age: 20} | |
681 | ||
682 | db.Where(User{Name: "noexisting_user"}).Attrs("age", 20).FirstOrInit(&user) | |
683 | //// SELECT * FROM USERS WHERE name = 'non_existing'; | |
684 | //// user -> User{Name: "non_existing", Age: 20} | |
685 | ||
686 | // Found | |
687 | db.Where(User{Name: "Jinzhu"}).Attrs(User{Age: 30}).FirstOrInit(&user) | |
688 | //// SELECT * FROM USERS WHERE name = jinzhu'; | |
689 | //// user -> User{Id: 111, Name: "Jinzhu", Age: 20} | |
690 | ``` | |
691 | ||
692 | ### Assign | |
693 | ||
694 | Ignore some values when searching, but assign it to the result regardless it is found or not. | |
695 | ||
696 | ```go | |
697 | // Unfound | |
698 | db.Where(User{Name: "non_existing"}).Assign(User{Age: 20}).FirstOrInit(&user) | |
699 | //// user -> User{Name: "non_existing", Age: 20} | |
700 | ||
701 | // Found | |
702 | db.Where(User{Name: "Jinzhu"}).Assign(User{Age: 30}).FirstOrInit(&user) | |
703 | //// SELECT * FROM USERS WHERE name = jinzhu'; | |
704 | //// user -> User{Id: 111, Name: "Jinzhu", Age: 30} | |
705 | ``` | |
706 | ||
707 | ## FirstOrCreate | |
708 | ||
709 | Get the first matched record, or create with search conditions. | |
710 | ||
711 | ```go | |
712 | // Unfound | |
713 | db.FirstOrCreate(&user, User{Name: "non_existing"}) | |
714 | //// INSERT INTO "users" (name) VALUES ("non_existing"); | |
715 | //// user -> User{Id: 112, Name: "non_existing"} | |
716 | ||
717 | // Found | |
718 | db.Where(User{Name: "Jinzhu"}).FirstOrCreate(&user) | |
719 | //// user -> User{Id: 111, Name: "Jinzhu"} | |
720 | ``` | |
721 | ||
722 | ### Attrs | |
723 | ||
724 | Ignore some values when searching, but use them to create the struct if record is not found. like `FirstOrInit` | |
725 | ||
726 | ```go | |
727 | // Unfound | |
728 | db.Where(User{Name: "non_existing"}).Attrs(User{Age: 20}).FirstOrCreate(&user) | |
729 | //// SELECT * FROM users WHERE name = 'non_existing'; | |
730 | //// INSERT INTO "users" (name, age) VALUES ("non_existing", 20); | |
731 | //// user -> User{Id: 112, Name: "non_existing", Age: 20} | |
732 | ||
733 | // Found | |
734 | db.Where(User{Name: "jinzhu"}).Attrs(User{Age: 30}).FirstOrCreate(&user) | |
735 | //// SELECT * FROM users WHERE name = 'jinzhu'; | |
736 | //// user -> User{Id: 111, Name: "jinzhu", Age: 20} | |
737 | ``` | |
738 | ||
739 | ### Assign | |
740 | ||
741 | Ignore some values when searching, but assign it to the record regardless it is found or not, then save back to database. like `FirstOrInit` | |
742 | ||
743 | ```go | |
744 | // Unfound | |
745 | db.Where(User{Name: "non_existing"}).Assign(User{Age: 20}).FirstOrCreate(&user) | |
746 | //// SELECT * FROM users WHERE name = 'non_existing'; | |
747 | //// INSERT INTO "users" (name, age) VALUES ("non_existing", 20); | |
748 | //// user -> User{Id: 112, Name: "non_existing", Age: 20} | |
749 | ||
750 | // Found | |
751 | db.Where(User{Name: "jinzhu"}).Assign(User{Age: 30}).FirstOrCreate(&user) | |
752 | //// SELECT * FROM users WHERE name = 'jinzhu'; | |
753 | //// UPDATE users SET age=30 WHERE id = 111; | |
754 | //// user -> User{Id: 111, Name: "jinzhu", Age: 30} | |
755 | ``` | |
756 | ||
757 | ## Select | |
758 | ||
759 | ```go | |
760 | db.Select("name, age").Find(&users) | |
761 | //// SELECT name, age FROM users; | |
762 | ||
763 | db.Select([]string{"name", "age"}).Find(&users) | |
764 | //// SELECT name, age FROM users; | |
765 | ||
766 | db.Table("users").Select("COALESCE(age,?)", 42).Rows() | |
767 | //// SELECT COALESCE(age,'42') FROM users; | |
768 | ``` | |
769 | ||
770 | ## Order | |
771 | ||
772 | ```go | |
773 | db.Order("age desc, name").Find(&users) | |
774 | //// SELECT * FROM users ORDER BY age desc, name; | |
775 | ||
776 | // Multiple orders | |
777 | db.Order("age desc").Order("name").Find(&users) | |
778 | //// SELECT * FROM users ORDER BY age desc, name; | |
779 | ||
780 | // ReOrder | |
781 | db.Order("age desc").Find(&users1).Order("age", true).Find(&users2) | |
782 | //// SELECT * FROM users ORDER BY age desc; (users1) | |
783 | //// SELECT * FROM users ORDER BY age; (users2) | |
784 | ``` | |
785 | ||
786 | ## Limit | |
787 | ||
788 | ```go | |
789 | db.Limit(3).Find(&users) | |
790 | //// SELECT * FROM users LIMIT 3; | |
791 | ||
792 | // Cancel limit condition with -1 | |
793 | db.Limit(10).Find(&users1).Limit(-1).Find(&users2) | |
794 | //// SELECT * FROM users LIMIT 10; (users1) | |
795 | //// SELECT * FROM users; (users2) | |
796 | ``` | |
797 | ||
798 | ## Offset | |
799 | ||
800 | ```go | |
801 | db.Offset(3).Find(&users) | |
802 | //// SELECT * FROM users OFFSET 3; | |
803 | ||
804 | // Cancel offset condition with -1 | |
805 | db.Offset(10).Find(&users1).Offset(-1).Find(&users2) | |
806 | //// SELECT * FROM users OFFSET 10; (users1) | |
807 | //// SELECT * FROM users; (users2) | |
808 | ``` | |
809 | ||
810 | ## Count | |
811 | ||
812 | ```go | |
813 | db.Where("name = ?", "jinzhu").Or("name = ?", "jinzhu 2").Find(&users).Count(&count) | |
814 | //// SELECT * from USERS WHERE name = 'jinzhu' OR name = 'jinzhu 2'; (users) | |
815 | //// SELECT count(*) FROM users WHERE name = 'jinzhu' OR name = 'jinzhu 2'; (count) | |
816 | ||
817 | db.Model(User{}).Where("name = ?", "jinzhu").Count(&count) | |
818 | //// SELECT count(*) FROM users WHERE name = 'jinzhu'; (count) | |
819 | ||
820 | db.Table("deleted_users").Count(&count) | |
821 | //// SELECT count(*) FROM deleted_users; | |
822 | ``` | |
823 | ||
824 | ## Pluck | |
825 | ||
826 | Get selected attributes as map | |
827 | ||
828 | ```go | |
829 | var ages []int64 | |
830 | db.Find(&users).Pluck("age", &ages) | |
831 | ||
832 | var names []string | |
833 | db.Model(&User{}).Pluck("name", &names) | |
834 | ||
835 | db.Table("deleted_users").Pluck("name", &names) | |
836 | ||
837 | // Requesting more than one column? Do it like this: | |
838 | db.Select("name, age").Find(&users) | |
839 | ``` | |
840 | ||
841 | ## Raw SQL | |
842 | ||
843 | ```go | |
844 | db.Exec("DROP TABLE users;") | |
845 | db.Exec("UPDATE orders SET shipped_at=? WHERE id IN (?)", time.Now, []int64{11,22,33}) | |
846 | ``` | |
847 | ||
848 | ## Row & Rows | |
849 | ||
850 | It is even possible to get query result as `*sql.Row` or `*sql.Rows` | |
851 | ||
852 | ```go | |
853 | row := db.Table("users").Where("name = ?", "jinzhu").Select("name, age").Row() // (*sql.Row) | |
854 | row.Scan(&name, &age) | |
855 | ||
856 | rows, err := db.Model(User{}).Where("name = ?", "jinzhu").Select("name, age, email").Rows() // (*sql.Rows, error) | |
857 | defer rows.Close() | |
858 | for rows.Next() { | |
859 | ... | |
860 | rows.Scan(&name, &age, &email) | |
861 | ... | |
862 | } | |
863 | ||
864 | // Raw SQL | |
865 | rows, err := db.Raw("select name, age, email from users where name = ?", "jinzhu").Rows() // (*sql.Rows, error) | |
866 | defer rows.Close() | |
867 | for rows.Next() { | |
868 | ... | |
869 | rows.Scan(&name, &age, &email) | |
870 | ... | |
871 | } | |
872 | ``` | |
873 | ||
874 | ## Scan | |
875 | ||
876 | Scan results into another struct. | |
877 | ||
878 | ```go | |
879 | type Result struct { | |
880 | Name string | |
881 | Age int | |
882 | } | |
883 | ||
884 | var result Result | |
885 | db.Table("users").Select("name, age").Where("name = ?", 3).Scan(&result) | |
886 | ||
887 | // Raw SQL | |
888 | db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) | |
889 | ``` | |
890 | ||
891 | ## Group & Having | |
892 | ||
893 | ```go | |
894 | rows, err := db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Rows() | |
895 | for rows.Next() { | |
896 | ... | |
897 | } | |
898 | ||
899 | rows, err := db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Having("sum(amount) > ?", 100).Rows() | |
900 | for rows.Next() { | |
901 | ... | |
902 | } | |
903 | ||
904 | type Result struct { | |
905 | Date time.Time | |
906 | Total int64 | |
907 | } | |
908 | db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Having("sum(amount) > ?", 100).Scan(&results) | |
909 | ``` | |
910 | ||
911 | ## Joins | |
912 | ||
913 | ```go | |
914 | rows, err := db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Rows() | |
915 | for rows.Next() { | |
916 | ... | |
917 | } | |
918 | ||
919 | db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Scan(&results) | |
920 | ||
921 | // find a user by email address | |
922 | db.Joins("inner join emails on emails.user_id = users.id").Where("emails.email = ?", "x@example.org").Find(&user) | |
923 | ||
924 | // find all email addresses for a user | |
925 | db.Joins("left join users on users.id = emails.user_id").Where("users.name = ?", "jinzhu").Find(&emails) | |
926 | ``` | |
927 | ||
928 | ## Transactions | |
929 | ||
930 | To perform a set of operations within a transaction, the general flow is as below. | |
931 | The database handle returned from ``` db.Begin() ``` should be used for all operations within the transaction. | |
932 | (Note that all individual save and delete operations are run in a transaction by default.) | |
933 | ||
934 | ```go | |
935 | // begin | |
936 | tx := db.Begin() | |
937 | ||
938 | // do some database operations (use 'tx' from this point, not 'db') | |
939 | tx.Create(...) | |
940 | ... | |
941 | ||
942 | // rollback in case of error | |
943 | tx.Rollback() | |
944 | ||
945 | // Or commit if all is ok | |
946 | tx.Commit() | |
947 | ``` | |
948 | ||
949 | ### A Specific Example | |
950 | ``` | |
951 | func CreateAnimals(db *gorm.DB) err { | |
952 | tx := db.Begin() | |
953 | // Note the use of tx as the database handle once you are within a transaction | |
954 | ||
955 | if err := tx.Create(&Animal{Name: "Giraffe"}).Error; err != nil { | |
956 | tx.Rollback() | |
957 | return err | |
958 | } | |
959 | ||
960 | if err := tx.Create(&Animal{Name: "Lion"}).Error; err != nil { | |
961 | tx.Rollback() | |
962 | return err | |
963 | } | |
964 | ||
965 | tx.Commit() | |
966 | return nil | |
967 | } | |
968 | ``` | |
969 | ||
970 | ## Scopes | |
971 | ||
972 | ```go | |
973 | func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { | |
974 | return db.Where("amount > ?", 1000) | |
975 | } | |
976 | ||
977 | func PaidWithCreditCard(db *gorm.DB) *gorm.DB { | |
978 | return db.Where("pay_mode_sign = ?", "C") | |
979 | } | |
980 | ||
981 | func PaidWithCod(db *gorm.DB) *gorm.DB { | |
982 | return db.Where("pay_mode_sign = ?", "C") | |
983 | } | |
984 | ||
985 | func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { | |
986 | return func (db *gorm.DB) *gorm.DB { | |
987 | return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) | |
988 | } | |
989 | } | |
990 | ||
991 | db.Scopes(AmountGreaterThan1000, PaidWithCreditCard).Find(&orders) | |
992 | // Find all credit card orders and amount greater than 1000 | |
993 | ||
994 | db.Scopes(AmountGreaterThan1000, PaidWithCod).Find(&orders) | |
995 | // Find all COD orders and amount greater than 1000 | |
996 | ||
997 | db.Scopes(OrderStatus([]string{"paid", "shipped"})).Find(&orders) | |
998 | // Find all paid, shipped orders | |
999 | ``` | |
1000 | ||
1001 | ## Callbacks | |
1002 | ||
1003 | Callbacks are methods defined on the pointer of struct. | |
1004 | If any callback returns an error, gorm will stop future operations and rollback all changes. | |
1005 | ||
1006 | Here is the list of all available callbacks: | |
1007 | (listed in the same order in which they will get called during the respective operations) | |
1008 | ||
1009 | ### Creating An Object | |
1010 | ||
1011 | ```go | |
1012 | BeforeSave | |
1013 | BeforeCreate | |
1014 | // save before associations | |
1015 | // save self | |
1016 | // save after associations | |
1017 | AfterCreate | |
1018 | AfterSave | |
1019 | ``` | |
1020 | ### Updating An Object | |
1021 | ||
1022 | ```go | |
1023 | BeforeSave | |
1024 | BeforeUpdate | |
1025 | // save before associations | |
1026 | // save self | |
1027 | // save after associations | |
1028 | AfterUpdate | |
1029 | AfterSave | |
1030 | ``` | |
1031 | ||
1032 | ### Destroying An Object | |
1033 | ||
1034 | ```go | |
1035 | BeforeDelete | |
1036 | // delete self | |
1037 | AfterDelete | |
1038 | ``` | |
1039 | ||
1040 | ### After Find | |
1041 | ||
1042 | ```go | |
1043 | // load data from database | |
1044 | AfterFind | |
1045 | ``` | |
1046 | ||
1047 | ### Example | |
1048 | ||
1049 | ```go | |
1050 | func (u *User) BeforeUpdate() (err error) { | |
1051 | if u.readonly() { | |
1052 | err = errors.New("read only user") | |
1053 | } | |
1054 | return | |
1055 | } | |
1056 | ||
1057 | // Rollback the insertion if user's id greater than 1000 | |
1058 | func (u *User) AfterCreate() (err error) { | |
1059 | if (u.Id > 1000) { | |
1060 | err = errors.New("user id is already greater than 1000") | |
1061 | } | |
1062 | return | |
1063 | } | |
1064 | ``` | |
1065 | ||
1066 | As you know, save/delete operations in gorm are running in a transaction, | |
1067 | This is means if changes made in the transaction is not visiable unless it is commited, | |
1068 | So if you want to use those changes in your callbacks, you need to run SQL in same transaction. | |
1069 | Fortunately, gorm support pass transaction to callbacks as you needed, you could do it like this: | |
1070 | ||
1071 | ```go | |
1072 | func (u *User) AfterCreate(tx *gorm.DB) (err error) { | |
1073 | tx.Model(u).Update("role", "admin") | |
1074 | return | |
1075 | } | |
1076 | ``` | |
1077 | ||
1078 | ## Specifying The Table Name | |
1079 | ||
1080 | ```go | |
1081 | // Create `deleted_users` table with struct User's definition | |
1082 | db.Table("deleted_users").CreateTable(&User{}) | |
1083 | ||
1084 | var deleted_users []User | |
1085 | db.Table("deleted_users").Find(&deleted_users) | |
1086 | //// SELECT * FROM deleted_users; | |
1087 | ||
1088 | db.Table("deleted_users").Where("name = ?", "jinzhu").Delete() | |
1089 | //// DELETE FROM deleted_users WHERE name = 'jinzhu'; | |
1090 | ``` | |
1091 | ||
1092 | ### Specifying The Table Name For A Struct Permanently with TableName | |
1093 | ||
1094 | ```go | |
1095 | type Cart struct { | |
1096 | } | |
1097 | ||
1098 | func (c Cart) TableName() string { | |
1099 | return "shopping_cart" | |
1100 | } | |
1101 | ||
1102 | func (u User) TableName() string { | |
1103 | if u.Role == "admin" { | |
1104 | return "admin_users" | |
1105 | } else { | |
1106 | return "users" | |
1107 | } | |
1108 | } | |
1109 | ``` | |
1110 | ||
1111 | ## Error Handling | |
1112 | ||
1113 | ```go | |
1114 | query := db.Where("name = ?", "jinzhu").First(&user) | |
1115 | query := db.First(&user).Limit(10).Find(&users) | |
1116 | // query.Error will return the last happened error | |
1117 | ||
1118 | // So you could do error handing in your application like this: | |
1119 | if err := db.Where("name = ?", "jinzhu").First(&user).Error; err != nil { | |
1120 | // error handling... | |
1121 | } | |
1122 | ||
1123 | // RecordNotFound | |
1124 | // If no record found when you query data, gorm will return RecordNotFound error, you could check it like this: | |
1125 | db.Where("name = ?", "hello world").First(&User{}).Error == gorm.RecordNotFound | |
1126 | // Or use the shortcut method | |
1127 | db.Where("name = ?", "hello world").First(&user).RecordNotFound() | |
1128 | ||
1129 | if db.Model(&user).Related(&credit_card).RecordNotFound() { | |
1130 | // no credit card found error handling | |
1131 | } | |
1132 | ``` | |
1133 | ||
1134 | ## Logger | |
1135 | ||
1136 | Gorm has built-in logger support | |
1137 | ||
1138 | ```go | |
1139 | // Enable Logger | |
1140 | db.LogMode(true) | |
1141 | ||
1142 | // Diable Logger | |
1143 | db.LogMode(false) | |
1144 | ||
1145 | // Debug a single operation | |
1146 | db.Debug().Where("name = ?", "jinzhu").First(&User{}) | |
1147 | ``` | |
1148 | ||
1149 | ![logger](https://raw.github.com/jinzhu/gorm/master/images/logger.png) | |
1150 | ||
1151 | ### Customize Logger | |
1152 | ||
1153 | ```go | |
1154 | // Refer gorm's default logger for how to: https://github.com/jinzhu/gorm/blob/master/logger.go#files | |
1155 | db.SetLogger(gorm.Logger{revel.TRACE}) | |
1156 | db.SetLogger(log.New(os.Stdout, "\r\n", 0)) | |
1157 | ``` | |
1158 | ||
1159 | ## Existing Schema | |
1160 | ||
1161 | If you have an existing database schema, and the primary key field is different from `id`, you can add a tag to the field structure to specify that this field is a primary key. | |
1162 | ||
1163 | ```go | |
1164 | type Animal struct { | |
1165 | AnimalId int64 `gorm:"primary_key"` | |
1166 | Birthday time.Time `sql:"DEFAULT:current_timestamp"` | |
1167 | Name string `sql:"default:'galeone'"` | |
1168 | Age int64 | |
1169 | } | |
1170 | ``` | |
1171 | ||
1172 | If your column names differ from the struct fields, you can specify them like this: | |
1173 | ||
1174 | ```go | |
1175 | type Animal struct { | |
1176 | AnimalId int64 `gorm:"column:beast_id;primary_key"` | |
1177 | Birthday time.Time `gorm:"column:day_of_the_beast"` | |
1178 | Age int64 `gorm:"column:age_of_the_beast"` | |
1179 | } | |
1180 | ``` | |
1181 | ||
1182 | ## Composite Primary Key | |
1183 | ||
1184 | ```go | |
1185 | type Product struct { | |
1186 | ID string `gorm:"primary_key"` | |
1187 | LanguageCode string `gorm:"primary_key"` | |
1188 | } | |
1189 | ``` | |
1190 | ||
1191 | ## Database Indexes & Foreign Key | |
1192 | ||
1193 | ```go | |
1194 | // Add foreign key | |
1195 | // 1st param : foreignkey field | |
1196 | // 2nd param : destination table(id) | |
1197 | // 3rd param : ONDELETE | |
1198 | // 4th param : ONUPDATE | |
1199 | db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") | |
1200 | ||
1201 | // Add index | |
1202 | db.Model(&User{}).AddIndex("idx_user_name", "name") | |
1203 | ||
1204 | // Multiple column index | |
1205 | db.Model(&User{}).AddIndex("idx_user_name_age", "name", "age") | |
1206 | ||
1207 | // Add unique index | |
1208 | db.Model(&User{}).AddUniqueIndex("idx_user_name", "name") | |
1209 | ||
1210 | // Multiple column unique index | |
1211 | db.Model(&User{}).AddUniqueIndex("idx_user_name_age", "name", "age") | |
1212 | ||
1213 | // Remove index | |
1214 | db.Model(&User{}).RemoveIndex("idx_user_name") | |
1215 | ``` | |
1216 | ||
1217 | ## Default values | |
1218 | ||
1219 | ```go | |
1220 | type Animal struct { | |
1221 | ID int64 | |
1222 | Name string `sql:"default:'galeone'"` | |
1223 | Age int64 | |
1224 | } | |
1225 | ``` | |
1226 | ||
1227 | If you have defined a default value in the `sql` tag, the generated create SQl will ignore these fields if it is blank. | |
1228 | ||
1229 | Eg. | |
1230 | ||
1231 | ```go | |
1232 | db.Create(&Animal{Age: 99, Name: ""}) | |
1233 | ``` | |
1234 | ||
1235 | The generated SQL will be: | |
1236 | ||
1237 | ```sql | |
1238 | INSERT INTO animals("age") values('99'); | |
1239 | ``` | |
1240 | ||
1241 | The same thing occurs in update statements. | |
1242 | ||
1243 | ## More examples with query chain | |
1244 | ||
1245 | ```go | |
1246 | db.First(&first_article).Count(&total_count).Limit(10).Find(&first_page_articles).Offset(10).Find(&second_page_articles) | |
1247 | //// SELECT * FROM articles LIMIT 1; (first_article) | |
1248 | //// SELECT count(*) FROM articles; (total_count) | |
1249 | //// SELECT * FROM articles LIMIT 10; (first_page_articles) | |
1250 | //// SELECT * FROM articles LIMIT 10 OFFSET 10; (second_page_articles) | |
1251 | ||
1252 | ||
1253 | db.Where("created_at > ?", "2013-10-10").Find(&cancelled_orders, "state = ?", "cancelled").Find(&shipped_orders, "state = ?", "shipped") | |
1254 | //// SELECT * FROM orders WHERE created_at > '2013/10/10' AND state = 'cancelled'; (cancelled_orders) | |
1255 | //// SELECT * FROM orders WHERE created_at > '2013/10/10' AND state = 'shipped'; (shipped_orders) | |
1256 | ||
1257 | ||
1258 | // Use variables to keep query chain | |
1259 | todays_orders := db.Where("created_at > ?", "2013-10-29") | |
1260 | cancelled_orders := todays_orders.Where("state = ?", "cancelled") | |
1261 | shipped_orders := todays_orders.Where("state = ?", "shipped") | |
1262 | ||
1263 | ||
1264 | // Search with shared conditions for different tables | |
1265 | db.Where("product_name = ?", "fancy_product").Find(&orders).Find(&shopping_carts) | |
1266 | //// SELECT * FROM orders WHERE product_name = 'fancy_product'; (orders) | |
1267 | //// SELECT * FROM carts WHERE product_name = 'fancy_product'; (shopping_carts) | |
1268 | ||
1269 | ||
1270 | // Search with shared conditions from different tables with specified table | |
1271 | db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").Find(&users2) | |
1272 | //// SELECT * FROM users WHERE mail_type = 'TEXT'; (users1) | |
1273 | //// SELECT * FROM deleted_users WHERE mail_type = 'TEXT'; (users2) | |
1274 | ||
1275 | ||
1276 | // FirstOrCreate example | |
1277 | db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111"}).FirstOrCreate(&user) | |
1278 | //// SELECT * FROM users WHERE email = 'x@example.org'; | |
1279 | //// INSERT INTO "users" (email,registered_ip) VALUES ("x@example.org", "111.111.111.111") // if record not found | |
1280 | ``` | |
1281 | ||
1282 | ## TODO | |
1283 | * Github Pages | |
1284 | ||
1285 | # Author | |
1286 | ||
1287 | **jinzhu** | |
1288 | ||
1289 | * <http://github.com/jinzhu> | |
1290 | * <wosmvp@gmail.com> | |
1291 | * <http://twitter.com/zhangjinzhu> | |
1292 | ||
1293 | ## License | |
1294 | ||
1295 | Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License). |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "fmt" | |
5 | "reflect" | |
6 | "strings" | |
7 | ) | |
8 | ||
9 | type Association struct { | |
10 | Scope *Scope | |
11 | Column string | |
12 | Error error | |
13 | Field *Field | |
14 | } | |
15 | ||
16 | func (association *Association) setErr(err error) *Association { | |
17 | if err != nil { | |
18 | association.Error = err | |
19 | } | |
20 | return association | |
21 | } | |
22 | ||
23 | func (association *Association) Find(value interface{}) *Association { | |
24 | association.Scope.related(value, association.Column) | |
25 | return association.setErr(association.Scope.db.Error) | |
26 | } | |
27 | ||
28 | func (association *Association) Append(values ...interface{}) *Association { | |
29 | scope := association.Scope | |
30 | field := association.Field | |
31 | ||
32 | createJoinTable := func(reflectValue reflect.Value) { | |
33 | var value = reflectValue.Interface() | |
34 | if reflectValue.Kind() != reflect.Ptr { | |
35 | reflectPtr := reflect.New(reflectValue.Type()) | |
36 | reflectPtr.Elem().Set(reflectValue) | |
37 | value = reflectPtr.Interface() | |
38 | } | |
39 | ||
40 | if scope.New(value).PrimaryKeyZero() { | |
41 | scope.NewDB().Save(value) | |
42 | } | |
43 | ||
44 | relationship := association.Field.Relationship | |
45 | association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, value)) | |
46 | ||
47 | result := reflect.ValueOf(value) | |
48 | fieldElemType := field.Field.Type().Elem() | |
49 | if result.Type().AssignableTo(fieldElemType) { | |
50 | field.Set(reflect.Append(field.Field, result)) | |
51 | } else if result.Type().Elem().AssignableTo(fieldElemType) { | |
52 | field.Set(reflect.Append(field.Field, result.Elem())) | |
53 | } | |
54 | } | |
55 | ||
56 | for _, value := range values { | |
57 | reflectValue := reflect.Indirect(reflect.ValueOf(value)) | |
58 | ||
59 | if reflectValue.Kind() == reflect.Struct { | |
60 | createJoinTable(reflectValue) | |
61 | } else if reflectValue.Kind() == reflect.Slice { | |
62 | for i := 0; i < reflectValue.Len(); i++ { | |
63 | createJoinTable(reflectValue.Index(i)) | |
64 | } | |
65 | } else { | |
66 | association.setErr(errors.New("invalid association type")) | |
67 | } | |
68 | } | |
69 | return association | |
70 | } | |
71 | ||
72 | func (association *Association) Delete(values ...interface{}) *Association { | |
73 | scope := association.Scope | |
74 | relationship := association.Field.Relationship | |
75 | ||
76 | // many to many | |
77 | if relationship.Kind == "many_to_many" { | |
78 | query := scope.NewDB() | |
79 | for idx, foreignKey := range relationship.ForeignDBNames { | |
80 | if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { | |
81 | query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) | |
82 | } | |
83 | } | |
84 | ||
85 | primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) | |
86 | sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)) | |
87 | query = query.Where(sql, toQueryValues(primaryKeys)...) | |
88 | ||
89 | if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { | |
90 | leftValues := reflect.Zero(association.Field.Field.Type()) | |
91 | for i := 0; i < association.Field.Field.Len(); i++ { | |
92 | reflectValue := association.Field.Field.Index(i) | |
93 | primaryKey := association.getPrimaryKeys(relationship.ForeignFieldNames, reflectValue.Interface())[0] | |
94 | var included = false | |
95 | for _, pk := range primaryKeys { | |
96 | if equalAsString(primaryKey, pk) { | |
97 | included = true | |
98 | } | |
99 | } | |
100 | if !included { | |
101 | leftValues = reflect.Append(leftValues, reflectValue) | |
102 | } | |
103 | } | |
104 | association.Field.Set(leftValues) | |
105 | } | |
106 | } else { | |
107 | association.setErr(errors.New("delete only support many to many")) | |
108 | } | |
109 | return association | |
110 | } | |
111 | ||
112 | func (association *Association) Replace(values ...interface{}) *Association { | |
113 | relationship := association.Field.Relationship | |
114 | scope := association.Scope | |
115 | if relationship.Kind == "many_to_many" { | |
116 | field := association.Field.Field | |
117 | ||
118 | oldPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) | |
119 | association.Field.Set(reflect.Zero(association.Field.Field.Type())) | |
120 | association.Append(values...) | |
121 | newPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) | |
122 | ||
123 | var addedPrimaryKeys = [][]interface{}{} | |
124 | for _, newKey := range newPrimaryKeys { | |
125 | hasEqual := false | |
126 | for _, oldKey := range oldPrimaryKeys { | |
127 | if equalAsString(newKey, oldKey) { | |
128 | hasEqual = true | |
129 | break | |
130 | } | |
131 | } | |
132 | if !hasEqual { | |
133 | addedPrimaryKeys = append(addedPrimaryKeys, newKey) | |
134 | } | |
135 | } | |
136 | ||
137 | for _, primaryKey := range association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) { | |
138 | addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) | |
139 | } | |
140 | ||
141 | query := scope.NewDB() | |
142 | for idx, foreignKey := range relationship.ForeignDBNames { | |
143 | if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { | |
144 | query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) | |
145 | } | |
146 | } | |
147 | ||
148 | if len(addedPrimaryKeys) > 0 { | |
149 | sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys)) | |
150 | query = query.Where(sql, toQueryValues(addedPrimaryKeys)...) | |
151 | } | |
152 | association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship)) | |
153 | } else { | |
154 | association.setErr(errors.New("replace only support many to many")) | |
155 | } | |
156 | return association | |
157 | } | |
158 | ||
159 | func (association *Association) Clear() *Association { | |
160 | relationship := association.Field.Relationship | |
161 | scope := association.Scope | |
162 | if relationship.Kind == "many_to_many" { | |
163 | query := scope.NewDB() | |
164 | for idx, foreignKey := range relationship.ForeignDBNames { | |
165 | if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { | |
166 | query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) | |
167 | } | |
168 | } | |
169 | ||
170 | if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { | |
171 | association.Field.Set(reflect.Zero(association.Field.Field.Type())) | |
172 | } else { | |
173 | association.setErr(err) | |
174 | } | |
175 | } else { | |
176 | association.setErr(errors.New("clear only support many to many")) | |
177 | } | |
178 | return association | |
179 | } | |
180 | ||
181 | func (association *Association) Count() int { | |
182 | count := -1 | |
183 | relationship := association.Field.Relationship | |
184 | scope := association.Scope | |
185 | newScope := scope.New(association.Field.Field.Interface()) | |
186 | ||
187 | if relationship.Kind == "many_to_many" { | |
188 | relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) | |
189 | } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { | |
190 | query := scope.DB() | |
191 | for idx, foreignKey := range relationship.ForeignDBNames { | |
192 | if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { | |
193 | query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), | |
194 | field.Field.Interface()) | |
195 | } | |
196 | } | |
197 | ||
198 | if relationship.PolymorphicType != "" { | |
199 | query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) | |
200 | } | |
201 | query.Table(newScope.TableName()).Count(&count) | |
202 | } else if relationship.Kind == "belongs_to" { | |
203 | query := scope.DB() | |
204 | for idx, foreignKey := range relationship.ForeignDBNames { | |
205 | if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { | |
206 | query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), | |
207 | field.Field.Interface()) | |
208 | } | |
209 | } | |
210 | query.Table(newScope.TableName()).Count(&count) | |
211 | } | |
212 | ||
213 | return count | |
214 | } | |
215 | ||
216 | func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) [][]interface{} { | |
217 | results := [][]interface{}{} | |
218 | scope := association.Scope | |
219 | ||
220 | for _, value := range values { | |
221 | reflectValue := reflect.Indirect(reflect.ValueOf(value)) | |
222 | if reflectValue.Kind() == reflect.Slice { | |
223 | for i := 0; i < reflectValue.Len(); i++ { | |
224 | primaryKeys := []interface{}{} | |
225 | newScope := scope.New(reflectValue.Index(i).Interface()) | |
226 | for _, column := range columns { | |
227 | if field, ok := newScope.FieldByName(column); ok { | |
228 | primaryKeys = append(primaryKeys, field.Field.Interface()) | |
229 | } else { | |
230 | primaryKeys = append(primaryKeys, "") | |
231 | } | |
232 | } | |
233 | results = append(results, primaryKeys) | |
234 | } | |
235 | } else if reflectValue.Kind() == reflect.Struct { | |
236 | newScope := scope.New(value) | |
237 | var primaryKeys []interface{} | |
238 | for _, column := range columns { | |
239 | if field, ok := newScope.FieldByName(column); ok { | |
240 | primaryKeys = append(primaryKeys, field.Field.Interface()) | |
241 | } else { | |
242 | primaryKeys = append(primaryKeys, "") | |
243 | } | |
244 | } | |
245 | ||
246 | results = append(results, primaryKeys) | |
247 | } | |
248 | } | |
249 | return results | |
250 | } | |
251 | ||
252 | func toQueryMarks(primaryValues [][]interface{}) string { | |
253 | var results []string | |
254 | ||
255 | for _, primaryValue := range primaryValues { | |
256 | var marks []string | |
257 | for _, _ = range primaryValue { | |
258 | marks = append(marks, "?") | |
259 | } | |
260 | ||
261 | if len(marks) > 1 { | |
262 | results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) | |
263 | } else { | |
264 | results = append(results, strings.Join(marks, "")) | |
265 | } | |
266 | } | |
267 | return strings.Join(results, ",") | |
268 | } | |
269 | ||
270 | func toQueryCondition(scope *Scope, columns []string) string { | |
271 | var newColumns []string | |
272 | for _, column := range columns { | |
273 | newColumns = append(newColumns, scope.Quote(column)) | |
274 | } | |
275 | ||
276 | if len(columns) > 1 { | |
277 | return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) | |
278 | } else { | |
279 | return strings.Join(columns, ",") | |
280 | } | |
281 | } | |
282 | ||
283 | func toQueryValues(primaryValues [][]interface{}) (values []interface{}) { | |
284 | for _, primaryValue := range primaryValues { | |
285 | for _, value := range primaryValue { | |
286 | values = append(values, value) | |
287 | } | |
288 | } | |
289 | return values | |
290 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "testing" | |
5 | ) | |
6 | ||
7 | func TestHasOneAndHasManyAssociation(t *testing.T) { | |
8 | DB.DropTable(Category{}, Post{}, Comment{}) | |
9 | DB.CreateTable(Category{}, Post{}, Comment{}) | |
10 | ||
11 | post := Post{ | |
12 | Title: "post 1", | |
13 | Body: "body 1", | |
14 | Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}}, | |
15 | Category: Category{Name: "Category 1"}, | |
16 | MainCategory: Category{Name: "Main Category 1"}, | |
17 | } | |
18 | ||
19 | if err := DB.Save(&post).Error; err != nil { | |
20 | t.Errorf("Got errors when save post", err.Error()) | |
21 | } | |
22 | ||
23 | if err := DB.First(&Category{}, "name = ?", "Category 1").Error; err != nil { | |
24 | t.Errorf("Category should be saved", err.Error()) | |
25 | } | |
26 | ||
27 | var p Post | |
28 | DB.First(&p, post.Id) | |
29 | ||
30 | if post.CategoryId.Int64 == 0 || p.CategoryId.Int64 == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 { | |
31 | t.Errorf("Category Id should exist") | |
32 | } | |
33 | ||
34 | if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil { | |
35 | t.Errorf("Comment 1 should be saved") | |
36 | } | |
37 | if post.Comments[0].PostId == 0 { | |
38 | t.Errorf("Comment Should have post id") | |
39 | } | |
40 | ||
41 | var comment Comment | |
42 | if DB.First(&comment, "content = ?", "Comment 2").Error != nil { | |
43 | t.Errorf("Comment 2 should be saved") | |
44 | } | |
45 | ||
46 | if comment.PostId == 0 { | |
47 | t.Errorf("Comment 2 Should have post id") | |
48 | } | |
49 | ||
50 | comment3 := Comment{Content: "Comment 3", Post: Post{Title: "Title 3", Body: "Body 3"}} | |
51 | DB.Save(&comment3) | |
52 | } | |
53 | ||
54 | func TestRelated(t *testing.T) { | |
55 | user := User{ | |
56 | Name: "jinzhu", | |
57 | BillingAddress: Address{Address1: "Billing Address - Address 1"}, | |
58 | ShippingAddress: Address{Address1: "Shipping Address - Address 1"}, | |
59 | Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}}, | |
60 | CreditCard: CreditCard{Number: "1234567890"}, | |
61 | Company: Company{Name: "company1"}, | |
62 | } | |
63 | ||
64 | DB.Save(&user) | |
65 | ||
66 | if user.CreditCard.ID == 0 { | |
67 | t.Errorf("After user save, credit card should have id") | |
68 | } | |
69 | ||
70 | if user.BillingAddress.ID == 0 { | |
71 | t.Errorf("After user save, billing address should have id") | |
72 | } | |
73 | ||
74 | if user.Emails[0].Id == 0 { | |
75 | t.Errorf("After user save, billing address should have id") | |
76 | } | |
77 | ||
78 | var emails []Email | |
79 | DB.Model(&user).Related(&emails) | |
80 | if len(emails) != 2 { | |
81 | t.Errorf("Should have two emails") | |
82 | } | |
83 | ||
84 | var emails2 []Email | |
85 | DB.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2) | |
86 | if len(emails2) != 1 { | |
87 | t.Errorf("Should have two emails") | |
88 | } | |
89 | ||
90 | var user1 User | |
91 | DB.Model(&user).Related(&user1.Emails) | |
92 | if len(user1.Emails) != 2 { | |
93 | t.Errorf("Should have only one email match related condition") | |
94 | } | |
95 | ||
96 | var address1 Address | |
97 | DB.Model(&user).Related(&address1, "BillingAddressId") | |
98 | if address1.Address1 != "Billing Address - Address 1" { | |
99 | t.Errorf("Should get billing address from user correctly") | |
100 | } | |
101 | ||
102 | user1 = User{} | |
103 | DB.Model(&address1).Related(&user1, "BillingAddressId") | |
104 | if DB.NewRecord(user1) { | |
105 | t.Errorf("Should get user from address correctly") | |
106 | } | |
107 | ||
108 | var user2 User | |
109 | DB.Model(&emails[0]).Related(&user2) | |
110 | if user2.Id != user.Id || user2.Name != user.Name { | |
111 | t.Errorf("Should get user from email correctly") | |
112 | } | |
113 | ||
114 | var creditcard CreditCard | |
115 | var user3 User | |
116 | DB.First(&creditcard, "number = ?", "1234567890") | |
117 | DB.Model(&creditcard).Related(&user3) | |
118 | if user3.Id != user.Id || user3.Name != user.Name { | |
119 | t.Errorf("Should get user from credit card correctly") | |
120 | } | |
121 | ||
122 | if !DB.Model(&CreditCard{}).Related(&User{}).RecordNotFound() { | |
123 | t.Errorf("RecordNotFound for Related") | |
124 | } | |
125 | ||
126 | var company Company | |
127 | if DB.Model(&user).Related(&company, "Company").RecordNotFound() || company.Name != "company1" { | |
128 | t.Errorf("RecordNotFound for Related") | |
129 | } | |
130 | } | |
131 | ||
132 | func TestManyToMany(t *testing.T) { | |
133 | DB.Raw("delete from languages") | |
134 | var languages = []Language{{Name: "ZH"}, {Name: "EN"}} | |
135 | user := User{Name: "Many2Many", Languages: languages} | |
136 | DB.Save(&user) | |
137 | ||
138 | // Query | |
139 | var newLanguages []Language | |
140 | DB.Model(&user).Related(&newLanguages, "Languages") | |
141 | if len(newLanguages) != len([]string{"ZH", "EN"}) { | |
142 | t.Errorf("Query many to many relations") | |
143 | } | |
144 | ||
145 | DB.Model(&user).Association("Languages").Find(&newLanguages) | |
146 | if len(newLanguages) != len([]string{"ZH", "EN"}) { | |
147 | t.Errorf("Should be able to find many to many relations") | |
148 | } | |
149 | ||
150 | if DB.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) { | |
151 | t.Errorf("Count should return correct result") | |
152 | } | |
153 | ||
154 | // Append | |
155 | DB.Model(&user).Association("Languages").Append(&Language{Name: "DE"}) | |
156 | if DB.Where("name = ?", "DE").First(&Language{}).RecordNotFound() { | |
157 | t.Errorf("New record should be saved when append") | |
158 | } | |
159 | ||
160 | languageA := Language{Name: "AA"} | |
161 | DB.Save(&languageA) | |
162 | DB.Model(&User{Id: user.Id}).Association("Languages").Append(&languageA) | |
163 | ||
164 | languageC := Language{Name: "CC"} | |
165 | DB.Save(&languageC) | |
166 | DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC}) | |
167 | ||
168 | DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}}) | |
169 | ||
170 | totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"} | |
171 | ||
172 | if DB.Model(&user).Association("Languages").Count() != len(totalLanguages) { | |
173 | t.Errorf("All appended languages should be saved") | |
174 | } | |
175 | ||
176 | // Delete | |
177 | user.Languages = []Language{} | |
178 | DB.Model(&user).Association("Languages").Find(&user.Languages) | |
179 | ||
180 | var language Language | |
181 | DB.Where("name = ?", "EE").First(&language) | |
182 | DB.Model(&user).Association("Languages").Delete(language, &language) | |
183 | ||
184 | if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 { | |
185 | t.Errorf("Relations should be deleted with Delete") | |
186 | } | |
187 | if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() { | |
188 | t.Errorf("Language EE should not be deleted") | |
189 | } | |
190 | ||
191 | DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages) | |
192 | ||
193 | user2 := User{Name: "Many2Many_User2", Languages: languages} | |
194 | DB.Save(&user2) | |
195 | ||
196 | DB.Model(&user).Association("Languages").Delete(languages, &languages) | |
197 | if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 || len(user.Languages) != len(totalLanguages)-3 { | |
198 | t.Errorf("Relations should be deleted with Delete") | |
199 | } | |
200 | ||
201 | if DB.Model(&user2).Association("Languages").Count() == 0 { | |
202 | t.Errorf("Other user's relations should not be deleted") | |
203 | } | |
204 | ||
205 | // Replace | |
206 | var languageB Language | |
207 | DB.Where("name = ?", "BB").First(&languageB) | |
208 | DB.Model(&user).Association("Languages").Replace(languageB) | |
209 | if len(user.Languages) != 1 || DB.Model(&user).Association("Languages").Count() != 1 { | |
210 | t.Errorf("Relations should be replaced") | |
211 | } | |
212 | ||
213 | DB.Model(&user).Association("Languages").Replace() | |
214 | if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { | |
215 | t.Errorf("Relations should be replaced with empty") | |
216 | } | |
217 | ||
218 | DB.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}}) | |
219 | if len(user.Languages) != 2 || DB.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) { | |
220 | t.Errorf("Relations should be replaced") | |
221 | } | |
222 | ||
223 | // Clear | |
224 | DB.Model(&user).Association("Languages").Clear() | |
225 | if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { | |
226 | t.Errorf("Relations should be cleared") | |
227 | } | |
228 | } | |
229 | ||
230 | func TestForeignKey(t *testing.T) { | |
231 | for _, structField := range DB.NewScope(&User{}).GetStructFields() { | |
232 | for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} { | |
233 | if structField.Name == foreignKey && !structField.IsForeignKey { | |
234 | t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) | |
235 | } | |
236 | } | |
237 | } | |
238 | ||
239 | for _, structField := range DB.NewScope(&Email{}).GetStructFields() { | |
240 | for _, foreignKey := range []string{"UserId"} { | |
241 | if structField.Name == foreignKey && !structField.IsForeignKey { | |
242 | t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) | |
243 | } | |
244 | } | |
245 | } | |
246 | ||
247 | for _, structField := range DB.NewScope(&Post{}).GetStructFields() { | |
248 | for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} { | |
249 | if structField.Name == foreignKey && !structField.IsForeignKey { | |
250 | t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) | |
251 | } | |
252 | } | |
253 | } | |
254 | ||
255 | for _, structField := range DB.NewScope(&Comment{}).GetStructFields() { | |
256 | for _, foreignKey := range []string{"PostId"} { | |
257 | if structField.Name == foreignKey && !structField.IsForeignKey { | |
258 | t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) | |
259 | } | |
260 | } | |
261 | } | |
262 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | ) | |
5 | ||
6 | type callback struct { | |
7 | creates []*func(scope *Scope) | |
8 | updates []*func(scope *Scope) | |
9 | deletes []*func(scope *Scope) | |
10 | queries []*func(scope *Scope) | |
11 | rowQueries []*func(scope *Scope) | |
12 | processors []*callbackProcessor | |
13 | } | |
14 | ||
15 | type callbackProcessor struct { | |
16 | name string | |
17 | before string | |
18 | after string | |
19 | replace bool | |
20 | remove bool | |
21 | typ string | |
22 | processor *func(scope *Scope) | |
23 | callback *callback | |
24 | } | |
25 | ||
26 | func (c *callback) addProcessor(typ string) *callbackProcessor { | |
27 | cp := &callbackProcessor{typ: typ, callback: c} | |
28 | c.processors = append(c.processors, cp) | |
29 | return cp | |
30 | } | |
31 | ||
32 | func (c *callback) clone() *callback { | |
33 | return &callback{ | |
34 | creates: c.creates, | |
35 | updates: c.updates, | |
36 | deletes: c.deletes, | |
37 | queries: c.queries, | |
38 | processors: c.processors, | |
39 | } | |
40 | } | |
41 | ||
42 | func (c *callback) Create() *callbackProcessor { | |
43 | return c.addProcessor("create") | |
44 | } | |
45 | ||
46 | func (c *callback) Update() *callbackProcessor { | |
47 | return c.addProcessor("update") | |
48 | } | |
49 | ||
50 | func (c *callback) Delete() *callbackProcessor { | |
51 | return c.addProcessor("delete") | |
52 | } | |
53 | ||
54 | func (c *callback) Query() *callbackProcessor { | |
55 | return c.addProcessor("query") | |
56 | } | |
57 | ||
58 | func (c *callback) RowQuery() *callbackProcessor { | |
59 | return c.addProcessor("row_query") | |
60 | } | |
61 | ||
62 | func (cp *callbackProcessor) Before(name string) *callbackProcessor { | |
63 | cp.before = name | |
64 | return cp | |
65 | } | |
66 | ||
67 | func (cp *callbackProcessor) After(name string) *callbackProcessor { | |
68 | cp.after = name | |
69 | return cp | |
70 | } | |
71 | ||
72 | func (cp *callbackProcessor) Register(name string, fc func(scope *Scope)) { | |
73 | cp.name = name | |
74 | cp.processor = &fc | |
75 | cp.callback.sort() | |
76 | } | |
77 | ||
78 | func (cp *callbackProcessor) Remove(name string) { | |
79 | fmt.Printf("[info] removing callback `%v` from %v\n", name, fileWithLineNum()) | |
80 | cp.name = name | |
81 | cp.remove = true | |
82 | cp.callback.sort() | |
83 | } | |
84 | ||
85 | func (cp *callbackProcessor) Replace(name string, fc func(scope *Scope)) { | |
86 | fmt.Printf("[info] replacing callback `%v` from %v\n", name, fileWithLineNum()) | |
87 | cp.name = name | |
88 | cp.processor = &fc | |
89 | cp.replace = true | |
90 | cp.callback.sort() | |
91 | } | |
92 | ||
93 | func getRIndex(strs []string, str string) int { | |
94 | for i := len(strs) - 1; i >= 0; i-- { | |
95 | if strs[i] == str { | |
96 | return i | |
97 | } | |
98 | } | |
99 | return -1 | |
100 | } | |
101 | ||
102 | func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) { | |
103 | var sortCallbackProcessor func(c *callbackProcessor) | |
104 | var names, sortedNames = []string{}, []string{} | |
105 | ||
106 | for _, cp := range cps { | |
107 | if index := getRIndex(names, cp.name); index > -1 { | |
108 | if !cp.replace && !cp.remove { | |
109 | fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) | |
110 | } | |
111 | } | |
112 | names = append(names, cp.name) | |
113 | } | |
114 | ||
115 | sortCallbackProcessor = func(c *callbackProcessor) { | |
116 | if getRIndex(sortedNames, c.name) > -1 { | |
117 | return | |
118 | } | |
119 | ||
120 | if len(c.before) > 0 { | |
121 | if index := getRIndex(sortedNames, c.before); index > -1 { | |
122 | sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) | |
123 | } else if index := getRIndex(names, c.before); index > -1 { | |
124 | sortedNames = append(sortedNames, c.name) | |
125 | sortCallbackProcessor(cps[index]) | |
126 | } else { | |
127 | sortedNames = append(sortedNames, c.name) | |
128 | } | |
129 | } | |
130 | ||
131 | if len(c.after) > 0 { | |
132 | if index := getRIndex(sortedNames, c.after); index > -1 { | |
133 | sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) | |
134 | } else if index := getRIndex(names, c.after); index > -1 { | |
135 | cp := cps[index] | |
136 | if len(cp.before) == 0 { | |
137 | cp.before = c.name | |
138 | } | |
139 | sortCallbackProcessor(cp) | |
140 | } else { | |
141 | sortedNames = append(sortedNames, c.name) | |
142 | } | |
143 | } | |
144 | ||
145 | if getRIndex(sortedNames, c.name) == -1 { | |
146 | sortedNames = append(sortedNames, c.name) | |
147 | } | |
148 | } | |
149 | ||
150 | for _, cp := range cps { | |
151 | sortCallbackProcessor(cp) | |
152 | } | |
153 | ||
154 | var funcs = []*func(scope *Scope){} | |
155 | var sortedFuncs = []*func(scope *Scope){} | |
156 | for _, name := range sortedNames { | |
157 | index := getRIndex(names, name) | |
158 | if !cps[index].remove { | |
159 | sortedFuncs = append(sortedFuncs, cps[index].processor) | |
160 | } | |
161 | } | |
162 | ||
163 | for _, cp := range cps { | |
164 | if sindex := getRIndex(sortedNames, cp.name); sindex == -1 { | |
165 | if !cp.remove { | |
166 | funcs = append(funcs, cp.processor) | |
167 | } | |
168 | } | |
169 | } | |
170 | ||
171 | return append(sortedFuncs, funcs...) | |
172 | } | |
173 | ||
174 | func (c *callback) sort() { | |
175 | var creates, updates, deletes, queries, rowQueries []*callbackProcessor | |
176 | ||
177 | for _, processor := range c.processors { | |
178 | switch processor.typ { | |
179 | case "create": | |
180 | creates = append(creates, processor) | |
181 | case "update": | |
182 | updates = append(updates, processor) | |
183 | case "delete": | |
184 | deletes = append(deletes, processor) | |
185 | case "query": | |
186 | queries = append(queries, processor) | |
187 | case "row_query": | |
188 | rowQueries = append(rowQueries, processor) | |
189 | } | |
190 | } | |
191 | ||
192 | c.creates = sortProcessors(creates) | |
193 | c.updates = sortProcessors(updates) | |
194 | c.deletes = sortProcessors(deletes) | |
195 | c.queries = sortProcessors(queries) | |
196 | c.rowQueries = sortProcessors(rowQueries) | |
197 | } | |
198 | ||
199 | var DefaultCallback = &callback{processors: []*callbackProcessor{}} |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "strings" | |
5 | ) | |
6 | ||
7 | func BeforeCreate(scope *Scope) { | |
8 | scope.CallMethodWithErrorCheck("BeforeSave") | |
9 | scope.CallMethodWithErrorCheck("BeforeCreate") | |
10 | } | |
11 | ||
12 | func UpdateTimeStampWhenCreate(scope *Scope) { | |
13 | if !scope.HasError() { | |
14 | now := NowFunc() | |
15 | scope.SetColumn("CreatedAt", now) | |
16 | scope.SetColumn("UpdatedAt", now) | |
17 | } | |
18 | } | |
19 | ||
20 | func Create(scope *Scope) { | |
21 | defer scope.Trace(NowFunc()) | |
22 | ||
23 | if !scope.HasError() { | |
24 | // set create sql | |
25 | var sqls, columns []string | |
26 | fields := scope.Fields() | |
27 | for _, field := range fields { | |
28 | if scope.changeableField(field) { | |
29 | if field.IsNormal { | |
30 | if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) { | |
31 | if !field.IsBlank || !field.HasDefaultValue { | |
32 | columns = append(columns, scope.Quote(field.DBName)) | |
33 | sqls = append(sqls, scope.AddToVars(field.Field.Interface())) | |
34 | } else if field.HasDefaultValue { | |
35 | scope.InstanceSet("gorm:force_reload_after_create", true) | |
36 | } | |
37 | } | |
38 | } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { | |
39 | for _, dbName := range relationship.ForeignDBNames { | |
40 | if relationField := fields[dbName]; !scope.changeableField(relationField) { | |
41 | columns = append(columns, scope.Quote(relationField.DBName)) | |
42 | sqls = append(sqls, scope.AddToVars(relationField.Field.Interface())) | |
43 | } | |
44 | } | |
45 | } | |
46 | } | |
47 | } | |
48 | ||
49 | returningKey := "*" | |
50 | primaryField := scope.PrimaryField() | |
51 | if primaryField != nil { | |
52 | returningKey = scope.Quote(primaryField.DBName) | |
53 | } | |
54 | ||
55 | if len(columns) == 0 { | |
56 | scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v", | |
57 | scope.QuotedTableName(), | |
58 | scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey), | |
59 | )) | |
60 | } else { | |
61 | scope.Raw(fmt.Sprintf( | |
62 | "INSERT INTO %v (%v) VALUES (%v) %v", | |
63 | scope.QuotedTableName(), | |
64 | strings.Join(columns, ","), | |
65 | strings.Join(sqls, ","), | |
66 | scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey), | |
67 | )) | |
68 | } | |
69 | ||
70 | // execute create sql | |
71 | if scope.Dialect().SupportLastInsertId() { | |
72 | if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { | |
73 | id, err := result.LastInsertId() | |
74 | if scope.Err(err) == nil { | |
75 | scope.db.RowsAffected, _ = result.RowsAffected() | |
76 | if primaryField != nil && primaryField.IsBlank { | |
77 | scope.Err(scope.SetColumn(primaryField, id)) | |
78 | } | |
79 | } | |
80 | } | |
81 | } else { | |
82 | if primaryField == nil { | |
83 | if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil { | |
84 | scope.db.RowsAffected, _ = results.RowsAffected() | |
85 | } else { | |
86 | scope.Err(err) | |
87 | } | |
88 | } else { | |
89 | if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil { | |
90 | scope.db.RowsAffected = 1 | |
91 | } else { | |
92 | scope.Err(err) | |
93 | } | |
94 | } | |
95 | } | |
96 | } | |
97 | } | |
98 | ||
99 | func ForceReloadAfterCreate(scope *Scope) { | |
100 | if _, ok := scope.InstanceGet("gorm:force_reload_after_create"); ok { | |
101 | scope.DB().New().First(scope.Value) | |
102 | } | |
103 | } | |
104 | ||
105 | func AfterCreate(scope *Scope) { | |
106 | scope.CallMethodWithErrorCheck("AfterCreate") | |
107 | scope.CallMethodWithErrorCheck("AfterSave") | |
108 | } | |
109 | ||
110 | func init() { | |
111 | DefaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction) | |
112 | DefaultCallback.Create().Register("gorm:before_create", BeforeCreate) | |
113 | DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations) | |
114 | DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate) | |
115 | DefaultCallback.Create().Register("gorm:create", Create) | |
116 | DefaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate) | |
117 | DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations) | |
118 | DefaultCallback.Create().Register("gorm:after_create", AfterCreate) | |
119 | DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | |
120 | } |
0 | package gorm | |
1 | ||
2 | import "fmt" | |
3 | ||
4 | func BeforeDelete(scope *Scope) { | |
5 | scope.CallMethodWithErrorCheck("BeforeDelete") | |
6 | } | |
7 | ||
8 | func Delete(scope *Scope) { | |
9 | if !scope.HasError() { | |
10 | if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") { | |
11 | scope.Raw( | |
12 | fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", | |
13 | scope.QuotedTableName(), | |
14 | scope.AddToVars(NowFunc()), | |
15 | scope.CombinedConditionSql(), | |
16 | )) | |
17 | } else { | |
18 | scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.QuotedTableName(), scope.CombinedConditionSql())) | |
19 | } | |
20 | ||
21 | scope.Exec() | |
22 | } | |
23 | } | |
24 | ||
25 | func AfterDelete(scope *Scope) { | |
26 | scope.CallMethodWithErrorCheck("AfterDelete") | |
27 | } | |
28 | ||
29 | func init() { | |
30 | DefaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction) | |
31 | DefaultCallback.Delete().Register("gorm:before_delete", BeforeDelete) | |
32 | DefaultCallback.Delete().Register("gorm:delete", Delete) | |
33 | DefaultCallback.Delete().Register("gorm:after_delete", AfterDelete) | |
34 | DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | |
35 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "fmt" | |
5 | "reflect" | |
6 | ) | |
7 | ||
8 | func Query(scope *Scope) { | |
9 | defer scope.Trace(NowFunc()) | |
10 | ||
11 | var ( | |
12 | isSlice bool | |
13 | isPtr bool | |
14 | anyRecordFound bool | |
15 | destType reflect.Type | |
16 | ) | |
17 | ||
18 | if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { | |
19 | if primaryKey := scope.PrimaryKey(); primaryKey != "" { | |
20 | scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryKey), orderBy)) | |
21 | } | |
22 | } | |
23 | ||
24 | var dest = scope.IndirectValue() | |
25 | if value, ok := scope.Get("gorm:query_destination"); ok { | |
26 | dest = reflect.Indirect(reflect.ValueOf(value)) | |
27 | } | |
28 | ||
29 | if kind := dest.Kind(); kind == reflect.Slice { | |
30 | isSlice = true | |
31 | destType = dest.Type().Elem() | |
32 | dest.Set(reflect.MakeSlice(dest.Type(), 0, 0)) | |
33 | ||
34 | if destType.Kind() == reflect.Ptr { | |
35 | isPtr = true | |
36 | destType = destType.Elem() | |
37 | } | |
38 | } else if kind != reflect.Struct { | |
39 | scope.Err(errors.New("unsupported destination, should be slice or struct")) | |
40 | return | |
41 | } | |
42 | ||
43 | scope.prepareQuerySql() | |
44 | ||
45 | if !scope.HasError() { | |
46 | rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...) | |
47 | scope.db.RowsAffected = 0 | |
48 | ||
49 | if scope.Err(err) != nil { | |
50 | return | |
51 | } | |
52 | defer rows.Close() | |
53 | ||
54 | columns, _ := rows.Columns() | |
55 | for rows.Next() { | |
56 | scope.db.RowsAffected++ | |
57 | ||
58 | anyRecordFound = true | |
59 | elem := dest | |
60 | if isSlice { | |
61 | elem = reflect.New(destType).Elem() | |
62 | } | |
63 | ||
64 | var values = make([]interface{}, len(columns)) | |
65 | ||
66 | fields := scope.New(elem.Addr().Interface()).Fields() | |
67 | ||
68 | for index, column := range columns { | |
69 | if field, ok := fields[column]; ok { | |
70 | if field.Field.Kind() == reflect.Ptr { | |
71 | values[index] = field.Field.Addr().Interface() | |
72 | } else { | |
73 | values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface() | |
74 | } | |
75 | } else { | |
76 | var value interface{} | |
77 | values[index] = &value | |
78 | } | |
79 | } | |
80 | ||
81 | scope.Err(rows.Scan(values...)) | |
82 | ||
83 | for index, column := range columns { | |
84 | value := values[index] | |
85 | if field, ok := fields[column]; ok { | |
86 | if field.Field.Kind() == reflect.Ptr { | |
87 | field.Field.Set(reflect.ValueOf(value).Elem()) | |
88 | } else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() { | |
89 | field.Field.Set(v) | |
90 | } | |
91 | } | |
92 | } | |
93 | ||
94 | if isSlice { | |
95 | if isPtr { | |
96 | dest.Set(reflect.Append(dest, elem.Addr())) | |
97 | } else { | |
98 | dest.Set(reflect.Append(dest, elem)) | |
99 | } | |
100 | } | |
101 | } | |
102 | ||
103 | if !anyRecordFound && !isSlice { | |
104 | scope.Err(RecordNotFound) | |
105 | } | |
106 | } | |
107 | } | |
108 | ||
109 | func AfterQuery(scope *Scope) { | |
110 | scope.CallMethodWithErrorCheck("AfterFind") | |
111 | } | |
112 | ||
113 | func init() { | |
114 | DefaultCallback.Query().Register("gorm:query", Query) | |
115 | DefaultCallback.Query().Register("gorm:after_query", AfterQuery) | |
116 | DefaultCallback.Query().Register("gorm:preload", Preload) | |
117 | } |
0 | package gorm | |
1 | ||
2 | import "reflect" | |
3 | ||
4 | func BeginTransaction(scope *Scope) { | |
5 | scope.Begin() | |
6 | } | |
7 | ||
8 | func CommitOrRollbackTransaction(scope *Scope) { | |
9 | scope.CommitOrRollback() | |
10 | } | |
11 | ||
12 | func SaveBeforeAssociations(scope *Scope) { | |
13 | if !scope.shouldSaveAssociations() { | |
14 | return | |
15 | } | |
16 | for _, field := range scope.Fields() { | |
17 | if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { | |
18 | if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { | |
19 | value := field.Field | |
20 | scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error) | |
21 | if len(relationship.ForeignFieldNames) != 0 { | |
22 | for idx, fieldName := range relationship.ForeignFieldNames { | |
23 | associationForeignName := relationship.AssociationForeignDBNames[idx] | |
24 | if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok { | |
25 | scope.Err(scope.SetColumn(fieldName, f.Field.Interface())) | |
26 | } | |
27 | } | |
28 | } | |
29 | } | |
30 | } | |
31 | } | |
32 | } | |
33 | ||
34 | func SaveAfterAssociations(scope *Scope) { | |
35 | if !scope.shouldSaveAssociations() { | |
36 | return | |
37 | } | |
38 | for _, field := range scope.Fields() { | |
39 | if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { | |
40 | if relationship := field.Relationship; relationship != nil && | |
41 | (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { | |
42 | value := field.Field | |
43 | ||
44 | switch value.Kind() { | |
45 | case reflect.Slice: | |
46 | for i := 0; i < value.Len(); i++ { | |
47 | newDB := scope.NewDB() | |
48 | elem := value.Index(i).Addr().Interface() | |
49 | newScope := newDB.NewScope(elem) | |
50 | ||
51 | if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { | |
52 | for idx, fieldName := range relationship.ForeignFieldNames { | |
53 | associationForeignName := relationship.AssociationForeignDBNames[idx] | |
54 | if f, ok := scope.FieldByName(associationForeignName); ok { | |
55 | scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) | |
56 | } | |
57 | } | |
58 | } | |
59 | ||
60 | if relationship.PolymorphicType != "" { | |
61 | scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) | |
62 | } | |
63 | ||
64 | scope.Err(newDB.Save(elem).Error) | |
65 | ||
66 | if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { | |
67 | scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value)) | |
68 | } | |
69 | } | |
70 | default: | |
71 | elem := value.Addr().Interface() | |
72 | newScope := scope.New(elem) | |
73 | if len(relationship.ForeignFieldNames) != 0 { | |
74 | for idx, fieldName := range relationship.ForeignFieldNames { | |
75 | associationForeignName := relationship.AssociationForeignDBNames[idx] | |
76 | if f, ok := scope.FieldByName(associationForeignName); ok { | |
77 | scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) | |
78 | } | |
79 | } | |
80 | } | |
81 | ||
82 | if relationship.PolymorphicType != "" { | |
83 | scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) | |
84 | } | |
85 | scope.Err(scope.NewDB().Save(elem).Error) | |
86 | } | |
87 | } | |
88 | } | |
89 | } | |
90 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "reflect" | |
4 | "runtime" | |
5 | "strings" | |
6 | "testing" | |
7 | ) | |
8 | ||
9 | func equalFuncs(funcs []*func(s *Scope), fnames []string) bool { | |
10 | var names []string | |
11 | for _, f := range funcs { | |
12 | fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".") | |
13 | names = append(names, fnames[len(fnames)-1]) | |
14 | } | |
15 | return reflect.DeepEqual(names, fnames) | |
16 | } | |
17 | ||
18 | func create(s *Scope) {} | |
19 | func beforeCreate1(s *Scope) {} | |
20 | func beforeCreate2(s *Scope) {} | |
21 | func afterCreate1(s *Scope) {} | |
22 | func afterCreate2(s *Scope) {} | |
23 | ||
24 | func TestRegisterCallback(t *testing.T) { | |
25 | var callback = &callback{processors: []*callbackProcessor{}} | |
26 | ||
27 | callback.Create().Register("before_create1", beforeCreate1) | |
28 | callback.Create().Register("before_create2", beforeCreate2) | |
29 | callback.Create().Register("create", create) | |
30 | callback.Create().Register("after_create1", afterCreate1) | |
31 | callback.Create().Register("after_create2", afterCreate2) | |
32 | ||
33 | if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { | |
34 | t.Errorf("register callback") | |
35 | } | |
36 | } | |
37 | ||
38 | func TestRegisterCallbackWithOrder(t *testing.T) { | |
39 | var callback1 = &callback{processors: []*callbackProcessor{}} | |
40 | callback1.Create().Register("before_create1", beforeCreate1) | |
41 | callback1.Create().Register("create", create) | |
42 | callback1.Create().Register("after_create1", afterCreate1) | |
43 | callback1.Create().Before("after_create1").Register("after_create2", afterCreate2) | |
44 | if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { | |
45 | t.Errorf("register callback with order") | |
46 | } | |
47 | ||
48 | var callback2 = &callback{processors: []*callbackProcessor{}} | |
49 | ||
50 | callback2.Update().Register("create", create) | |
51 | callback2.Update().Before("create").Register("before_create1", beforeCreate1) | |
52 | callback2.Update().After("after_create2").Register("after_create1", afterCreate1) | |
53 | callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2) | |
54 | callback2.Update().Register("after_create2", afterCreate2) | |
55 | ||
56 | if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { | |
57 | t.Errorf("register callback with order") | |
58 | } | |
59 | } | |
60 | ||
61 | func TestRegisterCallbackWithComplexOrder(t *testing.T) { | |
62 | var callback1 = &callback{processors: []*callbackProcessor{}} | |
63 | ||
64 | callback1.Query().Before("after_create1").After("before_create1").Register("create", create) | |
65 | callback1.Query().Register("before_create1", beforeCreate1) | |
66 | callback1.Query().Register("after_create1", afterCreate1) | |
67 | ||
68 | if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) { | |
69 | t.Errorf("register callback with order") | |
70 | } | |
71 | ||
72 | var callback2 = &callback{processors: []*callbackProcessor{}} | |
73 | ||
74 | callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) | |
75 | callback2.Delete().Before("create").Register("before_create1", beforeCreate1) | |
76 | callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) | |
77 | callback2.Delete().Register("after_create1", afterCreate1) | |
78 | callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) | |
79 | ||
80 | if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { | |
81 | t.Errorf("register callback with order") | |
82 | } | |
83 | } | |
84 | ||
85 | func replaceCreate(s *Scope) {} | |
86 | ||
87 | func TestReplaceCallback(t *testing.T) { | |
88 | var callback = &callback{processors: []*callbackProcessor{}} | |
89 | ||
90 | callback.Create().Before("after_create1").After("before_create1").Register("create", create) | |
91 | callback.Create().Register("before_create1", beforeCreate1) | |
92 | callback.Create().Register("after_create1", afterCreate1) | |
93 | callback.Create().Replace("create", replaceCreate) | |
94 | ||
95 | if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) { | |
96 | t.Errorf("replace callback") | |
97 | } | |
98 | } | |
99 | ||
100 | func TestRemoveCallback(t *testing.T) { | |
101 | var callback = &callback{processors: []*callbackProcessor{}} | |
102 | ||
103 | callback.Create().Before("after_create1").After("before_create1").Register("create", create) | |
104 | callback.Create().Register("before_create1", beforeCreate1) | |
105 | callback.Create().Register("after_create1", afterCreate1) | |
106 | callback.Create().Remove("create") | |
107 | ||
108 | if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) { | |
109 | t.Errorf("remove callback") | |
110 | } | |
111 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "strings" | |
5 | ) | |
6 | ||
7 | func AssignUpdateAttributes(scope *Scope) { | |
8 | if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { | |
9 | if maps := convertInterfaceToMap(attrs); len(maps) > 0 { | |
10 | protected, ok := scope.Get("gorm:ignore_protected_attrs") | |
11 | _, updateColumn := scope.Get("gorm:update_column") | |
12 | updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool)) | |
13 | ||
14 | if updateColumn { | |
15 | scope.InstanceSet("gorm:update_attrs", maps) | |
16 | } else if len(updateAttrs) > 0 { | |
17 | scope.InstanceSet("gorm:update_attrs", updateAttrs) | |
18 | } else if !hasUpdate { | |
19 | scope.SkipLeft() | |
20 | return | |
21 | } | |
22 | } | |
23 | } | |
24 | } | |
25 | ||
26 | func BeforeUpdate(scope *Scope) { | |
27 | if _, ok := scope.Get("gorm:update_column"); !ok { | |
28 | scope.CallMethodWithErrorCheck("BeforeSave") | |
29 | scope.CallMethodWithErrorCheck("BeforeUpdate") | |
30 | } | |
31 | } | |
32 | ||
33 | func UpdateTimeStampWhenUpdate(scope *Scope) { | |
34 | if _, ok := scope.Get("gorm:update_column"); !ok { | |
35 | scope.SetColumn("UpdatedAt", NowFunc()) | |
36 | } | |
37 | } | |
38 | ||
39 | func Update(scope *Scope) { | |
40 | if !scope.HasError() { | |
41 | var sqls []string | |
42 | ||
43 | if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { | |
44 | for key, value := range updateAttrs.(map[string]interface{}) { | |
45 | if scope.changeableDBColumn(key) { | |
46 | sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value))) | |
47 | } | |
48 | } | |
49 | } else { | |
50 | fields := scope.Fields() | |
51 | for _, field := range fields { | |
52 | if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal { | |
53 | sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) | |
54 | } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { | |
55 | for _, dbName := range relationship.ForeignDBNames { | |
56 | if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank { | |
57 | sql := fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface())) | |
58 | sqls = append(sqls, sql) | |
59 | } | |
60 | } | |
61 | } | |
62 | } | |
63 | } | |
64 | ||
65 | if len(sqls) > 0 { | |
66 | scope.Raw(fmt.Sprintf( | |
67 | "UPDATE %v SET %v %v", | |
68 | scope.QuotedTableName(), | |
69 | strings.Join(sqls, ", "), | |
70 | scope.CombinedConditionSql(), | |
71 | )) | |
72 | scope.Exec() | |
73 | } | |
74 | } | |
75 | } | |
76 | ||
77 | func AfterUpdate(scope *Scope) { | |
78 | if _, ok := scope.Get("gorm:update_column"); !ok { | |
79 | scope.CallMethodWithErrorCheck("AfterUpdate") | |
80 | scope.CallMethodWithErrorCheck("AfterSave") | |
81 | } | |
82 | } | |
83 | ||
84 | func init() { | |
85 | DefaultCallback.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes) | |
86 | DefaultCallback.Update().Register("gorm:begin_transaction", BeginTransaction) | |
87 | DefaultCallback.Update().Register("gorm:before_update", BeforeUpdate) | |
88 | DefaultCallback.Update().Register("gorm:save_before_associations", SaveBeforeAssociations) | |
89 | DefaultCallback.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate) | |
90 | DefaultCallback.Update().Register("gorm:update", Update) | |
91 | DefaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations) | |
92 | DefaultCallback.Update().Register("gorm:after_update", AfterUpdate) | |
93 | DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | |
94 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | ||
5 | "github.com/jinzhu/gorm" | |
6 | ||
7 | "reflect" | |
8 | "testing" | |
9 | ) | |
10 | ||
11 | func (s *Product) BeforeCreate() (err error) { | |
12 | if s.Code == "Invalid" { | |
13 | err = errors.New("invalid product") | |
14 | } | |
15 | s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 | |
16 | return | |
17 | } | |
18 | ||
19 | func (s *Product) BeforeUpdate() (err error) { | |
20 | if s.Code == "dont_update" { | |
21 | err = errors.New("can't update") | |
22 | } | |
23 | s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 | |
24 | return | |
25 | } | |
26 | ||
27 | func (s *Product) BeforeSave() (err error) { | |
28 | if s.Code == "dont_save" { | |
29 | err = errors.New("can't save") | |
30 | } | |
31 | s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 | |
32 | return | |
33 | } | |
34 | ||
35 | func (s *Product) AfterFind() { | |
36 | s.AfterFindCallTimes = s.AfterFindCallTimes + 1 | |
37 | } | |
38 | ||
39 | func (s *Product) AfterCreate(tx *gorm.DB) { | |
40 | tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1}) | |
41 | } | |
42 | ||
43 | func (s *Product) AfterUpdate() { | |
44 | s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 | |
45 | } | |
46 | ||
47 | func (s *Product) AfterSave() (err error) { | |
48 | if s.Code == "after_save_error" { | |
49 | err = errors.New("can't save") | |
50 | } | |
51 | s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 | |
52 | return | |
53 | } | |
54 | ||
55 | func (s *Product) BeforeDelete() (err error) { | |
56 | if s.Code == "dont_delete" { | |
57 | err = errors.New("can't delete") | |
58 | } | |
59 | s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 | |
60 | return | |
61 | } | |
62 | ||
63 | func (s *Product) AfterDelete() (err error) { | |
64 | if s.Code == "after_delete_error" { | |
65 | err = errors.New("can't delete") | |
66 | } | |
67 | s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 | |
68 | return | |
69 | } | |
70 | ||
71 | func (s *Product) GetCallTimes() []int64 { | |
72 | return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} | |
73 | } | |
74 | ||
75 | func TestRunCallbacks(t *testing.T) { | |
76 | p := Product{Code: "unique_code", Price: 100} | |
77 | DB.Save(&p) | |
78 | ||
79 | if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { | |
80 | t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) | |
81 | } | |
82 | ||
83 | DB.Where("Code = ?", "unique_code").First(&p) | |
84 | if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { | |
85 | t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes()) | |
86 | } | |
87 | ||
88 | p.Price = 200 | |
89 | DB.Save(&p) | |
90 | if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { | |
91 | t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) | |
92 | } | |
93 | ||
94 | var products []Product | |
95 | DB.Find(&products, "code = ?", "unique_code") | |
96 | if products[0].AfterFindCallTimes != 2 { | |
97 | t.Errorf("AfterFind callbacks should work with slice") | |
98 | } | |
99 | ||
100 | DB.Where("Code = ?", "unique_code").First(&p) | |
101 | if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { | |
102 | t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes()) | |
103 | } | |
104 | ||
105 | DB.Delete(&p) | |
106 | if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { | |
107 | t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) | |
108 | } | |
109 | ||
110 | if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { | |
111 | t.Errorf("Can't find a deleted record") | |
112 | } | |
113 | } | |
114 | ||
115 | func TestCallbacksWithErrors(t *testing.T) { | |
116 | p := Product{Code: "Invalid", Price: 100} | |
117 | if DB.Save(&p).Error == nil { | |
118 | t.Errorf("An error from before create callbacks happened when create with invalid value") | |
119 | } | |
120 | ||
121 | if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { | |
122 | t.Errorf("Should not save record that have errors") | |
123 | } | |
124 | ||
125 | if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { | |
126 | t.Errorf("An error from after create callbacks happened when create with invalid value") | |
127 | } | |
128 | ||
129 | p2 := Product{Code: "update_callback", Price: 100} | |
130 | DB.Save(&p2) | |
131 | ||
132 | p2.Code = "dont_update" | |
133 | if DB.Save(&p2).Error == nil { | |
134 | t.Errorf("An error from before update callbacks happened when update with invalid value") | |
135 | } | |
136 | ||
137 | if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { | |
138 | t.Errorf("Record Should not be updated due to errors happened in before update callback") | |
139 | } | |
140 | ||
141 | if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { | |
142 | t.Errorf("Record Should not be updated due to errors happened in before update callback") | |
143 | } | |
144 | ||
145 | p2.Code = "dont_save" | |
146 | if DB.Save(&p2).Error == nil { | |
147 | t.Errorf("An error from before save callbacks happened when update with invalid value") | |
148 | } | |
149 | ||
150 | p3 := Product{Code: "dont_delete", Price: 100} | |
151 | DB.Save(&p3) | |
152 | if DB.Delete(&p3).Error == nil { | |
153 | t.Errorf("An error from before delete callbacks happened when delete") | |
154 | } | |
155 | ||
156 | if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { | |
157 | t.Errorf("An error from before delete callbacks happened") | |
158 | } | |
159 | ||
160 | p4 := Product{Code: "after_save_error", Price: 100} | |
161 | DB.Save(&p4) | |
162 | if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { | |
163 | t.Errorf("Record should be reverted if get an error in after save callback") | |
164 | } | |
165 | ||
166 | p5 := Product{Code: "after_delete_error", Price: 100} | |
167 | DB.Save(&p5) | |
168 | if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { | |
169 | t.Errorf("Record should be found") | |
170 | } | |
171 | ||
172 | DB.Delete(&p5) | |
173 | if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { | |
174 | t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") | |
175 | } | |
176 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "reflect" | |
5 | "time" | |
6 | ) | |
7 | ||
8 | type commonDialect struct{} | |
9 | ||
10 | func (commonDialect) BinVar(i int) string { | |
11 | return "$$" // ? | |
12 | } | |
13 | ||
14 | func (commonDialect) SupportLastInsertId() bool { | |
15 | return true | |
16 | } | |
17 | ||
18 | func (commonDialect) HasTop() bool { | |
19 | return false | |
20 | } | |
21 | ||
22 | func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string { | |
23 | switch value.Kind() { | |
24 | case reflect.Bool: | |
25 | return "BOOLEAN" | |
26 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | |
27 | if autoIncrease { | |
28 | return "INTEGER AUTO_INCREMENT" | |
29 | } | |
30 | return "INTEGER" | |
31 | case reflect.Int64, reflect.Uint64: | |
32 | if autoIncrease { | |
33 | return "BIGINT AUTO_INCREMENT" | |
34 | } | |
35 | return "BIGINT" | |
36 | case reflect.Float32, reflect.Float64: | |
37 | return "FLOAT" | |
38 | case reflect.String: | |
39 | if size > 0 && size < 65532 { | |
40 | return fmt.Sprintf("VARCHAR(%d)", size) | |
41 | } | |
42 | return "VARCHAR(65532)" | |
43 | case reflect.Struct: | |
44 | if _, ok := value.Interface().(time.Time); ok { | |
45 | return "TIMESTAMP" | |
46 | } | |
47 | default: | |
48 | if _, ok := value.Interface().([]byte); ok { | |
49 | if size > 0 && size < 65532 { | |
50 | return fmt.Sprintf("BINARY(%d)", size) | |
51 | } | |
52 | return "BINARY(65532)" | |
53 | } | |
54 | } | |
55 | panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String())) | |
56 | } | |
57 | ||
58 | func (commonDialect) ReturningStr(tableName, key string) string { | |
59 | return "" | |
60 | } | |
61 | ||
62 | func (commonDialect) SelectFromDummyTable() string { | |
63 | return "" | |
64 | } | |
65 | ||
66 | func (commonDialect) Quote(key string) string { | |
67 | return fmt.Sprintf(`"%s"`, key) | |
68 | } | |
69 | ||
70 | func (c commonDialect) HasTable(scope *Scope, tableName string) bool { | |
71 | var ( | |
72 | count int | |
73 | databaseName = c.CurrentDatabase(scope) | |
74 | ) | |
75 | c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", databaseName, tableName) | |
76 | return count > 0 | |
77 | } | |
78 | ||
79 | func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { | |
80 | var ( | |
81 | count int | |
82 | databaseName = c.CurrentDatabase(scope) | |
83 | ) | |
84 | c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName) | |
85 | return count > 0 | |
86 | } | |
87 | ||
88 | func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { | |
89 | var ( | |
90 | count int | |
91 | databaseName = c.CurrentDatabase(scope) | |
92 | ) | |
93 | c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName) | |
94 | return count > 0 | |
95 | } | |
96 | ||
97 | func (commonDialect) RemoveIndex(scope *Scope, indexName string) { | |
98 | scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error) | |
99 | } | |
100 | ||
101 | // RawScanInt scans the first column of the first row into the `scan' int pointer. | |
102 | // This function captures raw query errors and propagates them to the original scope. | |
103 | func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) { | |
104 | scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr)) | |
105 | } | |
106 | ||
107 | // RawScanString scans the first column of the first row into the `scan' string pointer. | |
108 | // This function captures raw query errors and propagates them to the original scope. | |
109 | func (commonDialect) RawScanString(scope *Scope, scanPtr *string, query string, args ...interface{}) { | |
110 | scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr)) | |
111 | } | |
112 | ||
113 | func (commonDialect) CurrentDatabase(scope *Scope) (name string) { | |
114 | scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name)) | |
115 | return | |
116 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "reflect" | |
4 | "testing" | |
5 | "time" | |
6 | ) | |
7 | ||
8 | func TestCreate(t *testing.T) { | |
9 | float := 35.03554004971999 | |
10 | user := User{Name: "CreateUser", Age: 18, Birthday: time.Now(), UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float} | |
11 | ||
12 | if !DB.NewRecord(user) || !DB.NewRecord(&user) { | |
13 | t.Error("User should be new record before create") | |
14 | } | |
15 | ||
16 | if count := DB.Save(&user).RowsAffected; count != 1 { | |
17 | t.Error("There should be one record be affected when create record") | |
18 | } | |
19 | ||
20 | if DB.NewRecord(user) || DB.NewRecord(&user) { | |
21 | t.Error("User should not new record after save") | |
22 | } | |
23 | ||
24 | var newUser User | |
25 | DB.First(&newUser, user.Id) | |
26 | ||
27 | if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) { | |
28 | t.Errorf("User's PasswordHash should be saved ([]byte)") | |
29 | } | |
30 | ||
31 | if newUser.Age != 18 { | |
32 | t.Errorf("User's Age should be saved (int)") | |
33 | } | |
34 | ||
35 | if newUser.UserNum != Num(111) { | |
36 | t.Errorf("User's UserNum should be saved (custom type)") | |
37 | } | |
38 | ||
39 | if newUser.Latitude != float { | |
40 | t.Errorf("Float64 should not be changed after save") | |
41 | } | |
42 | ||
43 | if user.CreatedAt.IsZero() { | |
44 | t.Errorf("Should have created_at after create") | |
45 | } | |
46 | ||
47 | if newUser.CreatedAt.IsZero() { | |
48 | t.Errorf("Should have created_at after create") | |
49 | } | |
50 | ||
51 | DB.Model(user).Update("name", "create_user_new_name") | |
52 | DB.First(&user, user.Id) | |
53 | if user.CreatedAt != newUser.CreatedAt { | |
54 | t.Errorf("CreatedAt should not be changed after update") | |
55 | } | |
56 | } | |
57 | ||
58 | func TestCreateWithNoGORMPrimayKey(t *testing.T) { | |
59 | jt := JoinTable{From: 1, To: 2} | |
60 | err := DB.Create(&jt).Error | |
61 | if err != nil { | |
62 | t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) | |
63 | } | |
64 | } | |
65 | ||
66 | func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { | |
67 | animal := Animal{Name: "Ferdinand"} | |
68 | if DB.Save(&animal).Error != nil { | |
69 | t.Errorf("No error should happen when create a record without std primary key") | |
70 | } | |
71 | ||
72 | if animal.Counter == 0 { | |
73 | t.Errorf("No std primary key should be filled value after create") | |
74 | } | |
75 | ||
76 | if animal.Name != "Ferdinand" { | |
77 | t.Errorf("Default value should be overrided") | |
78 | } | |
79 | ||
80 | // Test create with default value not overrided | |
81 | an := Animal{From: "nerdz"} | |
82 | ||
83 | if DB.Save(&an).Error != nil { | |
84 | t.Errorf("No error should happen when create an record without std primary key") | |
85 | } | |
86 | ||
87 | // We must fetch the value again, to have the default fields updated | |
88 | // (We can't do this in the update statements, since sql default can be expressions | |
89 | // And be different from the fields' type (eg. a time.Time fiels has a default value of "now()" | |
90 | DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an) | |
91 | ||
92 | if an.Name != "galeone" { | |
93 | t.Errorf("Default value should fill the field. But got %v", an.Name) | |
94 | } | |
95 | } | |
96 | ||
97 | func TestAnonymousScanner(t *testing.T) { | |
98 | user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}} | |
99 | DB.Save(&user) | |
100 | ||
101 | var user2 User | |
102 | DB.First(&user2, "name = ?", "anonymous_scanner") | |
103 | if user2.Role.Name != "admin" { | |
104 | t.Errorf("Should be able to get anonymous scanner") | |
105 | } | |
106 | ||
107 | if !user2.IsAdmin() { | |
108 | t.Errorf("Should be able to get anonymous scanner") | |
109 | } | |
110 | } | |
111 | ||
112 | func TestAnonymousField(t *testing.T) { | |
113 | user := User{Name: "anonymous_field", Company: Company{Name: "company"}} | |
114 | DB.Save(&user) | |
115 | ||
116 | var user2 User | |
117 | DB.First(&user2, "name = ?", "anonymous_field") | |
118 | DB.Model(&user2).Related(&user2.Company) | |
119 | if user2.Company.Name != "company" { | |
120 | t.Errorf("Should be able to get anonymous field") | |
121 | } | |
122 | } | |
123 | ||
124 | func TestSelectWithCreate(t *testing.T) { | |
125 | user := getPreparedUser("select_user", "select_with_create") | |
126 | DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) | |
127 | ||
128 | var queryuser User | |
129 | DB.Preload("BillingAddress").Preload("ShippingAddress"). | |
130 | Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id) | |
131 | ||
132 | if queryuser.Name != user.Name || queryuser.Age == user.Age { | |
133 | t.Errorf("Should only create users with name column") | |
134 | } | |
135 | ||
136 | if queryuser.BillingAddressID.Int64 == 0 || queryuser.ShippingAddressId != 0 || | |
137 | queryuser.CreditCard.ID == 0 || len(queryuser.Emails) == 0 { | |
138 | t.Errorf("Should only create selected relationships") | |
139 | } | |
140 | } | |
141 | ||
142 | func TestOmitWithCreate(t *testing.T) { | |
143 | user := getPreparedUser("omit_user", "omit_with_create") | |
144 | DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) | |
145 | ||
146 | var queryuser User | |
147 | DB.Preload("BillingAddress").Preload("ShippingAddress"). | |
148 | Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id) | |
149 | ||
150 | if queryuser.Name == user.Name || queryuser.Age != user.Age { | |
151 | t.Errorf("Should only create users with age column") | |
152 | } | |
153 | ||
154 | if queryuser.BillingAddressID.Int64 != 0 || queryuser.ShippingAddressId == 0 || | |
155 | queryuser.CreditCard.ID != 0 || len(queryuser.Emails) != 0 { | |
156 | t.Errorf("Should not create omited relationships") | |
157 | } | |
158 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "testing" | |
4 | "time" | |
5 | ) | |
6 | ||
7 | type CustomizeColumn struct { | |
8 | ID int64 `gorm:"column:mapped_id; primary_key:yes"` | |
9 | Name string `gorm:"column:mapped_name"` | |
10 | Date time.Time `gorm:"column:mapped_time"` | |
11 | } | |
12 | ||
13 | // Make sure an ignored field does not interfere with another field's custom | |
14 | // column name that matches the ignored field. | |
15 | type CustomColumnAndIgnoredFieldClash struct { | |
16 | Body string `sql:"-"` | |
17 | RawBody string `gorm:"column:body"` | |
18 | } | |
19 | ||
20 | func TestCustomizeColumn(t *testing.T) { | |
21 | col := "mapped_name" | |
22 | DB.DropTable(&CustomizeColumn{}) | |
23 | DB.AutoMigrate(&CustomizeColumn{}) | |
24 | ||
25 | scope := DB.NewScope(&CustomizeColumn{}) | |
26 | if !scope.Dialect().HasColumn(scope, scope.TableName(), col) { | |
27 | t.Errorf("CustomizeColumn should have column %s", col) | |
28 | } | |
29 | ||
30 | col = "mapped_id" | |
31 | if scope.PrimaryKey() != col { | |
32 | t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey()) | |
33 | } | |
34 | ||
35 | expected := "foo" | |
36 | cc := CustomizeColumn{ID: 666, Name: expected, Date: time.Now()} | |
37 | ||
38 | if count := DB.Create(&cc).RowsAffected; count != 1 { | |
39 | t.Error("There should be one record be affected when create record") | |
40 | } | |
41 | ||
42 | var cc1 CustomizeColumn | |
43 | DB.First(&cc1, 666) | |
44 | ||
45 | if cc1.Name != expected { | |
46 | t.Errorf("Failed to query CustomizeColumn") | |
47 | } | |
48 | ||
49 | cc.Name = "bar" | |
50 | DB.Save(&cc) | |
51 | ||
52 | var cc2 CustomizeColumn | |
53 | DB.First(&cc2, 666) | |
54 | if cc2.Name != "bar" { | |
55 | t.Errorf("Failed to query CustomizeColumn") | |
56 | } | |
57 | } | |
58 | ||
59 | func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { | |
60 | DB.DropTable(&CustomColumnAndIgnoredFieldClash{}) | |
61 | if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil { | |
62 | t.Errorf("Should not raise error: %s", err) | |
63 | } | |
64 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "testing" | |
4 | ) | |
5 | ||
6 | func TestDdlErrors(t *testing.T) { | |
7 | var err error | |
8 | ||
9 | if err = DB.Close(); err != nil { | |
10 | t.Errorf("Closing DDL test db connection err=%s", err) | |
11 | } | |
12 | defer func() { | |
13 | // Reopen DB connection. | |
14 | if DB, err = OpenTestConnection(); err != nil { | |
15 | t.Fatalf("Failed re-opening db connection: %s", err) | |
16 | } | |
17 | }() | |
18 | ||
19 | DB.HasTable("foobarbaz") | |
20 | if DB.Error == nil { | |
21 | t.Errorf("Expected operation on closed db to produce an error, but err was nil") | |
22 | } | |
23 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "testing" | |
4 | "time" | |
5 | ) | |
6 | ||
7 | func TestDelete(t *testing.T) { | |
8 | user1, user2 := User{Name: "delete1"}, User{Name: "delete2"} | |
9 | DB.Save(&user1) | |
10 | DB.Save(&user2) | |
11 | ||
12 | if err := DB.Delete(&user1).Error; err != nil { | |
13 | t.Errorf("No error should happen when delete a record, err=%s", err) | |
14 | } | |
15 | ||
16 | if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { | |
17 | t.Errorf("User can't be found after delete") | |
18 | } | |
19 | ||
20 | if DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { | |
21 | t.Errorf("Other users that not deleted should be found-able") | |
22 | } | |
23 | } | |
24 | ||
25 | func TestInlineDelete(t *testing.T) { | |
26 | user1, user2 := User{Name: "inline_delete1"}, User{Name: "inline_delete2"} | |
27 | DB.Save(&user1) | |
28 | DB.Save(&user2) | |
29 | ||
30 | if DB.Delete(&User{}, user1.Id).Error != nil { | |
31 | t.Errorf("No error should happen when delete a record") | |
32 | } else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { | |
33 | t.Errorf("User can't be found after delete") | |
34 | } | |
35 | ||
36 | if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil { | |
37 | t.Errorf("No error should happen when delete a record, err=%s", err) | |
38 | } else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { | |
39 | t.Errorf("User can't be found after delete") | |
40 | } | |
41 | } | |
42 | ||
43 | func TestSoftDelete(t *testing.T) { | |
44 | type User struct { | |
45 | Id int64 | |
46 | Name string | |
47 | DeletedAt time.Time | |
48 | } | |
49 | DB.AutoMigrate(&User{}) | |
50 | ||
51 | user := User{Name: "soft_delete"} | |
52 | DB.Save(&user) | |
53 | DB.Delete(&user) | |
54 | ||
55 | if DB.First(&User{}, "name = ?", user.Name).Error == nil { | |
56 | t.Errorf("Can't find a soft deleted record") | |
57 | } | |
58 | ||
59 | if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { | |
60 | t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err) | |
61 | } | |
62 | ||
63 | DB.Unscoped().Delete(&user) | |
64 | if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() { | |
65 | t.Errorf("Can't find permanently deleted record") | |
66 | } | |
67 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "reflect" | |
5 | ) | |
6 | ||
7 | type Dialect interface { | |
8 | BinVar(i int) string | |
9 | SupportLastInsertId() bool | |
10 | HasTop() bool | |
11 | SqlTag(value reflect.Value, size int, autoIncrease bool) string | |
12 | ReturningStr(tableName, key string) string | |
13 | SelectFromDummyTable() string | |
14 | Quote(key string) string | |
15 | HasTable(scope *Scope, tableName string) bool | |
16 | HasColumn(scope *Scope, tableName string, columnName string) bool | |
17 | HasIndex(scope *Scope, tableName string, indexName string) bool | |
18 | RemoveIndex(scope *Scope, indexName string) | |
19 | CurrentDatabase(scope *Scope) string | |
20 | } | |
21 | ||
22 | func NewDialect(driver string) Dialect { | |
23 | var d Dialect | |
24 | switch driver { | |
25 | case "postgres": | |
26 | d = &postgres{} | |
27 | case "foundation": | |
28 | d = &foundation{} | |
29 | case "mysql": | |
30 | d = &mysql{} | |
31 | case "sqlite3": | |
32 | d = &sqlite3{} | |
33 | case "mssql": | |
34 | d = &mssql{} | |
35 | default: | |
36 | fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver) | |
37 | d = &commonDialect{} | |
38 | } | |
39 | return d | |
40 | } |
0 | # Gorm Development | |
1 | ||
2 | ## Architecture | |
3 | ||
4 | The most notable component of Gorm is`gorm.DB`, which hold database connection. It could be initialized like this: | |
5 | ||
6 | db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") | |
7 | ||
8 | Gorm has chainable API, `gorm.DB` is the bridge of chains, it save related information and pass it to the next chain. | |
9 | ||
10 | Lets use below code to explain how it works: | |
11 | ||
12 | db.Where("name = ?", "jinzhu").Find(&users) | |
13 | ||
14 | // equivalent code | |
15 | newdb := db.Where("name =?", "jinzhu") | |
16 | newdb.Find(&user) | |
17 | ||
18 | `newdb` is `db`'s clone, in addition, it contains search conditions from the `Where` method. | |
19 | `Find` is a query method, it creates a `Scope` instance, and pass it as argument to query callbacks. | |
20 | ||
21 | There are four kinds of callbacks corresponds to sql's CURD: create callbacks, update callbacks, query callbacks, delete callbacks. | |
22 | ||
23 | ## Callbacks | |
24 | ||
25 | ### Register a new callback | |
26 | ||
27 | func updateCreated(scope *Scope) { | |
28 | if scope.HasColumn("Created") { | |
29 | scope.SetColumn("Created", NowFunc()) | |
30 | } | |
31 | } | |
32 | ||
33 | db.Callback().Create().Register("update_created_at", updateCreated) | |
34 | // register a callback for Create process | |
35 | ||
36 | ### Delete an existing callback | |
37 | ||
38 | db.Callback().Create().Remove("gorm:create") | |
39 | // delete callback `gorm:create` from Create callbacks | |
40 | ||
41 | ### Replace an existing callback | |
42 | ||
43 | db.Callback().Create().Replace("gorm:create", newCreateFunction) | |
44 | // replace callback `gorm:create` with new function `newCreateFunction` for Create process | |
45 | ||
46 | ### Register callback orders | |
47 | ||
48 | db.Callback().Create().Before("gorm:create").Register("update_created_at", updateCreated) | |
49 | db.Callback().Create().After("gorm:create").Register("update_created_at", updateCreated) | |
50 | db.Callback().Query().After("gorm:query").Register("my_plugin:after_query", afterQuery) | |
51 | db.Callback().Delete().After("gorm:delete").Register("my_plugin:after_delete", afterDelete) | |
52 | db.Callback().Update().Before("gorm:update").Register("my_plugin:before_update", beforeUpdate) | |
53 | db.Callback().Create().Before("gorm:create").After("gorm:before_create").Register("my_plugin:before_create", beforeCreate) | |
54 | ||
55 | ### Callback API | |
56 | ||
57 | Gorm is powered by callbacks, so you could refer below links to learn how to write callbacks | |
58 | ||
59 | [Create callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go) | |
60 | ||
61 | [Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go) | |
62 | ||
63 | [Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go) | |
64 | ||
65 | [Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go) | |
66 | ||
67 | View [https://github.com/jinzhu/gorm/blob/master/scope.go](https://github.com/jinzhu/gorm/blob/master/scope.go) for all available API |
0 | package gorm_test | |
1 | ||
2 | import "testing" | |
3 | ||
4 | type BasePost struct { | |
5 | Id int64 | |
6 | Title string | |
7 | URL string | |
8 | } | |
9 | ||
10 | type HNPost struct { | |
11 | BasePost | |
12 | Upvotes int32 | |
13 | } | |
14 | ||
15 | type EngadgetPost struct { | |
16 | BasePost BasePost `gorm:"embedded"` | |
17 | ImageUrl string | |
18 | } | |
19 | ||
20 | func TestSaveAndQueryEmbeddedStruct(t *testing.T) { | |
21 | DB.Save(&HNPost{BasePost: BasePost{Title: "news"}}) | |
22 | DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) | |
23 | var news HNPost | |
24 | if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { | |
25 | t.Errorf("no error should happen when query with embedded struct, but got %v", err) | |
26 | } else if news.Title != "hn_news" { | |
27 | t.Errorf("embedded struct's value should be scanned correctly") | |
28 | } | |
29 | ||
30 | DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) | |
31 | var egNews EngadgetPost | |
32 | if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { | |
33 | t.Errorf("no error should happen when query with embedded struct, but got %v", err) | |
34 | } else if egNews.BasePost.Title != "engadget_news" { | |
35 | t.Errorf("embedded struct's value should be scanned correctly") | |
36 | } | |
37 | ||
38 | if DB.NewScope(&HNPost{}).PrimaryField() == nil { | |
39 | t.Errorf("primary key with embedded struct should works") | |
40 | } | |
41 | ||
42 | for _, field := range DB.NewScope(&HNPost{}).Fields() { | |
43 | if field.Name == "BasePost" { | |
44 | t.Errorf("scope Fields should not contain embedded struct") | |
45 | } | |
46 | } | |
47 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "strings" | |
5 | ) | |
6 | ||
7 | var ( | |
8 | RecordNotFound = errors.New("record not found") | |
9 | InvalidSql = errors.New("invalid sql") | |
10 | NoNewAttrs = errors.New("no new attributes") | |
11 | NoValidTransaction = errors.New("no valid transaction") | |
12 | CantStartTransaction = errors.New("can't start transaction") | |
13 | ) | |
14 | ||
15 | type errorsInterface interface { | |
16 | GetErrors() []error | |
17 | } | |
18 | ||
19 | type Errors struct { | |
20 | errors []error | |
21 | } | |
22 | ||
23 | func (errs Errors) GetErrors() []error { | |
24 | return errs.errors | |
25 | } | |
26 | ||
27 | func (errs *Errors) Add(err error) { | |
28 | if errors, ok := err.(errorsInterface); ok { | |
29 | for _, err := range errors.GetErrors() { | |
30 | errs.Add(err) | |
31 | } | |
32 | } else { | |
33 | for _, e := range errs.errors { | |
34 | if err == e { | |
35 | return | |
36 | } | |
37 | } | |
38 | errs.errors = append(errs.errors, err) | |
39 | } | |
40 | } | |
41 | ||
42 | func (errs Errors) Error() string { | |
43 | var errors = []string{} | |
44 | for _, e := range errs.errors { | |
45 | errors = append(errors, e.Error()) | |
46 | } | |
47 | return strings.Join(errors, "; ") | |
48 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "database/sql" | |
4 | "errors" | |
5 | "reflect" | |
6 | ) | |
7 | ||
8 | type Field struct { | |
9 | *StructField | |
10 | IsBlank bool | |
11 | Field reflect.Value | |
12 | } | |
13 | ||
14 | func (field *Field) Set(value interface{}) error { | |
15 | if !field.Field.IsValid() { | |
16 | return errors.New("field value not valid") | |
17 | } | |
18 | ||
19 | if !field.Field.CanAddr() { | |
20 | return errors.New("unaddressable value") | |
21 | } | |
22 | ||
23 | if rvalue, ok := value.(reflect.Value); ok { | |
24 | value = rvalue.Interface() | |
25 | } | |
26 | ||
27 | if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok { | |
28 | if v, ok := value.(reflect.Value); ok { | |
29 | if err := scanner.Scan(v.Interface()); err != nil { | |
30 | return err | |
31 | } | |
32 | } else { | |
33 | if err := scanner.Scan(value); err != nil { | |
34 | return err | |
35 | } | |
36 | } | |
37 | } else { | |
38 | reflectValue, ok := value.(reflect.Value) | |
39 | if !ok { | |
40 | reflectValue = reflect.ValueOf(value) | |
41 | } | |
42 | ||
43 | if reflectValue.Type().ConvertibleTo(field.Field.Type()) { | |
44 | field.Field.Set(reflectValue.Convert(field.Field.Type())) | |
45 | } else { | |
46 | return errors.New("could not convert argument") | |
47 | } | |
48 | } | |
49 | ||
50 | field.IsBlank = isBlank(field.Field) | |
51 | return nil | |
52 | } | |
53 | ||
54 | // Fields get value's fields | |
55 | func (scope *Scope) Fields() map[string]*Field { | |
56 | if scope.fields == nil { | |
57 | fields := map[string]*Field{} | |
58 | modelStruct := scope.GetModelStruct() | |
59 | ||
60 | indirectValue := scope.IndirectValue() | |
61 | isStruct := indirectValue.Kind() == reflect.Struct | |
62 | for _, structField := range modelStruct.StructFields { | |
63 | if isStruct { | |
64 | fields[structField.DBName] = getField(indirectValue, structField) | |
65 | } else { | |
66 | fields[structField.DBName] = &Field{StructField: structField, IsBlank: true} | |
67 | } | |
68 | } | |
69 | ||
70 | if modelStruct.cached { | |
71 | scope.fields = fields | |
72 | } | |
73 | return fields | |
74 | } | |
75 | return scope.fields | |
76 | } | |
77 | ||
78 | func getField(indirectValue reflect.Value, structField *StructField) *Field { | |
79 | field := &Field{StructField: structField} | |
80 | for _, name := range structField.Names { | |
81 | indirectValue = reflect.Indirect(indirectValue).FieldByName(name) | |
82 | } | |
83 | field.Field = indirectValue | |
84 | field.IsBlank = isBlank(indirectValue) | |
85 | return field | |
86 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "testing" | |
4 | ||
5 | "github.com/jinzhu/gorm" | |
6 | ) | |
7 | ||
8 | type CalculateField struct { | |
9 | gorm.Model | |
10 | Name string | |
11 | Children []CalculateFieldChild | |
12 | Category CalculateFieldCategory | |
13 | } | |
14 | ||
15 | type CalculateFieldChild struct { | |
16 | gorm.Model | |
17 | CalculateFieldID uint | |
18 | Name string | |
19 | } | |
20 | ||
21 | type CalculateFieldCategory struct { | |
22 | gorm.Model | |
23 | CalculateFieldID uint | |
24 | Name string | |
25 | } | |
26 | ||
27 | func TestCalculateField(t *testing.T) { | |
28 | var field CalculateField | |
29 | fields := DB.NewScope(&field).Fields() | |
30 | if fields["children"].Relationship == nil || fields["category"].Relationship == nil { | |
31 | t.Errorf("Should calculate fields correctly for the first time") | |
32 | } | |
33 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "reflect" | |
5 | "time" | |
6 | ) | |
7 | ||
8 | type foundation struct { | |
9 | commonDialect | |
10 | } | |
11 | ||
12 | func (foundation) BinVar(i int) string { | |
13 | return fmt.Sprintf("$%v", i) | |
14 | } | |
15 | ||
16 | func (foundation) SupportLastInsertId() bool { | |
17 | return false | |
18 | } | |
19 | ||
20 | func (foundation) SqlTag(value reflect.Value, size int, autoIncrease bool) string { | |
21 | switch value.Kind() { | |
22 | case reflect.Bool: | |
23 | return "boolean" | |
24 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | |
25 | if autoIncrease { | |
26 | return "serial" | |
27 | } | |
28 | return "int" | |
29 | case reflect.Int64, reflect.Uint64: | |
30 | if autoIncrease { | |
31 | return "bigserial" | |
32 | } | |
33 | return "bigint" | |
34 | case reflect.Float32, reflect.Float64: | |
35 | return "double" | |
36 | case reflect.String: | |
37 | if size > 0 && size < 65532 { | |
38 | return fmt.Sprintf("varchar(%d)", size) | |
39 | } | |
40 | return "clob" | |
41 | case reflect.Struct: | |
42 | if _, ok := value.Interface().(time.Time); ok { | |
43 | return "datetime" | |
44 | } | |
45 | default: | |
46 | if _, ok := value.Interface().([]byte); ok { | |
47 | return "blob" | |
48 | } | |
49 | } | |
50 | panic(fmt.Sprintf("invalid sql type %s (%s) for foundation", value.Type().Name(), value.Kind().String())) | |
51 | } | |
52 | ||
53 | func (s foundation) ReturningStr(tableName, key string) string { | |
54 | return fmt.Sprintf("RETURNING %v.%v", tableName, key) | |
55 | } | |
56 | ||
57 | func (s foundation) HasTable(scope *Scope, tableName string) bool { | |
58 | var count int | |
59 | s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName) | |
60 | return count > 0 | |
61 | } | |
62 | ||
63 | func (s foundation) HasColumn(scope *Scope, tableName string, columnName string) bool { | |
64 | var count int | |
65 | s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName) | |
66 | return count > 0 | |
67 | } | |
68 | ||
69 | func (s foundation) RemoveIndex(scope *Scope, indexName string) { | |
70 | scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", s.Quote(indexName))) | |
71 | } | |
72 | ||
73 | func (s foundation) HasIndex(scope *Scope, tableName string, indexName string) bool { | |
74 | var count int | |
75 | s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName) | |
76 | return count > 0 | |
77 | } | |
78 | ||
79 | func (s foundation) CurrentDatabase(scope *Scope) (name string) { | |
80 | s.RawScanString(scope, &name, "SELECT CURRENT_SCHEMA") | |
81 | return | |
82 | } |
Binary diff not shown
0 | package gorm | |
1 | ||
2 | import "database/sql" | |
3 | ||
4 | type sqlCommon interface { | |
5 | Exec(query string, args ...interface{}) (sql.Result, error) | |
6 | Prepare(query string) (*sql.Stmt, error) | |
7 | Query(query string, args ...interface{}) (*sql.Rows, error) | |
8 | QueryRow(query string, args ...interface{}) *sql.Row | |
9 | } | |
10 | ||
11 | type sqlDb interface { | |
12 | Begin() (*sql.Tx, error) | |
13 | } | |
14 | ||
15 | type sqlTx interface { | |
16 | Commit() error | |
17 | Rollback() error | |
18 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "fmt" | |
5 | "reflect" | |
6 | "strings" | |
7 | ) | |
8 | ||
9 | type JoinTableHandlerInterface interface { | |
10 | Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) | |
11 | Table(db *DB) string | |
12 | Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error | |
13 | Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error | |
14 | JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB | |
15 | SourceForeignKeys() []JoinTableForeignKey | |
16 | DestinationForeignKeys() []JoinTableForeignKey | |
17 | } | |
18 | ||
19 | type JoinTableForeignKey struct { | |
20 | DBName string | |
21 | AssociationDBName string | |
22 | } | |
23 | ||
24 | type JoinTableSource struct { | |
25 | ModelType reflect.Type | |
26 | ForeignKeys []JoinTableForeignKey | |
27 | } | |
28 | ||
29 | type JoinTableHandler struct { | |
30 | TableName string `sql:"-"` | |
31 | Source JoinTableSource `sql:"-"` | |
32 | Destination JoinTableSource `sql:"-"` | |
33 | } | |
34 | ||
35 | func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { | |
36 | return s.Source.ForeignKeys | |
37 | } | |
38 | ||
39 | func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { | |
40 | return s.Destination.ForeignKeys | |
41 | } | |
42 | ||
43 | func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { | |
44 | s.TableName = tableName | |
45 | ||
46 | s.Source = JoinTableSource{ModelType: source} | |
47 | for idx, dbName := range relationship.ForeignFieldNames { | |
48 | s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ | |
49 | DBName: relationship.ForeignDBNames[idx], | |
50 | AssociationDBName: dbName, | |
51 | }) | |
52 | } | |
53 | ||
54 | s.Destination = JoinTableSource{ModelType: destination} | |
55 | for idx, dbName := range relationship.AssociationForeignFieldNames { | |
56 | s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ | |
57 | DBName: relationship.AssociationForeignDBNames[idx], | |
58 | AssociationDBName: dbName, | |
59 | }) | |
60 | } | |
61 | } | |
62 | ||
63 | func (s JoinTableHandler) Table(db *DB) string { | |
64 | return s.TableName | |
65 | } | |
66 | ||
67 | func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} { | |
68 | values := map[string]interface{}{} | |
69 | ||
70 | for _, source := range sources { | |
71 | scope := db.NewScope(source) | |
72 | modelType := scope.GetModelStruct().ModelType | |
73 | ||
74 | if s.Source.ModelType == modelType { | |
75 | for _, foreignKey := range s.Source.ForeignKeys { | |
76 | values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface() | |
77 | } | |
78 | } else if s.Destination.ModelType == modelType { | |
79 | for _, foreignKey := range s.Destination.ForeignKeys { | |
80 | values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface() | |
81 | } | |
82 | } | |
83 | } | |
84 | return values | |
85 | } | |
86 | ||
87 | func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error { | |
88 | scope := db.NewScope("") | |
89 | searchMap := s.GetSearchMap(db, source1, source2) | |
90 | ||
91 | var assignColumns, binVars, conditions []string | |
92 | var values []interface{} | |
93 | for key, value := range searchMap { | |
94 | assignColumns = append(assignColumns, key) | |
95 | binVars = append(binVars, `?`) | |
96 | conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) | |
97 | values = append(values, value) | |
98 | } | |
99 | ||
100 | for _, value := range values { | |
101 | values = append(values, value) | |
102 | } | |
103 | ||
104 | quotedTable := handler.Table(db) | |
105 | sql := fmt.Sprintf( | |
106 | "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", | |
107 | quotedTable, | |
108 | strings.Join(assignColumns, ","), | |
109 | strings.Join(binVars, ","), | |
110 | scope.Dialect().SelectFromDummyTable(), | |
111 | quotedTable, | |
112 | strings.Join(conditions, " AND "), | |
113 | ) | |
114 | ||
115 | return db.Exec(sql, values...).Error | |
116 | } | |
117 | ||
118 | func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { | |
119 | var conditions []string | |
120 | var values []interface{} | |
121 | ||
122 | for key, value := range s.GetSearchMap(db, sources...) { | |
123 | conditions = append(conditions, fmt.Sprintf("%v = ?", key)) | |
124 | values = append(values, value) | |
125 | } | |
126 | ||
127 | return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error | |
128 | } | |
129 | ||
130 | func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { | |
131 | quotedTable := handler.Table(db) | |
132 | ||
133 | scope := db.NewScope(source) | |
134 | modelType := scope.GetModelStruct().ModelType | |
135 | var joinConditions []string | |
136 | var queryConditions []string | |
137 | var values []interface{} | |
138 | if s.Source.ModelType == modelType { | |
139 | destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() | |
140 | for _, foreignKey := range s.Destination.ForeignKeys { | |
141 | joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) | |
142 | } | |
143 | ||
144 | var foreignDBNames []string | |
145 | var foreignFieldNames []string | |
146 | ||
147 | for _, foreignKey := range s.Source.ForeignKeys { | |
148 | foreignDBNames = append(foreignDBNames, foreignKey.DBName) | |
149 | foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name) | |
150 | } | |
151 | ||
152 | foreignFieldValues := scope.getColumnAsArray(foreignFieldNames) | |
153 | ||
154 | condString := fmt.Sprintf("%v in (%v)", toQueryCondition(scope, foreignDBNames), toQueryMarks(foreignFieldValues)) | |
155 | ||
156 | keys := scope.getColumnAsArray(foreignFieldNames) | |
157 | values = append(values, toQueryValues(keys)) | |
158 | ||
159 | queryConditions = append(queryConditions, condString) | |
160 | ||
161 | return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))). | |
162 | Where(condString, toQueryValues(foreignFieldValues)...) | |
163 | } else { | |
164 | db.Error = errors.New("wrong source type for join table handler") | |
165 | return db | |
166 | } | |
167 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "testing" | |
5 | "time" | |
6 | ||
7 | "github.com/jinzhu/gorm" | |
8 | ) | |
9 | ||
10 | type Person struct { | |
11 | Id int | |
12 | Name string | |
13 | Addresses []*Address `gorm:"many2many:person_addresses;"` | |
14 | } | |
15 | ||
16 | type PersonAddress struct { | |
17 | gorm.JoinTableHandler | |
18 | PersonID int | |
19 | AddressID int | |
20 | DeletedAt time.Time | |
21 | CreatedAt time.Time | |
22 | } | |
23 | ||
24 | func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { | |
25 | return db.Where(map[string]interface{}{ | |
26 | "person_id": db.NewScope(foreignValue).PrimaryKeyValue(), | |
27 | "address_id": db.NewScope(associationValue).PrimaryKeyValue(), | |
28 | }).Assign(map[string]interface{}{ | |
29 | "person_id": foreignValue, | |
30 | "address_id": associationValue, | |
31 | "deleted_at": gorm.Expr("NULL"), | |
32 | }).FirstOrCreate(&PersonAddress{}).Error | |
33 | } | |
34 | ||
35 | func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error { | |
36 | return db.Delete(&PersonAddress{}).Error | |
37 | } | |
38 | ||
39 | func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { | |
40 | table := pa.Table(db) | |
41 | return db.Table(table).Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) | |
42 | } | |
43 | ||
44 | func TestJoinTable(t *testing.T) { | |
45 | DB.Exec("drop table person_addresses;") | |
46 | DB.AutoMigrate(&Person{}) | |
47 | DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{}) | |
48 | ||
49 | address1 := &Address{Address1: "address 1"} | |
50 | address2 := &Address{Address1: "address 2"} | |
51 | person := &Person{Name: "person", Addresses: []*Address{address1, address2}} | |
52 | DB.Save(person) | |
53 | ||
54 | DB.Model(person).Association("Addresses").Delete(address1) | |
55 | ||
56 | if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 1 { | |
57 | t.Errorf("Should found one address") | |
58 | } | |
59 | ||
60 | if DB.Model(person).Association("Addresses").Count() != 1 { | |
61 | t.Errorf("Should found one address") | |
62 | } | |
63 | ||
64 | if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 2 { | |
65 | t.Errorf("Found two addresses with Unscoped") | |
66 | } | |
67 | ||
68 | if DB.Model(person).Association("Addresses").Clear(); DB.Model(person).Association("Addresses").Count() != 0 { | |
69 | t.Errorf("Should deleted all addresses") | |
70 | } | |
71 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "database/sql/driver" | |
4 | "fmt" | |
5 | "log" | |
6 | "os" | |
7 | "reflect" | |
8 | "regexp" | |
9 | "time" | |
10 | ) | |
11 | ||
12 | type logger interface { | |
13 | Print(v ...interface{}) | |
14 | } | |
15 | ||
16 | type LogWriter interface { | |
17 | Println(v ...interface{}) | |
18 | } | |
19 | ||
20 | type Logger struct { | |
21 | LogWriter | |
22 | } | |
23 | ||
24 | var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} | |
25 | ||
26 | // Format log | |
27 | var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`) | |
28 | ||
29 | func (logger Logger) Print(values ...interface{}) { | |
30 | if len(values) > 1 { | |
31 | level := values[0] | |
32 | currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m" | |
33 | source := fmt.Sprintf("\033[35m(%v)\033[0m", values[1]) | |
34 | messages := []interface{}{source, currentTime} | |
35 | ||
36 | if level == "sql" { | |
37 | // duration | |
38 | messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) | |
39 | // sql | |
40 | var formatedValues []interface{} | |
41 | for _, value := range values[4].([]interface{}) { | |
42 | indirectValue := reflect.Indirect(reflect.ValueOf(value)) | |
43 | if indirectValue.IsValid() { | |
44 | value = indirectValue.Interface() | |
45 | if t, ok := value.(time.Time); ok { | |
46 | formatedValues = append(formatedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339))) | |
47 | } else if b, ok := value.([]byte); ok { | |
48 | formatedValues = append(formatedValues, fmt.Sprintf("'%v'", string(b))) | |
49 | } else if r, ok := value.(driver.Valuer); ok { | |
50 | if value, err := r.Value(); err == nil && value != nil { | |
51 | formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value)) | |
52 | } else { | |
53 | formatedValues = append(formatedValues, "NULL") | |
54 | } | |
55 | } else { | |
56 | formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value)) | |
57 | } | |
58 | } else { | |
59 | formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value)) | |
60 | } | |
61 | } | |
62 | messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(values[3].(string), "%v"), formatedValues...)) | |
63 | } else { | |
64 | messages = append(messages, "\033[31;1m") | |
65 | messages = append(messages, values[2:]...) | |
66 | messages = append(messages, "\033[0m") | |
67 | } | |
68 | logger.Println(messages...) | |
69 | } | |
70 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "database/sql" | |
4 | "errors" | |
5 | "fmt" | |
6 | "reflect" | |
7 | "strings" | |
8 | "time" | |
9 | ) | |
10 | ||
11 | // NowFunc returns current time, this function is exported in order to be able | |
12 | // to give the flexibility to the developer to customize it according to their | |
13 | // needs | |
14 | // | |
15 | // e.g: return time.Now().UTC() | |
16 | // | |
17 | var NowFunc = func() time.Time { | |
18 | return time.Now() | |
19 | } | |
20 | ||
21 | type DB struct { | |
22 | Value interface{} | |
23 | Error error | |
24 | RowsAffected int64 | |
25 | callback *callback | |
26 | db sqlCommon | |
27 | parent *DB | |
28 | search *search | |
29 | logMode int | |
30 | logger logger | |
31 | dialect Dialect | |
32 | singularTable bool | |
33 | source string | |
34 | values map[string]interface{} | |
35 | joinTableHandlers map[string]JoinTableHandler | |
36 | } | |
37 | ||
38 | func Open(dialect string, args ...interface{}) (DB, error) { | |
39 | var db DB | |
40 | var err error | |
41 | ||
42 | if len(args) == 0 { | |
43 | err = errors.New("invalid database source") | |
44 | } else { | |
45 | var source string | |
46 | var dbSql sqlCommon | |
47 | ||
48 | switch value := args[0].(type) { | |
49 | case string: | |
50 | var driver = dialect | |
51 | if len(args) == 1 { | |
52 | source = value | |
53 | } else if len(args) >= 2 { | |
54 | driver = value | |
55 | source = args[1].(string) | |
56 | } | |
57 | if driver == "foundation" { | |
58 | driver = "postgres" // FoundationDB speaks a postgres-compatible protocol. | |
59 | } | |
60 | dbSql, err = sql.Open(driver, source) | |
61 | case sqlCommon: | |
62 | source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() | |
63 | dbSql = value | |
64 | } | |
65 | ||
66 | db = DB{ | |
67 | dialect: NewDialect(dialect), | |
68 | logger: defaultLogger, | |
69 | callback: DefaultCallback, | |
70 | source: source, | |
71 | values: map[string]interface{}{}, | |
72 | db: dbSql, | |
73 | } | |
74 | db.parent = &db | |
75 | ||
76 | if err == nil { | |
77 | err = db.DB().Ping() // Send a ping to make sure the database connection is alive. | |
78 | } | |
79 | } | |
80 | ||
81 | return db, err | |
82 | } | |
83 | ||
84 | func (s *DB) Close() error { | |
85 | return s.parent.db.(*sql.DB).Close() | |
86 | } | |
87 | ||
88 | func (s *DB) DB() *sql.DB { | |
89 | return s.db.(*sql.DB) | |
90 | } | |
91 | ||
92 | func (s *DB) New() *DB { | |
93 | clone := s.clone() | |
94 | clone.search = nil | |
95 | clone.Value = nil | |
96 | return clone | |
97 | } | |
98 | ||
99 | // NewScope create scope for callbacks, including DB's search information | |
100 | func (db *DB) NewScope(value interface{}) *Scope { | |
101 | dbClone := db.clone() | |
102 | dbClone.Value = value | |
103 | return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} | |
104 | } | |
105 | ||
106 | // CommonDB Return the underlying sql.DB or sql.Tx instance. | |
107 | // Use of this method is discouraged. It's mainly intended to allow | |
108 | // coexistence with legacy non-GORM code. | |
109 | func (s *DB) CommonDB() sqlCommon { | |
110 | return s.db | |
111 | } | |
112 | ||
113 | func (s *DB) Callback() *callback { | |
114 | s.parent.callback = s.parent.callback.clone() | |
115 | return s.parent.callback | |
116 | } | |
117 | ||
118 | func (s *DB) SetLogger(l logger) { | |
119 | s.logger = l | |
120 | } | |
121 | ||
122 | func (s *DB) LogMode(enable bool) *DB { | |
123 | if enable { | |
124 | s.logMode = 2 | |
125 | } else { | |
126 | s.logMode = 1 | |
127 | } | |
128 | return s | |
129 | } | |
130 | ||
131 | func (s *DB) SingularTable(enable bool) { | |
132 | modelStructsMap = newModelStructsMap() | |
133 | s.parent.singularTable = enable | |
134 | } | |
135 | ||
136 | func (s *DB) Where(query interface{}, args ...interface{}) *DB { | |
137 | return s.clone().search.Where(query, args...).db | |
138 | } | |
139 | ||
140 | func (s *DB) Or(query interface{}, args ...interface{}) *DB { | |
141 | return s.clone().search.Or(query, args...).db | |
142 | } | |
143 | ||
144 | func (s *DB) Not(query interface{}, args ...interface{}) *DB { | |
145 | return s.clone().search.Not(query, args...).db | |
146 | } | |
147 | ||
148 | func (s *DB) Limit(value interface{}) *DB { | |
149 | return s.clone().search.Limit(value).db | |
150 | } | |
151 | ||
152 | func (s *DB) Offset(value interface{}) *DB { | |
153 | return s.clone().search.Offset(value).db | |
154 | } | |
155 | ||
156 | func (s *DB) Order(value string, reorder ...bool) *DB { | |
157 | return s.clone().search.Order(value, reorder...).db | |
158 | } | |
159 | ||
160 | func (s *DB) Select(query interface{}, args ...interface{}) *DB { | |
161 | return s.clone().search.Select(query, args...).db | |
162 | } | |
163 | ||
164 | func (s *DB) Omit(columns ...string) *DB { | |
165 | return s.clone().search.Omit(columns...).db | |
166 | } | |
167 | ||
168 | func (s *DB) Group(query string) *DB { | |
169 | return s.clone().search.Group(query).db | |
170 | } | |
171 | ||
172 | func (s *DB) Having(query string, values ...interface{}) *DB { | |
173 | return s.clone().search.Having(query, values...).db | |
174 | } | |
175 | ||
176 | func (s *DB) Joins(query string) *DB { | |
177 | return s.clone().search.Joins(query).db | |
178 | } | |
179 | ||
180 | func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { | |
181 | for _, f := range funcs { | |
182 | s = f(s) | |
183 | } | |
184 | return s | |
185 | } | |
186 | ||
187 | func (s *DB) Unscoped() *DB { | |
188 | return s.clone().search.unscoped().db | |
189 | } | |
190 | ||
191 | func (s *DB) Attrs(attrs ...interface{}) *DB { | |
192 | return s.clone().search.Attrs(attrs...).db | |
193 | } | |
194 | ||
195 | func (s *DB) Assign(attrs ...interface{}) *DB { | |
196 | return s.clone().search.Assign(attrs...).db | |
197 | } | |
198 | ||
199 | func (s *DB) First(out interface{}, where ...interface{}) *DB { | |
200 | newScope := s.clone().NewScope(out) | |
201 | newScope.Search.Limit(1) | |
202 | return newScope.Set("gorm:order_by_primary_key", "ASC"). | |
203 | inlineCondition(where...).callCallbacks(s.parent.callback.queries).db | |
204 | } | |
205 | ||
206 | func (s *DB) Last(out interface{}, where ...interface{}) *DB { | |
207 | newScope := s.clone().NewScope(out) | |
208 | newScope.Search.Limit(1) | |
209 | return newScope.Set("gorm:order_by_primary_key", "DESC"). | |
210 | inlineCondition(where...).callCallbacks(s.parent.callback.queries).db | |
211 | } | |
212 | ||
213 | func (s *DB) Find(out interface{}, where ...interface{}) *DB { | |
214 | return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db | |
215 | } | |
216 | ||
217 | func (s *DB) Scan(dest interface{}) *DB { | |
218 | return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db | |
219 | } | |
220 | ||
221 | func (s *DB) Row() *sql.Row { | |
222 | return s.NewScope(s.Value).row() | |
223 | } | |
224 | ||
225 | func (s *DB) Rows() (*sql.Rows, error) { | |
226 | return s.NewScope(s.Value).rows() | |
227 | } | |
228 | ||
229 | func (s *DB) Pluck(column string, value interface{}) *DB { | |
230 | return s.NewScope(s.Value).pluck(column, value).db | |
231 | } | |
232 | ||
233 | func (s *DB) Count(value interface{}) *DB { | |
234 | return s.NewScope(s.Value).count(value).db | |
235 | } | |
236 | ||
237 | func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { | |
238 | return s.clone().NewScope(s.Value).related(value, foreignKeys...).db | |
239 | } | |
240 | ||
241 | func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { | |
242 | c := s.clone() | |
243 | if result := c.First(out, where...); result.Error != nil { | |
244 | if !result.RecordNotFound() { | |
245 | return result | |
246 | } | |
247 | c.NewScope(out).inlineCondition(where...).initialize() | |
248 | } else { | |
249 | c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs), false) | |
250 | } | |
251 | return c | |
252 | } | |
253 | ||
254 | func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { | |
255 | c := s.clone() | |
256 | if result := c.First(out, where...); result.Error != nil { | |
257 | if !result.RecordNotFound() { | |
258 | return result | |
259 | } | |
260 | c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callback.creates).db.Error) | |
261 | } else if len(c.search.assignAttrs) > 0 { | |
262 | c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callback.updates).db.Error) | |
263 | } | |
264 | return c | |
265 | } | |
266 | ||
267 | func (s *DB) Update(attrs ...interface{}) *DB { | |
268 | return s.Updates(toSearchableMap(attrs...), true) | |
269 | } | |
270 | ||
271 | func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { | |
272 | return s.clone().NewScope(s.Value). | |
273 | Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). | |
274 | InstanceSet("gorm:update_interface", values). | |
275 | callCallbacks(s.parent.callback.updates).db | |
276 | } | |
277 | ||
278 | func (s *DB) UpdateColumn(attrs ...interface{}) *DB { | |
279 | return s.UpdateColumns(toSearchableMap(attrs...)) | |
280 | } | |
281 | ||
282 | func (s *DB) UpdateColumns(values interface{}) *DB { | |
283 | return s.clone().NewScope(s.Value). | |
284 | Set("gorm:update_column", true). | |
285 | Set("gorm:save_associations", false). | |
286 | InstanceSet("gorm:update_interface", values). | |
287 | callCallbacks(s.parent.callback.updates).db | |
288 | } | |
289 | ||
290 | func (s *DB) Save(value interface{}) *DB { | |
291 | scope := s.clone().NewScope(value) | |
292 | if scope.PrimaryKeyZero() { | |
293 | return scope.callCallbacks(s.parent.callback.creates).db | |
294 | } | |
295 | return scope.callCallbacks(s.parent.callback.updates).db | |
296 | } | |
297 | ||
298 | func (s *DB) Create(value interface{}) *DB { | |
299 | scope := s.clone().NewScope(value) | |
300 | return scope.callCallbacks(s.parent.callback.creates).db | |
301 | } | |
302 | ||
303 | func (s *DB) Delete(value interface{}, where ...interface{}) *DB { | |
304 | return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callback.deletes).db | |
305 | } | |
306 | ||
307 | func (s *DB) Raw(sql string, values ...interface{}) *DB { | |
308 | return s.clone().search.Raw(true).Where(sql, values...).db | |
309 | } | |
310 | ||
311 | func (s *DB) Exec(sql string, values ...interface{}) *DB { | |
312 | scope := s.clone().NewScope(nil) | |
313 | generatedSql := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) | |
314 | generatedSql = strings.TrimSuffix(strings.TrimPrefix(generatedSql, "("), ")") | |
315 | scope.Raw(generatedSql) | |
316 | return scope.Exec().db | |
317 | } | |
318 | ||
319 | func (s *DB) Model(value interface{}) *DB { | |
320 | c := s.clone() | |
321 | c.Value = value | |
322 | return c | |
323 | } | |
324 | ||
325 | func (s *DB) Table(name string) *DB { | |
326 | clone := s.clone() | |
327 | clone.search.Table(name) | |
328 | clone.Value = nil | |
329 | return clone | |
330 | } | |
331 | ||
332 | func (s *DB) Debug() *DB { | |
333 | return s.clone().LogMode(true) | |
334 | } | |
335 | ||
336 | func (s *DB) Begin() *DB { | |
337 | c := s.clone() | |
338 | if db, ok := c.db.(sqlDb); ok { | |
339 | tx, err := db.Begin() | |
340 | c.db = interface{}(tx).(sqlCommon) | |
341 | c.AddError(err) | |
342 | } else { | |
343 | c.AddError(CantStartTransaction) | |
344 | } | |
345 | return c | |
346 | } | |
347 | ||
348 | func (s *DB) Commit() *DB { | |
349 | if db, ok := s.db.(sqlTx); ok { | |
350 | s.AddError(db.Commit()) | |
351 | } else { | |
352 | s.AddError(NoValidTransaction) | |
353 | } | |
354 | return s | |
355 | } | |
356 | ||
357 | func (s *DB) Rollback() *DB { | |
358 | if db, ok := s.db.(sqlTx); ok { | |
359 | s.AddError(db.Rollback()) | |
360 | } else { | |
361 | s.AddError(NoValidTransaction) | |
362 | } | |
363 | return s | |
364 | } | |
365 | ||
366 | func (s *DB) NewRecord(value interface{}) bool { | |
367 | return s.clone().NewScope(value).PrimaryKeyZero() | |
368 | } | |
369 | ||
370 | func (s *DB) RecordNotFound() bool { | |
371 | return s.Error == RecordNotFound | |
372 | } | |
373 | ||
374 | // Migrations | |
375 | func (s *DB) CreateTable(values ...interface{}) *DB { | |
376 | db := s.clone() | |
377 | for _, value := range values { | |
378 | db = db.NewScope(value).createTable().db | |
379 | } | |
380 | return db | |
381 | } | |
382 | ||
383 | func (s *DB) DropTable(values ...interface{}) *DB { | |
384 | db := s.clone() | |
385 | for _, value := range values { | |
386 | db = db.NewScope(value).dropTable().db | |
387 | } | |
388 | return db | |
389 | } | |
390 | ||
391 | func (s *DB) DropTableIfExists(values ...interface{}) *DB { | |
392 | db := s.clone() | |
393 | for _, value := range values { | |
394 | db = db.NewScope(value).dropTableIfExists().db | |
395 | } | |
396 | return db | |
397 | } | |
398 | ||
399 | func (s *DB) HasTable(value interface{}) bool { | |
400 | scope := s.clone().NewScope(value) | |
401 | tableName := scope.TableName() | |
402 | has := scope.Dialect().HasTable(scope, tableName) | |
403 | s.AddError(scope.db.Error) | |
404 | return has | |
405 | } | |
406 | ||
407 | func (s *DB) AutoMigrate(values ...interface{}) *DB { | |
408 | db := s.clone() | |
409 | for _, value := range values { | |
410 | db = db.NewScope(value).NeedPtr().autoMigrate().db | |
411 | } | |
412 | return db | |
413 | } | |
414 | ||
415 | func (s *DB) ModifyColumn(column string, typ string) *DB { | |
416 | scope := s.clone().NewScope(s.Value) | |
417 | scope.modifyColumn(column, typ) | |
418 | return scope.db | |
419 | } | |
420 | ||
421 | func (s *DB) DropColumn(column string) *DB { | |
422 | scope := s.clone().NewScope(s.Value) | |
423 | scope.dropColumn(column) | |
424 | return scope.db | |
425 | } | |
426 | ||
427 | func (s *DB) AddIndex(indexName string, column ...string) *DB { | |
428 | scope := s.clone().NewScope(s.Value) | |
429 | scope.addIndex(false, indexName, column...) | |
430 | return scope.db | |
431 | } | |
432 | ||
433 | func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB { | |
434 | scope := s.clone().NewScope(s.Value) | |
435 | scope.addIndex(true, indexName, column...) | |
436 | return scope.db | |
437 | } | |
438 | ||
439 | func (s *DB) RemoveIndex(indexName string) *DB { | |
440 | scope := s.clone().NewScope(s.Value) | |
441 | scope.removeIndex(indexName) | |
442 | return scope.db | |
443 | } | |
444 | ||
445 | func (s *DB) CurrentDatabase() string { | |
446 | var ( | |
447 | scope = s.clone().NewScope(s.Value) | |
448 | name = s.dialect.CurrentDatabase(scope) | |
449 | ) | |
450 | return name | |
451 | } | |
452 | ||
453 | /* | |
454 | Add foreign key to the given scope | |
455 | ||
456 | Example: | |
457 | db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") | |
458 | */ | |
459 | func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { | |
460 | scope := s.clone().NewScope(s.Value) | |
461 | scope.addForeignKey(field, dest, onDelete, onUpdate) | |
462 | return scope.db | |
463 | } | |
464 | ||
465 | func (s *DB) Association(column string) *Association { | |
466 | var err error | |
467 | scope := s.clone().NewScope(s.Value) | |
468 | ||
469 | if primaryField := scope.PrimaryField(); primaryField.IsBlank { | |
470 | err = errors.New("primary key can't be nil") | |
471 | } else { | |
472 | if field, ok := scope.FieldByName(column); ok { | |
473 | if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { | |
474 | err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) | |
475 | } else { | |
476 | return &Association{Scope: scope, Column: column, Field: field} | |
477 | } | |
478 | } else { | |
479 | err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) | |
480 | } | |
481 | } | |
482 | ||
483 | return &Association{Error: err} | |
484 | } | |
485 | ||
486 | func (s *DB) Preload(column string, conditions ...interface{}) *DB { | |
487 | return s.clone().search.Preload(column, conditions...).db | |
488 | } | |
489 | ||
490 | // Set set value by name | |
491 | func (s *DB) Set(name string, value interface{}) *DB { | |
492 | return s.clone().InstantSet(name, value) | |
493 | } | |
494 | ||
495 | func (s *DB) InstantSet(name string, value interface{}) *DB { | |
496 | s.values[name] = value | |
497 | return s | |
498 | } | |
499 | ||
500 | // Get get value by name | |
501 | func (s *DB) Get(name string) (value interface{}, ok bool) { | |
502 | value, ok = s.values[name] | |
503 | return | |
504 | } | |
505 | ||
506 | func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { | |
507 | scope := s.NewScope(source) | |
508 | for _, field := range scope.GetModelStruct().StructFields { | |
509 | if field.Name == column || field.DBName == column { | |
510 | if many2many := parseTagSetting(field.Tag.Get("gorm"))["MANY2MANY"]; many2many != "" { | |
511 | source := (&Scope{Value: source}).GetModelStruct().ModelType | |
512 | destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType | |
513 | handler.Setup(field.Relationship, many2many, source, destination) | |
514 | field.Relationship.JoinTableHandler = handler | |
515 | if table := handler.Table(s); scope.Dialect().HasTable(scope, table) { | |
516 | s.Table(table).AutoMigrate(handler) | |
517 | } | |
518 | } | |
519 | } | |
520 | } | |
521 | } | |
522 | ||
523 | func (s *DB) AddError(err error) error { | |
524 | if err != nil { | |
525 | if err != RecordNotFound { | |
526 | if s.logMode == 0 { | |
527 | go s.print(fileWithLineNum(), err) | |
528 | } else { | |
529 | s.log(err) | |
530 | } | |
531 | ||
532 | errors := Errors{errors: s.GetErrors()} | |
533 | errors.Add(err) | |
534 | if len(errors.GetErrors()) > 1 { | |
535 | err = errors | |
536 | } | |
537 | } | |
538 | ||
539 | s.Error = err | |
540 | } | |
541 | return err | |
542 | } | |
543 | ||
544 | func (s *DB) GetErrors() (errors []error) { | |
545 | if errs, ok := s.Error.(errorsInterface); ok { | |
546 | return errs.GetErrors() | |
547 | } else if s.Error != nil { | |
548 | return []error{s.Error} | |
549 | } | |
550 | return | |
551 | } |
0 | package gorm | |
1 | ||
2 | import "time" | |
3 | ||
4 | func (s *DB) clone() *DB { | |
5 | db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} | |
6 | ||
7 | for key, value := range s.values { | |
8 | db.values[key] = value | |
9 | } | |
10 | ||
11 | if s.search == nil { | |
12 | db.search = &search{} | |
13 | } else { | |
14 | db.search = s.search.clone() | |
15 | } | |
16 | ||
17 | db.search.db = &db | |
18 | return &db | |
19 | } | |
20 | ||
21 | func (s *DB) print(v ...interface{}) { | |
22 | s.logger.(logger).Print(v...) | |
23 | } | |
24 | ||
25 | func (s *DB) log(v ...interface{}) { | |
26 | if s != nil && s.logMode == 2 { | |
27 | s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) | |
28 | } | |
29 | } | |
30 | ||
31 | func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { | |
32 | if s.logMode == 2 { | |
33 | s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars) | |
34 | } | |
35 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "database/sql" | |
4 | "database/sql/driver" | |
5 | "fmt" | |
6 | "strconv" | |
7 | ||
8 | _ "github.com/denisenkom/go-mssqldb" | |
9 | testdb "github.com/erikstmartin/go-testdb" | |
10 | _ "github.com/go-sql-driver/mysql" | |
11 | "github.com/jinzhu/gorm" | |
12 | "github.com/jinzhu/now" | |
13 | _ "github.com/lib/pq" | |
14 | _ "github.com/mattn/go-sqlite3" | |
15 | ||
16 | "os" | |
17 | "testing" | |
18 | "time" | |
19 | ) | |
20 | ||
21 | var ( | |
22 | DB gorm.DB | |
23 | t1, t2, t3, t4, t5 time.Time | |
24 | ) | |
25 | ||
26 | func init() { | |
27 | var err error | |
28 | ||
29 | if DB, err = OpenTestConnection(); err != nil { | |
30 | panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err)) | |
31 | } | |
32 | ||
33 | // DB.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) | |
34 | // DB.SetLogger(log.New(os.Stdout, "\r\n", 0)) | |
35 | // DB.LogMode(true) | |
36 | DB.LogMode(false) | |
37 | ||
38 | DB.DB().SetMaxIdleConns(10) | |
39 | ||
40 | runMigration() | |
41 | } | |
42 | ||
43 | func OpenTestConnection() (db gorm.DB, err error) { | |
44 | switch os.Getenv("GORM_DIALECT") { | |
45 | case "mysql": | |
46 | // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; | |
47 | // CREATE DATABASE gorm; | |
48 | // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; | |
49 | fmt.Println("testing mysql...") | |
50 | db, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") | |
51 | case "postgres": | |
52 | fmt.Println("testing postgres...") | |
53 | db, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable") | |
54 | case "foundation": | |
55 | fmt.Println("testing foundation...") | |
56 | db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") | |
57 | case "mssql": | |
58 | fmt.Println("testing mssql...") | |
59 | db, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433") | |
60 | default: | |
61 | fmt.Println("testing sqlite3...") | |
62 | db, err = gorm.Open("sqlite3", "/tmp/gorm.db") | |
63 | } | |
64 | return | |
65 | } | |
66 | ||
67 | func TestStringPrimaryKey(t *testing.T) { | |
68 | type UUIDStruct struct { | |
69 | ID string `gorm:"primary_key"` | |
70 | Name string | |
71 | } | |
72 | DB.AutoMigrate(&UUIDStruct{}) | |
73 | ||
74 | data := UUIDStruct{ID: "uuid", Name: "hello"} | |
75 | if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" { | |
76 | t.Errorf("string primary key should not be populated") | |
77 | } | |
78 | } | |
79 | ||
80 | func TestExceptionsWithInvalidSql(t *testing.T) { | |
81 | var columns []string | |
82 | if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { | |
83 | t.Errorf("Should got error with invalid SQL") | |
84 | } | |
85 | ||
86 | if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { | |
87 | t.Errorf("Should got error with invalid SQL") | |
88 | } | |
89 | ||
90 | if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil { | |
91 | t.Errorf("Should got error with invalid SQL") | |
92 | } | |
93 | ||
94 | var count1, count2 int64 | |
95 | DB.Model(&User{}).Count(&count1) | |
96 | if count1 <= 0 { | |
97 | t.Errorf("Should find some users") | |
98 | } | |
99 | ||
100 | if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil { | |
101 | t.Errorf("Should got error with invalid SQL") | |
102 | } | |
103 | ||
104 | DB.Model(&User{}).Count(&count2) | |
105 | if count1 != count2 { | |
106 | t.Errorf("No user should not be deleted by invalid SQL") | |
107 | } | |
108 | } | |
109 | ||
110 | func TestSetTable(t *testing.T) { | |
111 | DB.Create(getPreparedUser("pluck_user1", "pluck_user")) | |
112 | DB.Create(getPreparedUser("pluck_user2", "pluck_user")) | |
113 | DB.Create(getPreparedUser("pluck_user3", "pluck_user")) | |
114 | ||
115 | if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil { | |
116 | t.Errorf("No errors should happen if set table for pluck", err.Error()) | |
117 | } | |
118 | ||
119 | var users []User | |
120 | if DB.Table("users").Find(&[]User{}).Error != nil { | |
121 | t.Errorf("No errors should happen if set table for find") | |
122 | } | |
123 | ||
124 | if DB.Table("invalid_table").Find(&users).Error == nil { | |
125 | t.Errorf("Should got error when table is set to an invalid table") | |
126 | } | |
127 | ||
128 | DB.Exec("drop table deleted_users;") | |
129 | if DB.Table("deleted_users").CreateTable(&User{}).Error != nil { | |
130 | t.Errorf("Create table with specified table") | |
131 | } | |
132 | ||
133 | DB.Table("deleted_users").Save(&User{Name: "DeletedUser"}) | |
134 | ||
135 | var deletedUsers []User | |
136 | DB.Table("deleted_users").Find(&deletedUsers) | |
137 | if len(deletedUsers) != 1 { | |
138 | t.Errorf("Query from specified table") | |
139 | } | |
140 | ||
141 | DB.Save(getPreparedUser("normal_user", "reset_table")) | |
142 | DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table")) | |
143 | var user1, user2, user3 User | |
144 | DB.Where("role = ?", "reset_table").First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3) | |
145 | if (user1.Name != "normal_user") || (user2.Name != "deleted_user") || (user3.Name != "normal_user") { | |
146 | t.Errorf("unset specified table with blank string") | |
147 | } | |
148 | } | |
149 | ||
150 | type Order struct { | |
151 | } | |
152 | ||
153 | type Cart struct { | |
154 | } | |
155 | ||
156 | func (c Cart) TableName() string { | |
157 | return "shopping_cart" | |
158 | } | |
159 | ||
160 | func TestHasTable(t *testing.T) { | |
161 | type Foo struct { | |
162 | Id int | |
163 | Stuff string | |
164 | } | |
165 | DB.DropTable(&Foo{}) | |
166 | if ok := DB.HasTable(&Foo{}); ok { | |
167 | t.Errorf("Table should not exist, but does") | |
168 | } | |
169 | if err := DB.CreateTable(&Foo{}).Error; err != nil { | |
170 | t.Errorf("Table should be created") | |
171 | } | |
172 | if ok := DB.HasTable(&Foo{}); !ok { | |
173 | t.Errorf("Table should exist, but HasTable informs it does not") | |
174 | } | |
175 | } | |
176 | ||
177 | func TestTableName(t *testing.T) { | |
178 | DB := DB.Model("") | |
179 | if DB.NewScope(Order{}).TableName() != "orders" { | |
180 | t.Errorf("Order's table name should be orders") | |
181 | } | |
182 | ||
183 | if DB.NewScope(&Order{}).TableName() != "orders" { | |
184 | t.Errorf("&Order's table name should be orders") | |
185 | } | |
186 | ||
187 | if DB.NewScope([]Order{}).TableName() != "orders" { | |
188 | t.Errorf("[]Order's table name should be orders") | |
189 | } | |
190 | ||
191 | if DB.NewScope(&[]Order{}).TableName() != "orders" { | |
192 | t.Errorf("&[]Order's table name should be orders") | |
193 | } | |
194 | ||
195 | DB.SingularTable(true) | |
196 | if DB.NewScope(Order{}).TableName() != "order" { | |
197 | t.Errorf("Order's singular table name should be order") | |
198 | } | |
199 | ||
200 | if DB.NewScope(&Order{}).TableName() != "order" { | |
201 | t.Errorf("&Order's singular table name should be order") | |
202 | } | |
203 | ||
204 | if DB.NewScope([]Order{}).TableName() != "order" { | |
205 | t.Errorf("[]Order's singular table name should be order") | |
206 | } | |
207 | ||
208 | if DB.NewScope(&[]Order{}).TableName() != "order" { | |
209 | t.Errorf("&[]Order's singular table name should be order") | |
210 | } | |
211 | ||
212 | if DB.NewScope(&Cart{}).TableName() != "shopping_cart" { | |
213 | t.Errorf("&Cart's singular table name should be shopping_cart") | |
214 | } | |
215 | ||
216 | if DB.NewScope(Cart{}).TableName() != "shopping_cart" { | |
217 | t.Errorf("Cart's singular table name should be shopping_cart") | |
218 | } | |
219 | ||
220 | if DB.NewScope(&[]Cart{}).TableName() != "shopping_cart" { | |
221 | t.Errorf("&[]Cart's singular table name should be shopping_cart") | |
222 | } | |
223 | ||
224 | if DB.NewScope([]Cart{}).TableName() != "shopping_cart" { | |
225 | t.Errorf("[]Cart's singular table name should be shopping_cart") | |
226 | } | |
227 | DB.SingularTable(false) | |
228 | } | |
229 | ||
230 | func TestSqlNullValue(t *testing.T) { | |
231 | DB.DropTable(&NullValue{}) | |
232 | DB.AutoMigrate(&NullValue{}) | |
233 | ||
234 | if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello", Valid: true}, | |
235 | Age: sql.NullInt64{Int64: 18, Valid: true}, | |
236 | Male: sql.NullBool{Bool: true, Valid: true}, | |
237 | Height: sql.NullFloat64{Float64: 100.11, Valid: true}, | |
238 | AddedAt: NullTime{Time: time.Now(), Valid: true}, | |
239 | }).Error; err != nil { | |
240 | t.Errorf("Not error should raise when test null value") | |
241 | } | |
242 | ||
243 | var nv NullValue | |
244 | DB.First(&nv, "name = ?", "hello") | |
245 | ||
246 | if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true { | |
247 | t.Errorf("Should be able to fetch null value") | |
248 | } | |
249 | ||
250 | if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello-2", Valid: true}, | |
251 | Age: sql.NullInt64{Int64: 18, Valid: false}, | |
252 | Male: sql.NullBool{Bool: true, Valid: true}, | |
253 | Height: sql.NullFloat64{Float64: 100.11, Valid: true}, | |
254 | AddedAt: NullTime{Time: time.Now(), Valid: false}, | |
255 | }).Error; err != nil { | |
256 | t.Errorf("Not error should raise when test null value") | |
257 | } | |
258 | ||
259 | var nv2 NullValue | |
260 | DB.First(&nv2, "name = ?", "hello-2") | |
261 | if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { | |
262 | t.Errorf("Should be able to fetch null value") | |
263 | } | |
264 | ||
265 | if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello-3", Valid: false}, | |
266 | Age: sql.NullInt64{Int64: 18, Valid: false}, | |
267 | Male: sql.NullBool{Bool: true, Valid: true}, | |
268 | Height: sql.NullFloat64{Float64: 100.11, Valid: true}, | |
269 | AddedAt: NullTime{Time: time.Now(), Valid: false}, | |
270 | }).Error; err == nil { | |
271 | t.Errorf("Can't save because of name can't be null") | |
272 | } | |
273 | } | |
274 | ||
275 | func TestTransaction(t *testing.T) { | |
276 | tx := DB.Begin() | |
277 | u := User{Name: "transcation"} | |
278 | if err := tx.Save(&u).Error; err != nil { | |
279 | t.Errorf("No error should raise") | |
280 | } | |
281 | ||
282 | if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { | |
283 | t.Errorf("Should find saved record") | |
284 | } | |
285 | ||
286 | if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil { | |
287 | t.Errorf("Should return the underlying sql.Tx") | |
288 | } | |
289 | ||
290 | tx.Rollback() | |
291 | ||
292 | if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { | |
293 | t.Errorf("Should not find record after rollback") | |
294 | } | |
295 | ||
296 | tx2 := DB.Begin() | |
297 | u2 := User{Name: "transcation-2"} | |
298 | if err := tx2.Save(&u2).Error; err != nil { | |
299 | t.Errorf("No error should raise") | |
300 | } | |
301 | ||
302 | if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { | |
303 | t.Errorf("Should find saved record") | |
304 | } | |
305 | ||
306 | tx2.Commit() | |
307 | ||
308 | if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { | |
309 | t.Errorf("Should be able to find committed record") | |
310 | } | |
311 | } | |
312 | ||
313 | func TestRow(t *testing.T) { | |
314 | user1 := User{Name: "RowUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} | |
315 | user2 := User{Name: "RowUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} | |
316 | user3 := User{Name: "RowUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} | |
317 | DB.Save(&user1).Save(&user2).Save(&user3) | |
318 | ||
319 | row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row() | |
320 | var age int64 | |
321 | row.Scan(&age) | |
322 | if age != 10 { | |
323 | t.Errorf("Scan with Row") | |
324 | } | |
325 | } | |
326 | ||
327 | func TestRows(t *testing.T) { | |
328 | user1 := User{Name: "RowsUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} | |
329 | user2 := User{Name: "RowsUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} | |
330 | user3 := User{Name: "RowsUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} | |
331 | DB.Save(&user1).Save(&user2).Save(&user3) | |
332 | ||
333 | rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() | |
334 | if err != nil { | |
335 | t.Errorf("Not error should happen, but got") | |
336 | } | |
337 | ||
338 | count := 0 | |
339 | for rows.Next() { | |
340 | var name string | |
341 | var age int64 | |
342 | rows.Scan(&name, &age) | |
343 | count++ | |
344 | } | |
345 | if count != 2 { | |
346 | t.Errorf("Should found two records with name 3") | |
347 | } | |
348 | } | |
349 | ||
350 | func TestScan(t *testing.T) { | |
351 | user1 := User{Name: "ScanUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} | |
352 | user2 := User{Name: "ScanUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} | |
353 | user3 := User{Name: "ScanUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} | |
354 | DB.Save(&user1).Save(&user2).Save(&user3) | |
355 | ||
356 | type result struct { | |
357 | Name string | |
358 | Age int | |
359 | } | |
360 | ||
361 | var res result | |
362 | DB.Table("users").Select("name, age").Where("name = ?", user3.Name).Scan(&res) | |
363 | if res.Name != user3.Name { | |
364 | t.Errorf("Scan into struct should work") | |
365 | } | |
366 | ||
367 | var doubleAgeRes result | |
368 | DB.Table("users").Select("age + age as age").Where("name = ?", user3.Name).Scan(&doubleAgeRes) | |
369 | if doubleAgeRes.Age != res.Age*2 { | |
370 | t.Errorf("Scan double age as age") | |
371 | } | |
372 | ||
373 | var ress []result | |
374 | DB.Table("users").Select("name, age").Where("name in (?)", []string{user2.Name, user3.Name}).Scan(&ress) | |
375 | if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { | |
376 | t.Errorf("Scan into struct map") | |
377 | } | |
378 | } | |
379 | ||
380 | func TestRaw(t *testing.T) { | |
381 | user1 := User{Name: "ExecRawSqlUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} | |
382 | user2 := User{Name: "ExecRawSqlUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} | |
383 | user3 := User{Name: "ExecRawSqlUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} | |
384 | DB.Save(&user1).Save(&user2).Save(&user3) | |
385 | ||
386 | type result struct { | |
387 | Name string | |
388 | Email string | |
389 | } | |
390 | ||
391 | var ress []result | |
392 | DB.Raw("SELECT name, age FROM users WHERE name = ? or name = ?", user2.Name, user3.Name).Scan(&ress) | |
393 | if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { | |
394 | t.Errorf("Raw with scan") | |
395 | } | |
396 | ||
397 | rows, _ := DB.Raw("select name, age from users where name = ?", user3.Name).Rows() | |
398 | count := 0 | |
399 | for rows.Next() { | |
400 | count++ | |
401 | } | |
402 | if count != 1 { | |
403 | t.Errorf("Raw with Rows should find one record with name 3") | |
404 | } | |
405 | ||
406 | DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name}) | |
407 | if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.RecordNotFound { | |
408 | t.Error("Raw sql to update records") | |
409 | } | |
410 | } | |
411 | ||
412 | func TestGroup(t *testing.T) { | |
413 | rows, err := DB.Select("name").Table("users").Group("name").Rows() | |
414 | ||
415 | if err == nil { | |
416 | defer rows.Close() | |
417 | for rows.Next() { | |
418 | var name string | |
419 | rows.Scan(&name) | |
420 | } | |
421 | } else { | |
422 | t.Errorf("Should not raise any error") | |
423 | } | |
424 | } | |
425 | ||
426 | func TestJoins(t *testing.T) { | |
427 | var user = User{ | |
428 | Name: "joins", | |
429 | Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, | |
430 | } | |
431 | DB.Save(&user) | |
432 | ||
433 | var result User | |
434 | DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").First(&result) | |
435 | if result.Name != "joins" || result.Id != user.Id { | |
436 | t.Errorf("Should find all two emails with Join") | |
437 | } | |
438 | } | |
439 | ||
440 | func TestJoinsWithSelect(t *testing.T) { | |
441 | type result struct { | |
442 | Name string | |
443 | Email string | |
444 | } | |
445 | ||
446 | user := User{ | |
447 | Name: "joins_with_select", | |
448 | Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, | |
449 | } | |
450 | DB.Save(&user) | |
451 | ||
452 | var results []result | |
453 | DB.Table("users").Select("name, email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results) | |
454 | if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" { | |
455 | t.Errorf("Should find all two emails with Join select") | |
456 | } | |
457 | } | |
458 | ||
459 | func TestHaving(t *testing.T) { | |
460 | rows, err := DB.Select("name, count(*) as total").Table("users").Group("name").Having("name IN (?)", []string{"2", "3"}).Rows() | |
461 | ||
462 | if err == nil { | |
463 | defer rows.Close() | |
464 | for rows.Next() { | |
465 | var name string | |
466 | var total int64 | |
467 | rows.Scan(&name, &total) | |
468 | ||
469 | if name == "2" && total != 1 { | |
470 | t.Errorf("Should have one user having name 2") | |
471 | } | |
472 | if name == "3" && total != 2 { | |
473 | t.Errorf("Should have two users having name 3") | |
474 | } | |
475 | } | |
476 | } else { | |
477 | t.Errorf("Should not raise any error") | |
478 | } | |
479 | } | |
480 | ||
481 | func DialectHasTzSupport() bool { | |
482 | // NB: mssql and FoundationDB do not support time zones. | |
483 | if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" { | |
484 | return false | |
485 | } | |
486 | return true | |
487 | } | |
488 | ||
489 | func TestTimeWithZone(t *testing.T) { | |
490 | var format = "2006-01-02 15:04:05 -0700" | |
491 | var times []time.Time | |
492 | GMT8, _ := time.LoadLocation("Asia/Shanghai") | |
493 | times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8)) | |
494 | times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC)) | |
495 | ||
496 | for index, vtime := range times { | |
497 | name := "time_with_zone_" + strconv.Itoa(index) | |
498 | user := User{Name: name, Birthday: vtime} | |
499 | ||
500 | if !DialectHasTzSupport() { | |
501 | // If our driver dialect doesn't support TZ's, just use UTC for everything here. | |
502 | user.Birthday = vtime.UTC() | |
503 | } | |
504 | ||
505 | DB.Save(&user) | |
506 | expectedBirthday := "2013-02-18 17:51:49 +0000" | |
507 | foundBirthday := user.Birthday.UTC().Format(format) | |
508 | if foundBirthday != expectedBirthday { | |
509 | t.Errorf("User's birthday should not be changed after save for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) | |
510 | } | |
511 | ||
512 | var findUser, findUser2, findUser3 User | |
513 | DB.First(&findUser, "name = ?", name) | |
514 | foundBirthday = findUser.Birthday.UTC().Format(format) | |
515 | if foundBirthday != expectedBirthday { | |
516 | t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v or %+v", name, expectedBirthday, foundBirthday) | |
517 | } | |
518 | ||
519 | if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() { | |
520 | t.Errorf("User should be found") | |
521 | } | |
522 | ||
523 | if !DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(time.Minute)).First(&findUser3).RecordNotFound() { | |
524 | t.Errorf("User should not be found") | |
525 | } | |
526 | } | |
527 | } | |
528 | ||
529 | func TestHstore(t *testing.T) { | |
530 | type Details struct { | |
531 | Id int64 | |
532 | Bulk gorm.Hstore | |
533 | } | |
534 | ||
535 | if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { | |
536 | t.Skip() | |
537 | } | |
538 | ||
539 | if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").Error; err != nil { | |
540 | fmt.Println("\033[31mHINT: Must be superuser to create hstore extension (ALTER USER gorm WITH SUPERUSER;)\033[0m") | |
541 | panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err)) | |
542 | } | |
543 | ||
544 | DB.Exec("drop table details") | |
545 | ||
546 | if err := DB.CreateTable(&Details{}).Error; err != nil { | |
547 | panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) | |
548 | } | |
549 | ||
550 | bankAccountId, phoneNumber, opinion := "123456", "14151321232", "sharkbait" | |
551 | bulk := map[string]*string{ | |
552 | "bankAccountId": &bankAccountId, | |
553 | "phoneNumber": &phoneNumber, | |
554 | "opinion": &opinion, | |
555 | } | |
556 | d := Details{Bulk: bulk} | |
557 | DB.Save(&d) | |
558 | ||
559 | var d2 Details | |
560 | if err := DB.First(&d2).Error; err != nil { | |
561 | t.Errorf("Got error when tried to fetch details: %+v", err) | |
562 | } | |
563 | ||
564 | for k := range bulk { | |
565 | if r, ok := d2.Bulk[k]; ok { | |
566 | if res, _ := bulk[k]; *res != *r { | |
567 | t.Errorf("Details should be equal") | |
568 | } | |
569 | } else { | |
570 | t.Errorf("Details should be existed") | |
571 | } | |
572 | } | |
573 | } | |
574 | ||
575 | func TestSetAndGet(t *testing.T) { | |
576 | if value, ok := DB.Set("hello", "world").Get("hello"); !ok { | |
577 | t.Errorf("Should be able to get setting after set") | |
578 | } else { | |
579 | if value.(string) != "world" { | |
580 | t.Errorf("Setted value should not be changed") | |
581 | } | |
582 | } | |
583 | ||
584 | if _, ok := DB.Get("non_existing"); ok { | |
585 | t.Errorf("Get non existing key should return error") | |
586 | } | |
587 | } | |
588 | ||
589 | func TestCompatibilityMode(t *testing.T) { | |
590 | DB, _ := gorm.Open("testdb", "") | |
591 | testdb.SetQueryFunc(func(query string) (driver.Rows, error) { | |
592 | columns := []string{"id", "name", "age"} | |
593 | result := ` | |
594 | 1,Tim,20 | |
595 | 2,Joe,25 | |
596 | 3,Bob,30 | |
597 | ` | |
598 | return testdb.RowsFromCSVString(columns, result), nil | |
599 | }) | |
600 | ||
601 | var users []User | |
602 | DB.Find(&users) | |
603 | if (users[0].Name != "Tim") || len(users) != 3 { | |
604 | t.Errorf("Unexcepted result returned") | |
605 | } | |
606 | } | |
607 | ||
608 | func TestOpenExistingDB(t *testing.T) { | |
609 | DB.Save(&User{Name: "jnfeinstein"}) | |
610 | dialect := os.Getenv("GORM_DIALECT") | |
611 | ||
612 | db, err := gorm.Open(dialect, DB.DB()) | |
613 | if err != nil { | |
614 | t.Errorf("Should have wrapped the existing DB connection") | |
615 | } | |
616 | ||
617 | var user User | |
618 | if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound { | |
619 | t.Errorf("Should have found existing record") | |
620 | } | |
621 | } | |
622 | ||
623 | func BenchmarkGorm(b *testing.B) { | |
624 | b.N = 2000 | |
625 | for x := 0; x < b.N; x++ { | |
626 | e := strconv.Itoa(x) + "benchmark@example.org" | |
627 | email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} | |
628 | // Insert | |
629 | DB.Save(&email) | |
630 | // Query | |
631 | DB.First(&BigEmail{}, "email = ?", e) | |
632 | // Update | |
633 | DB.Model(&email).UpdateColumn("email", "new-"+e) | |
634 | // Delete | |
635 | DB.Delete(&email) | |
636 | } | |
637 | } | |
638 | ||
639 | func BenchmarkRawSql(b *testing.B) { | |
640 | DB, _ := sql.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable") | |
641 | DB.SetMaxIdleConns(10) | |
642 | insertSql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id" | |
643 | querySql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1" | |
644 | updateSql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3" | |
645 | deleteSql := "DELETE FROM orders WHERE id = $1" | |
646 | ||
647 | b.N = 2000 | |
648 | for x := 0; x < b.N; x++ { | |
649 | var id int64 | |
650 | e := strconv.Itoa(x) + "benchmark@example.org" | |
651 | email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} | |
652 | // Insert | |
653 | DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) | |
654 | // Query | |
655 | rows, _ := DB.Query(querySql, email.Email) | |
656 | rows.Close() | |
657 | // Update | |
658 | DB.Exec(updateSql, "new-"+e, time.Now(), id) | |
659 | // Delete | |
660 | DB.Exec(deleteSql, id) | |
661 | } | |
662 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "testing" | |
5 | "time" | |
6 | ) | |
7 | ||
8 | func runMigration() { | |
9 | if err := DB.DropTableIfExists(&User{}).Error; err != nil { | |
10 | fmt.Printf("Got error when try to delete table users, %+v\n", err) | |
11 | } | |
12 | ||
13 | for _, table := range []string{"animals", "user_languages"} { | |
14 | DB.Exec(fmt.Sprintf("drop table %v;", table)) | |
15 | } | |
16 | ||
17 | values := []interface{}{&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}} | |
18 | for _, value := range values { | |
19 | DB.DropTable(value) | |
20 | } | |
21 | ||
22 | if err := DB.AutoMigrate(values...).Error; err != nil { | |
23 | panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) | |
24 | } | |
25 | } | |
26 | ||
27 | func TestIndexes(t *testing.T) { | |
28 | if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil { | |
29 | t.Errorf("Got error when tried to create index: %+v", err) | |
30 | } | |
31 | ||
32 | scope := DB.NewScope(&Email{}) | |
33 | if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") { | |
34 | t.Errorf("Email should have index idx_email_email") | |
35 | } | |
36 | ||
37 | if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil { | |
38 | t.Errorf("Got error when tried to remove index: %+v", err) | |
39 | } | |
40 | ||
41 | if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") { | |
42 | t.Errorf("Email's index idx_email_email should be deleted") | |
43 | } | |
44 | ||
45 | if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { | |
46 | t.Errorf("Got error when tried to create index: %+v", err) | |
47 | } | |
48 | ||
49 | if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { | |
50 | t.Errorf("Email should have index idx_email_email_and_user_id") | |
51 | } | |
52 | ||
53 | if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { | |
54 | t.Errorf("Got error when tried to remove index: %+v", err) | |
55 | } | |
56 | ||
57 | if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { | |
58 | t.Errorf("Email's index idx_email_email_and_user_id should be deleted") | |
59 | } | |
60 | ||
61 | if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { | |
62 | t.Errorf("Got error when tried to create index: %+v", err) | |
63 | } | |
64 | ||
65 | if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { | |
66 | t.Errorf("Email should have index idx_email_email_and_user_id") | |
67 | } | |
68 | ||
69 | if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil { | |
70 | t.Errorf("Should get to create duplicate record when having unique index") | |
71 | } | |
72 | ||
73 | if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { | |
74 | t.Errorf("Got error when tried to remove index: %+v", err) | |
75 | } | |
76 | ||
77 | if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { | |
78 | t.Errorf("Email's index idx_email_email_and_user_id should be deleted") | |
79 | } | |
80 | ||
81 | if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil { | |
82 | t.Errorf("Should be able to create duplicated emails after remove unique index") | |
83 | } | |
84 | } | |
85 | ||
86 | type BigEmail struct { | |
87 | Id int64 | |
88 | UserId int64 | |
89 | Email string `sql:"index:idx_email_agent"` | |
90 | UserAgent string `sql:"index:idx_email_agent"` | |
91 | RegisteredAt time.Time `sql:"unique_index"` | |
92 | CreatedAt time.Time | |
93 | UpdatedAt time.Time | |
94 | } | |
95 | ||
96 | func (b BigEmail) TableName() string { | |
97 | return "emails" | |
98 | } | |
99 | ||
100 | func TestAutoMigration(t *testing.T) { | |
101 | DB.AutoMigrate(&Address{}) | |
102 | if err := DB.Table("emails").AutoMigrate(&BigEmail{}).Error; err != nil { | |
103 | t.Errorf("Auto Migrate should not raise any error") | |
104 | } | |
105 | ||
106 | DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()}) | |
107 | ||
108 | scope := DB.NewScope(&BigEmail{}) | |
109 | if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_agent") { | |
110 | t.Errorf("Failed to create index") | |
111 | } | |
112 | ||
113 | if !scope.Dialect().HasIndex(scope, scope.TableName(), "uix_emails_registered_at") { | |
114 | t.Errorf("Failed to create index") | |
115 | } | |
116 | ||
117 | var bigemail BigEmail | |
118 | DB.First(&bigemail, "user_agent = ?", "pc") | |
119 | if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() { | |
120 | t.Error("Big Emails should be saved and fetched correctly") | |
121 | } | |
122 | } |
0 | package gorm | |
1 | ||
2 | import "time" | |
3 | ||
4 | type Model struct { | |
5 | ID uint `gorm:"primary_key"` | |
6 | CreatedAt time.Time | |
7 | UpdatedAt time.Time | |
8 | DeletedAt *time.Time `sql:"index"` | |
9 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "database/sql" | |
4 | "fmt" | |
5 | "go/ast" | |
6 | "reflect" | |
7 | "strconv" | |
8 | "strings" | |
9 | "sync" | |
10 | "time" | |
11 | ||
12 | "github.com/qor/inflection" | |
13 | ) | |
14 | ||
15 | var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { | |
16 | return defaultTableName | |
17 | } | |
18 | ||
19 | type safeModelStructsMap struct { | |
20 | m map[reflect.Type]*ModelStruct | |
21 | l *sync.RWMutex | |
22 | } | |
23 | ||
24 | func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) { | |
25 | s.l.Lock() | |
26 | defer s.l.Unlock() | |
27 | s.m[key] = value | |
28 | } | |
29 | ||
30 | func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct { | |
31 | s.l.RLock() | |
32 | defer s.l.RUnlock() | |
33 | return s.m[key] | |
34 | } | |
35 | ||
36 | func newModelStructsMap() *safeModelStructsMap { | |
37 | return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)} | |
38 | } | |
39 | ||
40 | var modelStructsMap = newModelStructsMap() | |
41 | ||
42 | type ModelStruct struct { | |
43 | PrimaryFields []*StructField | |
44 | StructFields []*StructField | |
45 | ModelType reflect.Type | |
46 | defaultTableName string | |
47 | cached bool | |
48 | } | |
49 | ||
50 | func (s ModelStruct) TableName(db *DB) string { | |
51 | return DefaultTableNameHandler(db, s.defaultTableName) | |
52 | } | |
53 | ||
54 | type StructField struct { | |
55 | DBName string | |
56 | Name string | |
57 | Names []string | |
58 | IsPrimaryKey bool | |
59 | IsNormal bool | |
60 | IsIgnored bool | |
61 | IsScanner bool | |
62 | HasDefaultValue bool | |
63 | Tag reflect.StructTag | |
64 | Struct reflect.StructField | |
65 | IsForeignKey bool | |
66 | Relationship *Relationship | |
67 | } | |
68 | ||
69 | func (structField *StructField) clone() *StructField { | |
70 | return &StructField{ | |
71 | DBName: structField.DBName, | |
72 | Name: structField.Name, | |
73 | Names: structField.Names, | |
74 | IsPrimaryKey: structField.IsPrimaryKey, | |
75 | IsNormal: structField.IsNormal, | |
76 | IsIgnored: structField.IsIgnored, | |
77 | IsScanner: structField.IsScanner, | |
78 | HasDefaultValue: structField.HasDefaultValue, | |
79 | Tag: structField.Tag, | |
80 | Struct: structField.Struct, | |
81 | IsForeignKey: structField.IsForeignKey, | |
82 | Relationship: structField.Relationship, | |
83 | } | |
84 | } | |
85 | ||
86 | type Relationship struct { | |
87 | Kind string | |
88 | PolymorphicType string | |
89 | PolymorphicDBName string | |
90 | ForeignFieldNames []string | |
91 | ForeignDBNames []string | |
92 | AssociationForeignFieldNames []string | |
93 | AssociationForeignStructFieldNames []string | |
94 | AssociationForeignDBNames []string | |
95 | JoinTableHandler JoinTableHandlerInterface | |
96 | } | |
97 | ||
98 | func (scope *Scope) GetModelStruct() *ModelStruct { | |
99 | var modelStruct ModelStruct | |
100 | ||
101 | reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) | |
102 | if !reflectValue.IsValid() { | |
103 | return &modelStruct | |
104 | } | |
105 | ||
106 | if reflectValue.Kind() == reflect.Slice { | |
107 | reflectValue = reflect.Indirect(reflect.New(reflectValue.Type().Elem())) | |
108 | } | |
109 | ||
110 | scopeType := reflectValue.Type() | |
111 | ||
112 | if scopeType.Kind() == reflect.Ptr { | |
113 | scopeType = scopeType.Elem() | |
114 | } | |
115 | ||
116 | if value := modelStructsMap.Get(scopeType); value != nil { | |
117 | return value | |
118 | } | |
119 | ||
120 | modelStruct.ModelType = scopeType | |
121 | if scopeType.Kind() != reflect.Struct { | |
122 | return &modelStruct | |
123 | } | |
124 | ||
125 | if tabler, ok := reflect.New(scopeType).Interface().(interface { | |
126 | TableName() string | |
127 | }); ok { | |
128 | modelStruct.defaultTableName = tabler.TableName() | |
129 | } else { | |
130 | name := ToDBName(scopeType.Name()) | |
131 | if scope.db == nil || !scope.db.parent.singularTable { | |
132 | name = inflection.Plural(name) | |
133 | } | |
134 | ||
135 | modelStruct.defaultTableName = name | |
136 | } | |
137 | ||
138 | // Get all fields | |
139 | fields := []*StructField{} | |
140 | for i := 0; i < scopeType.NumField(); i++ { | |
141 | if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) { | |
142 | field := &StructField{ | |
143 | Struct: fieldStruct, | |
144 | Name: fieldStruct.Name, | |
145 | Names: []string{fieldStruct.Name}, | |
146 | Tag: fieldStruct.Tag, | |
147 | } | |
148 | ||
149 | if fieldStruct.Tag.Get("sql") == "-" { | |
150 | field.IsIgnored = true | |
151 | } else { | |
152 | sqlSettings := parseTagSetting(field.Tag.Get("sql")) | |
153 | gormSettings := parseTagSetting(field.Tag.Get("gorm")) | |
154 | if _, ok := gormSettings["PRIMARY_KEY"]; ok { | |
155 | field.IsPrimaryKey = true | |
156 | modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) | |
157 | } | |
158 | ||
159 | if _, ok := sqlSettings["DEFAULT"]; ok { | |
160 | field.HasDefaultValue = true | |
161 | } | |
162 | ||
163 | if value, ok := gormSettings["COLUMN"]; ok { | |
164 | field.DBName = value | |
165 | } else { | |
166 | field.DBName = ToDBName(fieldStruct.Name) | |
167 | } | |
168 | } | |
169 | fields = append(fields, field) | |
170 | } | |
171 | } | |
172 | ||
173 | var finished = make(chan bool) | |
174 | go func(finished chan bool) { | |
175 | for _, field := range fields { | |
176 | if !field.IsIgnored { | |
177 | fieldStruct := field.Struct | |
178 | indirectType := fieldStruct.Type | |
179 | if indirectType.Kind() == reflect.Ptr { | |
180 | indirectType = indirectType.Elem() | |
181 | } | |
182 | ||
183 | if _, isScanner := reflect.New(indirectType).Interface().(sql.Scanner); isScanner { | |
184 | field.IsScanner, field.IsNormal = true, true | |
185 | } | |
186 | ||
187 | if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime { | |
188 | field.IsNormal = true | |
189 | } | |
190 | ||
191 | if !field.IsNormal { | |
192 | gormSettings := parseTagSetting(field.Tag.Get("gorm")) | |
193 | toScope := scope.New(reflect.New(fieldStruct.Type).Interface()) | |
194 | ||
195 | getForeignField := func(column string, fields []*StructField) *StructField { | |
196 | for _, field := range fields { | |
197 | if field.Name == column || field.DBName == ToDBName(column) { | |
198 | return field | |
199 | } | |
200 | } | |
201 | return nil | |
202 | } | |
203 | ||
204 | var relationship = &Relationship{} | |
205 | ||
206 | if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { | |
207 | if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { | |
208 | if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { | |
209 | relationship.ForeignFieldNames = []string{polymorphicField.Name} | |
210 | relationship.ForeignDBNames = []string{polymorphicField.DBName} | |
211 | relationship.AssociationForeignFieldNames = []string{scope.PrimaryField().Name} | |
212 | relationship.AssociationForeignDBNames = []string{scope.PrimaryField().DBName} | |
213 | relationship.PolymorphicType = polymorphicType.Name | |
214 | relationship.PolymorphicDBName = polymorphicType.DBName | |
215 | polymorphicType.IsForeignKey = true | |
216 | polymorphicField.IsForeignKey = true | |
217 | } | |
218 | } | |
219 | } | |
220 | ||
221 | var foreignKeys []string | |
222 | if foreignKey, ok := gormSettings["FOREIGNKEY"]; ok { | |
223 | foreignKeys = append(foreignKeys, foreignKey) | |
224 | } | |
225 | switch indirectType.Kind() { | |
226 | case reflect.Slice: | |
227 | elemType := indirectType.Elem() | |
228 | if elemType.Kind() == reflect.Ptr { | |
229 | elemType = elemType.Elem() | |
230 | } | |
231 | ||
232 | if elemType.Kind() == reflect.Struct { | |
233 | if many2many := gormSettings["MANY2MANY"]; many2many != "" { | |
234 | relationship.Kind = "many_to_many" | |
235 | ||
236 | // foreign keys | |
237 | if len(foreignKeys) == 0 { | |
238 | for _, field := range scope.PrimaryFields() { | |
239 | foreignKeys = append(foreignKeys, field.DBName) | |
240 | } | |
241 | } | |
242 | ||
243 | for _, foreignKey := range foreignKeys { | |
244 | if field, ok := scope.FieldByName(foreignKey); ok { | |
245 | relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName) | |
246 | joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName | |
247 | relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) | |
248 | } | |
249 | } | |
250 | ||
251 | // association foreign keys | |
252 | var associationForeignKeys []string | |
253 | if foreignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { | |
254 | associationForeignKeys = []string{gormSettings["ASSOCIATIONFOREIGNKEY"]} | |
255 | } else { | |
256 | for _, field := range toScope.PrimaryFields() { | |
257 | associationForeignKeys = append(associationForeignKeys, field.DBName) | |
258 | } | |
259 | } | |
260 | ||
261 | for _, name := range associationForeignKeys { | |
262 | if field, ok := toScope.FieldByName(name); ok { | |
263 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) | |
264 | relationship.AssociationForeignStructFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) | |
265 | joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName | |
266 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) | |
267 | } | |
268 | } | |
269 | ||
270 | joinTableHandler := JoinTableHandler{} | |
271 | joinTableHandler.Setup(relationship, many2many, scopeType, elemType) | |
272 | relationship.JoinTableHandler = &joinTableHandler | |
273 | field.Relationship = relationship | |
274 | } else { | |
275 | relationship.Kind = "has_many" | |
276 | ||
277 | if len(foreignKeys) == 0 { | |
278 | for _, field := range scope.PrimaryFields() { | |
279 | if foreignField := getForeignField(scopeType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil { | |
280 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) | |
281 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) | |
282 | relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) | |
283 | relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) | |
284 | foreignField.IsForeignKey = true | |
285 | } | |
286 | } | |
287 | } else { | |
288 | for _, foreignKey := range foreignKeys { | |
289 | if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { | |
290 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) | |
291 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) | |
292 | relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) | |
293 | relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) | |
294 | foreignField.IsForeignKey = true | |
295 | } | |
296 | } | |
297 | } | |
298 | ||
299 | if len(relationship.ForeignFieldNames) != 0 { | |
300 | field.Relationship = relationship | |
301 | } | |
302 | } | |
303 | } else { | |
304 | field.IsNormal = true | |
305 | } | |
306 | case reflect.Struct: | |
307 | if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { | |
308 | for _, toField := range toScope.GetStructFields() { | |
309 | toField = toField.clone() | |
310 | toField.Names = append([]string{fieldStruct.Name}, toField.Names...) | |
311 | modelStruct.StructFields = append(modelStruct.StructFields, toField) | |
312 | if toField.IsPrimaryKey { | |
313 | modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField) | |
314 | } | |
315 | } | |
316 | continue | |
317 | } else { | |
318 | if len(foreignKeys) == 0 { | |
319 | for _, f := range scope.PrimaryFields() { | |
320 | if foreignField := getForeignField(modelStruct.ModelType.Name()+f.Name, toScope.GetStructFields()); foreignField != nil { | |
321 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name) | |
322 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName) | |
323 | relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) | |
324 | relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) | |
325 | foreignField.IsForeignKey = true | |
326 | } | |
327 | } | |
328 | } else { | |
329 | for _, foreignKey := range foreignKeys { | |
330 | if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { | |
331 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) | |
332 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) | |
333 | relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) | |
334 | relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) | |
335 | foreignField.IsForeignKey = true | |
336 | } | |
337 | } | |
338 | } | |
339 | ||
340 | if len(relationship.ForeignFieldNames) != 0 { | |
341 | relationship.Kind = "has_one" | |
342 | field.Relationship = relationship | |
343 | } else { | |
344 | if len(foreignKeys) == 0 { | |
345 | for _, f := range toScope.PrimaryFields() { | |
346 | if foreignField := getForeignField(field.Name+f.Name, fields); foreignField != nil { | |
347 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name) | |
348 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName) | |
349 | relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) | |
350 | relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) | |
351 | foreignField.IsForeignKey = true | |
352 | } | |
353 | } | |
354 | } else { | |
355 | for _, foreignKey := range foreignKeys { | |
356 | if foreignField := getForeignField(foreignKey, fields); foreignField != nil { | |
357 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, toScope.PrimaryField().Name) | |
358 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, toScope.PrimaryField().DBName) | |
359 | relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) | |
360 | relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) | |
361 | foreignField.IsForeignKey = true | |
362 | } | |
363 | } | |
364 | } | |
365 | ||
366 | if len(relationship.ForeignFieldNames) != 0 { | |
367 | relationship.Kind = "belongs_to" | |
368 | field.Relationship = relationship | |
369 | } | |
370 | } | |
371 | } | |
372 | default: | |
373 | field.IsNormal = true | |
374 | } | |
375 | } | |
376 | ||
377 | if field.IsNormal { | |
378 | if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" { | |
379 | field.IsPrimaryKey = true | |
380 | modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) | |
381 | } | |
382 | } | |
383 | } | |
384 | modelStruct.StructFields = append(modelStruct.StructFields, field) | |
385 | } | |
386 | finished <- true | |
387 | }(finished) | |
388 | ||
389 | modelStructsMap.Set(scopeType, &modelStruct) | |
390 | ||
391 | <-finished | |
392 | modelStruct.cached = true | |
393 | ||
394 | return &modelStruct | |
395 | } | |
396 | ||
397 | func (scope *Scope) GetStructFields() (fields []*StructField) { | |
398 | return scope.GetModelStruct().StructFields | |
399 | } | |
400 | ||
401 | func (scope *Scope) generateSqlTag(field *StructField) string { | |
402 | var sqlType string | |
403 | structType := field.Struct.Type | |
404 | if structType.Kind() == reflect.Ptr { | |
405 | structType = structType.Elem() | |
406 | } | |
407 | reflectValue := reflect.Indirect(reflect.New(structType)) | |
408 | sqlSettings := parseTagSetting(field.Tag.Get("sql")) | |
409 | ||
410 | if value, ok := sqlSettings["TYPE"]; ok { | |
411 | sqlType = value | |
412 | } | |
413 | ||
414 | additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"] | |
415 | if value, ok := sqlSettings["DEFAULT"]; ok { | |
416 | additionalType = additionalType + " DEFAULT " + value | |
417 | } | |
418 | ||
419 | if field.IsScanner { | |
420 | var getScannerValue func(reflect.Value) | |
421 | getScannerValue = func(value reflect.Value) { | |
422 | reflectValue = value | |
423 | if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct { | |
424 | getScannerValue(reflectValue.Field(0)) | |
425 | } | |
426 | } | |
427 | getScannerValue(reflectValue) | |
428 | } | |
429 | ||
430 | if sqlType == "" { | |
431 | var size = 255 | |
432 | ||
433 | if value, ok := sqlSettings["SIZE"]; ok { | |
434 | size, _ = strconv.Atoi(value) | |
435 | } | |
436 | ||
437 | _, autoIncrease := sqlSettings["AUTO_INCREMENT"] | |
438 | if field.IsPrimaryKey { | |
439 | autoIncrease = true | |
440 | } | |
441 | ||
442 | sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease) | |
443 | } | |
444 | ||
445 | if strings.TrimSpace(additionalType) == "" { | |
446 | return sqlType | |
447 | } else { | |
448 | return fmt.Sprintf("%v %v", sqlType, additionalType) | |
449 | } | |
450 | } | |
451 | ||
452 | func parseTagSetting(str string) map[string]string { | |
453 | tags := strings.Split(str, ";") | |
454 | setting := map[string]string{} | |
455 | for _, value := range tags { | |
456 | v := strings.Split(value, ":") | |
457 | k := strings.TrimSpace(strings.ToUpper(v[0])) | |
458 | if len(v) >= 2 { | |
459 | setting[k] = strings.Join(v[1:], ":") | |
460 | } else { | |
461 | setting[k] = k | |
462 | } | |
463 | } | |
464 | return setting | |
465 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "reflect" | |
5 | "time" | |
6 | ) | |
7 | ||
8 | type mssql struct { | |
9 | commonDialect | |
10 | } | |
11 | ||
12 | func (mssql) HasTop() bool { | |
13 | return true | |
14 | } | |
15 | ||
16 | func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { | |
17 | switch value.Kind() { | |
18 | case reflect.Bool: | |
19 | return "bit" | |
20 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | |
21 | if autoIncrease { | |
22 | return "int IDENTITY(1,1)" | |
23 | } | |
24 | return "int" | |
25 | case reflect.Int64, reflect.Uint64: | |
26 | if autoIncrease { | |
27 | return "bigint IDENTITY(1,1)" | |
28 | } | |
29 | return "bigint" | |
30 | case reflect.Float32, reflect.Float64: | |
31 | return "float" | |
32 | case reflect.String: | |
33 | if size > 0 && size < 65532 { | |
34 | return fmt.Sprintf("nvarchar(%d)", size) | |
35 | } | |
36 | return "text" | |
37 | case reflect.Struct: | |
38 | if _, ok := value.Interface().(time.Time); ok { | |
39 | return "datetime2" | |
40 | } | |
41 | default: | |
42 | if _, ok := value.Interface().([]byte); ok { | |
43 | if size > 0 && size < 65532 { | |
44 | return fmt.Sprintf("varchar(%d)", size) | |
45 | } | |
46 | return "text" | |
47 | } | |
48 | } | |
49 | panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String())) | |
50 | } | |
51 | ||
52 | func (s mssql) HasTable(scope *Scope, tableName string) bool { | |
53 | var ( | |
54 | count int | |
55 | databaseName = s.CurrentDatabase(scope) | |
56 | ) | |
57 | s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName) | |
58 | return count > 0 | |
59 | } | |
60 | ||
61 | func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { | |
62 | var ( | |
63 | count int | |
64 | databaseName = s.CurrentDatabase(scope) | |
65 | ) | |
66 | s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName) | |
67 | return count > 0 | |
68 | } | |
69 | ||
70 | func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { | |
71 | var count int | |
72 | s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName) | |
73 | return count > 0 | |
74 | } | |
75 | ||
76 | func (s mssql) CurrentDatabase(scope *Scope) (name string) { | |
77 | s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]") | |
78 | return | |
79 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "os" | |
5 | "testing" | |
6 | ) | |
7 | ||
8 | type Blog struct { | |
9 | ID uint `gorm:"primary_key"` | |
10 | Locale string `gorm:"primary_key"` | |
11 | Subject string | |
12 | Body string | |
13 | Tags []Tag `gorm:"many2many:blog_tags;"` | |
14 | } | |
15 | ||
16 | type Tag struct { | |
17 | ID uint `gorm:"primary_key"` | |
18 | Locale string `gorm:"primary_key"` | |
19 | Value string | |
20 | } | |
21 | ||
22 | func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { | |
23 | if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { | |
24 | DB.Exec(fmt.Sprintf("drop table blog_tags;")) | |
25 | DB.AutoMigrate(&Blog{}, &Tag{}) | |
26 | blog := Blog{ | |
27 | Locale: "ZH", | |
28 | Subject: "subject", | |
29 | Body: "body", | |
30 | Tags: []Tag{ | |
31 | {Locale: "ZH", Value: "tag1"}, | |
32 | {Locale: "ZH", Value: "tag2"}, | |
33 | }, | |
34 | } | |
35 | ||
36 | DB.Save(&blog) | |
37 | DB.Model(&blog).Association("Tags").Append([]Tag{{Locale: "ZH", Value: "tag3"}}) | |
38 | ||
39 | var tags []Tag | |
40 | DB.Model(&blog).Related(&tags, "Tags") | |
41 | if len(tags) != 3 { | |
42 | t.Errorf("should found 3 tags with blog") | |
43 | } | |
44 | } | |
45 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "reflect" | |
5 | "time" | |
6 | ) | |
7 | ||
8 | type mysql struct { | |
9 | commonDialect | |
10 | } | |
11 | ||
12 | func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { | |
13 | switch value.Kind() { | |
14 | case reflect.Bool: | |
15 | return "boolean" | |
16 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: | |
17 | if autoIncrease { | |
18 | return "int AUTO_INCREMENT" | |
19 | } | |
20 | return "int" | |
21 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | |
22 | if autoIncrease { | |
23 | return "int unsigned AUTO_INCREMENT" | |
24 | } | |
25 | return "int unsigned" | |
26 | case reflect.Int64: | |
27 | if autoIncrease { | |
28 | return "bigint AUTO_INCREMENT" | |
29 | } | |
30 | return "bigint" | |
31 | case reflect.Uint64: | |
32 | if autoIncrease { | |
33 | return "bigint unsigned AUTO_INCREMENT" | |
34 | } | |
35 | return "bigint unsigned" | |
36 | case reflect.Float32, reflect.Float64: | |
37 | return "double" | |
38 | case reflect.String: | |
39 | if size > 0 && size < 65532 { | |
40 | return fmt.Sprintf("varchar(%d)", size) | |
41 | } | |
42 | return "longtext" | |
43 | case reflect.Struct: | |
44 | if _, ok := value.Interface().(time.Time); ok { | |
45 | return "timestamp NULL" | |
46 | } | |
47 | default: | |
48 | if _, ok := value.Interface().([]byte); ok { | |
49 | if size > 0 && size < 65532 { | |
50 | return fmt.Sprintf("varbinary(%d)", size) | |
51 | } | |
52 | return "longblob" | |
53 | } | |
54 | } | |
55 | panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) | |
56 | } | |
57 | ||
58 | func (mysql) Quote(key string) string { | |
59 | return fmt.Sprintf("`%s`", key) | |
60 | } | |
61 | ||
62 | func (mysql) SelectFromDummyTable() string { | |
63 | return "FROM DUAL" | |
64 | } | |
65 | ||
66 | func (s mysql) CurrentDatabase(scope *Scope) (name string) { | |
67 | s.RawScanString(scope, &name, "SELECT DATABASE()") | |
68 | return | |
69 | } |
0 | package gorm_test | |
1 | ||
2 | import "testing" | |
3 | ||
4 | type PointerStruct struct { | |
5 | ID int64 | |
6 | Name *string | |
7 | Num *int | |
8 | } | |
9 | ||
10 | type NormalStruct struct { | |
11 | ID int64 | |
12 | Name string | |
13 | Num int | |
14 | } | |
15 | ||
16 | func TestPointerFields(t *testing.T) { | |
17 | DB.DropTable(&PointerStruct{}) | |
18 | DB.AutoMigrate(&PointerStruct{}) | |
19 | var name = "pointer struct 1" | |
20 | var num = 100 | |
21 | pointerStruct := PointerStruct{Name: &name, Num: &num} | |
22 | if DB.Create(&pointerStruct).Error != nil { | |
23 | t.Errorf("Failed to save pointer struct") | |
24 | } | |
25 | ||
26 | var pointerStructResult PointerStruct | |
27 | if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).Error; err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num { | |
28 | t.Errorf("Failed to query saved pointer struct") | |
29 | } | |
30 | ||
31 | var tableName = DB.NewScope(&PointerStruct{}).TableName() | |
32 | ||
33 | var normalStruct NormalStruct | |
34 | DB.Table(tableName).First(&normalStruct) | |
35 | if normalStruct.Name != name || normalStruct.Num != num { | |
36 | t.Errorf("Failed to query saved Normal struct") | |
37 | } | |
38 | ||
39 | var nilPointerStruct = PointerStruct{} | |
40 | if err := DB.Create(&nilPointerStruct).Error; err != nil { | |
41 | t.Errorf("Failed to save nil pointer struct", err) | |
42 | } | |
43 | ||
44 | var pointerStruct2 PointerStruct | |
45 | if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { | |
46 | t.Errorf("Failed to query saved nil pointer struct", err) | |
47 | } | |
48 | ||
49 | var normalStruct2 NormalStruct | |
50 | if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { | |
51 | t.Errorf("Failed to query saved nil pointer struct", err) | |
52 | } | |
53 | ||
54 | var partialNilPointerStruct1 = PointerStruct{Num: &num} | |
55 | if err := DB.Create(&partialNilPointerStruct1).Error; err != nil { | |
56 | t.Errorf("Failed to save partial nil pointer struct", err) | |
57 | } | |
58 | ||
59 | var pointerStruct3 PointerStruct | |
60 | if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num { | |
61 | t.Errorf("Failed to query saved partial nil pointer struct", err) | |
62 | } | |
63 | ||
64 | var normalStruct3 NormalStruct | |
65 | if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num { | |
66 | t.Errorf("Failed to query saved partial pointer struct", err) | |
67 | } | |
68 | ||
69 | var partialNilPointerStruct2 = PointerStruct{Name: &name} | |
70 | if err := DB.Create(&partialNilPointerStruct2).Error; err != nil { | |
71 | t.Errorf("Failed to save partial nil pointer struct", err) | |
72 | } | |
73 | ||
74 | var pointerStruct4 PointerStruct | |
75 | if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name { | |
76 | t.Errorf("Failed to query saved partial nil pointer struct", err) | |
77 | } | |
78 | ||
79 | var normalStruct4 NormalStruct | |
80 | if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name { | |
81 | t.Errorf("Failed to query saved partial pointer struct", err) | |
82 | } | |
83 | } |
0 | package gorm_test | |
1 | ||
2 | import "testing" | |
3 | ||
4 | type Cat struct { | |
5 | Id int | |
6 | Name string | |
7 | Toy Toy `gorm:"polymorphic:Owner;"` | |
8 | } | |
9 | ||
10 | type Dog struct { | |
11 | Id int | |
12 | Name string | |
13 | Toys []Toy `gorm:"polymorphic:Owner;"` | |
14 | } | |
15 | ||
16 | type Toy struct { | |
17 | Id int | |
18 | Name string | |
19 | OwnerId int | |
20 | OwnerType string | |
21 | } | |
22 | ||
23 | func TestPolymorphic(t *testing.T) { | |
24 | DB.AutoMigrate(&Cat{}) | |
25 | DB.AutoMigrate(&Dog{}) | |
26 | DB.AutoMigrate(&Toy{}) | |
27 | ||
28 | cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat nip"}} | |
29 | dog := Dog{Name: "Pluto", Toys: []Toy{Toy{Name: "orange ball"}, Toy{Name: "yellow ball"}}} | |
30 | DB.Save(&cat).Save(&dog) | |
31 | ||
32 | var catToys []Toy | |
33 | if DB.Model(&cat).Related(&catToys, "Toy").RecordNotFound() { | |
34 | t.Errorf("Did not find any has one polymorphic association") | |
35 | } else if len(catToys) != 1 { | |
36 | t.Errorf("Should have found only one polymorphic has one association") | |
37 | } else if catToys[0].Name != cat.Toy.Name { | |
38 | t.Errorf("Should have found the proper has one polymorphic association") | |
39 | } | |
40 | ||
41 | var dogToys []Toy | |
42 | if DB.Model(&dog).Related(&dogToys, "Toys").RecordNotFound() { | |
43 | t.Errorf("Did not find any polymorphic has many associations") | |
44 | } else if len(dogToys) != len(dog.Toys) { | |
45 | t.Errorf("Should have found all polymorphic has many associations") | |
46 | } | |
47 | ||
48 | if DB.Model(&cat).Association("Toy").Count() != 1 { | |
49 | t.Errorf("Should return one polymorphic has one association") | |
50 | } | |
51 | ||
52 | if DB.Model(&dog).Association("Toys").Count() != 2 { | |
53 | t.Errorf("Should return two polymorphic has many associations") | |
54 | } | |
55 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "database/sql" | |
4 | "database/sql/driver" | |
5 | "fmt" | |
6 | "reflect" | |
7 | "time" | |
8 | ||
9 | "github.com/lib/pq/hstore" | |
10 | ) | |
11 | ||
12 | type postgres struct { | |
13 | commonDialect | |
14 | } | |
15 | ||
16 | func (postgres) BinVar(i int) string { | |
17 | return fmt.Sprintf("$%v", i) | |
18 | } | |
19 | ||
20 | func (postgres) SupportLastInsertId() bool { | |
21 | return false | |
22 | } | |
23 | ||
24 | func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string { | |
25 | switch value.Kind() { | |
26 | case reflect.Bool: | |
27 | return "boolean" | |
28 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | |
29 | if autoIncrease { | |
30 | return "serial" | |
31 | } | |
32 | return "integer" | |
33 | case reflect.Int64, reflect.Uint64: | |
34 | if autoIncrease { | |
35 | return "bigserial" | |
36 | } | |
37 | return "bigint" | |
38 | case reflect.Float32, reflect.Float64: | |
39 | return "numeric" | |
40 | case reflect.String: | |
41 | if size > 0 && size < 65532 { | |
42 | return fmt.Sprintf("varchar(%d)", size) | |
43 | } | |
44 | return "text" | |
45 | case reflect.Struct: | |
46 | if _, ok := value.Interface().(time.Time); ok { | |
47 | return "timestamp with time zone" | |
48 | } | |
49 | case reflect.Map: | |
50 | if value.Type() == hstoreType { | |
51 | return "hstore" | |
52 | } | |
53 | default: | |
54 | if _, ok := value.Interface().([]byte); ok { | |
55 | return "bytea" | |
56 | } | |
57 | } | |
58 | panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String())) | |
59 | } | |
60 | ||
61 | func (s postgres) ReturningStr(tableName, key string) string { | |
62 | return fmt.Sprintf("RETURNING %v.%v", tableName, key) | |
63 | } | |
64 | ||
65 | func (s postgres) HasTable(scope *Scope, tableName string) bool { | |
66 | var count int | |
67 | s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName) | |
68 | return count > 0 | |
69 | } | |
70 | ||
71 | func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) bool { | |
72 | var count int | |
73 | s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName) | |
74 | return count > 0 | |
75 | } | |
76 | ||
77 | func (postgres) RemoveIndex(scope *Scope, indexName string) { | |
78 | scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) | |
79 | } | |
80 | ||
81 | func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { | |
82 | var count int | |
83 | s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName) | |
84 | return count > 0 | |
85 | } | |
86 | ||
87 | func (s postgres) CurrentDatabase(scope *Scope) (name string) { | |
88 | s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()") | |
89 | return | |
90 | } | |
91 | ||
92 | var hstoreType = reflect.TypeOf(Hstore{}) | |
93 | ||
94 | type Hstore map[string]*string | |
95 | ||
96 | func (h Hstore) Value() (driver.Value, error) { | |
97 | hstore := hstore.Hstore{Map: map[string]sql.NullString{}} | |
98 | if len(h) == 0 { | |
99 | return nil, nil | |
100 | } | |
101 | ||
102 | for key, value := range h { | |
103 | var s sql.NullString | |
104 | if value != nil { | |
105 | s.String = *value | |
106 | s.Valid = true | |
107 | } | |
108 | hstore.Map[key] = s | |
109 | } | |
110 | return hstore.Value() | |
111 | } | |
112 | ||
113 | func (h *Hstore) Scan(value interface{}) error { | |
114 | hstore := hstore.Hstore{} | |
115 | ||
116 | if err := hstore.Scan(value); err != nil { | |
117 | return err | |
118 | } | |
119 | ||
120 | if len(hstore.Map) == 0 { | |
121 | return nil | |
122 | } | |
123 | ||
124 | *h = Hstore{} | |
125 | for k := range hstore.Map { | |
126 | if hstore.Map[k].Valid { | |
127 | s := hstore.Map[k].String | |
128 | (*h)[k] = &s | |
129 | } else { | |
130 | (*h)[k] = nil | |
131 | } | |
132 | } | |
133 | ||
134 | return nil | |
135 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "database/sql/driver" | |
4 | "errors" | |
5 | "fmt" | |
6 | "reflect" | |
7 | "strings" | |
8 | ) | |
9 | ||
10 | func getRealValue(value reflect.Value, columns []string) (results []interface{}) { | |
11 | for _, column := range columns { | |
12 | if reflect.Indirect(value).FieldByName(column).IsValid() { | |
13 | result := reflect.Indirect(value).FieldByName(column).Interface() | |
14 | if r, ok := result.(driver.Valuer); ok { | |
15 | result, _ = r.Value() | |
16 | } | |
17 | results = append(results, result) | |
18 | } | |
19 | } | |
20 | return | |
21 | } | |
22 | ||
23 | func equalAsString(a interface{}, b interface{}) bool { | |
24 | return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b) | |
25 | } | |
26 | ||
27 | func Preload(scope *Scope) { | |
28 | if scope.Search.preload == nil { | |
29 | return | |
30 | } | |
31 | ||
32 | preloadMap := map[string]bool{} | |
33 | fields := scope.Fields() | |
34 | for _, preload := range scope.Search.preload { | |
35 | schema, conditions := preload.schema, preload.conditions | |
36 | keys := strings.Split(schema, ".") | |
37 | currentScope := scope | |
38 | currentFields := fields | |
39 | originalConditions := conditions | |
40 | conditions = []interface{}{} | |
41 | for i, key := range keys { | |
42 | var found bool | |
43 | if preloadMap[strings.Join(keys[:i+1], ".")] { | |
44 | goto nextLoop | |
45 | } | |
46 | ||
47 | if i == len(keys)-1 { | |
48 | conditions = originalConditions | |
49 | } | |
50 | ||
51 | for _, field := range currentFields { | |
52 | if field.Name != key || field.Relationship == nil { | |
53 | continue | |
54 | } | |
55 | ||
56 | found = true | |
57 | switch field.Relationship.Kind { | |
58 | case "has_one": | |
59 | currentScope.handleHasOnePreload(field, conditions) | |
60 | case "has_many": | |
61 | currentScope.handleHasManyPreload(field, conditions) | |
62 | case "belongs_to": | |
63 | currentScope.handleBelongsToPreload(field, conditions) | |
64 | case "many_to_many": | |
65 | currentScope.handleHasManyToManyPreload(field, conditions) | |
66 | default: | |
67 | currentScope.Err(errors.New("not supported relation")) | |
68 | } | |
69 | break | |
70 | } | |
71 | ||
72 | if !found { | |
73 | value := reflect.ValueOf(currentScope.Value) | |
74 | if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface { | |
75 | value = value.Index(0).Elem() | |
76 | } | |
77 | scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type())) | |
78 | return | |
79 | } | |
80 | ||
81 | preloadMap[strings.Join(keys[:i+1], ".")] = true | |
82 | ||
83 | nextLoop: | |
84 | if i < len(keys)-1 { | |
85 | currentScope = currentScope.getColumnsAsScope(key) | |
86 | currentFields = currentScope.Fields() | |
87 | } | |
88 | } | |
89 | } | |
90 | ||
91 | } | |
92 | ||
93 | func makeSlice(typ reflect.Type) interface{} { | |
94 | if typ.Kind() == reflect.Slice { | |
95 | typ = typ.Elem() | |
96 | } | |
97 | sliceType := reflect.SliceOf(typ) | |
98 | slice := reflect.New(sliceType) | |
99 | slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) | |
100 | return slice.Interface() | |
101 | } | |
102 | ||
103 | func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { | |
104 | relation := field.Relationship | |
105 | ||
106 | primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) | |
107 | if len(primaryKeys) == 0 { | |
108 | return | |
109 | } | |
110 | ||
111 | results := makeSlice(field.Struct.Type) | |
112 | scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) | |
113 | resultValues := reflect.Indirect(reflect.ValueOf(results)) | |
114 | ||
115 | for i := 0; i < resultValues.Len(); i++ { | |
116 | result := resultValues.Index(i) | |
117 | if scope.IndirectValue().Kind() == reflect.Slice { | |
118 | value := getRealValue(result, relation.ForeignFieldNames) | |
119 | objects := scope.IndirectValue() | |
120 | for j := 0; j < objects.Len(); j++ { | |
121 | if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) { | |
122 | reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) | |
123 | break | |
124 | } | |
125 | } | |
126 | } else { | |
127 | if err := scope.SetColumn(field, result); err != nil { | |
128 | scope.Err(err) | |
129 | return | |
130 | } | |
131 | } | |
132 | } | |
133 | } | |
134 | ||
135 | func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { | |
136 | relation := field.Relationship | |
137 | primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) | |
138 | if len(primaryKeys) == 0 { | |
139 | return | |
140 | } | |
141 | ||
142 | results := makeSlice(field.Struct.Type) | |
143 | scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) | |
144 | resultValues := reflect.Indirect(reflect.ValueOf(results)) | |
145 | ||
146 | if scope.IndirectValue().Kind() == reflect.Slice { | |
147 | for i := 0; i < resultValues.Len(); i++ { | |
148 | result := resultValues.Index(i) | |
149 | value := getRealValue(result, relation.ForeignFieldNames) | |
150 | objects := scope.IndirectValue() | |
151 | for j := 0; j < objects.Len(); j++ { | |
152 | object := reflect.Indirect(objects.Index(j)) | |
153 | if equalAsString(getRealValue(object, relation.AssociationForeignFieldNames), value) { | |
154 | f := object.FieldByName(field.Name) | |
155 | f.Set(reflect.Append(f, result)) | |
156 | break | |
157 | } | |
158 | } | |
159 | } | |
160 | } else { | |
161 | scope.SetColumn(field, resultValues) | |
162 | } | |
163 | } | |
164 | ||
165 | func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { | |
166 | relation := field.Relationship | |
167 | primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames) | |
168 | if len(primaryKeys) == 0 { | |
169 | return | |
170 | } | |
171 | ||
172 | results := makeSlice(field.Struct.Type) | |
173 | scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) | |
174 | resultValues := reflect.Indirect(reflect.ValueOf(results)) | |
175 | ||
176 | for i := 0; i < resultValues.Len(); i++ { | |
177 | result := resultValues.Index(i) | |
178 | if scope.IndirectValue().Kind() == reflect.Slice { | |
179 | value := getRealValue(result, relation.AssociationForeignFieldNames) | |
180 | objects := scope.IndirectValue() | |
181 | for j := 0; j < objects.Len(); j++ { | |
182 | object := reflect.Indirect(objects.Index(j)) | |
183 | if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) { | |
184 | object.FieldByName(field.Name).Set(result) | |
185 | } | |
186 | } | |
187 | } else { | |
188 | scope.SetColumn(field, result) | |
189 | } | |
190 | } | |
191 | } | |
192 | ||
193 | func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interface{}) { | |
194 | relation := field.Relationship | |
195 | ||
196 | joinTableHandler := relation.JoinTableHandler | |
197 | destType := field.StructField.Struct.Type.Elem() | |
198 | var isPtr bool | |
199 | if destType.Kind() == reflect.Ptr { | |
200 | isPtr = true | |
201 | destType = destType.Elem() | |
202 | } | |
203 | ||
204 | var sourceKeys []string | |
205 | var linkHash = make(map[string][]reflect.Value) | |
206 | ||
207 | for _, key := range joinTableHandler.SourceForeignKeys() { | |
208 | sourceKeys = append(sourceKeys, key.DBName) | |
209 | } | |
210 | ||
211 | db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*") | |
212 | preloadJoinDB := joinTableHandler.JoinWith(joinTableHandler, db, scope.Value) | |
213 | if len(conditions) > 0 { | |
214 | preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...) | |
215 | } | |
216 | rows, err := preloadJoinDB.Rows() | |
217 | ||
218 | if scope.Err(err) != nil { | |
219 | return | |
220 | } | |
221 | defer rows.Close() | |
222 | ||
223 | columns, _ := rows.Columns() | |
224 | for rows.Next() { | |
225 | elem := reflect.New(destType).Elem() | |
226 | var values = make([]interface{}, len(columns)) | |
227 | ||
228 | fields := scope.New(elem.Addr().Interface()).Fields() | |
229 | ||
230 | for index, column := range columns { | |
231 | if field, ok := fields[column]; ok { | |
232 | if field.Field.Kind() == reflect.Ptr { | |
233 | values[index] = field.Field.Addr().Interface() | |
234 | } else { | |
235 | values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface() | |
236 | } | |
237 | } else { | |
238 | var i interface{} | |
239 | values[index] = &i | |
240 | } | |
241 | } | |
242 | ||
243 | scope.Err(rows.Scan(values...)) | |
244 | ||
245 | var sourceKey []interface{} | |
246 | ||
247 | for index, column := range columns { | |
248 | value := values[index] | |
249 | if field, ok := fields[column]; ok { | |
250 | if field.Field.Kind() == reflect.Ptr { | |
251 | field.Field.Set(reflect.ValueOf(value).Elem()) | |
252 | } else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() { | |
253 | field.Field.Set(v) | |
254 | } | |
255 | } else if strInSlice(column, sourceKeys) { | |
256 | sourceKey = append(sourceKey, *(value.(*interface{}))) | |
257 | } | |
258 | } | |
259 | ||
260 | if len(sourceKey) != 0 { | |
261 | if isPtr { | |
262 | linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem.Addr()) | |
263 | } else { | |
264 | linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem) | |
265 | } | |
266 | } | |
267 | } | |
268 | ||
269 | var associationForeignStructFieldNames []string | |
270 | for _, dbName := range relation.AssociationForeignFieldNames { | |
271 | if field, ok := scope.FieldByName(dbName); ok { | |
272 | associationForeignStructFieldNames = append(associationForeignStructFieldNames, field.Name) | |
273 | } | |
274 | } | |
275 | ||
276 | if scope.IndirectValue().Kind() == reflect.Slice { | |
277 | objects := scope.IndirectValue() | |
278 | for j := 0; j < objects.Len(); j++ { | |
279 | object := reflect.Indirect(objects.Index(j)) | |
280 | source := getRealValue(object, associationForeignStructFieldNames) | |
281 | field := object.FieldByName(field.Name) | |
282 | for _, link := range linkHash[toString(source)] { | |
283 | field.Set(reflect.Append(field, link)) | |
284 | } | |
285 | } | |
286 | } else { | |
287 | object := scope.IndirectValue() | |
288 | source := getRealValue(object, associationForeignStructFieldNames) | |
289 | field := object.FieldByName(field.Name) | |
290 | for _, link := range linkHash[toString(source)] { | |
291 | field.Set(reflect.Append(field, link)) | |
292 | } | |
293 | } | |
294 | } | |
295 | ||
296 | func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) { | |
297 | values := scope.IndirectValue() | |
298 | switch values.Kind() { | |
299 | case reflect.Slice: | |
300 | for i := 0; i < values.Len(); i++ { | |
301 | var result []interface{} | |
302 | for _, column := range columns { | |
303 | result = append(result, reflect.Indirect(values.Index(i)).FieldByName(column).Interface()) | |
304 | } | |
305 | results = append(results, result) | |
306 | } | |
307 | case reflect.Struct: | |
308 | var result []interface{} | |
309 | for _, column := range columns { | |
310 | result = append(result, values.FieldByName(column).Interface()) | |
311 | } | |
312 | return [][]interface{}{result} | |
313 | } | |
314 | return | |
315 | } | |
316 | ||
317 | func (scope *Scope) getColumnsAsScope(column string) *Scope { | |
318 | values := scope.IndirectValue() | |
319 | switch values.Kind() { | |
320 | case reflect.Slice: | |
321 | modelType := values.Type().Elem() | |
322 | if modelType.Kind() == reflect.Ptr { | |
323 | modelType = modelType.Elem() | |
324 | } | |
325 | fieldStruct, _ := modelType.FieldByName(column) | |
326 | var columns reflect.Value | |
327 | if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr { | |
328 | columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem() | |
329 | } else { | |
330 | columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem() | |
331 | } | |
332 | for i := 0; i < values.Len(); i++ { | |
333 | column := reflect.Indirect(values.Index(i)).FieldByName(column) | |
334 | if column.Kind() == reflect.Ptr { | |
335 | column = column.Elem() | |
336 | } | |
337 | if column.Kind() == reflect.Slice { | |
338 | for i := 0; i < column.Len(); i++ { | |
339 | elem := column.Index(i) | |
340 | if elem.CanAddr() { | |
341 | columns = reflect.Append(columns, elem.Addr()) | |
342 | } | |
343 | } | |
344 | } else { | |
345 | if column.CanAddr() { | |
346 | columns = reflect.Append(columns, column.Addr()) | |
347 | } | |
348 | } | |
349 | } | |
350 | return scope.New(columns.Interface()) | |
351 | case reflect.Struct: | |
352 | field := values.FieldByName(column) | |
353 | if !field.CanAddr() { | |
354 | return nil | |
355 | } | |
356 | return scope.New(field.Addr().Interface()) | |
357 | } | |
358 | return nil | |
359 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "encoding/json" | |
4 | "os" | |
5 | "reflect" | |
6 | "testing" | |
7 | ) | |
8 | ||
9 | func getPreloadUser(name string) *User { | |
10 | return getPreparedUser(name, "Preload") | |
11 | } | |
12 | ||
13 | func checkUserHasPreloadData(user User, t *testing.T) { | |
14 | u := getPreloadUser(user.Name) | |
15 | if user.BillingAddress.Address1 != u.BillingAddress.Address1 { | |
16 | t.Error("Failed to preload user's BillingAddress") | |
17 | } | |
18 | ||
19 | if user.ShippingAddress.Address1 != u.ShippingAddress.Address1 { | |
20 | t.Error("Failed to preload user's ShippingAddress") | |
21 | } | |
22 | ||
23 | if user.CreditCard.Number != u.CreditCard.Number { | |
24 | t.Error("Failed to preload user's CreditCard") | |
25 | } | |
26 | ||
27 | if user.Company.Name != u.Company.Name { | |
28 | t.Error("Failed to preload user's Company") | |
29 | } | |
30 | ||
31 | if len(user.Emails) != len(u.Emails) { | |
32 | t.Error("Failed to preload user's Emails") | |
33 | } else { | |
34 | var found int | |
35 | for _, e1 := range u.Emails { | |
36 | for _, e2 := range user.Emails { | |
37 | if e1.Email == e2.Email { | |
38 | found++ | |
39 | break | |
40 | } | |
41 | } | |
42 | } | |
43 | if found != len(u.Emails) { | |
44 | t.Error("Failed to preload user's email details") | |
45 | } | |
46 | } | |
47 | } | |
48 | ||
49 | func TestPreload(t *testing.T) { | |
50 | user1 := getPreloadUser("user1") | |
51 | DB.Save(user1) | |
52 | ||
53 | preloadDB := DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress"). | |
54 | Preload("CreditCard").Preload("Emails").Preload("Company") | |
55 | var user User | |
56 | preloadDB.Find(&user) | |
57 | checkUserHasPreloadData(user, t) | |
58 | ||
59 | user2 := getPreloadUser("user2") | |
60 | DB.Save(user2) | |
61 | ||
62 | user3 := getPreloadUser("user3") | |
63 | DB.Save(user3) | |
64 | ||
65 | var users []User | |
66 | preloadDB.Find(&users) | |
67 | ||
68 | for _, user := range users { | |
69 | checkUserHasPreloadData(user, t) | |
70 | } | |
71 | ||
72 | var users2 []*User | |
73 | preloadDB.Find(&users2) | |
74 | ||
75 | for _, user := range users2 { | |
76 | checkUserHasPreloadData(*user, t) | |
77 | } | |
78 | ||
79 | var users3 []*User | |
80 | preloadDB.Preload("Emails", "email = ?", user3.Emails[0].Email).Find(&users3) | |
81 | ||
82 | for _, user := range users3 { | |
83 | if user.Name == user3.Name { | |
84 | if len(user.Emails) != 1 { | |
85 | t.Errorf("should only preload one emails for user3 when with condition") | |
86 | } | |
87 | } else if len(user.Emails) != 0 { | |
88 | t.Errorf("should not preload any emails for other users when with condition") | |
89 | } | |
90 | } | |
91 | } | |
92 | ||
93 | func TestNestedPreload1(t *testing.T) { | |
94 | type ( | |
95 | Level1 struct { | |
96 | ID uint | |
97 | Value string | |
98 | Level2ID uint | |
99 | } | |
100 | Level2 struct { | |
101 | ID uint | |
102 | Level1 Level1 | |
103 | Level3ID uint | |
104 | } | |
105 | Level3 struct { | |
106 | ID uint | |
107 | Name string | |
108 | Level2 Level2 | |
109 | } | |
110 | ) | |
111 | DB.DropTableIfExists(&Level3{}) | |
112 | DB.DropTableIfExists(&Level2{}) | |
113 | DB.DropTableIfExists(&Level1{}) | |
114 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { | |
115 | panic(err) | |
116 | } | |
117 | ||
118 | want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} | |
119 | if err := DB.Create(&want).Error; err != nil { | |
120 | panic(err) | |
121 | } | |
122 | ||
123 | var got Level3 | |
124 | if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { | |
125 | panic(err) | |
126 | } | |
127 | ||
128 | if !reflect.DeepEqual(got, want) { | |
129 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
130 | } | |
131 | } | |
132 | ||
133 | func TestNestedPreload2(t *testing.T) { | |
134 | type ( | |
135 | Level1 struct { | |
136 | ID uint | |
137 | Value string | |
138 | Level2ID uint | |
139 | } | |
140 | Level2 struct { | |
141 | ID uint | |
142 | Level1s []*Level1 | |
143 | Level3ID uint | |
144 | } | |
145 | Level3 struct { | |
146 | ID uint | |
147 | Name string | |
148 | Level2s []Level2 | |
149 | } | |
150 | ) | |
151 | DB.DropTableIfExists(&Level3{}) | |
152 | DB.DropTableIfExists(&Level2{}) | |
153 | DB.DropTableIfExists(&Level1{}) | |
154 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { | |
155 | panic(err) | |
156 | } | |
157 | ||
158 | want := Level3{ | |
159 | Level2s: []Level2{ | |
160 | { | |
161 | Level1s: []*Level1{ | |
162 | &Level1{Value: "value1"}, | |
163 | &Level1{Value: "value2"}, | |
164 | }, | |
165 | }, | |
166 | { | |
167 | Level1s: []*Level1{ | |
168 | &Level1{Value: "value3"}, | |
169 | }, | |
170 | }, | |
171 | }, | |
172 | } | |
173 | if err := DB.Create(&want).Error; err != nil { | |
174 | panic(err) | |
175 | } | |
176 | ||
177 | var got Level3 | |
178 | if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { | |
179 | panic(err) | |
180 | } | |
181 | ||
182 | if !reflect.DeepEqual(got, want) { | |
183 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
184 | } | |
185 | } | |
186 | ||
187 | func TestNestedPreload3(t *testing.T) { | |
188 | type ( | |
189 | Level1 struct { | |
190 | ID uint | |
191 | Value string | |
192 | Level2ID uint | |
193 | } | |
194 | Level2 struct { | |
195 | ID uint | |
196 | Level1 Level1 | |
197 | Level3ID uint | |
198 | } | |
199 | Level3 struct { | |
200 | Name string | |
201 | ID uint | |
202 | Level2s []Level2 | |
203 | } | |
204 | ) | |
205 | DB.DropTableIfExists(&Level3{}) | |
206 | DB.DropTableIfExists(&Level2{}) | |
207 | DB.DropTableIfExists(&Level1{}) | |
208 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { | |
209 | panic(err) | |
210 | } | |
211 | ||
212 | want := Level3{ | |
213 | Level2s: []Level2{ | |
214 | {Level1: Level1{Value: "value1"}}, | |
215 | {Level1: Level1{Value: "value2"}}, | |
216 | }, | |
217 | } | |
218 | if err := DB.Create(&want).Error; err != nil { | |
219 | panic(err) | |
220 | } | |
221 | ||
222 | var got Level3 | |
223 | if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { | |
224 | panic(err) | |
225 | } | |
226 | ||
227 | if !reflect.DeepEqual(got, want) { | |
228 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
229 | } | |
230 | } | |
231 | ||
232 | func TestNestedPreload4(t *testing.T) { | |
233 | type ( | |
234 | Level1 struct { | |
235 | ID uint | |
236 | Value string | |
237 | Level2ID uint | |
238 | } | |
239 | Level2 struct { | |
240 | ID uint | |
241 | Level1s []Level1 | |
242 | Level3ID uint | |
243 | } | |
244 | Level3 struct { | |
245 | ID uint | |
246 | Name string | |
247 | Level2 Level2 | |
248 | } | |
249 | ) | |
250 | DB.DropTableIfExists(&Level3{}) | |
251 | DB.DropTableIfExists(&Level2{}) | |
252 | DB.DropTableIfExists(&Level1{}) | |
253 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { | |
254 | panic(err) | |
255 | } | |
256 | ||
257 | want := Level3{ | |
258 | Level2: Level2{ | |
259 | Level1s: []Level1{ | |
260 | Level1{Value: "value1"}, | |
261 | Level1{Value: "value2"}, | |
262 | }, | |
263 | }, | |
264 | } | |
265 | if err := DB.Create(&want).Error; err != nil { | |
266 | panic(err) | |
267 | } | |
268 | ||
269 | var got Level3 | |
270 | if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { | |
271 | panic(err) | |
272 | } | |
273 | ||
274 | if !reflect.DeepEqual(got, want) { | |
275 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
276 | } | |
277 | } | |
278 | ||
279 | // Slice: []Level3 | |
280 | func TestNestedPreload5(t *testing.T) { | |
281 | type ( | |
282 | Level1 struct { | |
283 | ID uint | |
284 | Value string | |
285 | Level2ID uint | |
286 | } | |
287 | Level2 struct { | |
288 | ID uint | |
289 | Level1 Level1 | |
290 | Level3ID uint | |
291 | } | |
292 | Level3 struct { | |
293 | ID uint | |
294 | Name string | |
295 | Level2 Level2 | |
296 | } | |
297 | ) | |
298 | DB.DropTableIfExists(&Level3{}) | |
299 | DB.DropTableIfExists(&Level2{}) | |
300 | DB.DropTableIfExists(&Level1{}) | |
301 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { | |
302 | panic(err) | |
303 | } | |
304 | ||
305 | want := make([]Level3, 2) | |
306 | want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} | |
307 | if err := DB.Create(&want[0]).Error; err != nil { | |
308 | panic(err) | |
309 | } | |
310 | want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} | |
311 | if err := DB.Create(&want[1]).Error; err != nil { | |
312 | panic(err) | |
313 | } | |
314 | ||
315 | var got []Level3 | |
316 | if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { | |
317 | panic(err) | |
318 | } | |
319 | ||
320 | if !reflect.DeepEqual(got, want) { | |
321 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
322 | } | |
323 | } | |
324 | ||
325 | func TestNestedPreload6(t *testing.T) { | |
326 | type ( | |
327 | Level1 struct { | |
328 | ID uint | |
329 | Value string | |
330 | Level2ID uint | |
331 | } | |
332 | Level2 struct { | |
333 | ID uint | |
334 | Level1s []Level1 | |
335 | Level3ID uint | |
336 | } | |
337 | Level3 struct { | |
338 | ID uint | |
339 | Name string | |
340 | Level2s []Level2 | |
341 | } | |
342 | ) | |
343 | DB.DropTableIfExists(&Level3{}) | |
344 | DB.DropTableIfExists(&Level2{}) | |
345 | DB.DropTableIfExists(&Level1{}) | |
346 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { | |
347 | panic(err) | |
348 | } | |
349 | ||
350 | want := make([]Level3, 2) | |
351 | want[0] = Level3{ | |
352 | Level2s: []Level2{ | |
353 | { | |
354 | Level1s: []Level1{ | |
355 | {Value: "value1"}, | |
356 | {Value: "value2"}, | |
357 | }, | |
358 | }, | |
359 | { | |
360 | Level1s: []Level1{ | |
361 | {Value: "value3"}, | |
362 | }, | |
363 | }, | |
364 | }, | |
365 | } | |
366 | if err := DB.Create(&want[0]).Error; err != nil { | |
367 | panic(err) | |
368 | } | |
369 | ||
370 | want[1] = Level3{ | |
371 | Level2s: []Level2{ | |
372 | { | |
373 | Level1s: []Level1{ | |
374 | {Value: "value3"}, | |
375 | {Value: "value4"}, | |
376 | }, | |
377 | }, | |
378 | { | |
379 | Level1s: []Level1{ | |
380 | {Value: "value5"}, | |
381 | }, | |
382 | }, | |
383 | }, | |
384 | } | |
385 | if err := DB.Create(&want[1]).Error; err != nil { | |
386 | panic(err) | |
387 | } | |
388 | ||
389 | var got []Level3 | |
390 | if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { | |
391 | panic(err) | |
392 | } | |
393 | ||
394 | if !reflect.DeepEqual(got, want) { | |
395 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
396 | } | |
397 | } | |
398 | ||
399 | func TestNestedPreload7(t *testing.T) { | |
400 | type ( | |
401 | Level1 struct { | |
402 | ID uint | |
403 | Value string | |
404 | Level2ID uint | |
405 | } | |
406 | Level2 struct { | |
407 | ID uint | |
408 | Level1 Level1 | |
409 | Level3ID uint | |
410 | } | |
411 | Level3 struct { | |
412 | ID uint | |
413 | Name string | |
414 | Level2s []Level2 | |
415 | } | |
416 | ) | |
417 | DB.DropTableIfExists(&Level3{}) | |
418 | DB.DropTableIfExists(&Level2{}) | |
419 | DB.DropTableIfExists(&Level1{}) | |
420 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { | |
421 | panic(err) | |
422 | } | |
423 | ||
424 | want := make([]Level3, 2) | |
425 | want[0] = Level3{ | |
426 | Level2s: []Level2{ | |
427 | {Level1: Level1{Value: "value1"}}, | |
428 | {Level1: Level1{Value: "value2"}}, | |
429 | }, | |
430 | } | |
431 | if err := DB.Create(&want[0]).Error; err != nil { | |
432 | panic(err) | |
433 | } | |
434 | ||
435 | want[1] = Level3{ | |
436 | Level2s: []Level2{ | |
437 | {Level1: Level1{Value: "value3"}}, | |
438 | {Level1: Level1{Value: "value4"}}, | |
439 | }, | |
440 | } | |
441 | if err := DB.Create(&want[1]).Error; err != nil { | |
442 | panic(err) | |
443 | } | |
444 | ||
445 | var got []Level3 | |
446 | if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { | |
447 | panic(err) | |
448 | } | |
449 | ||
450 | if !reflect.DeepEqual(got, want) { | |
451 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
452 | } | |
453 | } | |
454 | ||
455 | func TestNestedPreload8(t *testing.T) { | |
456 | type ( | |
457 | Level1 struct { | |
458 | ID uint | |
459 | Value string | |
460 | Level2ID uint | |
461 | } | |
462 | Level2 struct { | |
463 | ID uint | |
464 | Level1s []Level1 | |
465 | Level3ID uint | |
466 | } | |
467 | Level3 struct { | |
468 | ID uint | |
469 | Name string | |
470 | Level2 Level2 | |
471 | } | |
472 | ) | |
473 | DB.DropTableIfExists(&Level3{}) | |
474 | DB.DropTableIfExists(&Level2{}) | |
475 | DB.DropTableIfExists(&Level1{}) | |
476 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { | |
477 | panic(err) | |
478 | } | |
479 | ||
480 | want := make([]Level3, 2) | |
481 | want[0] = Level3{ | |
482 | Level2: Level2{ | |
483 | Level1s: []Level1{ | |
484 | Level1{Value: "value1"}, | |
485 | Level1{Value: "value2"}, | |
486 | }, | |
487 | }, | |
488 | } | |
489 | if err := DB.Create(&want[0]).Error; err != nil { | |
490 | panic(err) | |
491 | } | |
492 | want[1] = Level3{ | |
493 | Level2: Level2{ | |
494 | Level1s: []Level1{ | |
495 | Level1{Value: "value3"}, | |
496 | Level1{Value: "value4"}, | |
497 | }, | |
498 | }, | |
499 | } | |
500 | if err := DB.Create(&want[1]).Error; err != nil { | |
501 | panic(err) | |
502 | } | |
503 | ||
504 | var got []Level3 | |
505 | if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { | |
506 | panic(err) | |
507 | } | |
508 | ||
509 | if !reflect.DeepEqual(got, want) { | |
510 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
511 | } | |
512 | } | |
513 | ||
514 | func TestNestedPreload9(t *testing.T) { | |
515 | type ( | |
516 | Level0 struct { | |
517 | ID uint | |
518 | Value string | |
519 | Level1ID uint | |
520 | } | |
521 | Level1 struct { | |
522 | ID uint | |
523 | Value string | |
524 | Level2ID uint | |
525 | Level2_1ID uint | |
526 | Level0s []Level0 | |
527 | } | |
528 | Level2 struct { | |
529 | ID uint | |
530 | Level1s []Level1 | |
531 | Level3ID uint | |
532 | } | |
533 | Level2_1 struct { | |
534 | ID uint | |
535 | Level1s []Level1 | |
536 | Level3ID uint | |
537 | } | |
538 | Level3 struct { | |
539 | ID uint | |
540 | Name string | |
541 | Level2 Level2 | |
542 | Level2_1 Level2_1 | |
543 | } | |
544 | ) | |
545 | DB.DropTableIfExists(&Level3{}) | |
546 | DB.DropTableIfExists(&Level2{}) | |
547 | DB.DropTableIfExists(&Level2_1{}) | |
548 | DB.DropTableIfExists(&Level1{}) | |
549 | DB.DropTableIfExists(&Level0{}) | |
550 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil { | |
551 | panic(err) | |
552 | } | |
553 | ||
554 | want := make([]Level3, 2) | |
555 | want[0] = Level3{ | |
556 | Level2: Level2{ | |
557 | Level1s: []Level1{ | |
558 | Level1{Value: "value1"}, | |
559 | Level1{Value: "value2"}, | |
560 | }, | |
561 | }, | |
562 | Level2_1: Level2_1{ | |
563 | Level1s: []Level1{ | |
564 | Level1{ | |
565 | Value: "value1-1", | |
566 | Level0s: []Level0{{Value: "Level0-1"}}, | |
567 | }, | |
568 | Level1{ | |
569 | Value: "value2-2", | |
570 | Level0s: []Level0{{Value: "Level0-2"}}, | |
571 | }, | |
572 | }, | |
573 | }, | |
574 | } | |
575 | if err := DB.Create(&want[0]).Error; err != nil { | |
576 | panic(err) | |
577 | } | |
578 | want[1] = Level3{ | |
579 | Level2: Level2{ | |
580 | Level1s: []Level1{ | |
581 | Level1{Value: "value3"}, | |
582 | Level1{Value: "value4"}, | |
583 | }, | |
584 | }, | |
585 | Level2_1: Level2_1{ | |
586 | Level1s: []Level1{ | |
587 | Level1{Value: "value3-3"}, | |
588 | Level1{Value: "value4-4"}, | |
589 | }, | |
590 | }, | |
591 | } | |
592 | if err := DB.Create(&want[1]).Error; err != nil { | |
593 | panic(err) | |
594 | } | |
595 | ||
596 | var got []Level3 | |
597 | if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil { | |
598 | panic(err) | |
599 | } | |
600 | ||
601 | if !reflect.DeepEqual(got, want) { | |
602 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
603 | } | |
604 | } | |
605 | ||
606 | func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { | |
607 | if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { | |
608 | return | |
609 | } | |
610 | ||
611 | type ( | |
612 | Level1 struct { | |
613 | ID uint `gorm:"primary_key;"` | |
614 | LanguageCode string `gorm:"primary_key"` | |
615 | Value string | |
616 | } | |
617 | Level2 struct { | |
618 | ID uint `gorm:"primary_key;"` | |
619 | LanguageCode string `gorm:"primary_key"` | |
620 | Value string | |
621 | Level1s []Level1 `gorm:"many2many:levels;"` | |
622 | } | |
623 | ) | |
624 | ||
625 | DB.DropTableIfExists(&Level2{}) | |
626 | DB.DropTableIfExists(&Level1{}) | |
627 | DB.Table("levels").DropTableIfExists("levels") | |
628 | ||
629 | if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { | |
630 | panic(err) | |
631 | } | |
632 | ||
633 | want := Level2{Value: "Bob", LanguageCode: "ru", Level1s: []Level1{ | |
634 | {Value: "ru", LanguageCode: "ru"}, | |
635 | {Value: "en", LanguageCode: "en"}, | |
636 | }} | |
637 | if err := DB.Save(&want).Error; err != nil { | |
638 | panic(err) | |
639 | } | |
640 | ||
641 | want2 := Level2{Value: "Tom", LanguageCode: "zh", Level1s: []Level1{ | |
642 | {Value: "zh", LanguageCode: "zh"}, | |
643 | {Value: "de", LanguageCode: "de"}, | |
644 | }} | |
645 | if err := DB.Save(&want2).Error; err != nil { | |
646 | panic(err) | |
647 | } | |
648 | ||
649 | var got Level2 | |
650 | if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { | |
651 | panic(err) | |
652 | } | |
653 | ||
654 | if !reflect.DeepEqual(got, want) { | |
655 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
656 | } | |
657 | ||
658 | var got2 Level2 | |
659 | if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { | |
660 | panic(err) | |
661 | } | |
662 | ||
663 | if !reflect.DeepEqual(got2, want2) { | |
664 | t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) | |
665 | } | |
666 | ||
667 | var got3 []Level2 | |
668 | if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { | |
669 | panic(err) | |
670 | } | |
671 | ||
672 | if !reflect.DeepEqual(got3, []Level2{got, got2}) { | |
673 | t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) | |
674 | } | |
675 | ||
676 | var got4 []Level2 | |
677 | if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { | |
678 | panic(err) | |
679 | } | |
680 | ||
681 | var ruLevel1 Level1 | |
682 | var zhLevel1 Level1 | |
683 | DB.First(&ruLevel1, "value = ?", "ru") | |
684 | DB.First(&zhLevel1, "value = ?", "zh") | |
685 | ||
686 | got.Level1s = []Level1{ruLevel1} | |
687 | got2.Level1s = []Level1{zhLevel1} | |
688 | if !reflect.DeepEqual(got4, []Level2{got, got2}) { | |
689 | t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) | |
690 | } | |
691 | } | |
692 | ||
693 | func TestManyToManyPreloadForPointer(t *testing.T) { | |
694 | type ( | |
695 | Level1 struct { | |
696 | ID uint `gorm:"primary_key;"` | |
697 | Value string | |
698 | } | |
699 | Level2 struct { | |
700 | ID uint `gorm:"primary_key;"` | |
701 | Value string | |
702 | Level1s []*Level1 `gorm:"many2many:levels;"` | |
703 | } | |
704 | ) | |
705 | ||
706 | DB.DropTableIfExists(&Level2{}) | |
707 | DB.DropTableIfExists(&Level1{}) | |
708 | DB.Table("levels").DropTableIfExists("levels") | |
709 | ||
710 | if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { | |
711 | panic(err) | |
712 | } | |
713 | ||
714 | want := Level2{Value: "Bob", Level1s: []*Level1{ | |
715 | {Value: "ru"}, | |
716 | {Value: "en"}, | |
717 | }} | |
718 | if err := DB.Save(&want).Error; err != nil { | |
719 | panic(err) | |
720 | } | |
721 | ||
722 | want2 := Level2{Value: "Tom", Level1s: []*Level1{ | |
723 | {Value: "zh"}, | |
724 | {Value: "de"}, | |
725 | }} | |
726 | if err := DB.Save(&want2).Error; err != nil { | |
727 | panic(err) | |
728 | } | |
729 | ||
730 | var got Level2 | |
731 | if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { | |
732 | panic(err) | |
733 | } | |
734 | ||
735 | if !reflect.DeepEqual(got, want) { | |
736 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
737 | } | |
738 | ||
739 | var got2 Level2 | |
740 | if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { | |
741 | panic(err) | |
742 | } | |
743 | ||
744 | if !reflect.DeepEqual(got2, want2) { | |
745 | t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) | |
746 | } | |
747 | ||
748 | var got3 []Level2 | |
749 | if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { | |
750 | panic(err) | |
751 | } | |
752 | ||
753 | if !reflect.DeepEqual(got3, []Level2{got, got2}) { | |
754 | t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) | |
755 | } | |
756 | ||
757 | var got4 []Level2 | |
758 | if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { | |
759 | panic(err) | |
760 | } | |
761 | ||
762 | var ruLevel1 Level1 | |
763 | var zhLevel1 Level1 | |
764 | DB.First(&ruLevel1, "value = ?", "ru") | |
765 | DB.First(&zhLevel1, "value = ?", "zh") | |
766 | ||
767 | got.Level1s = []*Level1{&ruLevel1} | |
768 | got2.Level1s = []*Level1{&zhLevel1} | |
769 | if !reflect.DeepEqual(got4, []Level2{got, got2}) { | |
770 | t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) | |
771 | } | |
772 | } | |
773 | ||
774 | func TestNilPointerSlice(t *testing.T) { | |
775 | type ( | |
776 | Level3 struct { | |
777 | ID uint `gorm:"primary_key;"` | |
778 | Value string | |
779 | } | |
780 | Level2 struct { | |
781 | ID uint `gorm:"primary_key;"` | |
782 | Value string | |
783 | Level3ID uint | |
784 | Level3 *Level3 | |
785 | } | |
786 | Level1 struct { | |
787 | ID uint `gorm:"primary_key;"` | |
788 | Value string | |
789 | Level2ID uint | |
790 | Level2 *Level2 | |
791 | } | |
792 | ) | |
793 | ||
794 | DB.DropTableIfExists(&Level3{}) | |
795 | DB.DropTableIfExists(&Level2{}) | |
796 | DB.DropTableIfExists(&Level1{}) | |
797 | ||
798 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { | |
799 | panic(err) | |
800 | } | |
801 | ||
802 | want := Level1{Value: "Bob", Level2: &Level2{ | |
803 | Value: "en", | |
804 | Level3: &Level3{ | |
805 | Value: "native", | |
806 | }, | |
807 | }} | |
808 | if err := DB.Save(&want).Error; err != nil { | |
809 | panic(err) | |
810 | } | |
811 | ||
812 | want2 := Level1{Value: "Tom", Level2: nil} | |
813 | if err := DB.Save(&want2).Error; err != nil { | |
814 | panic(err) | |
815 | } | |
816 | ||
817 | var got []Level1 | |
818 | if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).Error; err != nil { | |
819 | panic(err) | |
820 | } | |
821 | ||
822 | if len(got) != 2 { | |
823 | t.Fatalf("got %v items, expected 2", len(got)) | |
824 | } | |
825 | ||
826 | if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { | |
827 | t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) | |
828 | } | |
829 | ||
830 | if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { | |
831 | t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want2)) | |
832 | } | |
833 | } | |
834 | ||
835 | func toJSONString(v interface{}) []byte { | |
836 | r, _ := json.MarshalIndent(v, "", " ") | |
837 | return r | |
838 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "reflect" | |
5 | ||
6 | "github.com/jinzhu/gorm" | |
7 | "github.com/jinzhu/now" | |
8 | ||
9 | "testing" | |
10 | "time" | |
11 | ) | |
12 | ||
13 | func TestFirstAndLast(t *testing.T) { | |
14 | DB.Save(&User{Name: "user1", Emails: []Email{{Email: "user1@example.com"}}}) | |
15 | DB.Save(&User{Name: "user2", Emails: []Email{{Email: "user2@example.com"}}}) | |
16 | ||
17 | var user1, user2, user3, user4 User | |
18 | DB.First(&user1) | |
19 | DB.Order("id").Limit(1).Find(&user2) | |
20 | ||
21 | DB.Last(&user3) | |
22 | DB.Order("id desc").Limit(1).Find(&user4) | |
23 | if user1.Id != user2.Id || user3.Id != user4.Id { | |
24 | t.Errorf("First and Last should by order by primary key") | |
25 | } | |
26 | ||
27 | var users []User | |
28 | DB.First(&users) | |
29 | if len(users) != 1 { | |
30 | t.Errorf("Find first record as slice") | |
31 | } | |
32 | ||
33 | if DB.Joins("left join emails on emails.user_id = users.id").First(&User{}).Error != nil { | |
34 | t.Errorf("Should not raise any error when order with Join table") | |
35 | } | |
36 | } | |
37 | ||
38 | func TestFirstAndLastWithNoStdPrimaryKey(t *testing.T) { | |
39 | DB.Save(&Animal{Name: "animal1"}) | |
40 | DB.Save(&Animal{Name: "animal2"}) | |
41 | ||
42 | var animal1, animal2, animal3, animal4 Animal | |
43 | DB.First(&animal1) | |
44 | DB.Order("counter").Limit(1).Find(&animal2) | |
45 | ||
46 | DB.Last(&animal3) | |
47 | DB.Order("counter desc").Limit(1).Find(&animal4) | |
48 | if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter { | |
49 | t.Errorf("First and Last should work correctly") | |
50 | } | |
51 | } | |
52 | ||
53 | func TestUIntPrimaryKey(t *testing.T) { | |
54 | var animal Animal | |
55 | DB.First(&animal, uint64(1)) | |
56 | if animal.Counter != 1 { | |
57 | t.Errorf("Fetch a record from with a non-int primary key should work, but failed") | |
58 | } | |
59 | ||
60 | DB.Model(Animal{}).Where(Animal{Counter: uint64(2)}).Scan(&animal) | |
61 | if animal.Counter != 2 { | |
62 | t.Errorf("Fetch a record from with a non-int primary key should work, but failed") | |
63 | } | |
64 | } | |
65 | ||
66 | func TestFindAsSliceOfPointers(t *testing.T) { | |
67 | DB.Save(&User{Name: "user"}) | |
68 | ||
69 | var users []User | |
70 | DB.Find(&users) | |
71 | ||
72 | var userPointers []*User | |
73 | DB.Find(&userPointers) | |
74 | ||
75 | if len(users) == 0 || len(users) != len(userPointers) { | |
76 | t.Errorf("Find slice of pointers") | |
77 | } | |
78 | } | |
79 | ||
80 | func TestSearchWithPlainSQL(t *testing.T) { | |
81 | user1 := User{Name: "PlainSqlUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} | |
82 | user2 := User{Name: "PlainSqlUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} | |
83 | user3 := User{Name: "PlainSqlUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} | |
84 | DB.Save(&user1).Save(&user2).Save(&user3) | |
85 | scopedb := DB.Where("name LIKE ?", "%PlainSqlUser%") | |
86 | ||
87 | if DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { | |
88 | t.Errorf("Search with plain SQL") | |
89 | } | |
90 | ||
91 | if DB.Where("name LIKE ?", "%"+user1.Name+"%").First(&User{}).RecordNotFound() { | |
92 | t.Errorf("Search with plan SQL (regexp)") | |
93 | } | |
94 | ||
95 | var users []User | |
96 | DB.Find(&users, "name LIKE ? and age > ?", "%PlainSqlUser%", 1) | |
97 | if len(users) != 2 { | |
98 | t.Errorf("Should found 2 users that age > 1, but got %v", len(users)) | |
99 | } | |
100 | ||
101 | DB.Where("name LIKE ?", "%PlainSqlUser%").Where("age >= ?", 1).Find(&users) | |
102 | if len(users) != 3 { | |
103 | t.Errorf("Should found 3 users that age >= 1, but got %v", len(users)) | |
104 | } | |
105 | ||
106 | scopedb.Where("age <> ?", 20).Find(&users) | |
107 | if len(users) != 2 { | |
108 | t.Errorf("Should found 2 users age != 20, but got %v", len(users)) | |
109 | } | |
110 | ||
111 | scopedb.Where("birthday > ?", now.MustParse("2000-1-1")).Find(&users) | |
112 | if len(users) != 2 { | |
113 | t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) | |
114 | } | |
115 | ||
116 | scopedb.Where("birthday > ?", "2002-10-10").Find(&users) | |
117 | if len(users) != 2 { | |
118 | t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users)) | |
119 | } | |
120 | ||
121 | scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) | |
122 | if len(users) != 1 { | |
123 | t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) | |
124 | } | |
125 | ||
126 | DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) | |
127 | if len(users) != 2 { | |
128 | t.Errorf("Should found 2 users, but got %v", len(users)) | |
129 | } | |
130 | ||
131 | DB.Where("id in (?)", []int64{user1.Id, user2.Id, user3.Id}).Find(&users) | |
132 | if len(users) != 3 { | |
133 | t.Errorf("Should found 3 users, but got %v", len(users)) | |
134 | } | |
135 | ||
136 | DB.Where("id in (?)", user1.Id).Find(&users) | |
137 | if len(users) != 1 { | |
138 | t.Errorf("Should found 1 users, but got %v", len(users)) | |
139 | } | |
140 | ||
141 | if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() { | |
142 | t.Errorf("Should not get RecordNotFound error when looking for none existing records") | |
143 | } | |
144 | } | |
145 | ||
146 | func TestSearchWithStruct(t *testing.T) { | |
147 | user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} | |
148 | user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} | |
149 | user3 := User{Name: "StructSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} | |
150 | DB.Save(&user1).Save(&user2).Save(&user3) | |
151 | ||
152 | if DB.Where(user1.Id).First(&User{}).RecordNotFound() { | |
153 | t.Errorf("Search with primary key") | |
154 | } | |
155 | ||
156 | if DB.First(&User{}, user1.Id).RecordNotFound() { | |
157 | t.Errorf("Search with primary key as inline condition") | |
158 | } | |
159 | ||
160 | if DB.First(&User{}, fmt.Sprintf("%v", user1.Id)).RecordNotFound() { | |
161 | t.Errorf("Search with primary key as inline condition") | |
162 | } | |
163 | ||
164 | var users []User | |
165 | DB.Where([]int64{user1.Id, user2.Id, user3.Id}).Find(&users) | |
166 | if len(users) != 3 { | |
167 | t.Errorf("Should found 3 users when search with primary keys, but got %v", len(users)) | |
168 | } | |
169 | ||
170 | var user User | |
171 | DB.First(&user, &User{Name: user1.Name}) | |
172 | if user.Id == 0 || user.Name != user1.Name { | |
173 | t.Errorf("Search first record with inline pointer of struct") | |
174 | } | |
175 | ||
176 | DB.First(&user, User{Name: user1.Name}) | |
177 | if user.Id == 0 || user.Name != user.Name { | |
178 | t.Errorf("Search first record with inline struct") | |
179 | } | |
180 | ||
181 | DB.Where(&User{Name: user1.Name}).First(&user) | |
182 | if user.Id == 0 || user.Name != user1.Name { | |
183 | t.Errorf("Search first record with where struct") | |
184 | } | |
185 | ||
186 | DB.Find(&users, &User{Name: user2.Name}) | |
187 | if len(users) != 1 { | |
188 | t.Errorf("Search all records with inline struct") | |
189 | } | |
190 | } | |
191 | ||
192 | func TestSearchWithMap(t *testing.T) { | |
193 | user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} | |
194 | user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} | |
195 | user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} | |
196 | DB.Save(&user1).Save(&user2).Save(&user3) | |
197 | ||
198 | var user User | |
199 | DB.First(&user, map[string]interface{}{"name": user1.Name}) | |
200 | if user.Id == 0 || user.Name != user1.Name { | |
201 | t.Errorf("Search first record with inline map") | |
202 | } | |
203 | ||
204 | user = User{} | |
205 | DB.Where(map[string]interface{}{"name": user2.Name}).First(&user) | |
206 | if user.Id == 0 || user.Name != user2.Name { | |
207 | t.Errorf("Search first record with where map") | |
208 | } | |
209 | ||
210 | var users []User | |
211 | DB.Where(map[string]interface{}{"name": user3.Name}).Find(&users) | |
212 | if len(users) != 1 { | |
213 | t.Errorf("Search all records with inline map") | |
214 | } | |
215 | ||
216 | DB.Find(&users, map[string]interface{}{"name": user3.Name}) | |
217 | if len(users) != 1 { | |
218 | t.Errorf("Search all records with inline map") | |
219 | } | |
220 | } | |
221 | ||
222 | func TestSearchWithEmptyChain(t *testing.T) { | |
223 | user1 := User{Name: "ChainSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} | |
224 | user2 := User{Name: "ChainearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} | |
225 | user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} | |
226 | DB.Save(&user1).Save(&user2).Save(&user3) | |
227 | ||
228 | if DB.Where("").Where("").First(&User{}).Error != nil { | |
229 | t.Errorf("Should not raise any error if searching with empty strings") | |
230 | } | |
231 | ||
232 | if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { | |
233 | t.Errorf("Should not raise any error if searching with empty struct") | |
234 | } | |
235 | ||
236 | if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { | |
237 | t.Errorf("Should not raise any error if searching with empty map") | |
238 | } | |
239 | } | |
240 | ||
241 | func TestSelect(t *testing.T) { | |
242 | user1 := User{Name: "SelectUser1"} | |
243 | DB.Save(&user1) | |
244 | ||
245 | var user User | |
246 | DB.Where("name = ?", user1.Name).Select("name").Find(&user) | |
247 | if user.Id != 0 { | |
248 | t.Errorf("Should not have ID because only selected name, %+v", user.Id) | |
249 | } | |
250 | ||
251 | if user.Name != user1.Name { | |
252 | t.Errorf("Should have user Name when selected it") | |
253 | } | |
254 | } | |
255 | ||
256 | func TestOrderAndPluck(t *testing.T) { | |
257 | user1 := User{Name: "OrderPluckUser1", Age: 1} | |
258 | user2 := User{Name: "OrderPluckUser2", Age: 10} | |
259 | user3 := User{Name: "OrderPluckUser3", Age: 20} | |
260 | DB.Save(&user1).Save(&user2).Save(&user3) | |
261 | scopedb := DB.Model(&User{}).Where("name like ?", "%OrderPluckUser%") | |
262 | ||
263 | var ages []int64 | |
264 | scopedb.Order("age desc").Pluck("age", &ages) | |
265 | if ages[0] != 20 { | |
266 | t.Errorf("The first age should be 20 when order with age desc") | |
267 | } | |
268 | ||
269 | var ages1, ages2 []int64 | |
270 | scopedb.Order("age desc").Pluck("age", &ages1).Pluck("age", &ages2) | |
271 | if !reflect.DeepEqual(ages1, ages2) { | |
272 | t.Errorf("The first order is the primary order") | |
273 | } | |
274 | ||
275 | var ages3, ages4 []int64 | |
276 | scopedb.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4) | |
277 | if reflect.DeepEqual(ages3, ages4) { | |
278 | t.Errorf("Reorder should work") | |
279 | } | |
280 | ||
281 | var names []string | |
282 | var ages5 []int64 | |
283 | scopedb.Model(User{}).Order("name").Order("age desc").Pluck("age", &ages5).Pluck("name", &names) | |
284 | if names != nil && ages5 != nil { | |
285 | if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) { | |
286 | t.Errorf("Order with multiple orders") | |
287 | } | |
288 | } else { | |
289 | t.Errorf("Order with multiple orders") | |
290 | } | |
291 | ||
292 | DB.Model(User{}).Select("name, age").Find(&[]User{}) | |
293 | } | |
294 | ||
295 | func TestLimit(t *testing.T) { | |
296 | user1 := User{Name: "LimitUser1", Age: 1} | |
297 | user2 := User{Name: "LimitUser2", Age: 10} | |
298 | user3 := User{Name: "LimitUser3", Age: 20} | |
299 | user4 := User{Name: "LimitUser4", Age: 10} | |
300 | user5 := User{Name: "LimitUser5", Age: 20} | |
301 | DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5) | |
302 | ||
303 | var users1, users2, users3 []User | |
304 | DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) | |
305 | ||
306 | if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { | |
307 | t.Errorf("Limit should works") | |
308 | } | |
309 | } | |
310 | ||
311 | func TestOffset(t *testing.T) { | |
312 | for i := 0; i < 20; i++ { | |
313 | DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) | |
314 | } | |
315 | var users1, users2, users3, users4 []User | |
316 | DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) | |
317 | ||
318 | if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { | |
319 | t.Errorf("Offset should work") | |
320 | } | |
321 | } | |
322 | ||
323 | func TestOr(t *testing.T) { | |
324 | user1 := User{Name: "OrUser1", Age: 1} | |
325 | user2 := User{Name: "OrUser2", Age: 10} | |
326 | user3 := User{Name: "OrUser3", Age: 20} | |
327 | DB.Save(&user1).Save(&user2).Save(&user3) | |
328 | ||
329 | var users []User | |
330 | DB.Where("name = ?", user1.Name).Or("name = ?", user2.Name).Find(&users) | |
331 | if len(users) != 2 { | |
332 | t.Errorf("Find users with or") | |
333 | } | |
334 | } | |
335 | ||
336 | func TestCount(t *testing.T) { | |
337 | user1 := User{Name: "CountUser1", Age: 1} | |
338 | user2 := User{Name: "CountUser2", Age: 10} | |
339 | user3 := User{Name: "CountUser3", Age: 20} | |
340 | ||
341 | DB.Save(&user1).Save(&user2).Save(&user3) | |
342 | var count, count1, count2 int64 | |
343 | var users []User | |
344 | ||
345 | if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil { | |
346 | t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) | |
347 | } | |
348 | ||
349 | if count != int64(len(users)) { | |
350 | t.Errorf("Count() method should get correct value") | |
351 | } | |
352 | ||
353 | DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in (?)", []string{user2.Name, user3.Name}).Count(&count2) | |
354 | if count1 != 1 || count2 != 3 { | |
355 | t.Errorf("Multiple count in chain") | |
356 | } | |
357 | } | |
358 | ||
359 | func TestNot(t *testing.T) { | |
360 | DB.Create(getPreparedUser("user1", "not")) | |
361 | DB.Create(getPreparedUser("user2", "not")) | |
362 | DB.Create(getPreparedUser("user3", "not")) | |
363 | DB.Create(getPreparedUser("user4", "not")) | |
364 | DB := DB.Where("role = ?", "not") | |
365 | ||
366 | var users1, users2, users3, users4, users5, users6, users7, users8 []User | |
367 | if DB.Find(&users1).RowsAffected != 4 { | |
368 | t.Errorf("should find 4 not users") | |
369 | } | |
370 | DB.Not(users1[0].Id).Find(&users2) | |
371 | ||
372 | if len(users1)-len(users2) != 1 { | |
373 | t.Errorf("Should ignore the first users with Not") | |
374 | } | |
375 | ||
376 | DB.Not([]int{}).Find(&users3) | |
377 | if len(users1)-len(users3) != 0 { | |
378 | t.Errorf("Should find all users with a blank condition") | |
379 | } | |
380 | ||
381 | var name3Count int64 | |
382 | DB.Table("users").Where("name = ?", "user3").Count(&name3Count) | |
383 | DB.Not("name", "user3").Find(&users4) | |
384 | if len(users1)-len(users4) != int(name3Count) { | |
385 | t.Errorf("Should find all users's name not equal 3") | |
386 | } | |
387 | ||
388 | DB.Not("name = ?", "user3").Find(&users4) | |
389 | if len(users1)-len(users4) != int(name3Count) { | |
390 | t.Errorf("Should find all users's name not equal 3") | |
391 | } | |
392 | ||
393 | DB.Not("name <> ?", "user3").Find(&users4) | |
394 | if len(users4) != int(name3Count) { | |
395 | t.Errorf("Should find all users's name not equal 3") | |
396 | } | |
397 | ||
398 | DB.Not(User{Name: "user3"}).Find(&users5) | |
399 | ||
400 | if len(users1)-len(users5) != int(name3Count) { | |
401 | t.Errorf("Should find all users's name not equal 3") | |
402 | } | |
403 | ||
404 | DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6) | |
405 | if len(users1)-len(users6) != int(name3Count) { | |
406 | t.Errorf("Should find all users's name not equal 3") | |
407 | } | |
408 | ||
409 | DB.Not("name", []string{"user3"}).Find(&users7) | |
410 | if len(users1)-len(users7) != int(name3Count) { | |
411 | t.Errorf("Should find all users's name not equal 3") | |
412 | } | |
413 | ||
414 | var name2Count int64 | |
415 | DB.Table("users").Where("name = ?", "user2").Count(&name2Count) | |
416 | DB.Not("name", []string{"user3", "user2"}).Find(&users8) | |
417 | if len(users1)-len(users8) != (int(name3Count) + int(name2Count)) { | |
418 | t.Errorf("Should find all users's name not equal 3") | |
419 | } | |
420 | } | |
421 | ||
422 | func TestFillSmallerStruct(t *testing.T) { | |
423 | user1 := User{Name: "SmallerUser", Age: 100} | |
424 | DB.Save(&user1) | |
425 | type SimpleUser struct { | |
426 | Name string | |
427 | Id int64 | |
428 | UpdatedAt time.Time | |
429 | CreatedAt time.Time | |
430 | } | |
431 | ||
432 | var simpleUser SimpleUser | |
433 | DB.Table("users").Where("name = ?", user1.Name).First(&simpleUser) | |
434 | ||
435 | if simpleUser.Id == 0 || simpleUser.Name == "" { | |
436 | t.Errorf("Should fill data correctly into smaller struct") | |
437 | } | |
438 | } | |
439 | ||
440 | func TestFindOrInitialize(t *testing.T) { | |
441 | var user1, user2, user3, user4, user5, user6 User | |
442 | DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1) | |
443 | if user1.Name != "find or init" || user1.Id != 0 || user1.Age != 33 { | |
444 | t.Errorf("user should be initialized with search value") | |
445 | } | |
446 | ||
447 | DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2) | |
448 | if user2.Name != "find or init" || user2.Id != 0 || user2.Age != 33 { | |
449 | t.Errorf("user should be initialized with search value") | |
450 | } | |
451 | ||
452 | DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"}) | |
453 | if user3.Name != "find or init 2" || user3.Id != 0 { | |
454 | t.Errorf("user should be initialized with inline search value") | |
455 | } | |
456 | ||
457 | DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4) | |
458 | if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { | |
459 | t.Errorf("user should be initialized with search value and attrs") | |
460 | } | |
461 | ||
462 | DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4) | |
463 | if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { | |
464 | t.Errorf("user should be initialized with search value and assign attrs") | |
465 | } | |
466 | ||
467 | DB.Save(&User{Name: "find or init", Age: 33}) | |
468 | DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5) | |
469 | if user5.Name != "find or init" || user5.Id == 0 || user5.Age != 33 { | |
470 | t.Errorf("user should be found and not initialized by Attrs") | |
471 | } | |
472 | ||
473 | DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6) | |
474 | if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 33 { | |
475 | t.Errorf("user should be found with FirstOrInit") | |
476 | } | |
477 | ||
478 | DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6) | |
479 | if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 44 { | |
480 | t.Errorf("user should be found and updated with assigned attrs") | |
481 | } | |
482 | } | |
483 | ||
484 | func TestFindOrCreate(t *testing.T) { | |
485 | var user1, user2, user3, user4, user5, user6, user7, user8 User | |
486 | DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1) | |
487 | if user1.Name != "find or create" || user1.Id == 0 || user1.Age != 33 { | |
488 | t.Errorf("user should be created with search value") | |
489 | } | |
490 | ||
491 | DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2) | |
492 | if user1.Id != user2.Id || user2.Name != "find or create" || user2.Id == 0 || user2.Age != 33 { | |
493 | t.Errorf("user should be created with search value") | |
494 | } | |
495 | ||
496 | DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"}) | |
497 | if user3.Name != "find or create 2" || user3.Id == 0 { | |
498 | t.Errorf("user should be created with inline search value") | |
499 | } | |
500 | ||
501 | DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4) | |
502 | if user4.Name != "find or create 3" || user4.Id == 0 || user4.Age != 44 { | |
503 | t.Errorf("user should be created with search value and attrs") | |
504 | } | |
505 | ||
506 | updatedAt1 := user4.UpdatedAt | |
507 | DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4) | |
508 | if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) { | |
509 | t.Errorf("UpdateAt should be changed when update values with assign") | |
510 | } | |
511 | ||
512 | DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4) | |
513 | if user4.Name != "find or create 4" || user4.Id == 0 || user4.Age != 44 { | |
514 | t.Errorf("user should be created with search value and assigned attrs") | |
515 | } | |
516 | ||
517 | DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5) | |
518 | if user5.Name != "find or create" || user5.Id == 0 || user5.Age != 33 { | |
519 | t.Errorf("user should be found and not initialized by Attrs") | |
520 | } | |
521 | ||
522 | DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6) | |
523 | if user6.Name != "find or create" || user6.Id == 0 || user6.Age != 44 { | |
524 | t.Errorf("user should be found and updated with assigned attrs") | |
525 | } | |
526 | ||
527 | DB.Where(&User{Name: "find or create"}).Find(&user7) | |
528 | if user7.Name != "find or create" || user7.Id == 0 || user7.Age != 44 { | |
529 | t.Errorf("user should be found and updated with assigned attrs") | |
530 | } | |
531 | ||
532 | DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, CreditCard: CreditCard{Number: "1231231231"}, Emails: []Email{{Email: "jinzhu@assign_embedded_struct.com"}, {Email: "jinzhu-2@assign_embedded_struct.com"}}}).FirstOrCreate(&user8) | |
533 | if DB.Where("email = ?", "jinzhu-2@assign_embedded_struct.com").First(&Email{}).RecordNotFound() { | |
534 | t.Errorf("embedded struct email should be saved") | |
535 | } | |
536 | ||
537 | if DB.Where("email = ?", "1231231231").First(&CreditCard{}).RecordNotFound() { | |
538 | t.Errorf("embedded struct credit card should be saved") | |
539 | } | |
540 | } | |
541 | ||
542 | func TestSelectWithEscapedFieldName(t *testing.T) { | |
543 | user1 := User{Name: "EscapedFieldNameUser", Age: 1} | |
544 | user2 := User{Name: "EscapedFieldNameUser", Age: 10} | |
545 | user3 := User{Name: "EscapedFieldNameUser", Age: 20} | |
546 | DB.Save(&user1).Save(&user2).Save(&user3) | |
547 | ||
548 | var names []string | |
549 | DB.Model(User{}).Where(&User{Name: "EscapedFieldNameUser"}).Pluck("\"name\"", &names) | |
550 | ||
551 | if len(names) != 3 { | |
552 | t.Errorf("Expected 3 name, but got: %d", len(names)) | |
553 | } | |
554 | } | |
555 | ||
556 | func TestSelectWithVariables(t *testing.T) { | |
557 | DB.Save(&User{Name: "jinzhu"}) | |
558 | ||
559 | rows, _ := DB.Table("users").Select("? as fake", gorm.Expr("name")).Rows() | |
560 | ||
561 | if !rows.Next() { | |
562 | t.Errorf("Should have returned at least one row") | |
563 | } else { | |
564 | columns, _ := rows.Columns() | |
565 | if !reflect.DeepEqual(columns, []string{"fake"}) { | |
566 | t.Errorf("Should only contains one column") | |
567 | } | |
568 | } | |
569 | } | |
570 | ||
571 | func TestSelectWithArrayInput(t *testing.T) { | |
572 | DB.Save(&User{Name: "jinzhu", Age: 42}) | |
573 | ||
574 | var user User | |
575 | DB.Select([]string{"name", "age"}).Where("age = 42 AND name = 'jinzhu'").First(&user) | |
576 | ||
577 | if user.Name != "jinzhu" || user.Age != 42 { | |
578 | t.Errorf("Should have selected both age and name") | |
579 | } | |
580 | } | |
581 | ||
582 | func TestCurrentDatabase(t *testing.T) { | |
583 | databaseName := DB.CurrentDatabase() | |
584 | if err := DB.Error; err != nil { | |
585 | t.Errorf("Problem getting current db name: %s", err) | |
586 | } | |
587 | if databaseName == "" { | |
588 | t.Errorf("Current db name returned empty; this should never happen!") | |
589 | } | |
590 | t.Logf("Got current db name: %v", databaseName) | |
591 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "fmt" | |
5 | "regexp" | |
6 | "strings" | |
7 | "time" | |
8 | ||
9 | "reflect" | |
10 | ) | |
11 | ||
12 | type Scope struct { | |
13 | Search *search | |
14 | Value interface{} | |
15 | Sql string | |
16 | SqlVars []interface{} | |
17 | db *DB | |
18 | indirectValue *reflect.Value | |
19 | instanceId string | |
20 | primaryKeyField *Field | |
21 | skipLeft bool | |
22 | fields map[string]*Field | |
23 | selectAttrs *[]string | |
24 | } | |
25 | ||
26 | func (scope *Scope) IndirectValue() reflect.Value { | |
27 | if scope.indirectValue == nil { | |
28 | value := reflect.Indirect(reflect.ValueOf(scope.Value)) | |
29 | if value.Kind() == reflect.Ptr { | |
30 | value = value.Elem() | |
31 | } | |
32 | scope.indirectValue = &value | |
33 | } | |
34 | return *scope.indirectValue | |
35 | } | |
36 | ||
37 | func (scope *Scope) NeedPtr() *Scope { | |
38 | reflectKind := reflect.ValueOf(scope.Value).Kind() | |
39 | if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) { | |
40 | err := fmt.Errorf("%v %v\n", fileWithLineNum(), "using unaddressable value") | |
41 | scope.Err(err) | |
42 | fmt.Printf(err.Error()) | |
43 | } | |
44 | return scope | |
45 | } | |
46 | ||
47 | // New create a new Scope without search information | |
48 | func (scope *Scope) New(value interface{}) *Scope { | |
49 | return &Scope{db: scope.NewDB(), Search: &search{}, Value: value} | |
50 | } | |
51 | ||
52 | // NewDB create a new DB without search information | |
53 | func (scope *Scope) NewDB() *DB { | |
54 | if scope.db != nil { | |
55 | db := scope.db.clone() | |
56 | db.search = nil | |
57 | db.Value = nil | |
58 | return db | |
59 | } | |
60 | return nil | |
61 | } | |
62 | ||
63 | func (scope *Scope) DB() *DB { | |
64 | return scope.db | |
65 | } | |
66 | ||
67 | // SqlDB return *sql.DB | |
68 | func (scope *Scope) SqlDB() sqlCommon { | |
69 | return scope.db.db | |
70 | } | |
71 | ||
72 | // SkipLeft skip remaining callbacks | |
73 | func (scope *Scope) SkipLeft() { | |
74 | scope.skipLeft = true | |
75 | } | |
76 | ||
77 | // Quote used to quote database column name according to database dialect | |
78 | func (scope *Scope) Quote(str string) string { | |
79 | if strings.Index(str, ".") != -1 { | |
80 | newStrs := []string{} | |
81 | for _, str := range strings.Split(str, ".") { | |
82 | newStrs = append(newStrs, scope.Dialect().Quote(str)) | |
83 | } | |
84 | return strings.Join(newStrs, ".") | |
85 | } else { | |
86 | return scope.Dialect().Quote(str) | |
87 | } | |
88 | } | |
89 | ||
90 | func (scope *Scope) QuoteIfPossible(str string) string { | |
91 | if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) { | |
92 | return scope.Quote(str) | |
93 | } | |
94 | return str | |
95 | } | |
96 | ||
97 | // Dialect get dialect | |
98 | func (scope *Scope) Dialect() Dialect { | |
99 | return scope.db.parent.dialect | |
100 | } | |
101 | ||
102 | // Err write error | |
103 | func (scope *Scope) Err(err error) error { | |
104 | if err != nil { | |
105 | scope.db.AddError(err) | |
106 | } | |
107 | return err | |
108 | } | |
109 | ||
110 | // Log print log message | |
111 | func (scope *Scope) Log(v ...interface{}) { | |
112 | scope.db.log(v...) | |
113 | } | |
114 | ||
115 | // HasError check if there are any error | |
116 | func (scope *Scope) HasError() bool { | |
117 | return scope.db.Error != nil | |
118 | } | |
119 | ||
120 | func (scope *Scope) PrimaryFields() []*Field { | |
121 | var fields = []*Field{} | |
122 | for _, field := range scope.GetModelStruct().PrimaryFields { | |
123 | fields = append(fields, scope.Fields()[field.DBName]) | |
124 | } | |
125 | return fields | |
126 | } | |
127 | ||
128 | func (scope *Scope) PrimaryField() *Field { | |
129 | if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { | |
130 | if len(primaryFields) > 1 { | |
131 | if field, ok := scope.Fields()["id"]; ok { | |
132 | return field | |
133 | } | |
134 | } | |
135 | return scope.Fields()[primaryFields[0].DBName] | |
136 | } | |
137 | return nil | |
138 | } | |
139 | ||
140 | // PrimaryKey get the primary key's column name | |
141 | func (scope *Scope) PrimaryKey() string { | |
142 | if field := scope.PrimaryField(); field != nil { | |
143 | return field.DBName | |
144 | } | |
145 | return "" | |
146 | } | |
147 | ||
148 | // PrimaryKeyZero check the primary key is blank or not | |
149 | func (scope *Scope) PrimaryKeyZero() bool { | |
150 | field := scope.PrimaryField() | |
151 | return field == nil || field.IsBlank | |
152 | } | |
153 | ||
154 | // PrimaryKeyValue get the primary key's value | |
155 | func (scope *Scope) PrimaryKeyValue() interface{} { | |
156 | if field := scope.PrimaryField(); field != nil && field.Field.IsValid() { | |
157 | return field.Field.Interface() | |
158 | } | |
159 | return 0 | |
160 | } | |
161 | ||
162 | // HasColumn to check if has column | |
163 | func (scope *Scope) HasColumn(column string) bool { | |
164 | for _, field := range scope.GetStructFields() { | |
165 | if field.IsNormal && (field.Name == column || field.DBName == column) { | |
166 | return true | |
167 | } | |
168 | } | |
169 | return false | |
170 | } | |
171 | ||
172 | // SetColumn to set the column's value | |
173 | func (scope *Scope) SetColumn(column interface{}, value interface{}) error { | |
174 | if field, ok := column.(*Field); ok { | |
175 | return field.Set(value) | |
176 | } else if name, ok := column.(string); ok { | |
177 | ||
178 | if field, ok := scope.Fields()[name]; ok { | |
179 | return field.Set(value) | |
180 | } | |
181 | ||
182 | dbName := ToDBName(name) | |
183 | if field, ok := scope.Fields()[dbName]; ok { | |
184 | return field.Set(value) | |
185 | } | |
186 | ||
187 | if field, ok := scope.FieldByName(name); ok { | |
188 | return field.Set(value) | |
189 | } | |
190 | } | |
191 | return errors.New("could not convert column to field") | |
192 | } | |
193 | ||
194 | func (scope *Scope) CallMethod(name string, checkError bool) { | |
195 | if scope.Value == nil || (checkError && scope.HasError()) { | |
196 | return | |
197 | } | |
198 | ||
199 | call := func(value interface{}) { | |
200 | if fm := reflect.ValueOf(value).MethodByName(name); fm.IsValid() { | |
201 | switch f := fm.Interface().(type) { | |
202 | case func(): | |
203 | f() | |
204 | case func(s *Scope): | |
205 | f(scope) | |
206 | case func(s *DB): | |
207 | newDB := scope.NewDB() | |
208 | f(newDB) | |
209 | scope.Err(newDB.Error) | |
210 | case func() error: | |
211 | scope.Err(f()) | |
212 | case func(s *Scope) error: | |
213 | scope.Err(f(scope)) | |
214 | case func(s *DB) error: | |
215 | newDB := scope.NewDB() | |
216 | scope.Err(f(newDB)) | |
217 | scope.Err(newDB.Error) | |
218 | default: | |
219 | scope.Err(fmt.Errorf("unsupported function %v", name)) | |
220 | } | |
221 | } | |
222 | } | |
223 | ||
224 | if values := scope.IndirectValue(); values.Kind() == reflect.Slice { | |
225 | for i := 0; i < values.Len(); i++ { | |
226 | value := values.Index(i).Addr().Interface() | |
227 | if values.Index(i).Kind() == reflect.Ptr { | |
228 | value = values.Index(i).Interface() | |
229 | } | |
230 | call(value) | |
231 | } | |
232 | } else { | |
233 | if scope.IndirectValue().CanAddr() { | |
234 | call(scope.IndirectValue().Addr().Interface()) | |
235 | } else { | |
236 | call(scope.IndirectValue().Interface()) | |
237 | } | |
238 | } | |
239 | } | |
240 | ||
241 | func (scope *Scope) CallMethodWithErrorCheck(name string) { | |
242 | scope.CallMethod(name, true) | |
243 | } | |
244 | ||
245 | // AddToVars add value as sql's vars, gorm will escape them | |
246 | func (scope *Scope) AddToVars(value interface{}) string { | |
247 | if expr, ok := value.(*expr); ok { | |
248 | exp := expr.expr | |
249 | for _, arg := range expr.args { | |
250 | exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) | |
251 | } | |
252 | return exp | |
253 | } else { | |
254 | scope.SqlVars = append(scope.SqlVars, value) | |
255 | return scope.Dialect().BinVar(len(scope.SqlVars)) | |
256 | } | |
257 | } | |
258 | ||
259 | type tabler interface { | |
260 | TableName() string | |
261 | } | |
262 | ||
263 | type dbTabler interface { | |
264 | TableName(*DB) string | |
265 | } | |
266 | ||
267 | // TableName get table name | |
268 | func (scope *Scope) TableName() string { | |
269 | if scope.Search != nil && len(scope.Search.tableName) > 0 { | |
270 | return scope.Search.tableName | |
271 | } | |
272 | ||
273 | if tabler, ok := scope.Value.(tabler); ok { | |
274 | return tabler.TableName() | |
275 | } | |
276 | ||
277 | if tabler, ok := scope.Value.(dbTabler); ok { | |
278 | return tabler.TableName(scope.db) | |
279 | } | |
280 | ||
281 | return scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) | |
282 | } | |
283 | ||
284 | func (scope *Scope) QuotedTableName() (name string) { | |
285 | if scope.Search != nil && len(scope.Search.tableName) > 0 { | |
286 | if strings.Index(scope.Search.tableName, " ") != -1 { | |
287 | return scope.Search.tableName | |
288 | } | |
289 | return scope.Quote(scope.Search.tableName) | |
290 | } else { | |
291 | return scope.Quote(scope.TableName()) | |
292 | } | |
293 | } | |
294 | ||
295 | // CombinedConditionSql get combined condition sql | |
296 | func (scope *Scope) CombinedConditionSql() string { | |
297 | return scope.joinsSql() + scope.whereSql() + scope.groupSql() + | |
298 | scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql() | |
299 | } | |
300 | ||
301 | func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { | |
302 | for _, field := range scope.Fields() { | |
303 | if field.Name == name || field.DBName == name { | |
304 | return field, true | |
305 | } | |
306 | } | |
307 | return nil, false | |
308 | } | |
309 | ||
310 | // Raw set sql | |
311 | func (scope *Scope) Raw(sql string) *Scope { | |
312 | scope.Sql = strings.Replace(sql, "$$", "?", -1) | |
313 | return scope | |
314 | } | |
315 | ||
316 | // Exec invoke sql | |
317 | func (scope *Scope) Exec() *Scope { | |
318 | defer scope.Trace(NowFunc()) | |
319 | ||
320 | if !scope.HasError() { | |
321 | if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { | |
322 | if count, err := result.RowsAffected(); scope.Err(err) == nil { | |
323 | scope.db.RowsAffected = count | |
324 | } | |
325 | } | |
326 | } | |
327 | return scope | |
328 | } | |
329 | ||
330 | // Set set value by name | |
331 | func (scope *Scope) Set(name string, value interface{}) *Scope { | |
332 | scope.db.InstantSet(name, value) | |
333 | return scope | |
334 | } | |
335 | ||
336 | // Get get value by name | |
337 | func (scope *Scope) Get(name string) (interface{}, bool) { | |
338 | return scope.db.Get(name) | |
339 | } | |
340 | ||
341 | // InstanceId get InstanceId for scope | |
342 | func (scope *Scope) InstanceId() string { | |
343 | if scope.instanceId == "" { | |
344 | scope.instanceId = fmt.Sprintf("%v%v", &scope, &scope.db) | |
345 | } | |
346 | return scope.instanceId | |
347 | } | |
348 | ||
349 | func (scope *Scope) InstanceSet(name string, value interface{}) *Scope { | |
350 | return scope.Set(name+scope.InstanceId(), value) | |
351 | } | |
352 | ||
353 | func (scope *Scope) InstanceGet(name string) (interface{}, bool) { | |
354 | return scope.Get(name + scope.InstanceId()) | |
355 | } | |
356 | ||
357 | // Trace print sql log | |
358 | func (scope *Scope) Trace(t time.Time) { | |
359 | if len(scope.Sql) > 0 { | |
360 | scope.db.slog(scope.Sql, t, scope.SqlVars...) | |
361 | } | |
362 | } | |
363 | ||
364 | // Begin start a transaction | |
365 | func (scope *Scope) Begin() *Scope { | |
366 | if db, ok := scope.SqlDB().(sqlDb); ok { | |
367 | if tx, err := db.Begin(); err == nil { | |
368 | scope.db.db = interface{}(tx).(sqlCommon) | |
369 | scope.InstanceSet("gorm:started_transaction", true) | |
370 | } | |
371 | } | |
372 | return scope | |
373 | } | |
374 | ||
375 | // CommitOrRollback commit current transaction if there is no error, otherwise rollback it | |
376 | func (scope *Scope) CommitOrRollback() *Scope { | |
377 | if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { | |
378 | if db, ok := scope.db.db.(sqlTx); ok { | |
379 | if scope.HasError() { | |
380 | db.Rollback() | |
381 | } else { | |
382 | db.Commit() | |
383 | } | |
384 | scope.db.db = scope.db.parent.db | |
385 | } | |
386 | } | |
387 | return scope | |
388 | } | |
389 | ||
390 | func (scope *Scope) SelectAttrs() []string { | |
391 | if scope.selectAttrs == nil { | |
392 | attrs := []string{} | |
393 | for _, value := range scope.Search.selects { | |
394 | if str, ok := value.(string); ok { | |
395 | attrs = append(attrs, str) | |
396 | } else if strs, ok := value.([]string); ok { | |
397 | attrs = append(attrs, strs...) | |
398 | } else if strs, ok := value.([]interface{}); ok { | |
399 | for _, str := range strs { | |
400 | attrs = append(attrs, fmt.Sprintf("%v", str)) | |
401 | } | |
402 | } | |
403 | } | |
404 | scope.selectAttrs = &attrs | |
405 | } | |
406 | return *scope.selectAttrs | |
407 | } | |
408 | ||
409 | func (scope *Scope) OmitAttrs() []string { | |
410 | return scope.Search.omits | |
411 | } | |
412 | ||
413 | func (scope *Scope) changeableDBColumn(column string) bool { | |
414 | selectAttrs := scope.SelectAttrs() | |
415 | omitAttrs := scope.OmitAttrs() | |
416 | ||
417 | if len(selectAttrs) > 0 { | |
418 | for _, attr := range selectAttrs { | |
419 | if column == ToDBName(attr) { | |
420 | return true | |
421 | } | |
422 | } | |
423 | return false | |
424 | } | |
425 | ||
426 | for _, attr := range omitAttrs { | |
427 | if column == ToDBName(attr) { | |
428 | return false | |
429 | } | |
430 | } | |
431 | return true | |
432 | } | |
433 | ||
434 | func (scope *Scope) changeableField(field *Field) bool { | |
435 | selectAttrs := scope.SelectAttrs() | |
436 | omitAttrs := scope.OmitAttrs() | |
437 | ||
438 | if len(selectAttrs) > 0 { | |
439 | for _, attr := range selectAttrs { | |
440 | if field.Name == attr || field.DBName == attr { | |
441 | return true | |
442 | } | |
443 | } | |
444 | return false | |
445 | } | |
446 | ||
447 | for _, attr := range omitAttrs { | |
448 | if field.Name == attr || field.DBName == attr { | |
449 | return false | |
450 | } | |
451 | } | |
452 | ||
453 | return !field.IsIgnored | |
454 | } | |
455 | ||
456 | func (scope *Scope) shouldSaveAssociations() bool { | |
457 | saveAssociations, ok := scope.Get("gorm:save_associations") | |
458 | if ok && !saveAssociations.(bool) { | |
459 | return false | |
460 | } | |
461 | return true && !scope.HasError() | |
462 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "database/sql" | |
4 | "database/sql/driver" | |
5 | "fmt" | |
6 | "reflect" | |
7 | "regexp" | |
8 | "strconv" | |
9 | "strings" | |
10 | ) | |
11 | ||
12 | func (scope *Scope) primaryCondition(value interface{}) string { | |
13 | return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value) | |
14 | } | |
15 | ||
16 | func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { | |
17 | switch value := clause["query"].(type) { | |
18 | case string: | |
19 | // if string is number | |
20 | if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { | |
21 | id, _ := strconv.Atoi(value) | |
22 | return scope.primaryCondition(scope.AddToVars(id)) | |
23 | } else if value != "" { | |
24 | str = fmt.Sprintf("(%v)", value) | |
25 | } | |
26 | case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: | |
27 | return scope.primaryCondition(scope.AddToVars(value)) | |
28 | case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: | |
29 | str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey())) | |
30 | clause["args"] = []interface{}{value} | |
31 | case map[string]interface{}: | |
32 | var sqls []string | |
33 | for key, value := range value { | |
34 | sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value))) | |
35 | } | |
36 | return strings.Join(sqls, " AND ") | |
37 | case interface{}: | |
38 | var sqls []string | |
39 | for _, field := range scope.New(value).Fields() { | |
40 | if !field.IsIgnored && !field.IsBlank { | |
41 | sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) | |
42 | } | |
43 | } | |
44 | return strings.Join(sqls, " AND ") | |
45 | } | |
46 | ||
47 | args := clause["args"].([]interface{}) | |
48 | for _, arg := range args { | |
49 | switch reflect.ValueOf(arg).Kind() { | |
50 | case reflect.Slice: // For where("id in (?)", []int64{1,2}) | |
51 | values := reflect.ValueOf(arg) | |
52 | var tempMarks []string | |
53 | for i := 0; i < values.Len(); i++ { | |
54 | tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) | |
55 | } | |
56 | str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) | |
57 | default: | |
58 | if valuer, ok := interface{}(arg).(driver.Valuer); ok { | |
59 | arg, _ = valuer.Value() | |
60 | } | |
61 | ||
62 | str = strings.Replace(str, "?", scope.AddToVars(arg), 1) | |
63 | } | |
64 | } | |
65 | return | |
66 | } | |
67 | ||
68 | func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { | |
69 | var notEqualSql string | |
70 | var primaryKey = scope.PrimaryKey() | |
71 | ||
72 | switch value := clause["query"].(type) { | |
73 | case string: | |
74 | // is number | |
75 | if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { | |
76 | id, _ := strconv.Atoi(value) | |
77 | return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) | |
78 | } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) { | |
79 | str = fmt.Sprintf(" NOT (%v) ", value) | |
80 | notEqualSql = fmt.Sprintf("NOT (%v)", value) | |
81 | } else { | |
82 | str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value)) | |
83 | notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) | |
84 | } | |
85 | case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: | |
86 | return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value) | |
87 | case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: | |
88 | if reflect.ValueOf(value).Len() > 0 { | |
89 | str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey)) | |
90 | clause["args"] = []interface{}{value} | |
91 | } | |
92 | return "" | |
93 | case map[string]interface{}: | |
94 | var sqls []string | |
95 | for key, value := range value { | |
96 | sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value))) | |
97 | } | |
98 | return strings.Join(sqls, " AND ") | |
99 | case interface{}: | |
100 | var sqls []string | |
101 | for _, field := range scope.New(value).Fields() { | |
102 | if !field.IsBlank { | |
103 | sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) | |
104 | } | |
105 | } | |
106 | return strings.Join(sqls, " AND ") | |
107 | } | |
108 | ||
109 | args := clause["args"].([]interface{}) | |
110 | for _, arg := range args { | |
111 | switch reflect.ValueOf(arg).Kind() { | |
112 | case reflect.Slice: // For where("id in (?)", []int64{1,2}) | |
113 | values := reflect.ValueOf(arg) | |
114 | var tempMarks []string | |
115 | for i := 0; i < values.Len(); i++ { | |
116 | tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) | |
117 | } | |
118 | str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) | |
119 | default: | |
120 | if scanner, ok := interface{}(arg).(driver.Valuer); ok { | |
121 | arg, _ = scanner.Value() | |
122 | } | |
123 | str = strings.Replace(notEqualSql, "?", scope.AddToVars(arg), 1) | |
124 | } | |
125 | } | |
126 | return | |
127 | } | |
128 | ||
129 | func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { | |
130 | switch value := clause["query"].(type) { | |
131 | case string: | |
132 | str = value | |
133 | case []string: | |
134 | str = strings.Join(value, ", ") | |
135 | } | |
136 | ||
137 | args := clause["args"].([]interface{}) | |
138 | for _, arg := range args { | |
139 | switch reflect.ValueOf(arg).Kind() { | |
140 | case reflect.Slice: | |
141 | values := reflect.ValueOf(arg) | |
142 | var tempMarks []string | |
143 | for i := 0; i < values.Len(); i++ { | |
144 | tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) | |
145 | } | |
146 | str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) | |
147 | default: | |
148 | if valuer, ok := interface{}(arg).(driver.Valuer); ok { | |
149 | arg, _ = valuer.Value() | |
150 | } | |
151 | str = strings.Replace(str, "?", scope.AddToVars(arg), 1) | |
152 | } | |
153 | } | |
154 | return | |
155 | } | |
156 | ||
157 | func (scope *Scope) whereSql() (sql string) { | |
158 | var primaryConditions, andConditions, orConditions []string | |
159 | ||
160 | if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil { | |
161 | sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName()) | |
162 | primaryConditions = append(primaryConditions, sql) | |
163 | } | |
164 | ||
165 | if !scope.PrimaryKeyZero() { | |
166 | for _, field := range scope.PrimaryFields() { | |
167 | sql := fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())) | |
168 | primaryConditions = append(primaryConditions, sql) | |
169 | } | |
170 | } | |
171 | ||
172 | for _, clause := range scope.Search.whereConditions { | |
173 | if sql := scope.buildWhereCondition(clause); sql != "" { | |
174 | andConditions = append(andConditions, sql) | |
175 | } | |
176 | } | |
177 | ||
178 | for _, clause := range scope.Search.orConditions { | |
179 | if sql := scope.buildWhereCondition(clause); sql != "" { | |
180 | orConditions = append(orConditions, sql) | |
181 | } | |
182 | } | |
183 | ||
184 | for _, clause := range scope.Search.notConditions { | |
185 | if sql := scope.buildNotCondition(clause); sql != "" { | |
186 | andConditions = append(andConditions, sql) | |
187 | } | |
188 | } | |
189 | ||
190 | orSql := strings.Join(orConditions, " OR ") | |
191 | combinedSql := strings.Join(andConditions, " AND ") | |
192 | if len(combinedSql) > 0 { | |
193 | if len(orSql) > 0 { | |
194 | combinedSql = combinedSql + " OR " + orSql | |
195 | } | |
196 | } else { | |
197 | combinedSql = orSql | |
198 | } | |
199 | ||
200 | if len(primaryConditions) > 0 { | |
201 | sql = "WHERE " + strings.Join(primaryConditions, " AND ") | |
202 | if len(combinedSql) > 0 { | |
203 | sql = sql + " AND (" + combinedSql + ")" | |
204 | } | |
205 | } else if len(combinedSql) > 0 { | |
206 | sql = "WHERE " + combinedSql | |
207 | } | |
208 | return | |
209 | } | |
210 | ||
211 | var hasCountRegexp = regexp.MustCompile(`(?i)count(.+)`) | |
212 | ||
213 | func (scope *Scope) selectSql() string { | |
214 | if len(scope.Search.selects) == 0 { | |
215 | if scope.Search.joins != "" { | |
216 | return fmt.Sprintf("%v.*", scope.QuotedTableName()) | |
217 | } | |
218 | return "*" | |
219 | } | |
220 | sql := scope.buildSelectQuery(scope.Search.selects) | |
221 | scope.Search.countingQuery = (len(scope.Search.group) == 0) && hasCountRegexp.MatchString(sql) | |
222 | return sql | |
223 | } | |
224 | ||
225 | func (scope *Scope) orderSql() string { | |
226 | if len(scope.Search.orders) == 0 || scope.Search.countingQuery { | |
227 | return "" | |
228 | } | |
229 | return " ORDER BY " + strings.Join(scope.Search.orders, ",") | |
230 | } | |
231 | ||
232 | func (scope *Scope) limitSql() string { | |
233 | if !scope.Dialect().HasTop() { | |
234 | if len(scope.Search.limit) == 0 { | |
235 | return "" | |
236 | } | |
237 | return " LIMIT " + scope.Search.limit | |
238 | } | |
239 | ||
240 | return "" | |
241 | } | |
242 | ||
243 | func (scope *Scope) topSql() string { | |
244 | if scope.Dialect().HasTop() && len(scope.Search.offset) == 0 { | |
245 | if len(scope.Search.limit) == 0 { | |
246 | return "" | |
247 | } | |
248 | return " TOP(" + scope.Search.limit + ")" | |
249 | } | |
250 | ||
251 | return "" | |
252 | } | |
253 | ||
254 | func (scope *Scope) offsetSql() string { | |
255 | if len(scope.Search.offset) == 0 { | |
256 | return "" | |
257 | } | |
258 | ||
259 | if scope.Dialect().HasTop() { | |
260 | sql := " OFFSET " + scope.Search.offset + " ROW " | |
261 | if len(scope.Search.limit) > 0 { | |
262 | sql += "FETCH NEXT " + scope.Search.limit + " ROWS ONLY" | |
263 | } | |
264 | return sql | |
265 | } | |
266 | return " OFFSET " + scope.Search.offset | |
267 | } | |
268 | ||
269 | func (scope *Scope) groupSql() string { | |
270 | if len(scope.Search.group) == 0 { | |
271 | return "" | |
272 | } | |
273 | return " GROUP BY " + scope.Search.group | |
274 | } | |
275 | ||
276 | func (scope *Scope) havingSql() string { | |
277 | if scope.Search.havingConditions == nil { | |
278 | return "" | |
279 | } | |
280 | ||
281 | var andConditions []string | |
282 | ||
283 | for _, clause := range scope.Search.havingConditions { | |
284 | if sql := scope.buildWhereCondition(clause); sql != "" { | |
285 | andConditions = append(andConditions, sql) | |
286 | } | |
287 | } | |
288 | ||
289 | combinedSql := strings.Join(andConditions, " AND ") | |
290 | if len(combinedSql) == 0 { | |
291 | return "" | |
292 | } | |
293 | ||
294 | return " HAVING " + combinedSql | |
295 | } | |
296 | ||
297 | func (scope *Scope) joinsSql() string { | |
298 | return scope.Search.joins + " " | |
299 | } | |
300 | ||
301 | func (scope *Scope) prepareQuerySql() { | |
302 | if scope.Search.raw { | |
303 | scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")")) | |
304 | } else { | |
305 | scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql())) | |
306 | } | |
307 | return | |
308 | } | |
309 | ||
310 | func (scope *Scope) inlineCondition(values ...interface{}) *Scope { | |
311 | if len(values) > 0 { | |
312 | scope.Search.Where(values[0], values[1:]...) | |
313 | } | |
314 | return scope | |
315 | } | |
316 | ||
317 | func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { | |
318 | for _, f := range funcs { | |
319 | (*f)(scope) | |
320 | if scope.skipLeft { | |
321 | break | |
322 | } | |
323 | } | |
324 | return scope | |
325 | } | |
326 | ||
327 | func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) { | |
328 | if !scope.IndirectValue().CanAddr() { | |
329 | return values, true | |
330 | } | |
331 | ||
332 | var hasExpr bool | |
333 | fields := scope.Fields() | |
334 | for key, value := range values { | |
335 | if field, ok := fields[ToDBName(key)]; ok && field.Field.IsValid() { | |
336 | if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) { | |
337 | if _, ok := value.(*expr); ok { | |
338 | hasExpr = true | |
339 | } else if !equalAsString(field.Field.Interface(), value) { | |
340 | hasUpdate = true | |
341 | field.Set(value) | |
342 | } | |
343 | } | |
344 | } | |
345 | } | |
346 | if hasExpr { | |
347 | var updateMap = map[string]interface{}{} | |
348 | for key, value := range fields { | |
349 | if v, ok := values[key]; ok { | |
350 | updateMap[key] = v | |
351 | } else { | |
352 | updateMap[key] = value.Field.Interface() | |
353 | } | |
354 | } | |
355 | return updateMap, true | |
356 | } | |
357 | return | |
358 | } | |
359 | ||
360 | func (scope *Scope) row() *sql.Row { | |
361 | defer scope.Trace(NowFunc()) | |
362 | scope.callCallbacks(scope.db.parent.callback.rowQueries) | |
363 | scope.prepareQuerySql() | |
364 | return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...) | |
365 | } | |
366 | ||
367 | func (scope *Scope) rows() (*sql.Rows, error) { | |
368 | defer scope.Trace(NowFunc()) | |
369 | scope.callCallbacks(scope.db.parent.callback.rowQueries) | |
370 | scope.prepareQuerySql() | |
371 | return scope.SqlDB().Query(scope.Sql, scope.SqlVars...) | |
372 | } | |
373 | ||
374 | func (scope *Scope) initialize() *Scope { | |
375 | for _, clause := range scope.Search.whereConditions { | |
376 | scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false) | |
377 | } | |
378 | scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false) | |
379 | scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false) | |
380 | return scope | |
381 | } | |
382 | ||
383 | func (scope *Scope) pluck(column string, value interface{}) *Scope { | |
384 | dest := reflect.Indirect(reflect.ValueOf(value)) | |
385 | scope.Search.Select(column) | |
386 | if dest.Kind() != reflect.Slice { | |
387 | scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) | |
388 | return scope | |
389 | } | |
390 | ||
391 | rows, err := scope.rows() | |
392 | if scope.Err(err) == nil { | |
393 | defer rows.Close() | |
394 | for rows.Next() { | |
395 | elem := reflect.New(dest.Type().Elem()).Interface() | |
396 | scope.Err(rows.Scan(elem)) | |
397 | dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) | |
398 | } | |
399 | } | |
400 | return scope | |
401 | } | |
402 | ||
403 | func (scope *Scope) count(value interface{}) *Scope { | |
404 | scope.Search.Select("count(*)") | |
405 | scope.Err(scope.row().Scan(value)) | |
406 | return scope | |
407 | } | |
408 | ||
409 | func (scope *Scope) typeName() string { | |
410 | value := scope.IndirectValue() | |
411 | if value.Kind() == reflect.Slice { | |
412 | return value.Type().Elem().Name() | |
413 | } | |
414 | ||
415 | return value.Type().Name() | |
416 | } | |
417 | ||
418 | func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { | |
419 | toScope := scope.db.NewScope(value) | |
420 | fromFields := scope.Fields() | |
421 | toFields := toScope.Fields() | |
422 | for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { | |
423 | var fromField, toField *Field | |
424 | if field, ok := scope.FieldByName(foreignKey); ok { | |
425 | fromField = field | |
426 | } else { | |
427 | fromField = fromFields[ToDBName(foreignKey)] | |
428 | } | |
429 | if field, ok := toScope.FieldByName(foreignKey); ok { | |
430 | toField = field | |
431 | } else { | |
432 | toField = toFields[ToDBName(foreignKey)] | |
433 | } | |
434 | ||
435 | if fromField != nil { | |
436 | if relationship := fromField.Relationship; relationship != nil { | |
437 | if relationship.Kind == "many_to_many" { | |
438 | joinTableHandler := relationship.JoinTableHandler | |
439 | scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) | |
440 | } else if relationship.Kind == "belongs_to" { | |
441 | query := toScope.db | |
442 | for idx, foreignKey := range relationship.ForeignDBNames { | |
443 | if field, ok := scope.FieldByName(foreignKey); ok { | |
444 | query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) | |
445 | } | |
446 | } | |
447 | scope.Err(query.Find(value).Error) | |
448 | } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { | |
449 | query := toScope.db | |
450 | for idx, foreignKey := range relationship.ForeignDBNames { | |
451 | if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { | |
452 | query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) | |
453 | } | |
454 | } | |
455 | ||
456 | if relationship.PolymorphicType != "" { | |
457 | query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) | |
458 | } | |
459 | scope.Err(query.Find(value).Error) | |
460 | } | |
461 | } else { | |
462 | sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) | |
463 | scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error) | |
464 | } | |
465 | return scope | |
466 | } else if toField != nil { | |
467 | sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) | |
468 | scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) | |
469 | return scope | |
470 | } | |
471 | } | |
472 | ||
473 | scope.Err(fmt.Errorf("invalid association %v", foreignKeys)) | |
474 | return scope | |
475 | } | |
476 | ||
477 | /** | |
478 | Return the table options string or an empty string if the table options does not exist | |
479 | */ | |
480 | func (scope *Scope) getTableOptions() string { | |
481 | tableOptions, ok := scope.Get("gorm:table_options") | |
482 | if !ok { | |
483 | return "" | |
484 | } | |
485 | return tableOptions.(string) | |
486 | } | |
487 | ||
488 | func (scope *Scope) createJoinTable(field *StructField) { | |
489 | if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { | |
490 | joinTableHandler := relationship.JoinTableHandler | |
491 | joinTable := joinTableHandler.Table(scope.db) | |
492 | if !scope.Dialect().HasTable(scope, joinTable) { | |
493 | toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} | |
494 | ||
495 | var sqlTypes []string | |
496 | for idx, fieldName := range relationship.ForeignFieldNames { | |
497 | if field, ok := scope.Fields()[fieldName]; ok { | |
498 | value := reflect.Indirect(reflect.New(field.Struct.Type)) | |
499 | primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false) | |
500 | sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType) | |
501 | } | |
502 | } | |
503 | ||
504 | for idx, fieldName := range relationship.AssociationForeignFieldNames { | |
505 | if field, ok := toScope.Fields()[fieldName]; ok { | |
506 | value := reflect.Indirect(reflect.New(field.Struct.Type)) | |
507 | primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false) | |
508 | sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType) | |
509 | } | |
510 | } | |
511 | scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), scope.getTableOptions())).Error) | |
512 | } | |
513 | scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) | |
514 | } | |
515 | } | |
516 | ||
517 | func (scope *Scope) createTable() *Scope { | |
518 | var tags []string | |
519 | var primaryKeys []string | |
520 | var primaryKeyInColumnType bool = false | |
521 | for _, field := range scope.GetStructFields() { | |
522 | if field.IsNormal { | |
523 | sqlTag := scope.generateSqlTag(field) | |
524 | ||
525 | // Check if the primary key constraint was specified as | |
526 | // part of the column type. If so, we can only support | |
527 | // one column as the primary key. | |
528 | if strings.Contains(strings.ToLower(sqlTag), "primary key") { | |
529 | primaryKeyInColumnType = true | |
530 | } | |
531 | ||
532 | tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag) | |
533 | } | |
534 | ||
535 | if field.IsPrimaryKey { | |
536 | primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) | |
537 | } | |
538 | scope.createJoinTable(field) | |
539 | } | |
540 | ||
541 | var primaryKeyStr string | |
542 | if len(primaryKeys) > 0 && !primaryKeyInColumnType { | |
543 | primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) | |
544 | } | |
545 | scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() | |
546 | return scope | |
547 | } | |
548 | ||
549 | func (scope *Scope) dropTable() *Scope { | |
550 | scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() | |
551 | return scope | |
552 | } | |
553 | ||
554 | func (scope *Scope) dropTableIfExists() *Scope { | |
555 | if scope.Dialect().HasTable(scope, scope.TableName()) { | |
556 | scope.dropTable() | |
557 | } | |
558 | return scope | |
559 | } | |
560 | ||
561 | func (scope *Scope) modifyColumn(column string, typ string) { | |
562 | scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() | |
563 | } | |
564 | ||
565 | func (scope *Scope) dropColumn(column string) { | |
566 | scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() | |
567 | } | |
568 | ||
569 | func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { | |
570 | if scope.Dialect().HasIndex(scope, scope.TableName(), indexName) { | |
571 | return | |
572 | } | |
573 | ||
574 | var columns []string | |
575 | for _, name := range column { | |
576 | columns = append(columns, scope.QuoteIfPossible(name)) | |
577 | } | |
578 | ||
579 | sqlCreate := "CREATE INDEX" | |
580 | if unique { | |
581 | sqlCreate = "CREATE UNIQUE INDEX" | |
582 | } | |
583 | ||
584 | scope.Search.Unscoped = true | |
585 | scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSql())).Exec() | |
586 | scope.Search.Unscoped = false | |
587 | } | |
588 | ||
589 | func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { | |
590 | var table = scope.TableName() | |
591 | var keyName = fmt.Sprintf("%s_%s_%s_foreign", table, field, regexp.MustCompile("[^a-zA-Z]").ReplaceAllString(dest, "_")) | |
592 | keyName = regexp.MustCompile("_+").ReplaceAllString(keyName, "_") | |
593 | var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` | |
594 | scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.QuoteIfPossible(keyName), scope.QuoteIfPossible(field), dest, onDelete, onUpdate)).Exec() | |
595 | } | |
596 | ||
597 | func (scope *Scope) removeIndex(indexName string) { | |
598 | scope.Dialect().RemoveIndex(scope, indexName) | |
599 | } | |
600 | ||
601 | func (scope *Scope) autoMigrate() *Scope { | |
602 | tableName := scope.TableName() | |
603 | quotedTableName := scope.QuotedTableName() | |
604 | ||
605 | if !scope.Dialect().HasTable(scope, tableName) { | |
606 | scope.createTable() | |
607 | } else { | |
608 | for _, field := range scope.GetStructFields() { | |
609 | if !scope.Dialect().HasColumn(scope, tableName, field.DBName) { | |
610 | if field.IsNormal { | |
611 | sqlTag := scope.generateSqlTag(field) | |
612 | scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() | |
613 | } | |
614 | } | |
615 | scope.createJoinTable(field) | |
616 | } | |
617 | } | |
618 | ||
619 | scope.autoIndex() | |
620 | return scope | |
621 | } | |
622 | ||
623 | func (scope *Scope) autoIndex() *Scope { | |
624 | var indexes = map[string][]string{} | |
625 | var uniqueIndexes = map[string][]string{} | |
626 | ||
627 | for _, field := range scope.GetStructFields() { | |
628 | sqlSettings := parseTagSetting(field.Tag.Get("sql")) | |
629 | if name, ok := sqlSettings["INDEX"]; ok { | |
630 | if name == "INDEX" { | |
631 | name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) | |
632 | } | |
633 | indexes[name] = append(indexes[name], field.DBName) | |
634 | } | |
635 | ||
636 | if name, ok := sqlSettings["UNIQUE_INDEX"]; ok { | |
637 | if name == "UNIQUE_INDEX" { | |
638 | name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) | |
639 | } | |
640 | uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) | |
641 | } | |
642 | } | |
643 | ||
644 | for name, columns := range indexes { | |
645 | scope.addIndex(false, name, columns...) | |
646 | } | |
647 | ||
648 | for name, columns := range uniqueIndexes { | |
649 | scope.addIndex(true, name, columns...) | |
650 | } | |
651 | ||
652 | return scope | |
653 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "github.com/jinzhu/gorm" | |
4 | "testing" | |
5 | ) | |
6 | ||
7 | func NameIn1And2(d *gorm.DB) *gorm.DB { | |
8 | return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}) | |
9 | } | |
10 | ||
11 | func NameIn2And3(d *gorm.DB) *gorm.DB { | |
12 | return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}) | |
13 | } | |
14 | ||
15 | func NameIn(names []string) func(d *gorm.DB) *gorm.DB { | |
16 | return func(d *gorm.DB) *gorm.DB { | |
17 | return d.Where("name in (?)", names) | |
18 | } | |
19 | } | |
20 | ||
21 | func TestScopes(t *testing.T) { | |
22 | user1 := User{Name: "ScopeUser1", Age: 1} | |
23 | user2 := User{Name: "ScopeUser2", Age: 1} | |
24 | user3 := User{Name: "ScopeUser3", Age: 2} | |
25 | DB.Save(&user1).Save(&user2).Save(&user3) | |
26 | ||
27 | var users1, users2, users3 []User | |
28 | DB.Scopes(NameIn1And2).Find(&users1) | |
29 | if len(users1) != 2 { | |
30 | t.Errorf("Should found two users's name in 1, 2") | |
31 | } | |
32 | ||
33 | DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2) | |
34 | if len(users2) != 1 { | |
35 | t.Errorf("Should found one user's name is 2") | |
36 | } | |
37 | ||
38 | DB.Scopes(NameIn([]string{user1.Name, user3.Name})).Find(&users3) | |
39 | if len(users3) != 2 { | |
40 | t.Errorf("Should found two users's name in 1, 3") | |
41 | } | |
42 | } |
0 | package gorm | |
1 | ||
2 | import "fmt" | |
3 | ||
4 | type search struct { | |
5 | db *DB | |
6 | whereConditions []map[string]interface{} | |
7 | orConditions []map[string]interface{} | |
8 | notConditions []map[string]interface{} | |
9 | havingConditions []map[string]interface{} | |
10 | initAttrs []interface{} | |
11 | assignAttrs []interface{} | |
12 | selects map[string]interface{} | |
13 | omits []string | |
14 | orders []string | |
15 | joins string | |
16 | preload []searchPreload | |
17 | offset string | |
18 | limit string | |
19 | group string | |
20 | tableName string | |
21 | raw bool | |
22 | Unscoped bool | |
23 | countingQuery bool | |
24 | } | |
25 | ||
26 | type searchPreload struct { | |
27 | schema string | |
28 | conditions []interface{} | |
29 | } | |
30 | ||
31 | func (s *search) clone() *search { | |
32 | clone := *s | |
33 | return &clone | |
34 | } | |
35 | ||
36 | func (s *search) Where(query interface{}, values ...interface{}) *search { | |
37 | s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values}) | |
38 | return s | |
39 | } | |
40 | ||
41 | func (s *search) Not(query interface{}, values ...interface{}) *search { | |
42 | s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values}) | |
43 | return s | |
44 | } | |
45 | ||
46 | func (s *search) Or(query interface{}, values ...interface{}) *search { | |
47 | s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values}) | |
48 | return s | |
49 | } | |
50 | ||
51 | func (s *search) Attrs(attrs ...interface{}) *search { | |
52 | s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...)) | |
53 | return s | |
54 | } | |
55 | ||
56 | func (s *search) Assign(attrs ...interface{}) *search { | |
57 | s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...)) | |
58 | return s | |
59 | } | |
60 | ||
61 | func (s *search) Order(value string, reorder ...bool) *search { | |
62 | if len(reorder) > 0 && reorder[0] { | |
63 | if value != "" { | |
64 | s.orders = []string{value} | |
65 | } else { | |
66 | s.orders = []string{} | |
67 | } | |
68 | } else if value != "" { | |
69 | s.orders = append(s.orders, value) | |
70 | } | |
71 | return s | |
72 | } | |
73 | ||
74 | func (s *search) Select(query interface{}, args ...interface{}) *search { | |
75 | s.selects = map[string]interface{}{"query": query, "args": args} | |
76 | return s | |
77 | } | |
78 | ||
79 | func (s *search) Omit(columns ...string) *search { | |
80 | s.omits = columns | |
81 | return s | |
82 | } | |
83 | ||
84 | func (s *search) Limit(value interface{}) *search { | |
85 | s.limit = s.getInterfaceAsSql(value) | |
86 | return s | |
87 | } | |
88 | ||
89 | func (s *search) Offset(value interface{}) *search { | |
90 | s.offset = s.getInterfaceAsSql(value) | |
91 | return s | |
92 | } | |
93 | ||
94 | func (s *search) Group(query string) *search { | |
95 | s.group = s.getInterfaceAsSql(query) | |
96 | return s | |
97 | } | |
98 | ||
99 | func (s *search) Having(query string, values ...interface{}) *search { | |
100 | s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) | |
101 | return s | |
102 | } | |
103 | ||
104 | func (s *search) Joins(query string) *search { | |
105 | s.joins = query | |
106 | return s | |
107 | } | |
108 | ||
109 | func (s *search) Preload(schema string, values ...interface{}) *search { | |
110 | var preloads []searchPreload | |
111 | for _, preload := range s.preload { | |
112 | if preload.schema != schema { | |
113 | preloads = append(preloads, preload) | |
114 | } | |
115 | } | |
116 | preloads = append(preloads, searchPreload{schema, values}) | |
117 | s.preload = preloads | |
118 | return s | |
119 | } | |
120 | ||
121 | func (s *search) Raw(b bool) *search { | |
122 | s.raw = b | |
123 | return s | |
124 | } | |
125 | ||
126 | func (s *search) unscoped() *search { | |
127 | s.Unscoped = true | |
128 | return s | |
129 | } | |
130 | ||
131 | func (s *search) Table(name string) *search { | |
132 | s.tableName = name | |
133 | return s | |
134 | } | |
135 | ||
136 | func (s *search) getInterfaceAsSql(value interface{}) (str string) { | |
137 | switch value.(type) { | |
138 | case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: | |
139 | str = fmt.Sprintf("%v", value) | |
140 | default: | |
141 | s.db.AddError(InvalidSql) | |
142 | } | |
143 | ||
144 | if str == "-1" { | |
145 | return "" | |
146 | } | |
147 | return | |
148 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "reflect" | |
4 | "testing" | |
5 | ) | |
6 | ||
7 | func TestCloneSearch(t *testing.T) { | |
8 | s := new(search) | |
9 | s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Select("name, age") | |
10 | ||
11 | s1 := s.clone() | |
12 | s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Select("email") | |
13 | ||
14 | if reflect.DeepEqual(s.whereConditions, s1.whereConditions) { | |
15 | t.Errorf("Where should be copied") | |
16 | } | |
17 | ||
18 | if reflect.DeepEqual(s.orders, s1.orders) { | |
19 | t.Errorf("Order should be copied") | |
20 | } | |
21 | ||
22 | if reflect.DeepEqual(s.initAttrs, s1.initAttrs) { | |
23 | t.Errorf("InitAttrs should be copied") | |
24 | } | |
25 | ||
26 | if reflect.DeepEqual(s.Select, s1.Select) { | |
27 | t.Errorf("selectStr should be copied") | |
28 | } | |
29 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "database/sql/driver" | |
4 | "encoding/json" | |
5 | "testing" | |
6 | ) | |
7 | ||
8 | func TestScannableSlices(t *testing.T) { | |
9 | if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil { | |
10 | t.Errorf("Should create table with slice values correctly: %s", err) | |
11 | } | |
12 | ||
13 | r1 := RecordWithSlice{ | |
14 | Strings: ExampleStringSlice{"a", "b", "c"}, | |
15 | Structs: ExampleStructSlice{ | |
16 | {"name1", "value1"}, | |
17 | {"name2", "value2"}, | |
18 | }, | |
19 | } | |
20 | ||
21 | if err := DB.Save(&r1).Error; err != nil { | |
22 | t.Errorf("Should save record with slice values") | |
23 | } | |
24 | ||
25 | var r2 RecordWithSlice | |
26 | ||
27 | if err := DB.Find(&r2).Error; err != nil { | |
28 | t.Errorf("Should fetch record with slice values") | |
29 | } | |
30 | ||
31 | if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" { | |
32 | t.Errorf("Should have serialised and deserialised a string array") | |
33 | } | |
34 | ||
35 | if len(r2.Structs) != 2 || r2.Structs[0].Name != "name1" || r2.Structs[0].Value != "value1" || r2.Structs[1].Name != "name2" || r2.Structs[1].Value != "value2" { | |
36 | t.Errorf("Should have serialised and deserialised a struct array") | |
37 | } | |
38 | } | |
39 | ||
40 | type RecordWithSlice struct { | |
41 | ID uint64 | |
42 | Strings ExampleStringSlice `sql:"type:text"` | |
43 | Structs ExampleStructSlice `sql:"type:text"` | |
44 | } | |
45 | ||
46 | type ExampleStringSlice []string | |
47 | ||
48 | func (l ExampleStringSlice) Value() (driver.Value, error) { | |
49 | return json.Marshal(l) | |
50 | } | |
51 | ||
52 | func (l *ExampleStringSlice) Scan(input interface{}) error { | |
53 | return json.Unmarshal(input.([]byte), l) | |
54 | } | |
55 | ||
56 | type ExampleStruct struct { | |
57 | Name string | |
58 | Value string | |
59 | } | |
60 | ||
61 | type ExampleStructSlice []ExampleStruct | |
62 | ||
63 | func (l ExampleStructSlice) Value() (driver.Value, error) { | |
64 | return json.Marshal(l) | |
65 | } | |
66 | ||
67 | func (l *ExampleStructSlice) Scan(input interface{}) error { | |
68 | return json.Unmarshal(input.([]byte), l) | |
69 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "reflect" | |
5 | "time" | |
6 | ) | |
7 | ||
8 | type sqlite3 struct { | |
9 | commonDialect | |
10 | } | |
11 | ||
12 | func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string { | |
13 | switch value.Kind() { | |
14 | case reflect.Bool: | |
15 | return "bool" | |
16 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | |
17 | if autoIncrease { | |
18 | return "integer primary key autoincrement" | |
19 | } | |
20 | return "integer" | |
21 | case reflect.Int64, reflect.Uint64: | |
22 | if autoIncrease { | |
23 | return "integer primary key autoincrement" | |
24 | } | |
25 | return "bigint" | |
26 | case reflect.Float32, reflect.Float64: | |
27 | return "real" | |
28 | case reflect.String: | |
29 | if size > 0 && size < 65532 { | |
30 | return fmt.Sprintf("varchar(%d)", size) | |
31 | } | |
32 | return "text" | |
33 | case reflect.Struct: | |
34 | if _, ok := value.Interface().(time.Time); ok { | |
35 | return "datetime" | |
36 | } | |
37 | default: | |
38 | if _, ok := value.Interface().([]byte); ok { | |
39 | return "blob" | |
40 | } | |
41 | } | |
42 | panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String())) | |
43 | } | |
44 | ||
45 | func (s sqlite3) HasTable(scope *Scope, tableName string) bool { | |
46 | var count int | |
47 | s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName) | |
48 | return count > 0 | |
49 | } | |
50 | ||
51 | func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool { | |
52 | var count int | |
53 | s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", columnName, columnName, columnName, columnName), tableName) | |
54 | return count > 0 | |
55 | } | |
56 | ||
57 | func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { | |
58 | var count int | |
59 | s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName) | |
60 | return count > 0 | |
61 | } | |
62 | ||
63 | func (sqlite3) RemoveIndex(scope *Scope, indexName string) { | |
64 | scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) | |
65 | } | |
66 | ||
67 | func (sqlite3) CurrentDatabase(scope *Scope) (name string) { | |
68 | var ( | |
69 | ifaces = make([]interface{}, 3) | |
70 | pointers = make([]*string, 3) | |
71 | i int | |
72 | ) | |
73 | for i = 0; i < 3; i++ { | |
74 | ifaces[i] = &pointers[i] | |
75 | } | |
76 | if err := scope.NewDB().Raw("PRAGMA database_list").Row().Scan(ifaces...); scope.Err(err) != nil { | |
77 | return | |
78 | } | |
79 | if pointers[1] != nil { | |
80 | name = *pointers[1] | |
81 | } | |
82 | return | |
83 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "database/sql" | |
4 | "database/sql/driver" | |
5 | "errors" | |
6 | "fmt" | |
7 | ||
8 | "reflect" | |
9 | "time" | |
10 | ) | |
11 | ||
12 | type User struct { | |
13 | Id int64 | |
14 | Age int64 | |
15 | UserNum Num | |
16 | Name string `sql:"size:255"` | |
17 | Birthday time.Time // Time | |
18 | CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically | |
19 | UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically | |
20 | Emails []Email // Embedded structs | |
21 | BillingAddress Address // Embedded struct | |
22 | BillingAddressID sql.NullInt64 // Embedded struct's foreign key | |
23 | ShippingAddress Address // Embedded struct | |
24 | ShippingAddressId int64 // Embedded struct's foreign key | |
25 | CreditCard CreditCard | |
26 | Latitude float64 | |
27 | Languages []Language `gorm:"many2many:user_languages;"` | |
28 | CompanyID int64 | |
29 | Company Company | |
30 | Role | |
31 | PasswordHash []byte | |
32 | IgnoreMe int64 `sql:"-"` | |
33 | IgnoreStringSlice []string `sql:"-"` | |
34 | Ignored struct{ Name string } `sql:"-"` | |
35 | IgnoredPointer *User `sql:"-"` | |
36 | } | |
37 | ||
38 | type CreditCard struct { | |
39 | ID int8 | |
40 | Number string | |
41 | UserId sql.NullInt64 | |
42 | CreatedAt time.Time | |
43 | UpdatedAt time.Time | |
44 | DeletedAt time.Time | |
45 | } | |
46 | ||
47 | type Email struct { | |
48 | Id int16 | |
49 | UserId int | |
50 | Email string `sql:"type:varchar(100);"` | |
51 | CreatedAt time.Time | |
52 | UpdatedAt time.Time | |
53 | } | |
54 | ||
55 | type Address struct { | |
56 | ID int | |
57 | Address1 string | |
58 | Address2 string | |
59 | Post string | |
60 | CreatedAt time.Time | |
61 | UpdatedAt time.Time | |
62 | DeletedAt time.Time | |
63 | } | |
64 | ||
65 | type Language struct { | |
66 | Id int | |
67 | Name string | |
68 | Users []User `gorm:"many2many:user_languages;"` | |
69 | } | |
70 | ||
71 | type Product struct { | |
72 | Id int64 | |
73 | Code string | |
74 | Price int64 | |
75 | CreatedAt time.Time | |
76 | UpdatedAt time.Time | |
77 | AfterFindCallTimes int64 | |
78 | BeforeCreateCallTimes int64 | |
79 | AfterCreateCallTimes int64 | |
80 | BeforeUpdateCallTimes int64 | |
81 | AfterUpdateCallTimes int64 | |
82 | BeforeSaveCallTimes int64 | |
83 | AfterSaveCallTimes int64 | |
84 | BeforeDeleteCallTimes int64 | |
85 | AfterDeleteCallTimes int64 | |
86 | } | |
87 | ||
88 | type Company struct { | |
89 | Id int64 | |
90 | Name string | |
91 | Owner *User `sql:"-"` | |
92 | } | |
93 | ||
94 | type Role struct { | |
95 | Name string | |
96 | } | |
97 | ||
98 | func (role *Role) Scan(value interface{}) error { | |
99 | if b, ok := value.([]uint8); ok { | |
100 | role.Name = string(b) | |
101 | } else { | |
102 | role.Name = value.(string) | |
103 | } | |
104 | return nil | |
105 | } | |
106 | ||
107 | func (role Role) Value() (driver.Value, error) { | |
108 | return role.Name, nil | |
109 | } | |
110 | ||
111 | func (role Role) IsAdmin() bool { | |
112 | return role.Name == "admin" | |
113 | } | |
114 | ||
115 | type Num int64 | |
116 | ||
117 | func (i *Num) Scan(src interface{}) error { | |
118 | switch s := src.(type) { | |
119 | case []byte: | |
120 | case int64: | |
121 | *i = Num(s) | |
122 | default: | |
123 | return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) | |
124 | } | |
125 | return nil | |
126 | } | |
127 | ||
128 | type Animal struct { | |
129 | Counter uint64 `gorm:"primary_key:yes"` | |
130 | Name string `sql:"DEFAULT:'galeone'"` | |
131 | From string //test reserved sql keyword as field name | |
132 | Age time.Time `sql:"DEFAULT:current_timestamp"` | |
133 | unexported string // unexported value | |
134 | CreatedAt time.Time | |
135 | UpdatedAt time.Time | |
136 | } | |
137 | ||
138 | type JoinTable struct { | |
139 | From uint64 | |
140 | To uint64 | |
141 | Time time.Time `sql:"default: null"` | |
142 | } | |
143 | ||
144 | type Post struct { | |
145 | Id int64 | |
146 | CategoryId sql.NullInt64 | |
147 | MainCategoryId int64 | |
148 | Title string | |
149 | Body string | |
150 | Comments []*Comment | |
151 | Category Category | |
152 | MainCategory Category | |
153 | } | |
154 | ||
155 | type Category struct { | |
156 | Id int64 | |
157 | Name string | |
158 | } | |
159 | ||
160 | type Comment struct { | |
161 | Id int64 | |
162 | PostId int64 | |
163 | Content string | |
164 | Post Post | |
165 | } | |
166 | ||
167 | // Scanner | |
168 | type NullValue struct { | |
169 | Id int64 | |
170 | Name sql.NullString `sql:"not null"` | |
171 | Age sql.NullInt64 | |
172 | Male sql.NullBool | |
173 | Height sql.NullFloat64 | |
174 | AddedAt NullTime | |
175 | } | |
176 | ||
177 | type NullTime struct { | |
178 | Time time.Time | |
179 | Valid bool | |
180 | } | |
181 | ||
182 | func (nt *NullTime) Scan(value interface{}) error { | |
183 | if value == nil { | |
184 | nt.Valid = false | |
185 | return nil | |
186 | } | |
187 | nt.Time, nt.Valid = value.(time.Time), true | |
188 | return nil | |
189 | } | |
190 | ||
191 | func (nt NullTime) Value() (driver.Value, error) { | |
192 | if !nt.Valid { | |
193 | return nil, nil | |
194 | } | |
195 | return nt.Time, nil | |
196 | } | |
197 | ||
198 | func getPreparedUser(name string, role string) *User { | |
199 | var company Company | |
200 | DB.Where(Company{Name: role}).FirstOrCreate(&company) | |
201 | ||
202 | return &User{ | |
203 | Name: name, | |
204 | Age: 20, | |
205 | Role: Role{role}, | |
206 | BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)}, | |
207 | ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, | |
208 | CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)}, | |
209 | Emails: []Email{ | |
210 | {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, | |
211 | }, | |
212 | Company: company, | |
213 | Languages: []Language{ | |
214 | {Name: fmt.Sprintf("lang_1_%v", name)}, | |
215 | {Name: fmt.Sprintf("lang_2_%v", name)}, | |
216 | }, | |
217 | } | |
218 | } |
0 | dialects=("postgres" "mysql" "sqlite") | |
1 | ||
2 | for dialect in "${dialects[@]}" ; do | |
3 | GORM_DIALECT=${dialect} go test | |
4 | done |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "testing" | |
4 | "time" | |
5 | ||
6 | "github.com/jinzhu/gorm" | |
7 | ) | |
8 | ||
9 | func TestUpdate(t *testing.T) { | |
10 | product1 := Product{Code: "product1code"} | |
11 | product2 := Product{Code: "product2code"} | |
12 | ||
13 | DB.Save(&product1).Save(&product2).Update("code", "product2newcode") | |
14 | ||
15 | if product2.Code != "product2newcode" { | |
16 | t.Errorf("Record should be updated") | |
17 | } | |
18 | ||
19 | DB.First(&product1, product1.Id) | |
20 | DB.First(&product2, product2.Id) | |
21 | updatedAt1 := product1.UpdatedAt | |
22 | updatedAt2 := product2.UpdatedAt | |
23 | ||
24 | var product3 Product | |
25 | DB.First(&product3, product2.Id).Update("code", "product2newcode") | |
26 | if updatedAt2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) { | |
27 | t.Errorf("updatedAt should not be updated if nothing changed") | |
28 | } | |
29 | ||
30 | if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() { | |
31 | t.Errorf("Product1 should not be updated") | |
32 | } | |
33 | ||
34 | if !DB.First(&Product{}, "code = ?", "product2code").RecordNotFound() { | |
35 | t.Errorf("Product2's code should be updated") | |
36 | } | |
37 | ||
38 | if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { | |
39 | t.Errorf("Product2's code should be updated") | |
40 | } | |
41 | ||
42 | DB.Table("products").Where("code in (?)", []string{"product1code"}).Update("code", "product1newcode") | |
43 | ||
44 | var product4 Product | |
45 | DB.First(&product4, product1.Id) | |
46 | if updatedAt1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { | |
47 | t.Errorf("updatedAt should be updated if something changed") | |
48 | } | |
49 | ||
50 | if !DB.First(&Product{}, "code = 'product1code'").RecordNotFound() { | |
51 | t.Errorf("Product1's code should be updated") | |
52 | } | |
53 | ||
54 | if DB.First(&Product{}, "code = 'product1newcode'").RecordNotFound() { | |
55 | t.Errorf("Product should not be changed to 789") | |
56 | } | |
57 | ||
58 | if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil { | |
59 | t.Error("No error should raise when update with CamelCase") | |
60 | } | |
61 | ||
62 | if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { | |
63 | t.Error("No error should raise when update_column with CamelCase") | |
64 | } | |
65 | ||
66 | var products []Product | |
67 | DB.Find(&products) | |
68 | if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) { | |
69 | t.Error("RowsAffected should be correct when do batch update") | |
70 | } | |
71 | ||
72 | DB.First(&product4, product4.Id) | |
73 | DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50)) | |
74 | var product5 Product | |
75 | DB.First(&product5, product4.Id) | |
76 | if product5.Price != product4.Price+100-50 { | |
77 | t.Errorf("Update with expression") | |
78 | } | |
79 | if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) { | |
80 | t.Errorf("Update with expression should update UpdatedAt") | |
81 | } | |
82 | } | |
83 | ||
84 | func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { | |
85 | animal := Animal{Name: "Ferdinand"} | |
86 | DB.Save(&animal) | |
87 | updatedAt1 := animal.UpdatedAt | |
88 | ||
89 | DB.Save(&animal).Update("name", "Francis") | |
90 | ||
91 | if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) { | |
92 | t.Errorf("updatedAt should not be updated if nothing changed") | |
93 | } | |
94 | ||
95 | var animals []Animal | |
96 | DB.Find(&animals) | |
97 | if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { | |
98 | t.Error("RowsAffected should be correct when do batch update") | |
99 | } | |
100 | ||
101 | animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone) | |
102 | DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched | |
103 | DB.First(&animal, animal.Counter) | |
104 | if animal.Name != "galeone" { | |
105 | t.Errorf("Name fiels shouldn't be changed if untouched, but got %v", animal.Name) | |
106 | } | |
107 | ||
108 | // When changing a field with a default value, the change must occur | |
109 | animal.Name = "amazing horse" | |
110 | DB.Save(&animal) | |
111 | DB.First(&animal, animal.Counter) | |
112 | if animal.Name != "amazing horse" { | |
113 | t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name) | |
114 | } | |
115 | ||
116 | // When changing a field with a default value with blank value | |
117 | animal.Name = "" | |
118 | DB.Save(&animal) | |
119 | DB.First(&animal, animal.Counter) | |
120 | if animal.Name != "" { | |
121 | t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name) | |
122 | } | |
123 | } | |
124 | ||
125 | func TestUpdates(t *testing.T) { | |
126 | product1 := Product{Code: "product1code", Price: 10} | |
127 | product2 := Product{Code: "product2code", Price: 10} | |
128 | DB.Save(&product1).Save(&product2) | |
129 | DB.Model(&product1).Updates(map[string]interface{}{"code": "product1newcode", "price": 100}) | |
130 | if product1.Code != "product1newcode" || product1.Price != 100 { | |
131 | t.Errorf("Record should be updated also with map") | |
132 | } | |
133 | ||
134 | DB.First(&product1, product1.Id) | |
135 | DB.First(&product2, product2.Id) | |
136 | updatedAt1 := product1.UpdatedAt | |
137 | updatedAt2 := product2.UpdatedAt | |
138 | ||
139 | var product3 Product | |
140 | DB.First(&product3, product1.Id).Updates(Product{Code: "product1newcode", Price: 100}) | |
141 | if product3.Code != "product1newcode" || product3.Price != 100 { | |
142 | t.Errorf("Record should be updated with struct") | |
143 | } | |
144 | ||
145 | if updatedAt1.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) { | |
146 | t.Errorf("updatedAt should not be updated if nothing changed") | |
147 | } | |
148 | ||
149 | if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() { | |
150 | t.Errorf("Product2 should not be updated") | |
151 | } | |
152 | ||
153 | if DB.First(&Product{}, "code = ?", "product1newcode").RecordNotFound() { | |
154 | t.Errorf("Product1 should be updated") | |
155 | } | |
156 | ||
157 | DB.Table("products").Where("code in (?)", []string{"product2code"}).Updates(Product{Code: "product2newcode"}) | |
158 | if !DB.First(&Product{}, "code = 'product2code'").RecordNotFound() { | |
159 | t.Errorf("Product2's code should be updated") | |
160 | } | |
161 | ||
162 | var product4 Product | |
163 | DB.First(&product4, product2.Id) | |
164 | if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { | |
165 | t.Errorf("updatedAt should be updated if something changed") | |
166 | } | |
167 | ||
168 | if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { | |
169 | t.Errorf("product2's code should be updated") | |
170 | } | |
171 | ||
172 | DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)}) | |
173 | var product5 Product | |
174 | DB.First(&product5, product4.Id) | |
175 | if product5.Price != product4.Price+100 { | |
176 | t.Errorf("Updates with expression") | |
177 | } | |
178 | if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) { | |
179 | t.Errorf("Updates with expression should update UpdatedAt") | |
180 | } | |
181 | } | |
182 | ||
183 | func TestUpdateColumn(t *testing.T) { | |
184 | product1 := Product{Code: "product1code", Price: 10} | |
185 | product2 := Product{Code: "product2code", Price: 20} | |
186 | DB.Save(&product1).Save(&product2).UpdateColumn(map[string]interface{}{"code": "product2newcode", "price": 100}) | |
187 | if product2.Code != "product2newcode" || product2.Price != 100 { | |
188 | t.Errorf("product 2 should be updated with update column") | |
189 | } | |
190 | ||
191 | var product3 Product | |
192 | DB.First(&product3, product1.Id) | |
193 | if product3.Code != "product1code" || product3.Price != 10 { | |
194 | t.Errorf("product 1 should not be updated") | |
195 | } | |
196 | ||
197 | DB.First(&product2, product2.Id) | |
198 | updatedAt2 := product2.UpdatedAt | |
199 | DB.Model(product2).UpdateColumn("code", "update_column_new") | |
200 | var product4 Product | |
201 | DB.First(&product4, product2.Id) | |
202 | if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { | |
203 | t.Errorf("updatedAt should not be updated with update column") | |
204 | } | |
205 | ||
206 | DB.Model(&product4).UpdateColumn("price", gorm.Expr("price + 100 - 50")) | |
207 | var product5 Product | |
208 | DB.First(&product5, product4.Id) | |
209 | if product5.Price != product4.Price+100-50 { | |
210 | t.Errorf("UpdateColumn with expression") | |
211 | } | |
212 | if product5.UpdatedAt.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { | |
213 | t.Errorf("UpdateColumn with expression should not update UpdatedAt") | |
214 | } | |
215 | } | |
216 | ||
217 | func TestSelectWithUpdate(t *testing.T) { | |
218 | user := getPreparedUser("select_user", "select_with_update") | |
219 | DB.Create(user) | |
220 | ||
221 | var reloadUser User | |
222 | DB.First(&reloadUser, user.Id) | |
223 | reloadUser.Name = "new_name" | |
224 | reloadUser.Age = 50 | |
225 | reloadUser.BillingAddress = Address{Address1: "New Billing Address"} | |
226 | reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} | |
227 | reloadUser.CreditCard = CreditCard{Number: "987654321"} | |
228 | reloadUser.Emails = []Email{ | |
229 | {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, | |
230 | } | |
231 | reloadUser.Company = Company{Name: "new company"} | |
232 | ||
233 | DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) | |
234 | ||
235 | var queryUser User | |
236 | DB.Preload("BillingAddress").Preload("ShippingAddress"). | |
237 | Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) | |
238 | ||
239 | if queryUser.Name == user.Name || queryUser.Age != user.Age { | |
240 | t.Errorf("Should only update users with name column") | |
241 | } | |
242 | ||
243 | if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || | |
244 | queryUser.ShippingAddressId != user.ShippingAddressId || | |
245 | queryUser.CreditCard.ID == user.CreditCard.ID || | |
246 | len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { | |
247 | t.Errorf("Should only update selected relationships") | |
248 | } | |
249 | } | |
250 | ||
251 | func TestSelectWithUpdateWithMap(t *testing.T) { | |
252 | user := getPreparedUser("select_user", "select_with_update_map") | |
253 | DB.Create(user) | |
254 | ||
255 | updateValues := map[string]interface{}{ | |
256 | "Name": "new_name", | |
257 | "Age": 50, | |
258 | "BillingAddress": Address{Address1: "New Billing Address"}, | |
259 | "ShippingAddress": Address{Address1: "New ShippingAddress Address"}, | |
260 | "CreditCard": CreditCard{Number: "987654321"}, | |
261 | "Emails": []Email{ | |
262 | {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, | |
263 | }, | |
264 | "Company": Company{Name: "new company"}, | |
265 | } | |
266 | ||
267 | var reloadUser User | |
268 | DB.First(&reloadUser, user.Id) | |
269 | DB.Model(&reloadUser).Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) | |
270 | ||
271 | var queryUser User | |
272 | DB.Preload("BillingAddress").Preload("ShippingAddress"). | |
273 | Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) | |
274 | ||
275 | if queryUser.Name == user.Name || queryUser.Age != user.Age { | |
276 | t.Errorf("Should only update users with name column") | |
277 | } | |
278 | ||
279 | if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || | |
280 | queryUser.ShippingAddressId != user.ShippingAddressId || | |
281 | queryUser.CreditCard.ID == user.CreditCard.ID || | |
282 | len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { | |
283 | t.Errorf("Should only update selected relationships") | |
284 | } | |
285 | } | |
286 | ||
287 | func TestOmitWithUpdate(t *testing.T) { | |
288 | user := getPreparedUser("omit_user", "omit_with_update") | |
289 | DB.Create(user) | |
290 | ||
291 | var reloadUser User | |
292 | DB.First(&reloadUser, user.Id) | |
293 | reloadUser.Name = "new_name" | |
294 | reloadUser.Age = 50 | |
295 | reloadUser.BillingAddress = Address{Address1: "New Billing Address"} | |
296 | reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} | |
297 | reloadUser.CreditCard = CreditCard{Number: "987654321"} | |
298 | reloadUser.Emails = []Email{ | |
299 | {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, | |
300 | } | |
301 | reloadUser.Company = Company{Name: "new company"} | |
302 | ||
303 | DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) | |
304 | ||
305 | var queryUser User | |
306 | DB.Preload("BillingAddress").Preload("ShippingAddress"). | |
307 | Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) | |
308 | ||
309 | if queryUser.Name != user.Name || queryUser.Age == user.Age { | |
310 | t.Errorf("Should only update users with name column") | |
311 | } | |
312 | ||
313 | if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || | |
314 | queryUser.ShippingAddressId == user.ShippingAddressId || | |
315 | queryUser.CreditCard.ID != user.CreditCard.ID || | |
316 | len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { | |
317 | t.Errorf("Should only update relationships that not omited") | |
318 | } | |
319 | } | |
320 | ||
321 | func TestOmitWithUpdateWithMap(t *testing.T) { | |
322 | user := getPreparedUser("select_user", "select_with_update_map") | |
323 | DB.Create(user) | |
324 | ||
325 | updateValues := map[string]interface{}{ | |
326 | "Name": "new_name", | |
327 | "Age": 50, | |
328 | "BillingAddress": Address{Address1: "New Billing Address"}, | |
329 | "ShippingAddress": Address{Address1: "New ShippingAddress Address"}, | |
330 | "CreditCard": CreditCard{Number: "987654321"}, | |
331 | "Emails": []Email{ | |
332 | {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, | |
333 | }, | |
334 | "Company": Company{Name: "new company"}, | |
335 | } | |
336 | ||
337 | var reloadUser User | |
338 | DB.First(&reloadUser, user.Id) | |
339 | DB.Model(&reloadUser).Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) | |
340 | ||
341 | var queryUser User | |
342 | DB.Preload("BillingAddress").Preload("ShippingAddress"). | |
343 | Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) | |
344 | ||
345 | if queryUser.Name != user.Name || queryUser.Age == user.Age { | |
346 | t.Errorf("Should only update users with name column") | |
347 | } | |
348 | ||
349 | if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || | |
350 | queryUser.ShippingAddressId == user.ShippingAddressId || | |
351 | queryUser.CreditCard.ID != user.CreditCard.ID || | |
352 | len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { | |
353 | t.Errorf("Should only update relationships not omited") | |
354 | } | |
355 | } | |
356 | ||
357 | func TestSelectWithUpdateColumn(t *testing.T) { | |
358 | user := getPreparedUser("select_user", "select_with_update_map") | |
359 | DB.Create(user) | |
360 | ||
361 | updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} | |
362 | ||
363 | var reloadUser User | |
364 | DB.First(&reloadUser, user.Id) | |
365 | DB.Model(&reloadUser).Select("Name").UpdateColumn(updateValues) | |
366 | ||
367 | var queryUser User | |
368 | DB.First(&queryUser, user.Id) | |
369 | ||
370 | if queryUser.Name == user.Name || queryUser.Age != user.Age { | |
371 | t.Errorf("Should only update users with name column") | |
372 | } | |
373 | } | |
374 | ||
375 | func TestOmitWithUpdateColumn(t *testing.T) { | |
376 | user := getPreparedUser("select_user", "select_with_update_map") | |
377 | DB.Create(user) | |
378 | ||
379 | updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} | |
380 | ||
381 | var reloadUser User | |
382 | DB.First(&reloadUser, user.Id) | |
383 | DB.Model(&reloadUser).Omit("Name").UpdateColumn(updateValues) | |
384 | ||
385 | var queryUser User | |
386 | DB.First(&queryUser, user.Id) | |
387 | ||
388 | if queryUser.Name != user.Name || queryUser.Age == user.Age { | |
389 | t.Errorf("Should omit name column when update user") | |
390 | } | |
391 | } | |
392 | ||
393 | func TestUpdateColumnsSkipsAssociations(t *testing.T) { | |
394 | user := getPreparedUser("update_columns_user", "special_role") | |
395 | user.Age = 99 | |
396 | address1 := "first street" | |
397 | user.BillingAddress = Address{Address1: address1} | |
398 | DB.Save(user) | |
399 | ||
400 | // Update a single field of the user and verify that the changed address is not stored. | |
401 | newAge := int64(100) | |
402 | user.BillingAddress.Address1 = "second street" | |
403 | db := DB.Model(user).UpdateColumns(User{Age: newAge}) | |
404 | if db.RowsAffected != 1 { | |
405 | t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected) | |
406 | } | |
407 | ||
408 | // Verify that Age now=`newAge`. | |
409 | freshUser := &User{Id: user.Id} | |
410 | DB.First(freshUser) | |
411 | if freshUser.Age != newAge { | |
412 | t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, freshUser.Age) | |
413 | } | |
414 | ||
415 | // Verify that user's BillingAddress.Address1 is not changed and is still "first street". | |
416 | DB.First(&freshUser.BillingAddress, freshUser.BillingAddressID) | |
417 | if freshUser.BillingAddress.Address1 != address1 { | |
418 | t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1) | |
419 | } | |
420 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "bytes" | |
4 | "strings" | |
5 | "sync" | |
6 | ) | |
7 | ||
8 | // Copied from golint | |
9 | var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} | |
10 | var commonInitialismsReplacer *strings.Replacer | |
11 | ||
12 | func init() { | |
13 | var commonInitialismsForReplacer []string | |
14 | for _, initialism := range commonInitialisms { | |
15 | commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) | |
16 | } | |
17 | commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) | |
18 | } | |
19 | ||
20 | type safeMap struct { | |
21 | m map[string]string | |
22 | l *sync.RWMutex | |
23 | } | |
24 | ||
25 | func (s *safeMap) Set(key string, value string) { | |
26 | s.l.Lock() | |
27 | defer s.l.Unlock() | |
28 | s.m[key] = value | |
29 | } | |
30 | ||
31 | func (s *safeMap) Get(key string) string { | |
32 | s.l.RLock() | |
33 | defer s.l.RUnlock() | |
34 | return s.m[key] | |
35 | } | |
36 | ||
37 | func newSafeMap() *safeMap { | |
38 | return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} | |
39 | } | |
40 | ||
41 | var smap = newSafeMap() | |
42 | ||
43 | func ToDBName(name string) string { | |
44 | if v := smap.Get(name); v != "" { | |
45 | return v | |
46 | } | |
47 | ||
48 | value := commonInitialismsReplacer.Replace(name) | |
49 | buf := bytes.NewBufferString("") | |
50 | for i, v := range value { | |
51 | if i > 0 && v >= 'A' && v <= 'Z' { | |
52 | buf.WriteRune('_') | |
53 | } | |
54 | buf.WriteRune(v) | |
55 | } | |
56 | ||
57 | s := strings.ToLower(buf.String()) | |
58 | smap.Set(name, s) | |
59 | return s | |
60 | } | |
61 | ||
62 | type expr struct { | |
63 | expr string | |
64 | args []interface{} | |
65 | } | |
66 | ||
67 | func Expr(expression string, args ...interface{}) *expr { | |
68 | return &expr{expr: expression, args: args} | |
69 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "reflect" | |
5 | "regexp" | |
6 | "runtime" | |
7 | "strings" | |
8 | ) | |
9 | ||
10 | func fileWithLineNum() string { | |
11 | for i := 2; i < 15; i++ { | |
12 | _, file, line, ok := runtime.Caller(i) | |
13 | if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { | |
14 | return fmt.Sprintf("%v:%v", file, line) | |
15 | } | |
16 | } | |
17 | return "" | |
18 | } | |
19 | ||
20 | func isBlank(value reflect.Value) bool { | |
21 | return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) | |
22 | } | |
23 | ||
24 | func toSearchableMap(attrs ...interface{}) (result interface{}) { | |
25 | if len(attrs) > 1 { | |
26 | if str, ok := attrs[0].(string); ok { | |
27 | result = map[string]interface{}{str: attrs[1]} | |
28 | } | |
29 | } else if len(attrs) == 1 { | |
30 | if attr, ok := attrs[0].(map[string]interface{}); ok { | |
31 | result = attr | |
32 | } | |
33 | ||
34 | if attr, ok := attrs[0].(interface{}); ok { | |
35 | result = attr | |
36 | } | |
37 | } | |
38 | return | |
39 | } | |
40 | ||
41 | func convertInterfaceToMap(values interface{}) map[string]interface{} { | |
42 | attrs := map[string]interface{}{} | |
43 | ||
44 | switch value := values.(type) { | |
45 | case map[string]interface{}: | |
46 | for k, v := range value { | |
47 | attrs[ToDBName(k)] = v | |
48 | } | |
49 | case []interface{}: | |
50 | for _, v := range value { | |
51 | for key, value := range convertInterfaceToMap(v) { | |
52 | attrs[key] = value | |
53 | } | |
54 | } | |
55 | case interface{}: | |
56 | reflectValue := reflect.ValueOf(values) | |
57 | ||
58 | switch reflectValue.Kind() { | |
59 | case reflect.Map: | |
60 | for _, key := range reflectValue.MapKeys() { | |
61 | attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() | |
62 | } | |
63 | default: | |
64 | scope := Scope{Value: values} | |
65 | for _, field := range scope.Fields() { | |
66 | if !field.IsBlank && !field.IsIgnored { | |
67 | attrs[field.DBName] = field.Field.Interface() | |
68 | } | |
69 | } | |
70 | } | |
71 | } | |
72 | return attrs | |
73 | } | |
74 | ||
75 | func toString(str interface{}) string { | |
76 | if values, ok := str.([]interface{}); ok { | |
77 | var results []string | |
78 | for _, value := range values { | |
79 | results = append(results, toString(value)) | |
80 | } | |
81 | return strings.Join(results, "_") | |
82 | } else if bytes, ok := str.([]byte); ok { | |
83 | return string(bytes) | |
84 | } else { | |
85 | return fmt.Sprintf("%v", str) | |
86 | } | |
87 | } | |
88 | ||
89 | func strInSlice(a string, list []string) bool { | |
90 | for _, b := range list { | |
91 | if b == a { | |
92 | return true | |
93 | } | |
94 | } | |
95 | return false | |
96 | } |