New upstream version 1.0+git20180218.58e3472
Michael Stapelberg
6 years ago
0 | Your issue may already be reported! Please search on the [issue track](https://github.com/jinzhu/gorm/issues) before creating one. | |
1 | ||
2 | ### What version of Go are you using (`go version`)? | |
3 | ||
4 | ||
5 | ### Which database and its version are you using? | |
6 | ||
7 | ||
8 | ### Please provide a complete runnable program to reproduce your issue. **IMPORTANT** | |
9 | ||
10 | Need to runnable with [GORM's docker compose config](https://github.com/jinzhu/gorm/blob/master/docker-compose.yml) or please provides your config. | |
11 | ||
12 | ```go | |
13 | package main | |
14 | ||
15 | import ( | |
16 | "github.com/jinzhu/gorm" | |
17 | _ "github.com/jinzhu/gorm/dialects/mssql" | |
18 | _ "github.com/jinzhu/gorm/dialects/mysql" | |
19 | _ "github.com/jinzhu/gorm/dialects/postgres" | |
20 | _ "github.com/jinzhu/gorm/dialects/sqlite" | |
21 | ) | |
22 | ||
23 | var db *gorm.DB | |
24 | ||
25 | func init() { | |
26 | var err error | |
27 | db, err = gorm.Open("sqlite3", "test.db") | |
28 | // db, err = gorm.Open("postgres", "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable") | |
29 | // db, err = gorm.Open("mysql", "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True") | |
30 | // db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm") | |
31 | if err != nil { | |
32 | panic(err) | |
33 | } | |
34 | db.LogMode(true) | |
35 | } | |
36 | ||
37 | func main() { | |
38 | if /* failure condition */ { | |
39 | fmt.Println("failed") | |
40 | } else { | |
41 | fmt.Println("success") | |
42 | } | |
43 | } | |
44 | ``` |
0 | Make sure these boxes checked before submitting your pull request. | |
1 | ||
2 | - [] Do only one thing | |
3 | - [] No API-breaking changes | |
4 | - [] New code/logic commented & tested | |
5 | ||
6 | For significant changes like big bug fixes, new features, please open an issue to make an agreement on an implementation design/plan first before starting it. | |
7 | ||
8 | ### What did this pull request do? |
0 | 0 | # GORM |
1 | ||
2 | [![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) | |
3 | 1 | |
4 | 2 | The fantastic ORM library for Golang, aims to be developer friendly. |
5 | 3 | |
6 | [![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921) | |
4 | [![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm) | |
5 | [![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b) | |
6 | [![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) | |
7 | [![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm) | |
8 | [![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm) | |
9 | [![MIT license](http://img.shields.io/badge/license-MIT-brightgreen.svg)](http://opensource.org/licenses/MIT) | |
10 | [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) | |
7 | 11 | |
8 | 12 | ## Overview |
9 | 13 | |
10 | 14 | * Full-Featured ORM (almost) |
11 | * Chainable API | |
12 | * Auto Migrations | |
13 | * Relations (Has One, Has Many, Belongs To, Many To Many, [Polymorphism](#polymorphism)) | |
14 | * Callbacks (Before/After Create/Save/Update/Delete/Find) | |
15 | * Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism) | |
16 | * Hooks (Before/After Create/Save/Update/Delete/Find) | |
15 | 17 | * Preloading (eager loading) |
16 | 18 | * Transactions |
17 | * Embed Anonymous Struct | |
18 | * Soft Deletes | |
19 | * Customizable Logger | |
20 | * Iteration Support via [Rows](#row--rows) | |
19 | * Composite Primary Key | |
20 | * SQL Builder | |
21 | * Auto Migrations | |
22 | * Logger | |
23 | * Extendable, write Plugins based on GORM callbacks | |
21 | 24 | * Every feature comes with tests |
22 | 25 | * Developer Friendly |
23 | 26 | |
24 | # Getting Started | |
27 | ## Getting Started | |
25 | 28 | |
26 | ## Install | |
29 | * GORM Guides [http://gorm.io](http://gorm.io) | |
27 | 30 | |
28 | ``` | |
29 | go get -u github.com/jinzhu/gorm | |
30 | ``` | |
31 | ## Contributing | |
31 | 32 | |
32 | ## Documentation | |
33 | ||
34 | [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) | |
35 | ||
36 | `go doc` format documentation for this project can be viewed online without | |
37 | installing the package by using the GoDoc page at: | |
38 | http://godoc.org/github.com/jinzhu/gorm | |
39 | ||
40 | ## Table of Contents | |
41 | ||
42 | - [Define Models (Structs)](#define-models-structs) | |
43 | - [Conventions](#conventions) | |
44 | - [Initialize Database](#initialize-database) | |
45 | - [Migration](#migration) | |
46 | - [Basic CRUD](#basic-crud) | |
47 | - [Create](#create-record) | |
48 | - [Query](#query) | |
49 | - [Query With Where (Plain SQL)](#query-with-where-plain-sql) | |
50 | - [Query With Where (Struct & Map)](#query-with-where-struct--map) | |
51 | - [Query With Not](#query-with-not) | |
52 | - [Query With Inline Condition](#query-with-inline-condition) | |
53 | - [Query With Or](#query-with-or) | |
54 | - [Query Chains](#query-chains) | |
55 | - [Preloading (Eager loading)](#preloading-eager-loading) | |
56 | - [Update](#update) | |
57 | - [Update Without Callbacks](#update-without-callbacks) | |
58 | - [Batch Updates](#batch-updates) | |
59 | - [Update with SQL Expression](#update-with-sql-expression) | |
60 | - [Delete](#delete) | |
61 | - [Batch Delete](#batch-delete) | |
62 | - [Soft Delete](#soft-delete) | |
63 | - [Associations](#associations) | |
64 | - [Has One](#has-one) | |
65 | - [Belongs To](#belongs-to) | |
66 | - [Has Many](#has-many) | |
67 | - [Many To Many](#many-to-many) | |
68 | - [Polymorphism](#polymorphism) | |
69 | - [Advanced Usage](#advanced-usage) | |
70 | - [FirstOrInit](#firstorinit) | |
71 | - [FirstOrCreate](#firstorcreate) | |
72 | - [Select](#select) | |
73 | - [Order](#order) | |
74 | - [Limit](#limit) | |
75 | - [Offset](#offset) | |
76 | - [Count](#count) | |
77 | - [Pluck](#pluck) | |
78 | - [Raw SQL](#raw-sql) | |
79 | - [Row & Rows](#row--rows) | |
80 | - [Scan](#scan) | |
81 | - [Group & Having](#group--having) | |
82 | - [Joins](#joins) | |
83 | - [Transactions](#transactions) | |
84 | - [Scopes](#scopes) | |
85 | - [Callbacks](#callbacks) | |
86 | - [Specifying The Table Name](#specifying-the-table-name) | |
87 | - [Error Handling](#error-handling) | |
88 | - [Logger](#logger) | |
89 | - [Existing Schema](#existing-schema) | |
90 | - [Composite Primary Key](#composite-primary-key) | |
91 | - [Database Indexes & Foreign Key](#database-indexes--foreign-key) | |
92 | - [Default values](#default-values) | |
93 | - [More examples with query chain](#more-examples-with-query-chain) | |
94 | ||
95 | ## Define Models (Structs) | |
96 | ||
97 | ```go | |
98 | type User struct { | |
99 | ID int | |
100 | Birthday time.Time | |
101 | Age int | |
102 | Name string `sql:"size:255"` // Default size for string is 255, you could reset it with this tag | |
103 | Num int `sql:"AUTO_INCREMENT"` | |
104 | CreatedAt time.Time | |
105 | UpdatedAt time.Time | |
106 | DeletedAt *time.Time | |
107 | ||
108 | Emails []Email // One-To-Many relationship (has many) | |
109 | BillingAddress Address // One-To-One relationship (has one) | |
110 | BillingAddressID sql.NullInt64 // Foreign key of BillingAddress | |
111 | ShippingAddress Address // One-To-One relationship (has one) | |
112 | ShippingAddressID int // Foreign key of ShippingAddress | |
113 | IgnoreMe int `sql:"-"` // Ignore this field | |
114 | Languages []Language `gorm:"many2many:user_languages;"` // Many-To-Many relationship, 'user_languages' is join table | |
115 | } | |
116 | ||
117 | type Email struct { | |
118 | ID int | |
119 | UserID int `sql:"index"` // Foreign key (belongs to), tag `index` will create index for this field when using AutoMigrate | |
120 | Email string `sql:"type:varchar(100);unique_index"` // Set field's sql type, tag `unique_index` will create unique index | |
121 | Subscribed bool | |
122 | } | |
123 | ||
124 | type Address struct { | |
125 | ID int | |
126 | Address1 string `sql:"not null;unique"` // Set field as not nullable and unique | |
127 | Address2 string `sql:"type:varchar(100);unique"` | |
128 | Post sql.NullString `sql:"not null"` | |
129 | } | |
130 | ||
131 | type Language struct { | |
132 | ID int | |
133 | Name string `sql:"index:idx_name_code"` // Create index with name, and will create combined index if find other fields defined same name | |
134 | Code string `sql:"index:idx_name_code"` // `unique_index` also works | |
135 | } | |
136 | ``` | |
137 | ||
138 | ## Conventions | |
139 | ||
140 | * Table name is the plural of struct name's snake case, you can disable pluralization with `db.SingularTable(true)`, or [Specifying The Table Name For A Struct Permanently With TableName](#specifying-the-table-name-for-a-struct-permanently-with-tablename) | |
141 | ||
142 | ```go | |
143 | type User struct{} // struct User's database table name is "users" by default, will be "user" if you disabled pluralisation | |
144 | ``` | |
145 | ||
146 | * Column name is the snake case of field's name | |
147 | * Use `ID` field as primary key | |
148 | * Use `CreatedAt` to store record's created time if field exists | |
149 | * Use `UpdatedAt` to store record's updated time if field exists | |
150 | * Use `DeletedAt` to store record's deleted time if field exists [Soft Delete](#soft-delete) | |
151 | * Gorm provide a default model struct, you could embed it in your struct | |
152 | ||
153 | ```go | |
154 | type Model struct { | |
155 | ID uint `gorm:"primary_key"` | |
156 | CreatedAt time.Time | |
157 | UpdatedAt time.Time | |
158 | DeletedAt *time.Time | |
159 | } | |
160 | ||
161 | type User struct { | |
162 | gorm.Model | |
163 | Name string | |
164 | } | |
165 | ``` | |
166 | ||
167 | ## Initialize Database | |
168 | ||
169 | ```go | |
170 | import ( | |
171 | "github.com/jinzhu/gorm" | |
172 | _ "github.com/lib/pq" | |
173 | _ "github.com/go-sql-driver/mysql" | |
174 | _ "github.com/mattn/go-sqlite3" | |
175 | ) | |
176 | ||
177 | db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") | |
178 | // db, err := gorm.Open("foundation", "dbname=gorm") // FoundationDB. | |
179 | // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") | |
180 | // db, err := gorm.Open("sqlite3", "/tmp/gorm.db") | |
181 | ||
182 | // You can also use an existing database connection handle | |
183 | // dbSql, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable") | |
184 | // db, _ := gorm.Open("postgres", dbSql) | |
185 | ||
186 | // Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB) | |
187 | db.DB() | |
188 | ||
189 | // Then you could invoke `*sql.DB`'s functions with it | |
190 | db.DB().Ping() | |
191 | db.DB().SetMaxIdleConns(10) | |
192 | db.DB().SetMaxOpenConns(100) | |
193 | ||
194 | // Disable table name's pluralization | |
195 | db.SingularTable(true) | |
196 | ``` | |
197 | ||
198 | ## Migration | |
199 | ||
200 | ```go | |
201 | // Create table | |
202 | db.CreateTable(&User{}) | |
203 | db.Set("gorm:table_options", "ENGINE=InnoDB").CreateTable(&User{}) | |
204 | ||
205 | // Drop table | |
206 | db.DropTable(&User{}) | |
207 | ||
208 | // ModifyColumn | |
209 | db.Model(&User{}).ModifyColumn("description", "text") | |
210 | ||
211 | // DropColumn | |
212 | db.Model(&User{}).DropColumn("description") | |
213 | ||
214 | // Automating Migration | |
215 | db.AutoMigrate(&User{}) | |
216 | db.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(&User{}) | |
217 | db.AutoMigrate(&User{}, &Product{}, &Order{}) | |
218 | // Feel free to change your struct, AutoMigrate will keep your database up-to-date. | |
219 | // AutoMigrate will ONLY add *new columns* and *new indexes*, | |
220 | // WON'T update current column's type or delete unused columns, to protect your data. | |
221 | // If the table is not existing, AutoMigrate will create the table automatically. | |
222 | ``` | |
223 | ||
224 | # Basic CRUD | |
225 | ||
226 | ## Create Record | |
227 | ||
228 | ```go | |
229 | user := User{Name: "Jinzhu", Age: 18, Birthday: time.Now()} | |
230 | ||
231 | db.NewRecord(user) // => returns `true` if primary key is blank | |
232 | ||
233 | db.Create(&user) | |
234 | ||
235 | db.NewRecord(user) // => return `false` after `user` created | |
236 | ||
237 | // Associations will be inserted automatically when save the record | |
238 | user := User{ | |
239 | Name: "jinzhu", | |
240 | BillingAddress: Address{Address1: "Billing Address - Address 1"}, | |
241 | ShippingAddress: Address{Address1: "Shipping Address - Address 1"}, | |
242 | Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}}, | |
243 | Languages: []Language{{Name: "ZH"}, {Name: "EN"}}, | |
244 | } | |
245 | ||
246 | db.Create(&user) | |
247 | //// BEGIN TRANSACTION; | |
248 | //// INSERT INTO "addresses" (address1) VALUES ("Billing Address - Address 1"); | |
249 | //// INSERT INTO "addresses" (address1) VALUES ("Shipping Address - Address 1"); | |
250 | //// INSERT INTO "users" (name,billing_address_id,shipping_address_id) VALUES ("jinzhu", 1, 2); | |
251 | //// INSERT INTO "emails" (user_id,email) VALUES (111, "jinzhu@example.com"); | |
252 | //// INSERT INTO "emails" (user_id,email) VALUES (111, "jinzhu-2@example.com"); | |
253 | //// INSERT INTO "languages" ("name") VALUES ('ZH'); | |
254 | //// INSERT INTO user_languages ("user_id","language_id") VALUES (111, 1); | |
255 | //// INSERT INTO "languages" ("name") VALUES ('EN'); | |
256 | //// INSERT INTO user_languages ("user_id","language_id") VALUES (111, 2); | |
257 | //// COMMIT; | |
258 | ``` | |
259 | ||
260 | Refer [Associations](#associations) for more details | |
261 | ||
262 | ## Query | |
263 | ||
264 | ```go | |
265 | // Get the first record | |
266 | db.First(&user) | |
267 | //// SELECT * FROM users ORDER BY id LIMIT 1; | |
268 | ||
269 | // Get the last record | |
270 | db.Last(&user) | |
271 | //// SELECT * FROM users ORDER BY id DESC LIMIT 1; | |
272 | ||
273 | // Get all records | |
274 | db.Find(&users) | |
275 | //// SELECT * FROM users; | |
276 | ||
277 | // Get record with primary key | |
278 | db.First(&user, 10) | |
279 | //// SELECT * FROM users WHERE id = 10; | |
280 | ``` | |
281 | ||
282 | ### Query With Where (Plain SQL) | |
283 | ||
284 | ```go | |
285 | // Get the first matched record | |
286 | db.Where("name = ?", "jinzhu").First(&user) | |
287 | //// SELECT * FROM users WHERE name = 'jinzhu' limit 1; | |
288 | ||
289 | // Get all matched records | |
290 | db.Where("name = ?", "jinzhu").Find(&users) | |
291 | //// SELECT * FROM users WHERE name = 'jinzhu'; | |
292 | ||
293 | db.Where("name <> ?", "jinzhu").Find(&users) | |
294 | ||
295 | // IN | |
296 | db.Where("name in (?)", []string{"jinzhu", "jinzhu 2"}).Find(&users) | |
297 | ||
298 | // LIKE | |
299 | db.Where("name LIKE ?", "%jin%").Find(&users) | |
300 | ||
301 | // AND | |
302 | db.Where("name = ? and age >= ?", "jinzhu", "22").Find(&users) | |
303 | ||
304 | // Time | |
305 | db.Where("updated_at > ?", lastWeek).Find(&users) | |
306 | ||
307 | db.Where("created_at BETWEEN ? AND ?", lastWeek, today).Find(&users) | |
308 | ``` | |
309 | ||
310 | ### Query With Where (Struct & Map) | |
311 | ||
312 | ```go | |
313 | // Struct | |
314 | db.Where(&User{Name: "jinzhu", Age: 20}).First(&user) | |
315 | //// SELECT * FROM users WHERE name = "jinzhu" AND age = 20 LIMIT 1; | |
316 | ||
317 | // Map | |
318 | db.Where(map[string]interface{}{"name": "jinzhu", "age": 20}).Find(&users) | |
319 | //// SELECT * FROM users WHERE name = "jinzhu" AND age = 20; | |
320 | ||
321 | // Slice of primary keys | |
322 | db.Where([]int64{20, 21, 22}).Find(&users) | |
323 | //// SELECT * FROM users WHERE id IN (20, 21, 22); | |
324 | ``` | |
325 | ||
326 | ### Query With Not | |
327 | ||
328 | ```go | |
329 | db.Not("name", "jinzhu").First(&user) | |
330 | //// SELECT * FROM users WHERE name <> "jinzhu" LIMIT 1; | |
331 | ||
332 | // Not In | |
333 | db.Not("name", []string{"jinzhu", "jinzhu 2"}).Find(&users) | |
334 | //// SELECT * FROM users WHERE name NOT IN ("jinzhu", "jinzhu 2"); | |
335 | ||
336 | // Not In slice of primary keys | |
337 | db.Not([]int64{1,2,3}).First(&user) | |
338 | //// SELECT * FROM users WHERE id NOT IN (1,2,3); | |
339 | ||
340 | db.Not([]int64{}).First(&user) | |
341 | //// SELECT * FROM users; | |
342 | ||
343 | // Plain SQL | |
344 | db.Not("name = ?", "jinzhu").First(&user) | |
345 | //// SELECT * FROM users WHERE NOT(name = "jinzhu"); | |
346 | ||
347 | // Struct | |
348 | db.Not(User{Name: "jinzhu"}).First(&user) | |
349 | //// SELECT * FROM users WHERE name <> "jinzhu"; | |
350 | ``` | |
351 | ||
352 | ### Query With Inline Condition | |
353 | ||
354 | ```go | |
355 | // Get by primary key | |
356 | db.First(&user, 23) | |
357 | //// SELECT * FROM users WHERE id = 23 LIMIT 1; | |
358 | ||
359 | // Plain SQL | |
360 | db.Find(&user, "name = ?", "jinzhu") | |
361 | //// SELECT * FROM users WHERE name = "jinzhu"; | |
362 | ||
363 | db.Find(&users, "name <> ? AND age > ?", "jinzhu", 20) | |
364 | //// SELECT * FROM users WHERE name <> "jinzhu" AND age > 20; | |
365 | ||
366 | // Struct | |
367 | db.Find(&users, User{Age: 20}) | |
368 | //// SELECT * FROM users WHERE age = 20; | |
369 | ||
370 | // Map | |
371 | db.Find(&users, map[string]interface{}{"age": 20}) | |
372 | //// SELECT * FROM users WHERE age = 20; | |
373 | ``` | |
374 | ||
375 | ### Query With Or | |
376 | ||
377 | ```go | |
378 | db.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&users) | |
379 | //// SELECT * FROM users WHERE role = 'admin' OR role = 'super_admin'; | |
380 | ||
381 | // Struct | |
382 | db.Where("name = 'jinzhu'").Or(User{Name: "jinzhu 2"}).Find(&users) | |
383 | //// SELECT * FROM users WHERE name = 'jinzhu' OR name = 'jinzhu 2'; | |
384 | ||
385 | // Map | |
386 | db.Where("name = 'jinzhu'").Or(map[string]interface{}{"name": "jinzhu 2"}).Find(&users) | |
387 | ``` | |
388 | ||
389 | ### Query Chains | |
390 | ||
391 | Gorm has a chainable API, you could use it like this | |
392 | ||
393 | ```go | |
394 | db.Where("name <> ?","jinzhu").Where("age >= ? and role <> ?",20,"admin").Find(&users) | |
395 | //// SELECT * FROM users WHERE name <> 'jinzhu' AND age >= 20 AND role <> 'admin'; | |
396 | ||
397 | db.Where("role = ?", "admin").Or("role = ?", "super_admin").Not("name = ?", "jinzhu").Find(&users) | |
398 | ``` | |
399 | ||
400 | ### Preloading (Eager loading) | |
401 | ||
402 | ```go | |
403 | db.Preload("Orders").Find(&users) | |
404 | //// SELECT * FROM users; | |
405 | //// SELECT * FROM orders WHERE user_id IN (1,2,3,4); | |
406 | ||
407 | db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) | |
408 | //// SELECT * FROM users; | |
409 | //// SELECT * FROM orders WHERE user_id IN (1,2,3,4) AND state NOT IN ('cancelled'); | |
410 | ||
411 | db.Where("state = ?", "active").Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) | |
412 | //// SELECT * FROM users WHERE state = 'active'; | |
413 | //// SELECT * FROM orders WHERE user_id IN (1,2) AND state NOT IN ('cancelled'); | |
414 | ||
415 | db.Preload("Orders").Preload("Profile").Preload("Role").Find(&users) | |
416 | //// SELECT * FROM users; | |
417 | //// SELECT * FROM orders WHERE user_id IN (1,2,3,4); // has many | |
418 | //// SELECT * FROM profiles WHERE user_id IN (1,2,3,4); // has one | |
419 | //// SELECT * FROM roles WHERE id IN (4,5,6); // belongs to | |
420 | ``` | |
421 | ||
422 | #### Nested Preloading | |
423 | ||
424 | ```go | |
425 | db.Preload("Orders.OrderItems").Find(&users) | |
426 | db.Preload("Orders", "state = ?", "paid").Preload("Orders.OrderItems").Find(&users) | |
427 | ``` | |
428 | ||
429 | ## Update | |
430 | ||
431 | ```go | |
432 | // Update an existing struct | |
433 | db.First(&user) | |
434 | user.Name = "jinzhu 2" | |
435 | user.Age = 100 | |
436 | db.Save(&user) | |
437 | //// UPDATE users SET name='jinzhu 2', age=100, updated_at = '2013-11-17 21:34:10' WHERE id=111; | |
438 | ||
439 | db.Where("active = ?", true).Save(&user) | |
440 | //// UPDATE users SET name='jinzhu 2', age=100, updated_at = '2013-11-17 21:34:10' WHERE id=111 AND active = true; | |
441 | ||
442 | // Update an attribute if it is changed | |
443 | db.Model(&user).Update("name", "hello") | |
444 | //// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111; | |
445 | ||
446 | db.Model(&user).Where("active = ?", true).Update("name", "hello") | |
447 | //// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111 AND active = true; | |
448 | ||
449 | db.First(&user, 111).Update("name", "hello") | |
450 | //// SELECT * FROM users LIMIT 1; | |
451 | //// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111; | |
452 | ||
453 | // Update multiple attributes if they are changed | |
454 | db.Model(&user).Updates(map[string]interface{}{"name": "hello", "age": 18, "actived": false}) | |
455 | ||
456 | // Update multiple attributes if they are changed (update with struct only works with none zero values) | |
457 | db.Model(&user).Updates(User{Name: "hello", Age: 18}) | |
458 | //// UPDATE users SET name='hello', age=18, updated_at = '2013-11-17 21:34:10' WHERE id = 111; | |
459 | ``` | |
460 | ||
461 | ### Update Without Callbacks | |
462 | ||
463 | By default, update will call BeforeUpdate, AfterUpdate callbacks, if you want to update w/o callbacks and w/o saving associations: | |
464 | ||
465 | ```go | |
466 | db.Model(&user).UpdateColumn("name", "hello") | |
467 | //// UPDATE users SET name='hello' WHERE id = 111; | |
468 | ||
469 | // Update with struct only works with none zero values, or use map[string]interface{} | |
470 | db.Model(&user).UpdateColumns(User{Name: "hello", Age: 18}) | |
471 | //// UPDATE users SET name='hello', age=18 WHERE id = 111; | |
472 | ``` | |
473 | ||
474 | ### Batch Updates | |
475 | ||
476 | ```go | |
477 | db.Table("users").Where("id = ?", 10).Updates(map[string]interface{}{"name": "hello", "age": 18}) | |
478 | //// UPDATE users SET name='hello', age=18 WHERE id = 10; | |
479 | ||
480 | // Update with struct only works with none zero values, or use map[string]interface{} | |
481 | db.Model(User{}).Updates(User{Name: "hello", Age: 18}) | |
482 | //// UPDATE users SET name='hello', age=18; | |
483 | ||
484 | // Callbacks won't run when do batch updates | |
485 | ||
486 | // Use `RowsAffected` to get the count of affected records | |
487 | db.Model(User{}).Updates(User{Name: "hello", Age: 18}).RowsAffected | |
488 | ``` | |
489 | ||
490 | ### Update with SQL Expression | |
491 | ||
492 | ```go | |
493 | DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) | |
494 | //// UPDATE "products" SET "code" = 'L1212', "price" = price * '2' + '100', "updated_at" = '2013-11-17 21:34:10' WHERE "id" = '2'; | |
495 | ||
496 | DB.Model(&product).Updates(map[string]interface{}{"price": gorm.Expr("price * ? + ?", 2, 100)}) | |
497 | //// UPDATE "products" SET "code" = 'L1212', "price" = price * '2' + '100', "updated_at" = '2013-11-17 21:34:10' WHERE "id" = '2'; | |
498 | ||
499 | DB.Model(&product).UpdateColumn("quantity", gorm.Expr("quantity - ?", 1)) | |
500 | //// UPDATE "products" SET "quantity" = quantity - 1 WHERE "id" = '2'; | |
501 | ||
502 | DB.Model(&product).Where("quantity > 1").UpdateColumn("quantity", gorm.Expr("quantity - ?", 1)) | |
503 | //// UPDATE "products" SET "quantity" = quantity - 1 WHERE "id" = '2' AND quantity > 1; | |
504 | ``` | |
505 | ||
506 | ## Delete | |
507 | ||
508 | ```go | |
509 | // Delete an existing record | |
510 | db.Delete(&email) | |
511 | //// DELETE from emails where id=10; | |
512 | ``` | |
513 | ||
514 | ### Batch Delete | |
515 | ||
516 | ```go | |
517 | db.Where("email LIKE ?", "%jinzhu%").Delete(Email{}) | |
518 | //// DELETE from emails where email LIKE "%jinhu%"; | |
519 | ``` | |
520 | ||
521 | ### Soft Delete | |
522 | ||
523 | If struct has `DeletedAt` field, it will get soft delete ability automatically! | |
524 | Then it won't be deleted from database permanently when call `Delete`. | |
525 | ||
526 | ```go | |
527 | db.Delete(&user) | |
528 | //// UPDATE users SET deleted_at="2013-10-29 10:23" WHERE id = 111; | |
529 | ||
530 | // Batch Delete | |
531 | db.Where("age = ?", 20).Delete(&User{}) | |
532 | //// UPDATE users SET deleted_at="2013-10-29 10:23" WHERE age = 20; | |
533 | ||
534 | // Soft deleted records will be ignored when query them | |
535 | db.Where("age = 20").Find(&user) | |
536 | //// SELECT * FROM users WHERE age = 20 AND (deleted_at IS NULL OR deleted_at <= '0001-01-02'); | |
537 | ||
538 | // Find soft deleted records with Unscoped | |
539 | db.Unscoped().Where("age = 20").Find(&users) | |
540 | //// SELECT * FROM users WHERE age = 20; | |
541 | ||
542 | // Delete record permanently with Unscoped | |
543 | db.Unscoped().Delete(&order) | |
544 | //// DELETE FROM orders WHERE id=10; | |
545 | ``` | |
546 | ||
547 | ## Associations | |
548 | ||
549 | ### Has One | |
550 | ||
551 | ```go | |
552 | // User has one address | |
553 | db.Model(&user).Related(&address) | |
554 | //// SELECT * FROM addresses WHERE id = 123; // 123 is user's foreign key AddressId | |
555 | ||
556 | // Specify the foreign key | |
557 | db.Model(&user).Related(&address1, "BillingAddressId") | |
558 | //// SELECT * FROM addresses WHERE id = 123; // 123 is user's foreign key BillingAddressId | |
559 | ``` | |
560 | ||
561 | ### Belongs To | |
562 | ||
563 | ```go | |
564 | // Email belongs to user | |
565 | db.Model(&email).Related(&user) | |
566 | //// SELECT * FROM users WHERE id = 111; // 111 is email's foreign key UserId | |
567 | ||
568 | // Specify the foreign key | |
569 | db.Model(&email).Related(&user, "ProfileId") | |
570 | //// SELECT * FROM users WHERE id = 111; // 111 is email's foreign key ProfileId | |
571 | ``` | |
572 | ||
573 | ### Has Many | |
574 | ||
575 | ```go | |
576 | // User has many emails | |
577 | db.Model(&user).Related(&emails) | |
578 | //// SELECT * FROM emails WHERE user_id = 111; | |
579 | // user_id is the foreign key, 111 is user's primary key's value | |
580 | ||
581 | // Specify the foreign key | |
582 | db.Model(&user).Related(&emails, "ProfileId") | |
583 | //// SELECT * FROM emails WHERE profile_id = 111; | |
584 | // profile_id is the foreign key, 111 is user's primary key's value | |
585 | ``` | |
586 | ||
587 | ### Many To Many | |
588 | ||
589 | ```go | |
590 | // User has many languages and belongs to many languages | |
591 | db.Model(&user).Related(&languages, "Languages") | |
592 | //// SELECT * FROM "languages" INNER JOIN "user_languages" ON "user_languages"."language_id" = "languages"."id" WHERE "user_languages"."user_id" = 111 | |
593 | // `Languages` is user's column name, this column's tag defined join table like this `gorm:"many2many:user_languages;"` | |
594 | ``` | |
595 | ||
596 | There is also a mode used to handle many to many relations easily | |
597 | ||
598 | ```go | |
599 | // Query | |
600 | db.Model(&user).Association("Languages").Find(&languages) | |
601 | // same as `db.Model(&user).Related(&languages, "Languages")` | |
602 | ||
603 | db.Where("name = ?", "ZH").First(&languageZH) | |
604 | db.Where("name = ?", "EN").First(&languageEN) | |
605 | ||
606 | // Append | |
607 | db.Model(&user).Association("Languages").Append([]Language{languageZH, languageEN}) | |
608 | db.Model(&user).Association("Languages").Append([]Language{{Name: "DE"}}) | |
609 | db.Model(&user).Association("Languages").Append(Language{Name: "DE"}) | |
610 | ||
611 | // Delete | |
612 | db.Model(&user).Association("Languages").Delete([]Language{languageZH, languageEN}) | |
613 | db.Model(&user).Association("Languages").Delete(languageZH, languageEN) | |
614 | ||
615 | // Replace | |
616 | db.Model(&user).Association("Languages").Replace([]Language{languageZH, languageEN}) | |
617 | db.Model(&user).Association("Languages").Replace(Language{Name: "DE"}, languageEN) | |
618 | ||
619 | // Count | |
620 | db.Model(&user).Association("Languages").Count() | |
621 | // Return the count of languages the user has | |
622 | ||
623 | // Clear | |
624 | db.Model(&user).Association("Languages").Clear() | |
625 | // Remove all relations between the user and languages | |
626 | ``` | |
627 | ||
628 | ### Polymorphism | |
629 | ||
630 | Supports polymorphic has-many and has-one associations. | |
631 | ||
632 | ```go | |
633 | type Cat struct { | |
634 | Id int | |
635 | Name string | |
636 | Toy Toy `gorm:"polymorphic:Owner;"` | |
637 | } | |
638 | ||
639 | type Dog struct { | |
640 | Id int | |
641 | Name string | |
642 | Toy Toy `gorm:"polymorphic:Owner;"` | |
643 | } | |
644 | ||
645 | type Toy struct { | |
646 | Id int | |
647 | Name string | |
648 | OwnerId int | |
649 | OwnerType string | |
650 | } | |
651 | ``` | |
652 | Note: polymorphic belongs-to and many-to-many are explicitly NOT supported, and will throw errors. | |
653 | ||
654 | ## Advanced Usage | |
655 | ||
656 | ## FirstOrInit | |
657 | ||
658 | Get the first matched record, or initialize a record with search conditions. | |
659 | ||
660 | ```go | |
661 | // Unfound | |
662 | db.FirstOrInit(&user, User{Name: "non_existing"}) | |
663 | //// user -> User{Name: "non_existing"} | |
664 | ||
665 | // Found | |
666 | db.Where(User{Name: "Jinzhu"}).FirstOrInit(&user) | |
667 | //// user -> User{Id: 111, Name: "Jinzhu", Age: 20} | |
668 | db.FirstOrInit(&user, map[string]interface{}{"name": "jinzhu"}) | |
669 | //// user -> User{Id: 111, Name: "Jinzhu", Age: 20} | |
670 | ``` | |
671 | ||
672 | ### Attrs | |
673 | ||
674 | Ignore some values when searching, but use them to initialize the struct if record is not found. | |
675 | ||
676 | ```go | |
677 | // Unfound | |
678 | db.Where(User{Name: "non_existing"}).Attrs(User{Age: 20}).FirstOrInit(&user) | |
679 | //// SELECT * FROM USERS WHERE name = 'non_existing'; | |
680 | //// user -> User{Name: "non_existing", Age: 20} | |
681 | ||
682 | db.Where(User{Name: "noexisting_user"}).Attrs("age", 20).FirstOrInit(&user) | |
683 | //// SELECT * FROM USERS WHERE name = 'non_existing'; | |
684 | //// user -> User{Name: "non_existing", Age: 20} | |
685 | ||
686 | // Found | |
687 | db.Where(User{Name: "Jinzhu"}).Attrs(User{Age: 30}).FirstOrInit(&user) | |
688 | //// SELECT * FROM USERS WHERE name = jinzhu'; | |
689 | //// user -> User{Id: 111, Name: "Jinzhu", Age: 20} | |
690 | ``` | |
691 | ||
692 | ### Assign | |
693 | ||
694 | Ignore some values when searching, but assign it to the result regardless it is found or not. | |
695 | ||
696 | ```go | |
697 | // Unfound | |
698 | db.Where(User{Name: "non_existing"}).Assign(User{Age: 20}).FirstOrInit(&user) | |
699 | //// user -> User{Name: "non_existing", Age: 20} | |
700 | ||
701 | // Found | |
702 | db.Where(User{Name: "Jinzhu"}).Assign(User{Age: 30}).FirstOrInit(&user) | |
703 | //// SELECT * FROM USERS WHERE name = jinzhu'; | |
704 | //// user -> User{Id: 111, Name: "Jinzhu", Age: 30} | |
705 | ``` | |
706 | ||
707 | ## FirstOrCreate | |
708 | ||
709 | Get the first matched record, or create with search conditions. | |
710 | ||
711 | ```go | |
712 | // Unfound | |
713 | db.FirstOrCreate(&user, User{Name: "non_existing"}) | |
714 | //// INSERT INTO "users" (name) VALUES ("non_existing"); | |
715 | //// user -> User{Id: 112, Name: "non_existing"} | |
716 | ||
717 | // Found | |
718 | db.Where(User{Name: "Jinzhu"}).FirstOrCreate(&user) | |
719 | //// user -> User{Id: 111, Name: "Jinzhu"} | |
720 | ``` | |
721 | ||
722 | ### Attrs | |
723 | ||
724 | Ignore some values when searching, but use them to create the struct if record is not found. like `FirstOrInit` | |
725 | ||
726 | ```go | |
727 | // Unfound | |
728 | db.Where(User{Name: "non_existing"}).Attrs(User{Age: 20}).FirstOrCreate(&user) | |
729 | //// SELECT * FROM users WHERE name = 'non_existing'; | |
730 | //// INSERT INTO "users" (name, age) VALUES ("non_existing", 20); | |
731 | //// user -> User{Id: 112, Name: "non_existing", Age: 20} | |
732 | ||
733 | // Found | |
734 | db.Where(User{Name: "jinzhu"}).Attrs(User{Age: 30}).FirstOrCreate(&user) | |
735 | //// SELECT * FROM users WHERE name = 'jinzhu'; | |
736 | //// user -> User{Id: 111, Name: "jinzhu", Age: 20} | |
737 | ``` | |
738 | ||
739 | ### Assign | |
740 | ||
741 | Ignore some values when searching, but assign it to the record regardless it is found or not, then save back to database. like `FirstOrInit` | |
742 | ||
743 | ```go | |
744 | // Unfound | |
745 | db.Where(User{Name: "non_existing"}).Assign(User{Age: 20}).FirstOrCreate(&user) | |
746 | //// SELECT * FROM users WHERE name = 'non_existing'; | |
747 | //// INSERT INTO "users" (name, age) VALUES ("non_existing", 20); | |
748 | //// user -> User{Id: 112, Name: "non_existing", Age: 20} | |
749 | ||
750 | // Found | |
751 | db.Where(User{Name: "jinzhu"}).Assign(User{Age: 30}).FirstOrCreate(&user) | |
752 | //// SELECT * FROM users WHERE name = 'jinzhu'; | |
753 | //// UPDATE users SET age=30 WHERE id = 111; | |
754 | //// user -> User{Id: 111, Name: "jinzhu", Age: 30} | |
755 | ``` | |
756 | ||
757 | ## Select | |
758 | ||
759 | ```go | |
760 | db.Select("name, age").Find(&users) | |
761 | //// SELECT name, age FROM users; | |
762 | ||
763 | db.Select([]string{"name", "age"}).Find(&users) | |
764 | //// SELECT name, age FROM users; | |
765 | ||
766 | db.Table("users").Select("COALESCE(age,?)", 42).Rows() | |
767 | //// SELECT COALESCE(age,'42') FROM users; | |
768 | ``` | |
769 | ||
770 | ## Order | |
771 | ||
772 | ```go | |
773 | db.Order("age desc, name").Find(&users) | |
774 | //// SELECT * FROM users ORDER BY age desc, name; | |
775 | ||
776 | // Multiple orders | |
777 | db.Order("age desc").Order("name").Find(&users) | |
778 | //// SELECT * FROM users ORDER BY age desc, name; | |
779 | ||
780 | // ReOrder | |
781 | db.Order("age desc").Find(&users1).Order("age", true).Find(&users2) | |
782 | //// SELECT * FROM users ORDER BY age desc; (users1) | |
783 | //// SELECT * FROM users ORDER BY age; (users2) | |
784 | ``` | |
785 | ||
786 | ## Limit | |
787 | ||
788 | ```go | |
789 | db.Limit(3).Find(&users) | |
790 | //// SELECT * FROM users LIMIT 3; | |
791 | ||
792 | // Cancel limit condition with -1 | |
793 | db.Limit(10).Find(&users1).Limit(-1).Find(&users2) | |
794 | //// SELECT * FROM users LIMIT 10; (users1) | |
795 | //// SELECT * FROM users; (users2) | |
796 | ``` | |
797 | ||
798 | ## Offset | |
799 | ||
800 | ```go | |
801 | db.Offset(3).Find(&users) | |
802 | //// SELECT * FROM users OFFSET 3; | |
803 | ||
804 | // Cancel offset condition with -1 | |
805 | db.Offset(10).Find(&users1).Offset(-1).Find(&users2) | |
806 | //// SELECT * FROM users OFFSET 10; (users1) | |
807 | //// SELECT * FROM users; (users2) | |
808 | ``` | |
809 | ||
810 | ## Count | |
811 | ||
812 | ```go | |
813 | db.Where("name = ?", "jinzhu").Or("name = ?", "jinzhu 2").Find(&users).Count(&count) | |
814 | //// SELECT * from USERS WHERE name = 'jinzhu' OR name = 'jinzhu 2'; (users) | |
815 | //// SELECT count(*) FROM users WHERE name = 'jinzhu' OR name = 'jinzhu 2'; (count) | |
816 | ||
817 | db.Model(User{}).Where("name = ?", "jinzhu").Count(&count) | |
818 | //// SELECT count(*) FROM users WHERE name = 'jinzhu'; (count) | |
819 | ||
820 | db.Table("deleted_users").Count(&count) | |
821 | //// SELECT count(*) FROM deleted_users; | |
822 | ``` | |
823 | ||
824 | ## Pluck | |
825 | ||
826 | Get selected attributes as map | |
827 | ||
828 | ```go | |
829 | var ages []int64 | |
830 | db.Find(&users).Pluck("age", &ages) | |
831 | ||
832 | var names []string | |
833 | db.Model(&User{}).Pluck("name", &names) | |
834 | ||
835 | db.Table("deleted_users").Pluck("name", &names) | |
836 | ||
837 | // Requesting more than one column? Do it like this: | |
838 | db.Select("name, age").Find(&users) | |
839 | ``` | |
840 | ||
841 | ## Raw SQL | |
842 | ||
843 | ```go | |
844 | db.Exec("DROP TABLE users;") | |
845 | db.Exec("UPDATE orders SET shipped_at=? WHERE id IN (?)", time.Now, []int64{11,22,33}) | |
846 | ``` | |
847 | ||
848 | ## Row & Rows | |
849 | ||
850 | It is even possible to get query result as `*sql.Row` or `*sql.Rows` | |
851 | ||
852 | ```go | |
853 | row := db.Table("users").Where("name = ?", "jinzhu").Select("name, age").Row() // (*sql.Row) | |
854 | row.Scan(&name, &age) | |
855 | ||
856 | rows, err := db.Model(User{}).Where("name = ?", "jinzhu").Select("name, age, email").Rows() // (*sql.Rows, error) | |
857 | defer rows.Close() | |
858 | for rows.Next() { | |
859 | ... | |
860 | rows.Scan(&name, &age, &email) | |
861 | ... | |
862 | } | |
863 | ||
864 | // Raw SQL | |
865 | rows, err := db.Raw("select name, age, email from users where name = ?", "jinzhu").Rows() // (*sql.Rows, error) | |
866 | defer rows.Close() | |
867 | for rows.Next() { | |
868 | ... | |
869 | rows.Scan(&name, &age, &email) | |
870 | ... | |
871 | } | |
872 | ``` | |
873 | ||
874 | ## Scan | |
875 | ||
876 | Scan results into another struct. | |
877 | ||
878 | ```go | |
879 | type Result struct { | |
880 | Name string | |
881 | Age int | |
882 | } | |
883 | ||
884 | var result Result | |
885 | db.Table("users").Select("name, age").Where("name = ?", 3).Scan(&result) | |
886 | ||
887 | // Raw SQL | |
888 | db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) | |
889 | ``` | |
890 | ||
891 | ## Group & Having | |
892 | ||
893 | ```go | |
894 | rows, err := db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Rows() | |
895 | for rows.Next() { | |
896 | ... | |
897 | } | |
898 | ||
899 | rows, err := db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Having("sum(amount) > ?", 100).Rows() | |
900 | for rows.Next() { | |
901 | ... | |
902 | } | |
903 | ||
904 | type Result struct { | |
905 | Date time.Time | |
906 | Total int64 | |
907 | } | |
908 | db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Having("sum(amount) > ?", 100).Scan(&results) | |
909 | ``` | |
910 | ||
911 | ## Joins | |
912 | ||
913 | ```go | |
914 | rows, err := db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Rows() | |
915 | for rows.Next() { | |
916 | ... | |
917 | } | |
918 | ||
919 | db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Scan(&results) | |
920 | ||
921 | // find a user by email address | |
922 | db.Joins("inner join emails on emails.user_id = users.id").Where("emails.email = ?", "x@example.org").Find(&user) | |
923 | ||
924 | // find all email addresses for a user | |
925 | db.Joins("left join users on users.id = emails.user_id").Where("users.name = ?", "jinzhu").Find(&emails) | |
926 | ``` | |
927 | ||
928 | ## Transactions | |
929 | ||
930 | To perform a set of operations within a transaction, the general flow is as below. | |
931 | The database handle returned from ``` db.Begin() ``` should be used for all operations within the transaction. | |
932 | (Note that all individual save and delete operations are run in a transaction by default.) | |
933 | ||
934 | ```go | |
935 | // begin | |
936 | tx := db.Begin() | |
937 | ||
938 | // do some database operations (use 'tx' from this point, not 'db') | |
939 | tx.Create(...) | |
940 | ... | |
941 | ||
942 | // rollback in case of error | |
943 | tx.Rollback() | |
944 | ||
945 | // Or commit if all is ok | |
946 | tx.Commit() | |
947 | ``` | |
948 | ||
949 | ### A Specific Example | |
950 | ``` | |
951 | func CreateAnimals(db *gorm.DB) err { | |
952 | tx := db.Begin() | |
953 | // Note the use of tx as the database handle once you are within a transaction | |
954 | ||
955 | if err := tx.Create(&Animal{Name: "Giraffe"}).Error; err != nil { | |
956 | tx.Rollback() | |
957 | return err | |
958 | } | |
959 | ||
960 | if err := tx.Create(&Animal{Name: "Lion"}).Error; err != nil { | |
961 | tx.Rollback() | |
962 | return err | |
963 | } | |
964 | ||
965 | tx.Commit() | |
966 | return nil | |
967 | } | |
968 | ``` | |
969 | ||
970 | ## Scopes | |
971 | ||
972 | ```go | |
973 | func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { | |
974 | return db.Where("amount > ?", 1000) | |
975 | } | |
976 | ||
977 | func PaidWithCreditCard(db *gorm.DB) *gorm.DB { | |
978 | return db.Where("pay_mode_sign = ?", "C") | |
979 | } | |
980 | ||
981 | func PaidWithCod(db *gorm.DB) *gorm.DB { | |
982 | return db.Where("pay_mode_sign = ?", "C") | |
983 | } | |
984 | ||
985 | func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { | |
986 | return func (db *gorm.DB) *gorm.DB { | |
987 | return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) | |
988 | } | |
989 | } | |
990 | ||
991 | db.Scopes(AmountGreaterThan1000, PaidWithCreditCard).Find(&orders) | |
992 | // Find all credit card orders and amount greater than 1000 | |
993 | ||
994 | db.Scopes(AmountGreaterThan1000, PaidWithCod).Find(&orders) | |
995 | // Find all COD orders and amount greater than 1000 | |
996 | ||
997 | db.Scopes(OrderStatus([]string{"paid", "shipped"})).Find(&orders) | |
998 | // Find all paid, shipped orders | |
999 | ``` | |
1000 | ||
1001 | ## Callbacks | |
1002 | ||
1003 | Callbacks are methods defined on the pointer of struct. | |
1004 | If any callback returns an error, gorm will stop future operations and rollback all changes. | |
1005 | ||
1006 | Here is the list of all available callbacks: | |
1007 | (listed in the same order in which they will get called during the respective operations) | |
1008 | ||
1009 | ### Creating An Object | |
1010 | ||
1011 | ```go | |
1012 | BeforeSave | |
1013 | BeforeCreate | |
1014 | // save before associations | |
1015 | // save self | |
1016 | // save after associations | |
1017 | AfterCreate | |
1018 | AfterSave | |
1019 | ``` | |
1020 | ### Updating An Object | |
1021 | ||
1022 | ```go | |
1023 | BeforeSave | |
1024 | BeforeUpdate | |
1025 | // save before associations | |
1026 | // save self | |
1027 | // save after associations | |
1028 | AfterUpdate | |
1029 | AfterSave | |
1030 | ``` | |
1031 | ||
1032 | ### Destroying An Object | |
1033 | ||
1034 | ```go | |
1035 | BeforeDelete | |
1036 | // delete self | |
1037 | AfterDelete | |
1038 | ``` | |
1039 | ||
1040 | ### After Find | |
1041 | ||
1042 | ```go | |
1043 | // load data from database | |
1044 | AfterFind | |
1045 | ``` | |
1046 | ||
1047 | ### Example | |
1048 | ||
1049 | ```go | |
1050 | func (u *User) BeforeUpdate() (err error) { | |
1051 | if u.readonly() { | |
1052 | err = errors.New("read only user") | |
1053 | } | |
1054 | return | |
1055 | } | |
1056 | ||
1057 | // Rollback the insertion if user's id greater than 1000 | |
1058 | func (u *User) AfterCreate() (err error) { | |
1059 | if (u.Id > 1000) { | |
1060 | err = errors.New("user id is already greater than 1000") | |
1061 | } | |
1062 | return | |
1063 | } | |
1064 | ``` | |
1065 | ||
1066 | As you know, save/delete operations in gorm are running in a transaction, | |
1067 | This is means if changes made in the transaction is not visiable unless it is commited, | |
1068 | So if you want to use those changes in your callbacks, you need to run SQL in same transaction. | |
1069 | Fortunately, gorm support pass transaction to callbacks as you needed, you could do it like this: | |
1070 | ||
1071 | ```go | |
1072 | func (u *User) AfterCreate(tx *gorm.DB) (err error) { | |
1073 | tx.Model(u).Update("role", "admin") | |
1074 | return | |
1075 | } | |
1076 | ``` | |
1077 | ||
1078 | ## Specifying The Table Name | |
1079 | ||
1080 | ```go | |
1081 | // Create `deleted_users` table with struct User's definition | |
1082 | db.Table("deleted_users").CreateTable(&User{}) | |
1083 | ||
1084 | var deleted_users []User | |
1085 | db.Table("deleted_users").Find(&deleted_users) | |
1086 | //// SELECT * FROM deleted_users; | |
1087 | ||
1088 | db.Table("deleted_users").Where("name = ?", "jinzhu").Delete() | |
1089 | //// DELETE FROM deleted_users WHERE name = 'jinzhu'; | |
1090 | ``` | |
1091 | ||
1092 | ### Specifying The Table Name For A Struct Permanently with TableName | |
1093 | ||
1094 | ```go | |
1095 | type Cart struct { | |
1096 | } | |
1097 | ||
1098 | func (c Cart) TableName() string { | |
1099 | return "shopping_cart" | |
1100 | } | |
1101 | ||
1102 | func (u User) TableName() string { | |
1103 | if u.Role == "admin" { | |
1104 | return "admin_users" | |
1105 | } else { | |
1106 | return "users" | |
1107 | } | |
1108 | } | |
1109 | ``` | |
1110 | ||
1111 | ## Error Handling | |
1112 | ||
1113 | ```go | |
1114 | query := db.Where("name = ?", "jinzhu").First(&user) | |
1115 | query := db.First(&user).Limit(10).Find(&users) | |
1116 | // query.Error will return the last happened error | |
1117 | ||
1118 | // So you could do error handing in your application like this: | |
1119 | if err := db.Where("name = ?", "jinzhu").First(&user).Error; err != nil { | |
1120 | // error handling... | |
1121 | } | |
1122 | ||
1123 | // RecordNotFound | |
1124 | // If no record found when you query data, gorm will return RecordNotFound error, you could check it like this: | |
1125 | db.Where("name = ?", "hello world").First(&User{}).Error == gorm.RecordNotFound | |
1126 | // Or use the shortcut method | |
1127 | db.Where("name = ?", "hello world").First(&user).RecordNotFound() | |
1128 | ||
1129 | if db.Model(&user).Related(&credit_card).RecordNotFound() { | |
1130 | // no credit card found error handling | |
1131 | } | |
1132 | ``` | |
1133 | ||
1134 | ## Logger | |
1135 | ||
1136 | Gorm has built-in logger support | |
1137 | ||
1138 | ```go | |
1139 | // Enable Logger | |
1140 | db.LogMode(true) | |
1141 | ||
1142 | // Diable Logger | |
1143 | db.LogMode(false) | |
1144 | ||
1145 | // Debug a single operation | |
1146 | db.Debug().Where("name = ?", "jinzhu").First(&User{}) | |
1147 | ``` | |
1148 | ||
1149 | ![logger](https://raw.github.com/jinzhu/gorm/master/images/logger.png) | |
1150 | ||
1151 | ### Customize Logger | |
1152 | ||
1153 | ```go | |
1154 | // Refer gorm's default logger for how to: https://github.com/jinzhu/gorm/blob/master/logger.go#files | |
1155 | db.SetLogger(gorm.Logger{revel.TRACE}) | |
1156 | db.SetLogger(log.New(os.Stdout, "\r\n", 0)) | |
1157 | ``` | |
1158 | ||
1159 | ## Existing Schema | |
1160 | ||
1161 | If you have an existing database schema, and the primary key field is different from `id`, you can add a tag to the field structure to specify that this field is a primary key. | |
1162 | ||
1163 | ```go | |
1164 | type Animal struct { | |
1165 | AnimalId int64 `gorm:"primary_key"` | |
1166 | Birthday time.Time `sql:"DEFAULT:current_timestamp"` | |
1167 | Name string `sql:"default:'galeone'"` | |
1168 | Age int64 | |
1169 | } | |
1170 | ``` | |
1171 | ||
1172 | If your column names differ from the struct fields, you can specify them like this: | |
1173 | ||
1174 | ```go | |
1175 | type Animal struct { | |
1176 | AnimalId int64 `gorm:"column:beast_id;primary_key"` | |
1177 | Birthday time.Time `gorm:"column:day_of_the_beast"` | |
1178 | Age int64 `gorm:"column:age_of_the_beast"` | |
1179 | } | |
1180 | ``` | |
1181 | ||
1182 | ## Composite Primary Key | |
1183 | ||
1184 | ```go | |
1185 | type Product struct { | |
1186 | ID string `gorm:"primary_key"` | |
1187 | LanguageCode string `gorm:"primary_key"` | |
1188 | } | |
1189 | ``` | |
1190 | ||
1191 | ## Database Indexes & Foreign Key | |
1192 | ||
1193 | ```go | |
1194 | // Add foreign key | |
1195 | // 1st param : foreignkey field | |
1196 | // 2nd param : destination table(id) | |
1197 | // 3rd param : ONDELETE | |
1198 | // 4th param : ONUPDATE | |
1199 | db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") | |
1200 | ||
1201 | // Add index | |
1202 | db.Model(&User{}).AddIndex("idx_user_name", "name") | |
1203 | ||
1204 | // Multiple column index | |
1205 | db.Model(&User{}).AddIndex("idx_user_name_age", "name", "age") | |
1206 | ||
1207 | // Add unique index | |
1208 | db.Model(&User{}).AddUniqueIndex("idx_user_name", "name") | |
1209 | ||
1210 | // Multiple column unique index | |
1211 | db.Model(&User{}).AddUniqueIndex("idx_user_name_age", "name", "age") | |
1212 | ||
1213 | // Remove index | |
1214 | db.Model(&User{}).RemoveIndex("idx_user_name") | |
1215 | ``` | |
1216 | ||
1217 | ## Default values | |
1218 | ||
1219 | ```go | |
1220 | type Animal struct { | |
1221 | ID int64 | |
1222 | Name string `sql:"default:'galeone'"` | |
1223 | Age int64 | |
1224 | } | |
1225 | ``` | |
1226 | ||
1227 | If you have defined a default value in the `sql` tag, the generated create SQl will ignore these fields if it is blank. | |
1228 | ||
1229 | Eg. | |
1230 | ||
1231 | ```go | |
1232 | db.Create(&Animal{Age: 99, Name: ""}) | |
1233 | ``` | |
1234 | ||
1235 | The generated SQL will be: | |
1236 | ||
1237 | ```sql | |
1238 | INSERT INTO animals("age") values('99'); | |
1239 | ``` | |
1240 | ||
1241 | The same thing occurs in update statements. | |
1242 | ||
1243 | ## More examples with query chain | |
1244 | ||
1245 | ```go | |
1246 | db.First(&first_article).Count(&total_count).Limit(10).Find(&first_page_articles).Offset(10).Find(&second_page_articles) | |
1247 | //// SELECT * FROM articles LIMIT 1; (first_article) | |
1248 | //// SELECT count(*) FROM articles; (total_count) | |
1249 | //// SELECT * FROM articles LIMIT 10; (first_page_articles) | |
1250 | //// SELECT * FROM articles LIMIT 10 OFFSET 10; (second_page_articles) | |
1251 | ||
1252 | ||
1253 | db.Where("created_at > ?", "2013-10-10").Find(&cancelled_orders, "state = ?", "cancelled").Find(&shipped_orders, "state = ?", "shipped") | |
1254 | //// SELECT * FROM orders WHERE created_at > '2013/10/10' AND state = 'cancelled'; (cancelled_orders) | |
1255 | //// SELECT * FROM orders WHERE created_at > '2013/10/10' AND state = 'shipped'; (shipped_orders) | |
1256 | ||
1257 | ||
1258 | // Use variables to keep query chain | |
1259 | todays_orders := db.Where("created_at > ?", "2013-10-29") | |
1260 | cancelled_orders := todays_orders.Where("state = ?", "cancelled") | |
1261 | shipped_orders := todays_orders.Where("state = ?", "shipped") | |
1262 | ||
1263 | ||
1264 | // Search with shared conditions for different tables | |
1265 | db.Where("product_name = ?", "fancy_product").Find(&orders).Find(&shopping_carts) | |
1266 | //// SELECT * FROM orders WHERE product_name = 'fancy_product'; (orders) | |
1267 | //// SELECT * FROM carts WHERE product_name = 'fancy_product'; (shopping_carts) | |
1268 | ||
1269 | ||
1270 | // Search with shared conditions from different tables with specified table | |
1271 | db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").Find(&users2) | |
1272 | //// SELECT * FROM users WHERE mail_type = 'TEXT'; (users1) | |
1273 | //// SELECT * FROM deleted_users WHERE mail_type = 'TEXT'; (users2) | |
1274 | ||
1275 | ||
1276 | // FirstOrCreate example | |
1277 | db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111"}).FirstOrCreate(&user) | |
1278 | //// SELECT * FROM users WHERE email = 'x@example.org'; | |
1279 | //// INSERT INTO "users" (email,registered_ip) VALUES ("x@example.org", "111.111.111.111") // if record not found | |
1280 | ``` | |
1281 | ||
1282 | ## TODO | |
1283 | * Github Pages | |
1284 | ||
1285 | # Author | |
1286 | ||
1287 | **jinzhu** | |
1288 | ||
1289 | * <http://github.com/jinzhu> | |
1290 | * <wosmvp@gmail.com> | |
1291 | * <http://twitter.com/zhangjinzhu> | |
33 | [You can help to deliver a better GORM, check out things you can do](http://gorm.io/contribute.html) | |
1292 | 34 | |
1293 | 35 | ## License |
1294 | 36 | |
1295 | Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License). | |
37 | © Jinzhu, 2013~time.Now | |
38 | ||
39 | Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License) |
3 | 3 | "errors" |
4 | 4 | "fmt" |
5 | 5 | "reflect" |
6 | "strings" | |
7 | 6 | ) |
8 | 7 | |
8 | // Association Mode contains some helper methods to handle relationship things easily. | |
9 | 9 | type Association struct { |
10 | Scope *Scope | |
11 | Column string | |
12 | 10 | Error error |
13 | Field *Field | |
11 | scope *Scope | |
12 | column string | |
13 | field *Field | |
14 | } | |
15 | ||
16 | // Find find out all related associations | |
17 | func (association *Association) Find(value interface{}) *Association { | |
18 | association.scope.related(value, association.column) | |
19 | return association.setErr(association.scope.db.Error) | |
20 | } | |
21 | ||
22 | // Append append new associations for many2many, has_many, replace current association for has_one, belongs_to | |
23 | func (association *Association) Append(values ...interface{}) *Association { | |
24 | if association.Error != nil { | |
25 | return association | |
26 | } | |
27 | ||
28 | if relationship := association.field.Relationship; relationship.Kind == "has_one" { | |
29 | return association.Replace(values...) | |
30 | } | |
31 | return association.saveAssociations(values...) | |
32 | } | |
33 | ||
34 | // Replace replace current associations with new one | |
35 | func (association *Association) Replace(values ...interface{}) *Association { | |
36 | if association.Error != nil { | |
37 | return association | |
38 | } | |
39 | ||
40 | var ( | |
41 | relationship = association.field.Relationship | |
42 | scope = association.scope | |
43 | field = association.field.Field | |
44 | newDB = scope.NewDB() | |
45 | ) | |
46 | ||
47 | // Append new values | |
48 | association.field.Set(reflect.Zero(association.field.Field.Type())) | |
49 | association.saveAssociations(values...) | |
50 | ||
51 | // Belongs To | |
52 | if relationship.Kind == "belongs_to" { | |
53 | // Set foreign key to be null when clearing value (length equals 0) | |
54 | if len(values) == 0 { | |
55 | // Set foreign key to be nil | |
56 | var foreignKeyMap = map[string]interface{}{} | |
57 | for _, foreignKey := range relationship.ForeignDBNames { | |
58 | foreignKeyMap[foreignKey] = nil | |
59 | } | |
60 | association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) | |
61 | } | |
62 | } else { | |
63 | // Polymorphic Relations | |
64 | if relationship.PolymorphicDBName != "" { | |
65 | newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) | |
66 | } | |
67 | ||
68 | // Delete Relations except new created | |
69 | if len(values) > 0 { | |
70 | var associationForeignFieldNames, associationForeignDBNames []string | |
71 | if relationship.Kind == "many_to_many" { | |
72 | // if many to many relations, get association fields name from association foreign keys | |
73 | associationScope := scope.New(reflect.New(field.Type()).Interface()) | |
74 | for idx, dbName := range relationship.AssociationForeignFieldNames { | |
75 | if field, ok := associationScope.FieldByName(dbName); ok { | |
76 | associationForeignFieldNames = append(associationForeignFieldNames, field.Name) | |
77 | associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx]) | |
78 | } | |
79 | } | |
80 | } else { | |
81 | // If has one/many relations, use primary keys | |
82 | for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { | |
83 | associationForeignFieldNames = append(associationForeignFieldNames, field.Name) | |
84 | associationForeignDBNames = append(associationForeignDBNames, field.DBName) | |
85 | } | |
86 | } | |
87 | ||
88 | newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) | |
89 | ||
90 | if len(newPrimaryKeys) > 0 { | |
91 | sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys)) | |
92 | newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) | |
93 | } | |
94 | } | |
95 | ||
96 | if relationship.Kind == "many_to_many" { | |
97 | // if many to many relations, delete related relations from join table | |
98 | var sourceForeignFieldNames []string | |
99 | ||
100 | for _, dbName := range relationship.ForeignFieldNames { | |
101 | if field, ok := scope.FieldByName(dbName); ok { | |
102 | sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name) | |
103 | } | |
104 | } | |
105 | ||
106 | if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { | |
107 | newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) | |
108 | ||
109 | association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) | |
110 | } | |
111 | } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { | |
112 | // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) | |
113 | var foreignKeyMap = map[string]interface{}{} | |
114 | for idx, foreignKey := range relationship.ForeignDBNames { | |
115 | foreignKeyMap[foreignKey] = nil | |
116 | if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { | |
117 | newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) | |
118 | } | |
119 | } | |
120 | ||
121 | fieldValue := reflect.New(association.field.Field.Type()).Interface() | |
122 | association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) | |
123 | } | |
124 | } | |
125 | return association | |
126 | } | |
127 | ||
128 | // Delete remove relationship between source & passed arguments, but won't delete those arguments | |
129 | func (association *Association) Delete(values ...interface{}) *Association { | |
130 | if association.Error != nil { | |
131 | return association | |
132 | } | |
133 | ||
134 | var ( | |
135 | relationship = association.field.Relationship | |
136 | scope = association.scope | |
137 | field = association.field.Field | |
138 | newDB = scope.NewDB() | |
139 | ) | |
140 | ||
141 | if len(values) == 0 { | |
142 | return association | |
143 | } | |
144 | ||
145 | var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string | |
146 | for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { | |
147 | deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name) | |
148 | deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) | |
149 | } | |
150 | ||
151 | deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) | |
152 | ||
153 | if relationship.Kind == "many_to_many" { | |
154 | // source value's foreign keys | |
155 | for idx, foreignKey := range relationship.ForeignDBNames { | |
156 | if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { | |
157 | newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) | |
158 | } | |
159 | } | |
160 | ||
161 | // get association's foreign fields name | |
162 | var associationScope = scope.New(reflect.New(field.Type()).Interface()) | |
163 | var associationForeignFieldNames []string | |
164 | for _, associationDBName := range relationship.AssociationForeignFieldNames { | |
165 | if field, ok := associationScope.FieldByName(associationDBName); ok { | |
166 | associationForeignFieldNames = append(associationForeignFieldNames, field.Name) | |
167 | } | |
168 | } | |
169 | ||
170 | // association value's foreign keys | |
171 | deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...) | |
172 | sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) | |
173 | newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) | |
174 | ||
175 | association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) | |
176 | } else { | |
177 | var foreignKeyMap = map[string]interface{}{} | |
178 | for _, foreignKey := range relationship.ForeignDBNames { | |
179 | foreignKeyMap[foreignKey] = nil | |
180 | } | |
181 | ||
182 | if relationship.Kind == "belongs_to" { | |
183 | // find with deleting relation's foreign keys | |
184 | primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...) | |
185 | newDB = newDB.Where( | |
186 | fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), | |
187 | toQueryValues(primaryKeys)..., | |
188 | ) | |
189 | ||
190 | // set foreign key to be null if there are some records affected | |
191 | modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() | |
192 | if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { | |
193 | if results.RowsAffected > 0 { | |
194 | scope.updatedAttrsWithValues(foreignKeyMap) | |
195 | } | |
196 | } else { | |
197 | association.setErr(results.Error) | |
198 | } | |
199 | } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { | |
200 | // find all relations | |
201 | primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) | |
202 | newDB = newDB.Where( | |
203 | fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), | |
204 | toQueryValues(primaryKeys)..., | |
205 | ) | |
206 | ||
207 | // only include those deleting relations | |
208 | newDB = newDB.Where( | |
209 | fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)), | |
210 | toQueryValues(deletingPrimaryKeys)..., | |
211 | ) | |
212 | ||
213 | // set matched relation's foreign key to be null | |
214 | fieldValue := reflect.New(association.field.Field.Type()).Interface() | |
215 | association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) | |
216 | } | |
217 | } | |
218 | ||
219 | // Remove deleted records from source's field | |
220 | if association.Error == nil { | |
221 | if field.Kind() == reflect.Slice { | |
222 | leftValues := reflect.Zero(field.Type()) | |
223 | ||
224 | for i := 0; i < field.Len(); i++ { | |
225 | reflectValue := field.Index(i) | |
226 | primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] | |
227 | var isDeleted = false | |
228 | for _, pk := range deletingPrimaryKeys { | |
229 | if equalAsString(primaryKey, pk) { | |
230 | isDeleted = true | |
231 | break | |
232 | } | |
233 | } | |
234 | if !isDeleted { | |
235 | leftValues = reflect.Append(leftValues, reflectValue) | |
236 | } | |
237 | } | |
238 | ||
239 | association.field.Set(leftValues) | |
240 | } else if field.Kind() == reflect.Struct { | |
241 | primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] | |
242 | for _, pk := range deletingPrimaryKeys { | |
243 | if equalAsString(primaryKey, pk) { | |
244 | association.field.Set(reflect.Zero(field.Type())) | |
245 | break | |
246 | } | |
247 | } | |
248 | } | |
249 | } | |
250 | ||
251 | return association | |
252 | } | |
253 | ||
254 | // Clear remove relationship between source & current associations, won't delete those associations | |
255 | func (association *Association) Clear() *Association { | |
256 | return association.Replace() | |
257 | } | |
258 | ||
259 | // Count return the count of current associations | |
260 | func (association *Association) Count() int { | |
261 | var ( | |
262 | count = 0 | |
263 | relationship = association.field.Relationship | |
264 | scope = association.scope | |
265 | fieldValue = association.field.Field.Interface() | |
266 | query = scope.DB() | |
267 | ) | |
268 | ||
269 | if relationship.Kind == "many_to_many" { | |
270 | query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) | |
271 | } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { | |
272 | primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) | |
273 | query = query.Where( | |
274 | fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), | |
275 | toQueryValues(primaryKeys)..., | |
276 | ) | |
277 | } else if relationship.Kind == "belongs_to" { | |
278 | primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) | |
279 | query = query.Where( | |
280 | fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), | |
281 | toQueryValues(primaryKeys)..., | |
282 | ) | |
283 | } | |
284 | ||
285 | if relationship.PolymorphicType != "" { | |
286 | query = query.Where( | |
287 | fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), | |
288 | relationship.PolymorphicValue, | |
289 | ) | |
290 | } | |
291 | ||
292 | if err := query.Model(fieldValue).Count(&count).Error; err != nil { | |
293 | association.Error = err | |
294 | } | |
295 | return count | |
296 | } | |
297 | ||
298 | // saveAssociations save passed values as associations | |
299 | func (association *Association) saveAssociations(values ...interface{}) *Association { | |
300 | var ( | |
301 | scope = association.scope | |
302 | field = association.field | |
303 | relationship = field.Relationship | |
304 | ) | |
305 | ||
306 | saveAssociation := func(reflectValue reflect.Value) { | |
307 | // value has to been pointer | |
308 | if reflectValue.Kind() != reflect.Ptr { | |
309 | reflectPtr := reflect.New(reflectValue.Type()) | |
310 | reflectPtr.Elem().Set(reflectValue) | |
311 | reflectValue = reflectPtr | |
312 | } | |
313 | ||
314 | // value has to been saved for many2many | |
315 | if relationship.Kind == "many_to_many" { | |
316 | if scope.New(reflectValue.Interface()).PrimaryKeyZero() { | |
317 | association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) | |
318 | } | |
319 | } | |
320 | ||
321 | // Assign Fields | |
322 | var fieldType = field.Field.Type() | |
323 | var setFieldBackToValue, setSliceFieldBackToValue bool | |
324 | if reflectValue.Type().AssignableTo(fieldType) { | |
325 | field.Set(reflectValue) | |
326 | } else if reflectValue.Type().Elem().AssignableTo(fieldType) { | |
327 | // if field's type is struct, then need to set value back to argument after save | |
328 | setFieldBackToValue = true | |
329 | field.Set(reflectValue.Elem()) | |
330 | } else if fieldType.Kind() == reflect.Slice { | |
331 | if reflectValue.Type().AssignableTo(fieldType.Elem()) { | |
332 | field.Set(reflect.Append(field.Field, reflectValue)) | |
333 | } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { | |
334 | // if field's type is slice of struct, then need to set value back to argument after save | |
335 | setSliceFieldBackToValue = true | |
336 | field.Set(reflect.Append(field.Field, reflectValue.Elem())) | |
337 | } | |
338 | } | |
339 | ||
340 | if relationship.Kind == "many_to_many" { | |
341 | association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) | |
342 | } else { | |
343 | association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) | |
344 | ||
345 | if setFieldBackToValue { | |
346 | reflectValue.Elem().Set(field.Field) | |
347 | } else if setSliceFieldBackToValue { | |
348 | reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1)) | |
349 | } | |
350 | } | |
351 | } | |
352 | ||
353 | for _, value := range values { | |
354 | reflectValue := reflect.ValueOf(value) | |
355 | indirectReflectValue := reflect.Indirect(reflectValue) | |
356 | if indirectReflectValue.Kind() == reflect.Struct { | |
357 | saveAssociation(reflectValue) | |
358 | } else if indirectReflectValue.Kind() == reflect.Slice { | |
359 | for i := 0; i < indirectReflectValue.Len(); i++ { | |
360 | saveAssociation(indirectReflectValue.Index(i)) | |
361 | } | |
362 | } else { | |
363 | association.setErr(errors.New("invalid value type")) | |
364 | } | |
365 | } | |
366 | return association | |
14 | 367 | } |
15 | 368 | |
16 | 369 | func (association *Association) setErr(err error) *Association { |
19 | 372 | } |
20 | 373 | return association |
21 | 374 | } |
22 | ||
23 | func (association *Association) Find(value interface{}) *Association { | |
24 | association.Scope.related(value, association.Column) | |
25 | return association.setErr(association.Scope.db.Error) | |
26 | } | |
27 | ||
28 | func (association *Association) Append(values ...interface{}) *Association { | |
29 | scope := association.Scope | |
30 | field := association.Field | |
31 | ||
32 | createJoinTable := func(reflectValue reflect.Value) { | |
33 | var value = reflectValue.Interface() | |
34 | if reflectValue.Kind() != reflect.Ptr { | |
35 | reflectPtr := reflect.New(reflectValue.Type()) | |
36 | reflectPtr.Elem().Set(reflectValue) | |
37 | value = reflectPtr.Interface() | |
38 | } | |
39 | ||
40 | if scope.New(value).PrimaryKeyZero() { | |
41 | scope.NewDB().Save(value) | |
42 | } | |
43 | ||
44 | relationship := association.Field.Relationship | |
45 | association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, value)) | |
46 | ||
47 | result := reflect.ValueOf(value) | |
48 | fieldElemType := field.Field.Type().Elem() | |
49 | if result.Type().AssignableTo(fieldElemType) { | |
50 | field.Set(reflect.Append(field.Field, result)) | |
51 | } else if result.Type().Elem().AssignableTo(fieldElemType) { | |
52 | field.Set(reflect.Append(field.Field, result.Elem())) | |
53 | } | |
54 | } | |
55 | ||
56 | for _, value := range values { | |
57 | reflectValue := reflect.Indirect(reflect.ValueOf(value)) | |
58 | ||
59 | if reflectValue.Kind() == reflect.Struct { | |
60 | createJoinTable(reflectValue) | |
61 | } else if reflectValue.Kind() == reflect.Slice { | |
62 | for i := 0; i < reflectValue.Len(); i++ { | |
63 | createJoinTable(reflectValue.Index(i)) | |
64 | } | |
65 | } else { | |
66 | association.setErr(errors.New("invalid association type")) | |
67 | } | |
68 | } | |
69 | return association | |
70 | } | |
71 | ||
72 | func (association *Association) Delete(values ...interface{}) *Association { | |
73 | scope := association.Scope | |
74 | relationship := association.Field.Relationship | |
75 | ||
76 | // many to many | |
77 | if relationship.Kind == "many_to_many" { | |
78 | query := scope.NewDB() | |
79 | for idx, foreignKey := range relationship.ForeignDBNames { | |
80 | if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { | |
81 | query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) | |
82 | } | |
83 | } | |
84 | ||
85 | primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) | |
86 | sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)) | |
87 | query = query.Where(sql, toQueryValues(primaryKeys)...) | |
88 | ||
89 | if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { | |
90 | leftValues := reflect.Zero(association.Field.Field.Type()) | |
91 | for i := 0; i < association.Field.Field.Len(); i++ { | |
92 | reflectValue := association.Field.Field.Index(i) | |
93 | primaryKey := association.getPrimaryKeys(relationship.ForeignFieldNames, reflectValue.Interface())[0] | |
94 | var included = false | |
95 | for _, pk := range primaryKeys { | |
96 | if equalAsString(primaryKey, pk) { | |
97 | included = true | |
98 | } | |
99 | } | |
100 | if !included { | |
101 | leftValues = reflect.Append(leftValues, reflectValue) | |
102 | } | |
103 | } | |
104 | association.Field.Set(leftValues) | |
105 | } | |
106 | } else { | |
107 | association.setErr(errors.New("delete only support many to many")) | |
108 | } | |
109 | return association | |
110 | } | |
111 | ||
112 | func (association *Association) Replace(values ...interface{}) *Association { | |
113 | relationship := association.Field.Relationship | |
114 | scope := association.Scope | |
115 | if relationship.Kind == "many_to_many" { | |
116 | field := association.Field.Field | |
117 | ||
118 | oldPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) | |
119 | association.Field.Set(reflect.Zero(association.Field.Field.Type())) | |
120 | association.Append(values...) | |
121 | newPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) | |
122 | ||
123 | var addedPrimaryKeys = [][]interface{}{} | |
124 | for _, newKey := range newPrimaryKeys { | |
125 | hasEqual := false | |
126 | for _, oldKey := range oldPrimaryKeys { | |
127 | if equalAsString(newKey, oldKey) { | |
128 | hasEqual = true | |
129 | break | |
130 | } | |
131 | } | |
132 | if !hasEqual { | |
133 | addedPrimaryKeys = append(addedPrimaryKeys, newKey) | |
134 | } | |
135 | } | |
136 | ||
137 | for _, primaryKey := range association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) { | |
138 | addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) | |
139 | } | |
140 | ||
141 | query := scope.NewDB() | |
142 | for idx, foreignKey := range relationship.ForeignDBNames { | |
143 | if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { | |
144 | query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) | |
145 | } | |
146 | } | |
147 | ||
148 | if len(addedPrimaryKeys) > 0 { | |
149 | sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys)) | |
150 | query = query.Where(sql, toQueryValues(addedPrimaryKeys)...) | |
151 | } | |
152 | association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship)) | |
153 | } else { | |
154 | association.setErr(errors.New("replace only support many to many")) | |
155 | } | |
156 | return association | |
157 | } | |
158 | ||
159 | func (association *Association) Clear() *Association { | |
160 | relationship := association.Field.Relationship | |
161 | scope := association.Scope | |
162 | if relationship.Kind == "many_to_many" { | |
163 | query := scope.NewDB() | |
164 | for idx, foreignKey := range relationship.ForeignDBNames { | |
165 | if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { | |
166 | query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) | |
167 | } | |
168 | } | |
169 | ||
170 | if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { | |
171 | association.Field.Set(reflect.Zero(association.Field.Field.Type())) | |
172 | } else { | |
173 | association.setErr(err) | |
174 | } | |
175 | } else { | |
176 | association.setErr(errors.New("clear only support many to many")) | |
177 | } | |
178 | return association | |
179 | } | |
180 | ||
181 | func (association *Association) Count() int { | |
182 | count := -1 | |
183 | relationship := association.Field.Relationship | |
184 | scope := association.Scope | |
185 | newScope := scope.New(association.Field.Field.Interface()) | |
186 | ||
187 | if relationship.Kind == "many_to_many" { | |
188 | relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) | |
189 | } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { | |
190 | query := scope.DB() | |
191 | for idx, foreignKey := range relationship.ForeignDBNames { | |
192 | if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { | |
193 | query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), | |
194 | field.Field.Interface()) | |
195 | } | |
196 | } | |
197 | ||
198 | if relationship.PolymorphicType != "" { | |
199 | query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) | |
200 | } | |
201 | query.Table(newScope.TableName()).Count(&count) | |
202 | } else if relationship.Kind == "belongs_to" { | |
203 | query := scope.DB() | |
204 | for idx, foreignKey := range relationship.ForeignDBNames { | |
205 | if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { | |
206 | query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), | |
207 | field.Field.Interface()) | |
208 | } | |
209 | } | |
210 | query.Table(newScope.TableName()).Count(&count) | |
211 | } | |
212 | ||
213 | return count | |
214 | } | |
215 | ||
216 | func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) [][]interface{} { | |
217 | results := [][]interface{}{} | |
218 | scope := association.Scope | |
219 | ||
220 | for _, value := range values { | |
221 | reflectValue := reflect.Indirect(reflect.ValueOf(value)) | |
222 | if reflectValue.Kind() == reflect.Slice { | |
223 | for i := 0; i < reflectValue.Len(); i++ { | |
224 | primaryKeys := []interface{}{} | |
225 | newScope := scope.New(reflectValue.Index(i).Interface()) | |
226 | for _, column := range columns { | |
227 | if field, ok := newScope.FieldByName(column); ok { | |
228 | primaryKeys = append(primaryKeys, field.Field.Interface()) | |
229 | } else { | |
230 | primaryKeys = append(primaryKeys, "") | |
231 | } | |
232 | } | |
233 | results = append(results, primaryKeys) | |
234 | } | |
235 | } else if reflectValue.Kind() == reflect.Struct { | |
236 | newScope := scope.New(value) | |
237 | var primaryKeys []interface{} | |
238 | for _, column := range columns { | |
239 | if field, ok := newScope.FieldByName(column); ok { | |
240 | primaryKeys = append(primaryKeys, field.Field.Interface()) | |
241 | } else { | |
242 | primaryKeys = append(primaryKeys, "") | |
243 | } | |
244 | } | |
245 | ||
246 | results = append(results, primaryKeys) | |
247 | } | |
248 | } | |
249 | return results | |
250 | } | |
251 | ||
252 | func toQueryMarks(primaryValues [][]interface{}) string { | |
253 | var results []string | |
254 | ||
255 | for _, primaryValue := range primaryValues { | |
256 | var marks []string | |
257 | for _, _ = range primaryValue { | |
258 | marks = append(marks, "?") | |
259 | } | |
260 | ||
261 | if len(marks) > 1 { | |
262 | results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) | |
263 | } else { | |
264 | results = append(results, strings.Join(marks, "")) | |
265 | } | |
266 | } | |
267 | return strings.Join(results, ",") | |
268 | } | |
269 | ||
270 | func toQueryCondition(scope *Scope, columns []string) string { | |
271 | var newColumns []string | |
272 | for _, column := range columns { | |
273 | newColumns = append(newColumns, scope.Quote(column)) | |
274 | } | |
275 | ||
276 | if len(columns) > 1 { | |
277 | return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) | |
278 | } else { | |
279 | return strings.Join(columns, ",") | |
280 | } | |
281 | } | |
282 | ||
283 | func toQueryValues(primaryValues [][]interface{}) (values []interface{}) { | |
284 | for _, primaryValue := range primaryValues { | |
285 | for _, value := range primaryValue { | |
286 | values = append(values, value) | |
287 | } | |
288 | } | |
289 | return values | |
290 | } |
1 | 1 | |
2 | 2 | import ( |
3 | 3 | "fmt" |
4 | "os" | |
5 | "reflect" | |
6 | "sort" | |
4 | 7 | "testing" |
8 | ||
9 | "github.com/jinzhu/gorm" | |
5 | 10 | ) |
6 | 11 | |
7 | func TestHasOneAndHasManyAssociation(t *testing.T) { | |
8 | DB.DropTable(Category{}, Post{}, Comment{}) | |
9 | DB.CreateTable(Category{}, Post{}, Comment{}) | |
10 | ||
12 | func TestBelongsTo(t *testing.T) { | |
11 | 13 | post := Post{ |
12 | Title: "post 1", | |
13 | Body: "body 1", | |
14 | Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}}, | |
14 | Title: "post belongs to", | |
15 | Body: "body belongs to", | |
15 | 16 | Category: Category{Name: "Category 1"}, |
16 | 17 | MainCategory: Category{Name: "Main Category 1"}, |
17 | 18 | } |
18 | 19 | |
19 | 20 | if err := DB.Save(&post).Error; err != nil { |
20 | t.Errorf("Got errors when save post", err.Error()) | |
21 | } | |
22 | ||
23 | if err := DB.First(&Category{}, "name = ?", "Category 1").Error; err != nil { | |
24 | t.Errorf("Category should be saved", err.Error()) | |
25 | } | |
26 | ||
27 | var p Post | |
28 | DB.First(&p, post.Id) | |
29 | ||
30 | if post.CategoryId.Int64 == 0 || p.CategoryId.Int64 == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 { | |
31 | t.Errorf("Category Id should exist") | |
32 | } | |
33 | ||
21 | t.Error("Got errors when save post", err) | |
22 | } | |
23 | ||
24 | if post.Category.ID == 0 || post.MainCategory.ID == 0 { | |
25 | t.Errorf("Category's primary key should be updated") | |
26 | } | |
27 | ||
28 | if post.CategoryId.Int64 == 0 || post.MainCategoryId == 0 { | |
29 | t.Errorf("post's foreign key should be updated") | |
30 | } | |
31 | ||
32 | // Query | |
33 | var category1 Category | |
34 | DB.Model(&post).Association("Category").Find(&category1) | |
35 | if category1.Name != "Category 1" { | |
36 | t.Errorf("Query belongs to relations with Association") | |
37 | } | |
38 | ||
39 | var mainCategory1 Category | |
40 | DB.Model(&post).Association("MainCategory").Find(&mainCategory1) | |
41 | if mainCategory1.Name != "Main Category 1" { | |
42 | t.Errorf("Query belongs to relations with Association") | |
43 | } | |
44 | ||
45 | var category11 Category | |
46 | DB.Model(&post).Related(&category11) | |
47 | if category11.Name != "Category 1" { | |
48 | t.Errorf("Query belongs to relations with Related") | |
49 | } | |
50 | ||
51 | if DB.Model(&post).Association("Category").Count() != 1 { | |
52 | t.Errorf("Post's category count should be 1") | |
53 | } | |
54 | ||
55 | if DB.Model(&post).Association("MainCategory").Count() != 1 { | |
56 | t.Errorf("Post's main category count should be 1") | |
57 | } | |
58 | ||
59 | // Append | |
60 | var category2 = Category{ | |
61 | Name: "Category 2", | |
62 | } | |
63 | DB.Model(&post).Association("Category").Append(&category2) | |
64 | ||
65 | if category2.ID == 0 { | |
66 | t.Errorf("Category should has ID when created with Append") | |
67 | } | |
68 | ||
69 | var category21 Category | |
70 | DB.Model(&post).Related(&category21) | |
71 | ||
72 | if category21.Name != "Category 2" { | |
73 | t.Errorf("Category should be updated with Append") | |
74 | } | |
75 | ||
76 | if DB.Model(&post).Association("Category").Count() != 1 { | |
77 | t.Errorf("Post's category count should be 1") | |
78 | } | |
79 | ||
80 | // Replace | |
81 | var category3 = Category{ | |
82 | Name: "Category 3", | |
83 | } | |
84 | DB.Model(&post).Association("Category").Replace(&category3) | |
85 | ||
86 | if category3.ID == 0 { | |
87 | t.Errorf("Category should has ID when created with Replace") | |
88 | } | |
89 | ||
90 | var category31 Category | |
91 | DB.Model(&post).Related(&category31) | |
92 | if category31.Name != "Category 3" { | |
93 | t.Errorf("Category should be updated with Replace") | |
94 | } | |
95 | ||
96 | if DB.Model(&post).Association("Category").Count() != 1 { | |
97 | t.Errorf("Post's category count should be 1") | |
98 | } | |
99 | ||
100 | // Delete | |
101 | DB.Model(&post).Association("Category").Delete(&category2) | |
102 | if DB.Model(&post).Related(&Category{}).RecordNotFound() { | |
103 | t.Errorf("Should not delete any category when Delete a unrelated Category") | |
104 | } | |
105 | ||
106 | if post.Category.Name == "" { | |
107 | t.Errorf("Post's category should not be reseted when Delete a unrelated Category") | |
108 | } | |
109 | ||
110 | DB.Model(&post).Association("Category").Delete(&category3) | |
111 | ||
112 | if post.Category.Name != "" { | |
113 | t.Errorf("Post's category should be reseted after Delete") | |
114 | } | |
115 | ||
116 | var category41 Category | |
117 | DB.Model(&post).Related(&category41) | |
118 | if category41.Name != "" { | |
119 | t.Errorf("Category should be deleted with Delete") | |
120 | } | |
121 | ||
122 | if count := DB.Model(&post).Association("Category").Count(); count != 0 { | |
123 | t.Errorf("Post's category count should be 0 after Delete, but got %v", count) | |
124 | } | |
125 | ||
126 | // Clear | |
127 | DB.Model(&post).Association("Category").Append(&Category{ | |
128 | Name: "Category 2", | |
129 | }) | |
130 | ||
131 | if DB.Model(&post).Related(&Category{}).RecordNotFound() { | |
132 | t.Errorf("Should find category after append") | |
133 | } | |
134 | ||
135 | if post.Category.Name == "" { | |
136 | t.Errorf("Post's category should has value after Append") | |
137 | } | |
138 | ||
139 | DB.Model(&post).Association("Category").Clear() | |
140 | ||
141 | if post.Category.Name != "" { | |
142 | t.Errorf("Post's category should be cleared after Clear") | |
143 | } | |
144 | ||
145 | if !DB.Model(&post).Related(&Category{}).RecordNotFound() { | |
146 | t.Errorf("Should not find any category after Clear") | |
147 | } | |
148 | ||
149 | if count := DB.Model(&post).Association("Category").Count(); count != 0 { | |
150 | t.Errorf("Post's category count should be 0 after Clear, but got %v", count) | |
151 | } | |
152 | ||
153 | // Check Association mode with soft delete | |
154 | category6 := Category{ | |
155 | Name: "Category 6", | |
156 | } | |
157 | DB.Model(&post).Association("Category").Append(&category6) | |
158 | ||
159 | if count := DB.Model(&post).Association("Category").Count(); count != 1 { | |
160 | t.Errorf("Post's category count should be 1 after Append, but got %v", count) | |
161 | } | |
162 | ||
163 | DB.Delete(&category6) | |
164 | ||
165 | if count := DB.Model(&post).Association("Category").Count(); count != 0 { | |
166 | t.Errorf("Post's category count should be 0 after the category has been deleted, but got %v", count) | |
167 | } | |
168 | ||
169 | if err := DB.Model(&post).Association("Category").Find(&Category{}).Error; err == nil { | |
170 | t.Errorf("Post's category is not findable after Delete") | |
171 | } | |
172 | ||
173 | if count := DB.Unscoped().Model(&post).Association("Category").Count(); count != 1 { | |
174 | t.Errorf("Post's category count should be 1 when query with Unscoped, but got %v", count) | |
175 | } | |
176 | ||
177 | if err := DB.Unscoped().Model(&post).Association("Category").Find(&Category{}).Error; err != nil { | |
178 | t.Errorf("Post's category should be findable when query with Unscoped, got %v", err) | |
179 | } | |
180 | } | |
181 | ||
182 | func TestBelongsToOverrideForeignKey1(t *testing.T) { | |
183 | type Profile struct { | |
184 | gorm.Model | |
185 | Name string | |
186 | } | |
187 | ||
188 | type User struct { | |
189 | gorm.Model | |
190 | Profile Profile `gorm:"ForeignKey:ProfileRefer"` | |
191 | ProfileRefer int | |
192 | } | |
193 | ||
194 | if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { | |
195 | if relation.Relationship.Kind != "belongs_to" || | |
196 | !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) || | |
197 | !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { | |
198 | t.Errorf("Override belongs to foreign key with tag") | |
199 | } | |
200 | } | |
201 | } | |
202 | ||
203 | func TestBelongsToOverrideForeignKey2(t *testing.T) { | |
204 | type Profile struct { | |
205 | gorm.Model | |
206 | Refer string | |
207 | Name string | |
208 | } | |
209 | ||
210 | type User struct { | |
211 | gorm.Model | |
212 | Profile Profile `gorm:"ForeignKey:ProfileID;AssociationForeignKey:Refer"` | |
213 | ProfileID int | |
214 | } | |
215 | ||
216 | if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { | |
217 | if relation.Relationship.Kind != "belongs_to" || | |
218 | !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) || | |
219 | !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { | |
220 | t.Errorf("Override belongs to foreign key with tag") | |
221 | } | |
222 | } | |
223 | } | |
224 | ||
225 | func TestHasOne(t *testing.T) { | |
226 | user := User{ | |
227 | Name: "has one", | |
228 | CreditCard: CreditCard{Number: "411111111111"}, | |
229 | } | |
230 | ||
231 | if err := DB.Save(&user).Error; err != nil { | |
232 | t.Error("Got errors when save user", err.Error()) | |
233 | } | |
234 | ||
235 | if user.CreditCard.UserId.Int64 == 0 { | |
236 | t.Errorf("CreditCard's foreign key should be updated") | |
237 | } | |
238 | ||
239 | // Query | |
240 | var creditCard1 CreditCard | |
241 | DB.Model(&user).Related(&creditCard1) | |
242 | ||
243 | if creditCard1.Number != "411111111111" { | |
244 | t.Errorf("Query has one relations with Related") | |
245 | } | |
246 | ||
247 | var creditCard11 CreditCard | |
248 | DB.Model(&user).Association("CreditCard").Find(&creditCard11) | |
249 | ||
250 | if creditCard11.Number != "411111111111" { | |
251 | t.Errorf("Query has one relations with Related") | |
252 | } | |
253 | ||
254 | if DB.Model(&user).Association("CreditCard").Count() != 1 { | |
255 | t.Errorf("User's credit card count should be 1") | |
256 | } | |
257 | ||
258 | // Append | |
259 | var creditcard2 = CreditCard{ | |
260 | Number: "411111111112", | |
261 | } | |
262 | DB.Model(&user).Association("CreditCard").Append(&creditcard2) | |
263 | ||
264 | if creditcard2.ID == 0 { | |
265 | t.Errorf("Creditcard should has ID when created with Append") | |
266 | } | |
267 | ||
268 | var creditcard21 CreditCard | |
269 | DB.Model(&user).Related(&creditcard21) | |
270 | if creditcard21.Number != "411111111112" { | |
271 | t.Errorf("CreditCard should be updated with Append") | |
272 | } | |
273 | ||
274 | if DB.Model(&user).Association("CreditCard").Count() != 1 { | |
275 | t.Errorf("User's credit card count should be 1") | |
276 | } | |
277 | ||
278 | // Replace | |
279 | var creditcard3 = CreditCard{ | |
280 | Number: "411111111113", | |
281 | } | |
282 | DB.Model(&user).Association("CreditCard").Replace(&creditcard3) | |
283 | ||
284 | if creditcard3.ID == 0 { | |
285 | t.Errorf("Creditcard should has ID when created with Replace") | |
286 | } | |
287 | ||
288 | var creditcard31 CreditCard | |
289 | DB.Model(&user).Related(&creditcard31) | |
290 | if creditcard31.Number != "411111111113" { | |
291 | t.Errorf("CreditCard should be updated with Replace") | |
292 | } | |
293 | ||
294 | if DB.Model(&user).Association("CreditCard").Count() != 1 { | |
295 | t.Errorf("User's credit card count should be 1") | |
296 | } | |
297 | ||
298 | // Delete | |
299 | DB.Model(&user).Association("CreditCard").Delete(&creditcard2) | |
300 | var creditcard4 CreditCard | |
301 | DB.Model(&user).Related(&creditcard4) | |
302 | if creditcard4.Number != "411111111113" { | |
303 | t.Errorf("Should not delete credit card when Delete a unrelated CreditCard") | |
304 | } | |
305 | ||
306 | if DB.Model(&user).Association("CreditCard").Count() != 1 { | |
307 | t.Errorf("User's credit card count should be 1") | |
308 | } | |
309 | ||
310 | DB.Model(&user).Association("CreditCard").Delete(&creditcard3) | |
311 | if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { | |
312 | t.Errorf("Should delete credit card with Delete") | |
313 | } | |
314 | ||
315 | if DB.Model(&user).Association("CreditCard").Count() != 0 { | |
316 | t.Errorf("User's credit card count should be 0 after Delete") | |
317 | } | |
318 | ||
319 | // Clear | |
320 | var creditcard5 = CreditCard{ | |
321 | Number: "411111111115", | |
322 | } | |
323 | DB.Model(&user).Association("CreditCard").Append(&creditcard5) | |
324 | ||
325 | if DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { | |
326 | t.Errorf("Should added credit card with Append") | |
327 | } | |
328 | ||
329 | if DB.Model(&user).Association("CreditCard").Count() != 1 { | |
330 | t.Errorf("User's credit card count should be 1") | |
331 | } | |
332 | ||
333 | DB.Model(&user).Association("CreditCard").Clear() | |
334 | if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { | |
335 | t.Errorf("Credit card should be deleted with Clear") | |
336 | } | |
337 | ||
338 | if DB.Model(&user).Association("CreditCard").Count() != 0 { | |
339 | t.Errorf("User's credit card count should be 0 after Clear") | |
340 | } | |
341 | ||
342 | // Check Association mode with soft delete | |
343 | var creditcard6 = CreditCard{ | |
344 | Number: "411111111116", | |
345 | } | |
346 | DB.Model(&user).Association("CreditCard").Append(&creditcard6) | |
347 | ||
348 | if count := DB.Model(&user).Association("CreditCard").Count(); count != 1 { | |
349 | t.Errorf("User's credit card count should be 1 after Append, but got %v", count) | |
350 | } | |
351 | ||
352 | DB.Delete(&creditcard6) | |
353 | ||
354 | if count := DB.Model(&user).Association("CreditCard").Count(); count != 0 { | |
355 | t.Errorf("User's credit card count should be 0 after credit card deleted, but got %v", count) | |
356 | } | |
357 | ||
358 | if err := DB.Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err == nil { | |
359 | t.Errorf("User's creditcard is not findable after Delete") | |
360 | } | |
361 | ||
362 | if count := DB.Unscoped().Model(&user).Association("CreditCard").Count(); count != 1 { | |
363 | t.Errorf("User's credit card count should be 1 when query with Unscoped, but got %v", count) | |
364 | } | |
365 | ||
366 | if err := DB.Unscoped().Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err != nil { | |
367 | t.Errorf("User's creditcard should be findable when query with Unscoped, got %v", err) | |
368 | } | |
369 | } | |
370 | ||
371 | func TestHasOneOverrideForeignKey1(t *testing.T) { | |
372 | type Profile struct { | |
373 | gorm.Model | |
374 | Name string | |
375 | UserRefer uint | |
376 | } | |
377 | ||
378 | type User struct { | |
379 | gorm.Model | |
380 | Profile Profile `gorm:"ForeignKey:UserRefer"` | |
381 | } | |
382 | ||
383 | if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { | |
384 | if relation.Relationship.Kind != "has_one" || | |
385 | !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || | |
386 | !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { | |
387 | t.Errorf("Override belongs to foreign key with tag") | |
388 | } | |
389 | } | |
390 | } | |
391 | ||
392 | func TestHasOneOverrideForeignKey2(t *testing.T) { | |
393 | type Profile struct { | |
394 | gorm.Model | |
395 | Name string | |
396 | UserID uint | |
397 | } | |
398 | ||
399 | type User struct { | |
400 | gorm.Model | |
401 | Refer string | |
402 | Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` | |
403 | } | |
404 | ||
405 | if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { | |
406 | if relation.Relationship.Kind != "has_one" || | |
407 | !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || | |
408 | !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { | |
409 | t.Errorf("Override belongs to foreign key with tag") | |
410 | } | |
411 | } | |
412 | } | |
413 | ||
414 | func TestHasMany(t *testing.T) { | |
415 | post := Post{ | |
416 | Title: "post has many", | |
417 | Body: "body has many", | |
418 | Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}}, | |
419 | } | |
420 | ||
421 | if err := DB.Save(&post).Error; err != nil { | |
422 | t.Error("Got errors when save post", err) | |
423 | } | |
424 | ||
425 | for _, comment := range post.Comments { | |
426 | if comment.PostId == 0 { | |
427 | t.Errorf("comment's PostID should be updated") | |
428 | } | |
429 | } | |
430 | ||
431 | var compareComments = func(comments []Comment, contents []string) bool { | |
432 | var commentContents []string | |
433 | for _, comment := range comments { | |
434 | commentContents = append(commentContents, comment.Content) | |
435 | } | |
436 | sort.Strings(commentContents) | |
437 | sort.Strings(contents) | |
438 | return reflect.DeepEqual(commentContents, contents) | |
439 | } | |
440 | ||
441 | // Query | |
34 | 442 | if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil { |
35 | 443 | t.Errorf("Comment 1 should be saved") |
36 | 444 | } |
37 | if post.Comments[0].PostId == 0 { | |
38 | t.Errorf("Comment Should have post id") | |
39 | } | |
40 | ||
41 | var comment Comment | |
42 | if DB.First(&comment, "content = ?", "Comment 2").Error != nil { | |
43 | t.Errorf("Comment 2 should be saved") | |
44 | } | |
45 | ||
46 | if comment.PostId == 0 { | |
47 | t.Errorf("Comment 2 Should have post id") | |
48 | } | |
49 | ||
50 | comment3 := Comment{Content: "Comment 3", Post: Post{Title: "Title 3", Body: "Body 3"}} | |
51 | DB.Save(&comment3) | |
445 | ||
446 | var comments1 []Comment | |
447 | DB.Model(&post).Association("Comments").Find(&comments1) | |
448 | if !compareComments(comments1, []string{"Comment 1", "Comment 2"}) { | |
449 | t.Errorf("Query has many relations with Association") | |
450 | } | |
451 | ||
452 | var comments11 []Comment | |
453 | DB.Model(&post).Related(&comments11) | |
454 | if !compareComments(comments11, []string{"Comment 1", "Comment 2"}) { | |
455 | t.Errorf("Query has many relations with Related") | |
456 | } | |
457 | ||
458 | if DB.Model(&post).Association("Comments").Count() != 2 { | |
459 | t.Errorf("Post's comments count should be 2") | |
460 | } | |
461 | ||
462 | // Append | |
463 | DB.Model(&post).Association("Comments").Append(&Comment{Content: "Comment 3"}) | |
464 | ||
465 | var comments2 []Comment | |
466 | DB.Model(&post).Related(&comments2) | |
467 | if !compareComments(comments2, []string{"Comment 1", "Comment 2", "Comment 3"}) { | |
468 | t.Errorf("Append new record to has many relations") | |
469 | } | |
470 | ||
471 | if DB.Model(&post).Association("Comments").Count() != 3 { | |
472 | t.Errorf("Post's comments count should be 3 after Append") | |
473 | } | |
474 | ||
475 | // Delete | |
476 | DB.Model(&post).Association("Comments").Delete(comments11) | |
477 | ||
478 | var comments3 []Comment | |
479 | DB.Model(&post).Related(&comments3) | |
480 | if !compareComments(comments3, []string{"Comment 3"}) { | |
481 | t.Errorf("Delete an existing resource for has many relations") | |
482 | } | |
483 | ||
484 | if DB.Model(&post).Association("Comments").Count() != 1 { | |
485 | t.Errorf("Post's comments count should be 1 after Delete 2") | |
486 | } | |
487 | ||
488 | // Replace | |
489 | DB.Model(&Post{Id: 999}).Association("Comments").Replace() | |
490 | ||
491 | var comments4 []Comment | |
492 | DB.Model(&post).Related(&comments4) | |
493 | if len(comments4) == 0 { | |
494 | t.Errorf("Replace for other resource should not clear all comments") | |
495 | } | |
496 | ||
497 | DB.Model(&post).Association("Comments").Replace(&Comment{Content: "Comment 4"}, &Comment{Content: "Comment 5"}) | |
498 | ||
499 | var comments41 []Comment | |
500 | DB.Model(&post).Related(&comments41) | |
501 | if !compareComments(comments41, []string{"Comment 4", "Comment 5"}) { | |
502 | t.Errorf("Replace has many relations") | |
503 | } | |
504 | ||
505 | // Clear | |
506 | DB.Model(&Post{Id: 999}).Association("Comments").Clear() | |
507 | ||
508 | var comments5 []Comment | |
509 | DB.Model(&post).Related(&comments5) | |
510 | if len(comments5) == 0 { | |
511 | t.Errorf("Clear should not clear all comments") | |
512 | } | |
513 | ||
514 | DB.Model(&post).Association("Comments").Clear() | |
515 | ||
516 | var comments51 []Comment | |
517 | DB.Model(&post).Related(&comments51) | |
518 | if len(comments51) != 0 { | |
519 | t.Errorf("Clear has many relations") | |
520 | } | |
521 | ||
522 | // Check Association mode with soft delete | |
523 | var comment6 = Comment{ | |
524 | Content: "comment 6", | |
525 | } | |
526 | DB.Model(&post).Association("Comments").Append(&comment6) | |
527 | ||
528 | if count := DB.Model(&post).Association("Comments").Count(); count != 1 { | |
529 | t.Errorf("post's comments count should be 1 after Append, but got %v", count) | |
530 | } | |
531 | ||
532 | DB.Delete(&comment6) | |
533 | ||
534 | if count := DB.Model(&post).Association("Comments").Count(); count != 0 { | |
535 | t.Errorf("post's comments count should be 0 after comment been deleted, but got %v", count) | |
536 | } | |
537 | ||
538 | var comments6 []Comment | |
539 | if DB.Model(&post).Association("Comments").Find(&comments6); len(comments6) != 0 { | |
540 | t.Errorf("post's comments count should be 0 when find with Find, but got %v", len(comments6)) | |
541 | } | |
542 | ||
543 | if count := DB.Unscoped().Model(&post).Association("Comments").Count(); count != 1 { | |
544 | t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", count) | |
545 | } | |
546 | ||
547 | var comments61 []Comment | |
548 | if DB.Unscoped().Model(&post).Association("Comments").Find(&comments61); len(comments61) != 1 { | |
549 | t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", len(comments61)) | |
550 | } | |
551 | } | |
552 | ||
553 | func TestHasManyOverrideForeignKey1(t *testing.T) { | |
554 | type Profile struct { | |
555 | gorm.Model | |
556 | Name string | |
557 | UserRefer uint | |
558 | } | |
559 | ||
560 | type User struct { | |
561 | gorm.Model | |
562 | Profile []Profile `gorm:"ForeignKey:UserRefer"` | |
563 | } | |
564 | ||
565 | if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { | |
566 | if relation.Relationship.Kind != "has_many" || | |
567 | !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || | |
568 | !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { | |
569 | t.Errorf("Override belongs to foreign key with tag") | |
570 | } | |
571 | } | |
572 | } | |
573 | ||
574 | func TestHasManyOverrideForeignKey2(t *testing.T) { | |
575 | type Profile struct { | |
576 | gorm.Model | |
577 | Name string | |
578 | UserID uint | |
579 | } | |
580 | ||
581 | type User struct { | |
582 | gorm.Model | |
583 | Refer string | |
584 | Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` | |
585 | } | |
586 | ||
587 | if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { | |
588 | if relation.Relationship.Kind != "has_many" || | |
589 | !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || | |
590 | !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { | |
591 | t.Errorf("Override belongs to foreign key with tag") | |
592 | } | |
593 | } | |
594 | } | |
595 | ||
596 | func TestManyToMany(t *testing.T) { | |
597 | DB.Raw("delete from languages") | |
598 | var languages = []Language{{Name: "ZH"}, {Name: "EN"}} | |
599 | user := User{Name: "Many2Many", Languages: languages} | |
600 | DB.Save(&user) | |
601 | ||
602 | // Query | |
603 | var newLanguages []Language | |
604 | DB.Model(&user).Related(&newLanguages, "Languages") | |
605 | if len(newLanguages) != len([]string{"ZH", "EN"}) { | |
606 | t.Errorf("Query many to many relations") | |
607 | } | |
608 | ||
609 | DB.Model(&user).Association("Languages").Find(&newLanguages) | |
610 | if len(newLanguages) != len([]string{"ZH", "EN"}) { | |
611 | t.Errorf("Should be able to find many to many relations") | |
612 | } | |
613 | ||
614 | if DB.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) { | |
615 | t.Errorf("Count should return correct result") | |
616 | } | |
617 | ||
618 | // Append | |
619 | DB.Model(&user).Association("Languages").Append(&Language{Name: "DE"}) | |
620 | if DB.Where("name = ?", "DE").First(&Language{}).RecordNotFound() { | |
621 | t.Errorf("New record should be saved when append") | |
622 | } | |
623 | ||
624 | languageA := Language{Name: "AA"} | |
625 | DB.Save(&languageA) | |
626 | DB.Model(&User{Id: user.Id}).Association("Languages").Append(&languageA) | |
627 | ||
628 | languageC := Language{Name: "CC"} | |
629 | DB.Save(&languageC) | |
630 | DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC}) | |
631 | ||
632 | DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}}) | |
633 | ||
634 | totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"} | |
635 | ||
636 | if DB.Model(&user).Association("Languages").Count() != len(totalLanguages) { | |
637 | t.Errorf("All appended languages should be saved") | |
638 | } | |
639 | ||
640 | // Delete | |
641 | user.Languages = []Language{} | |
642 | DB.Model(&user).Association("Languages").Find(&user.Languages) | |
643 | ||
644 | var language Language | |
645 | DB.Where("name = ?", "EE").First(&language) | |
646 | DB.Model(&user).Association("Languages").Delete(language, &language) | |
647 | ||
648 | if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 { | |
649 | t.Errorf("Relations should be deleted with Delete") | |
650 | } | |
651 | if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() { | |
652 | t.Errorf("Language EE should not be deleted") | |
653 | } | |
654 | ||
655 | DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages) | |
656 | ||
657 | user2 := User{Name: "Many2Many_User2", Languages: languages} | |
658 | DB.Save(&user2) | |
659 | ||
660 | DB.Model(&user).Association("Languages").Delete(languages, &languages) | |
661 | if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 || len(user.Languages) != len(totalLanguages)-3 { | |
662 | t.Errorf("Relations should be deleted with Delete") | |
663 | } | |
664 | ||
665 | if DB.Model(&user2).Association("Languages").Count() == 0 { | |
666 | t.Errorf("Other user's relations should not be deleted") | |
667 | } | |
668 | ||
669 | // Replace | |
670 | var languageB Language | |
671 | DB.Where("name = ?", "BB").First(&languageB) | |
672 | DB.Model(&user).Association("Languages").Replace(languageB) | |
673 | if len(user.Languages) != 1 || DB.Model(&user).Association("Languages").Count() != 1 { | |
674 | t.Errorf("Relations should be replaced") | |
675 | } | |
676 | ||
677 | DB.Model(&user).Association("Languages").Replace() | |
678 | if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { | |
679 | t.Errorf("Relations should be replaced with empty") | |
680 | } | |
681 | ||
682 | DB.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}}) | |
683 | if len(user.Languages) != 2 || DB.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) { | |
684 | t.Errorf("Relations should be replaced") | |
685 | } | |
686 | ||
687 | // Clear | |
688 | DB.Model(&user).Association("Languages").Clear() | |
689 | if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { | |
690 | t.Errorf("Relations should be cleared") | |
691 | } | |
692 | ||
693 | // Check Association mode with soft delete | |
694 | var language6 = Language{ | |
695 | Name: "language 6", | |
696 | } | |
697 | DB.Model(&user).Association("Languages").Append(&language6) | |
698 | ||
699 | if count := DB.Model(&user).Association("Languages").Count(); count != 1 { | |
700 | t.Errorf("user's languages count should be 1 after Append, but got %v", count) | |
701 | } | |
702 | ||
703 | DB.Delete(&language6) | |
704 | ||
705 | if count := DB.Model(&user).Association("Languages").Count(); count != 0 { | |
706 | t.Errorf("user's languages count should be 0 after language been deleted, but got %v", count) | |
707 | } | |
708 | ||
709 | var languages6 []Language | |
710 | if DB.Model(&user).Association("Languages").Find(&languages6); len(languages6) != 0 { | |
711 | t.Errorf("user's languages count should be 0 when find with Find, but got %v", len(languages6)) | |
712 | } | |
713 | ||
714 | if count := DB.Unscoped().Model(&user).Association("Languages").Count(); count != 1 { | |
715 | t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", count) | |
716 | } | |
717 | ||
718 | var languages61 []Language | |
719 | if DB.Unscoped().Model(&user).Association("Languages").Find(&languages61); len(languages61) != 1 { | |
720 | t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", len(languages61)) | |
721 | } | |
52 | 722 | } |
53 | 723 | |
54 | 724 | func TestRelated(t *testing.T) { |
61 | 731 | Company: Company{Name: "company1"}, |
62 | 732 | } |
63 | 733 | |
64 | DB.Save(&user) | |
734 | if err := DB.Save(&user).Error; err != nil { | |
735 | t.Errorf("No error should happen when saving user") | |
736 | } | |
65 | 737 | |
66 | 738 | if user.CreditCard.ID == 0 { |
67 | 739 | t.Errorf("After user save, credit card should have id") |
84 | 756 | var emails2 []Email |
85 | 757 | DB.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2) |
86 | 758 | if len(emails2) != 1 { |
759 | t.Errorf("Should have two emails") | |
760 | } | |
761 | ||
762 | var emails3 []*Email | |
763 | DB.Model(&user).Related(&emails3) | |
764 | if len(emails3) != 2 { | |
87 | 765 | t.Errorf("Should have two emails") |
88 | 766 | } |
89 | 767 | |
129 | 807 | } |
130 | 808 | } |
131 | 809 | |
132 | func TestManyToMany(t *testing.T) { | |
133 | DB.Raw("delete from languages") | |
134 | var languages = []Language{{Name: "ZH"}, {Name: "EN"}} | |
135 | user := User{Name: "Many2Many", Languages: languages} | |
136 | DB.Save(&user) | |
137 | ||
138 | // Query | |
139 | var newLanguages []Language | |
140 | DB.Model(&user).Related(&newLanguages, "Languages") | |
141 | if len(newLanguages) != len([]string{"ZH", "EN"}) { | |
142 | t.Errorf("Query many to many relations") | |
143 | } | |
144 | ||
145 | DB.Model(&user).Association("Languages").Find(&newLanguages) | |
146 | if len(newLanguages) != len([]string{"ZH", "EN"}) { | |
147 | t.Errorf("Should be able to find many to many relations") | |
148 | } | |
149 | ||
150 | if DB.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) { | |
151 | t.Errorf("Count should return correct result") | |
152 | } | |
153 | ||
154 | // Append | |
155 | DB.Model(&user).Association("Languages").Append(&Language{Name: "DE"}) | |
156 | if DB.Where("name = ?", "DE").First(&Language{}).RecordNotFound() { | |
157 | t.Errorf("New record should be saved when append") | |
158 | } | |
159 | ||
160 | languageA := Language{Name: "AA"} | |
161 | DB.Save(&languageA) | |
162 | DB.Model(&User{Id: user.Id}).Association("Languages").Append(&languageA) | |
163 | ||
164 | languageC := Language{Name: "CC"} | |
165 | DB.Save(&languageC) | |
166 | DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC}) | |
167 | ||
168 | DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}}) | |
169 | ||
170 | totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"} | |
171 | ||
172 | if DB.Model(&user).Association("Languages").Count() != len(totalLanguages) { | |
173 | t.Errorf("All appended languages should be saved") | |
174 | } | |
175 | ||
176 | // Delete | |
177 | user.Languages = []Language{} | |
178 | DB.Model(&user).Association("Languages").Find(&user.Languages) | |
179 | ||
180 | var language Language | |
181 | DB.Where("name = ?", "EE").First(&language) | |
182 | DB.Model(&user).Association("Languages").Delete(language, &language) | |
183 | ||
184 | if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 { | |
185 | t.Errorf("Relations should be deleted with Delete") | |
186 | } | |
187 | if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() { | |
188 | t.Errorf("Language EE should not be deleted") | |
189 | } | |
190 | ||
191 | DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages) | |
192 | ||
193 | user2 := User{Name: "Many2Many_User2", Languages: languages} | |
194 | DB.Save(&user2) | |
195 | ||
196 | DB.Model(&user).Association("Languages").Delete(languages, &languages) | |
197 | if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 || len(user.Languages) != len(totalLanguages)-3 { | |
198 | t.Errorf("Relations should be deleted with Delete") | |
199 | } | |
200 | ||
201 | if DB.Model(&user2).Association("Languages").Count() == 0 { | |
202 | t.Errorf("Other user's relations should not be deleted") | |
203 | } | |
204 | ||
205 | // Replace | |
206 | var languageB Language | |
207 | DB.Where("name = ?", "BB").First(&languageB) | |
208 | DB.Model(&user).Association("Languages").Replace(languageB) | |
209 | if len(user.Languages) != 1 || DB.Model(&user).Association("Languages").Count() != 1 { | |
210 | t.Errorf("Relations should be replaced") | |
211 | } | |
212 | ||
213 | DB.Model(&user).Association("Languages").Replace() | |
214 | if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { | |
215 | t.Errorf("Relations should be replaced with empty") | |
216 | } | |
217 | ||
218 | DB.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}}) | |
219 | if len(user.Languages) != 2 || DB.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) { | |
220 | t.Errorf("Relations should be replaced") | |
221 | } | |
222 | ||
223 | // Clear | |
224 | DB.Model(&user).Association("Languages").Clear() | |
225 | if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { | |
226 | t.Errorf("Relations should be cleared") | |
227 | } | |
228 | } | |
229 | ||
230 | 810 | func TestForeignKey(t *testing.T) { |
231 | 811 | for _, structField := range DB.NewScope(&User{}).GetStructFields() { |
232 | 812 | for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} { |
260 | 840 | } |
261 | 841 | } |
262 | 842 | } |
843 | ||
844 | func testForeignKey(t *testing.T, source interface{}, sourceFieldName string, target interface{}, targetFieldName string) { | |
845 | if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { | |
846 | // sqlite does not support ADD CONSTRAINT in ALTER TABLE | |
847 | return | |
848 | } | |
849 | targetScope := DB.NewScope(target) | |
850 | targetTableName := targetScope.TableName() | |
851 | modelScope := DB.NewScope(source) | |
852 | modelField, ok := modelScope.FieldByName(sourceFieldName) | |
853 | if !ok { | |
854 | t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", sourceFieldName)) | |
855 | } | |
856 | targetField, ok := targetScope.FieldByName(targetFieldName) | |
857 | if !ok { | |
858 | t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", targetFieldName)) | |
859 | } | |
860 | dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName) | |
861 | err := DB.Model(source).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error | |
862 | if err != nil { | |
863 | t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err)) | |
864 | } | |
865 | } | |
866 | ||
867 | func TestLongForeignKey(t *testing.T) { | |
868 | testForeignKey(t, &NotSoLongTableName{}, "ReallyLongThingID", &ReallyLongTableNameToTestMySQLNameLengthLimit{}, "ID") | |
869 | } | |
870 | ||
871 | func TestLongForeignKeyWithShortDest(t *testing.T) { | |
872 | testForeignKey(t, &ReallyLongThingThatReferencesShort{}, "ShortID", &Short{}, "ID") | |
873 | } | |
874 | ||
875 | func TestHasManyChildrenWithOneStruct(t *testing.T) { | |
876 | category := Category{ | |
877 | Name: "main", | |
878 | Categories: []Category{ | |
879 | {Name: "sub1"}, | |
880 | {Name: "sub2"}, | |
881 | }, | |
882 | } | |
883 | ||
884 | DB.Save(&category) | |
885 | } | |
886 | ||
887 | func TestAutoSaveBelongsToAssociation(t *testing.T) { | |
888 | type Company struct { | |
889 | gorm.Model | |
890 | Name string | |
891 | } | |
892 | ||
893 | type User struct { | |
894 | gorm.Model | |
895 | Name string | |
896 | CompanyID uint | |
897 | Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` | |
898 | } | |
899 | ||
900 | DB.Where("name = ?", "auto_save_association").Delete(&Company{}) | |
901 | DB.AutoMigrate(&Company{}, &User{}) | |
902 | ||
903 | DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_association"}}) | |
904 | ||
905 | if !DB.Where("name = ?", "auto_save_association").First(&Company{}).RecordNotFound() { | |
906 | t.Errorf("Company auto_save_association should not have been saved when autosave is false") | |
907 | } | |
908 | ||
909 | // if foreign key is set, this should be saved even if association isn't | |
910 | company := Company{Name: "auto_save_association"} | |
911 | DB.Save(&company) | |
912 | ||
913 | company.Name = "auto_save_association_new_name" | |
914 | user := User{Name: "jinzhu", Company: company} | |
915 | ||
916 | DB.Save(&user) | |
917 | ||
918 | if !DB.Where("name = ?", "auto_save_association_new_name").First(&Company{}).RecordNotFound() { | |
919 | t.Errorf("Company should not have been updated") | |
920 | } | |
921 | ||
922 | if DB.Where("id = ? AND company_id = ?", user.ID, company.ID).First(&User{}).RecordNotFound() { | |
923 | t.Errorf("User's foreign key should have been saved") | |
924 | } | |
925 | ||
926 | user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_association_2"}} | |
927 | DB.Set("gorm:association_autocreate", true).Save(&user2) | |
928 | if DB.Where("name = ?", "auto_save_association_2").First(&Company{}).RecordNotFound() { | |
929 | t.Errorf("Company auto_save_association_2 should been created when autocreate is true") | |
930 | } | |
931 | ||
932 | user2.Company.Name = "auto_save_association_2_newname" | |
933 | DB.Set("gorm:association_autoupdate", true).Save(&user2) | |
934 | ||
935 | if DB.Where("name = ?", "auto_save_association_2_newname").First(&Company{}).RecordNotFound() { | |
936 | t.Errorf("Company should been updated") | |
937 | } | |
938 | } | |
939 | ||
940 | func TestAutoSaveHasOneAssociation(t *testing.T) { | |
941 | type Company struct { | |
942 | gorm.Model | |
943 | UserID uint | |
944 | Name string | |
945 | } | |
946 | ||
947 | type User struct { | |
948 | gorm.Model | |
949 | Name string | |
950 | Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` | |
951 | } | |
952 | ||
953 | DB.Where("name = ?", "auto_save_has_one_association").Delete(&Company{}) | |
954 | DB.AutoMigrate(&Company{}, &User{}) | |
955 | ||
956 | DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_has_one_association"}}) | |
957 | ||
958 | if !DB.Where("name = ?", "auto_save_has_one_association").First(&Company{}).RecordNotFound() { | |
959 | t.Errorf("Company auto_save_has_one_association should not have been saved when autosave is false") | |
960 | } | |
961 | ||
962 | company := Company{Name: "auto_save_has_one_association"} | |
963 | DB.Save(&company) | |
964 | ||
965 | company.Name = "auto_save_has_one_association_new_name" | |
966 | user := User{Name: "jinzhu", Company: company} | |
967 | ||
968 | DB.Save(&user) | |
969 | ||
970 | if !DB.Where("name = ?", "auto_save_has_one_association_new_name").First(&Company{}).RecordNotFound() { | |
971 | t.Errorf("Company should not have been updated") | |
972 | } | |
973 | ||
974 | if !DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association", user.ID).First(&Company{}).RecordNotFound() { | |
975 | t.Errorf("Company should not have been updated") | |
976 | } | |
977 | ||
978 | if user.Company.UserID == 0 { | |
979 | t.Errorf("UserID should be assigned") | |
980 | } | |
981 | ||
982 | company.Name = "auto_save_has_one_association_2_new_name" | |
983 | DB.Set("gorm:association_autoupdate", true).Save(&user) | |
984 | ||
985 | if DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association_new_name", user.ID).First(&Company{}).RecordNotFound() { | |
986 | t.Errorf("Company should been updated") | |
987 | } | |
988 | ||
989 | user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_has_one_association_2"}} | |
990 | DB.Set("gorm:association_autocreate", true).Save(&user2) | |
991 | if DB.Where("name = ?", "auto_save_has_one_association_2").First(&Company{}).RecordNotFound() { | |
992 | t.Errorf("Company auto_save_has_one_association_2 should been created when autocreate is true") | |
993 | } | |
994 | } | |
995 | ||
996 | func TestAutoSaveMany2ManyAssociation(t *testing.T) { | |
997 | type Company struct { | |
998 | gorm.Model | |
999 | Name string | |
1000 | } | |
1001 | ||
1002 | type User struct { | |
1003 | gorm.Model | |
1004 | Name string | |
1005 | Companies []Company `gorm:"many2many:user_companies;association_autoupdate:false;association_autocreate:false;"` | |
1006 | } | |
1007 | ||
1008 | DB.AutoMigrate(&Company{}, &User{}) | |
1009 | ||
1010 | DB.Save(&User{Name: "jinzhu", Companies: []Company{{Name: "auto_save_m2m_association"}}}) | |
1011 | ||
1012 | if !DB.Where("name = ?", "auto_save_m2m_association").First(&Company{}).RecordNotFound() { | |
1013 | t.Errorf("Company auto_save_m2m_association should not have been saved when autosave is false") | |
1014 | } | |
1015 | ||
1016 | company := Company{Name: "auto_save_m2m_association"} | |
1017 | DB.Save(&company) | |
1018 | ||
1019 | company.Name = "auto_save_m2m_association_new_name" | |
1020 | user := User{Name: "jinzhu", Companies: []Company{company, {Name: "auto_save_m2m_association_new_name_2"}}} | |
1021 | ||
1022 | DB.Save(&user) | |
1023 | ||
1024 | if !DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { | |
1025 | t.Errorf("Company should not have been updated") | |
1026 | } | |
1027 | ||
1028 | if !DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { | |
1029 | t.Errorf("Company should not been created") | |
1030 | } | |
1031 | ||
1032 | if DB.Model(&user).Association("Companies").Count() != 1 { | |
1033 | t.Errorf("Relationship should been saved") | |
1034 | } | |
1035 | ||
1036 | DB.Set("gorm:association_autoupdate", true).Set("gorm:association_autocreate", true).Save(&user) | |
1037 | ||
1038 | if DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { | |
1039 | t.Errorf("Company should been updated") | |
1040 | } | |
1041 | ||
1042 | if DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { | |
1043 | t.Errorf("Company should been created") | |
1044 | } | |
1045 | ||
1046 | if DB.Model(&user).Association("Companies").Count() != 2 { | |
1047 | t.Errorf("Relationship should been updated") | |
1048 | } | |
1049 | } |
0 | 0 | package gorm |
1 | 1 | |
2 | import ( | |
3 | "fmt" | |
4 | ) | |
5 | ||
6 | type callback struct { | |
2 | import "log" | |
3 | ||
4 | // DefaultCallback default callbacks defined by gorm | |
5 | var DefaultCallback = &Callback{} | |
6 | ||
7 | // Callback is a struct that contains all CRUD callbacks | |
8 | // Field `creates` contains callbacks will be call when creating object | |
9 | // Field `updates` contains callbacks will be call when updating object | |
10 | // Field `deletes` contains callbacks will be call when deleting object | |
11 | // Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association... | |
12 | // Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... | |
13 | // Field `processors` contains all callback processors, will be used to generate above callbacks in order | |
14 | type Callback struct { | |
7 | 15 | creates []*func(scope *Scope) |
8 | 16 | updates []*func(scope *Scope) |
9 | 17 | deletes []*func(scope *Scope) |
10 | 18 | queries []*func(scope *Scope) |
11 | 19 | rowQueries []*func(scope *Scope) |
12 | processors []*callbackProcessor | |
13 | } | |
14 | ||
15 | type callbackProcessor struct { | |
16 | name string | |
17 | before string | |
18 | after string | |
19 | replace bool | |
20 | remove bool | |
21 | typ string | |
22 | processor *func(scope *Scope) | |
23 | callback *callback | |
24 | } | |
25 | ||
26 | func (c *callback) addProcessor(typ string) *callbackProcessor { | |
27 | cp := &callbackProcessor{typ: typ, callback: c} | |
28 | c.processors = append(c.processors, cp) | |
29 | return cp | |
30 | } | |
31 | ||
32 | func (c *callback) clone() *callback { | |
33 | return &callback{ | |
20 | processors []*CallbackProcessor | |
21 | } | |
22 | ||
23 | // CallbackProcessor contains callback informations | |
24 | type CallbackProcessor struct { | |
25 | name string // current callback's name | |
26 | before string // register current callback before a callback | |
27 | after string // register current callback after a callback | |
28 | replace bool // replace callbacks with same name | |
29 | remove bool // delete callbacks with same name | |
30 | kind string // callback type: create, update, delete, query, row_query | |
31 | processor *func(scope *Scope) // callback handler | |
32 | parent *Callback | |
33 | } | |
34 | ||
35 | func (c *Callback) clone() *Callback { | |
36 | return &Callback{ | |
34 | 37 | creates: c.creates, |
35 | 38 | updates: c.updates, |
36 | 39 | deletes: c.deletes, |
37 | 40 | queries: c.queries, |
41 | rowQueries: c.rowQueries, | |
38 | 42 | processors: c.processors, |
39 | 43 | } |
40 | 44 | } |
41 | 45 | |
42 | func (c *callback) Create() *callbackProcessor { | |
43 | return c.addProcessor("create") | |
44 | } | |
45 | ||
46 | func (c *callback) Update() *callbackProcessor { | |
47 | return c.addProcessor("update") | |
48 | } | |
49 | ||
50 | func (c *callback) Delete() *callbackProcessor { | |
51 | return c.addProcessor("delete") | |
52 | } | |
53 | ||
54 | func (c *callback) Query() *callbackProcessor { | |
55 | return c.addProcessor("query") | |
56 | } | |
57 | ||
58 | func (c *callback) RowQuery() *callbackProcessor { | |
59 | return c.addProcessor("row_query") | |
60 | } | |
61 | ||
62 | func (cp *callbackProcessor) Before(name string) *callbackProcessor { | |
63 | cp.before = name | |
46 | // Create could be used to register callbacks for creating object | |
47 | // db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) { | |
48 | // // business logic | |
49 | // ... | |
50 | // | |
51 | // // set error if some thing wrong happened, will rollback the creating | |
52 | // scope.Err(errors.New("error")) | |
53 | // }) | |
54 | func (c *Callback) Create() *CallbackProcessor { | |
55 | return &CallbackProcessor{kind: "create", parent: c} | |
56 | } | |
57 | ||
58 | // Update could be used to register callbacks for updating object, refer `Create` for usage | |
59 | func (c *Callback) Update() *CallbackProcessor { | |
60 | return &CallbackProcessor{kind: "update", parent: c} | |
61 | } | |
62 | ||
63 | // Delete could be used to register callbacks for deleting object, refer `Create` for usage | |
64 | func (c *Callback) Delete() *CallbackProcessor { | |
65 | return &CallbackProcessor{kind: "delete", parent: c} | |
66 | } | |
67 | ||
68 | // Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`... | |
69 | // Refer `Create` for usage | |
70 | func (c *Callback) Query() *CallbackProcessor { | |
71 | return &CallbackProcessor{kind: "query", parent: c} | |
72 | } | |
73 | ||
74 | // RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage | |
75 | func (c *Callback) RowQuery() *CallbackProcessor { | |
76 | return &CallbackProcessor{kind: "row_query", parent: c} | |
77 | } | |
78 | ||
79 | // After insert a new callback after callback `callbackName`, refer `Callbacks.Create` | |
80 | func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor { | |
81 | cp.after = callbackName | |
64 | 82 | return cp |
65 | 83 | } |
66 | 84 | |
67 | func (cp *callbackProcessor) After(name string) *callbackProcessor { | |
68 | cp.after = name | |
85 | // Before insert a new callback before callback `callbackName`, refer `Callbacks.Create` | |
86 | func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { | |
87 | cp.before = callbackName | |
69 | 88 | return cp |
70 | 89 | } |
71 | 90 | |
72 | func (cp *callbackProcessor) Register(name string, fc func(scope *Scope)) { | |
73 | cp.name = name | |
74 | cp.processor = &fc | |
75 | cp.callback.sort() | |
76 | } | |
77 | ||
78 | func (cp *callbackProcessor) Remove(name string) { | |
79 | fmt.Printf("[info] removing callback `%v` from %v\n", name, fileWithLineNum()) | |
80 | cp.name = name | |
91 | // Register a new callback, refer `Callbacks.Create` | |
92 | func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { | |
93 | if cp.kind == "row_query" { | |
94 | if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { | |
95 | log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName) | |
96 | cp.before = "gorm:row_query" | |
97 | } | |
98 | } | |
99 | ||
100 | cp.name = callbackName | |
101 | cp.processor = &callback | |
102 | cp.parent.processors = append(cp.parent.processors, cp) | |
103 | cp.parent.reorder() | |
104 | } | |
105 | ||
106 | // Remove a registered callback | |
107 | // db.Callback().Create().Remove("gorm:update_time_stamp_when_create") | |
108 | func (cp *CallbackProcessor) Remove(callbackName string) { | |
109 | log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()) | |
110 | cp.name = callbackName | |
81 | 111 | cp.remove = true |
82 | cp.callback.sort() | |
83 | } | |
84 | ||
85 | func (cp *callbackProcessor) Replace(name string, fc func(scope *Scope)) { | |
86 | fmt.Printf("[info] replacing callback `%v` from %v\n", name, fileWithLineNum()) | |
87 | cp.name = name | |
88 | cp.processor = &fc | |
112 | cp.parent.processors = append(cp.parent.processors, cp) | |
113 | cp.parent.reorder() | |
114 | } | |
115 | ||
116 | // Replace a registered callback with new callback | |
117 | // db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) { | |
118 | // scope.SetColumn("Created", now) | |
119 | // scope.SetColumn("Updated", now) | |
120 | // }) | |
121 | func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { | |
122 | log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()) | |
123 | cp.name = callbackName | |
124 | cp.processor = &callback | |
89 | 125 | cp.replace = true |
90 | cp.callback.sort() | |
91 | } | |
92 | ||
126 | cp.parent.processors = append(cp.parent.processors, cp) | |
127 | cp.parent.reorder() | |
128 | } | |
129 | ||
130 | // Get registered callback | |
131 | // db.Callback().Create().Get("gorm:create") | |
132 | func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { | |
133 | for _, p := range cp.parent.processors { | |
134 | if p.name == callbackName && p.kind == cp.kind && !cp.remove { | |
135 | return *p.processor | |
136 | } | |
137 | } | |
138 | return nil | |
139 | } | |
140 | ||
141 | // getRIndex get right index from string slice | |
93 | 142 | func getRIndex(strs []string, str string) int { |
94 | 143 | for i := len(strs) - 1; i >= 0; i-- { |
95 | 144 | if strs[i] == str { |
99 | 148 | return -1 |
100 | 149 | } |
101 | 150 | |
102 | func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) { | |
103 | var sortCallbackProcessor func(c *callbackProcessor) | |
104 | var names, sortedNames = []string{}, []string{} | |
151 | // sortProcessors sort callback processors based on its before, after, remove, replace | |
152 | func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { | |
153 | var ( | |
154 | allNames, sortedNames []string | |
155 | sortCallbackProcessor func(c *CallbackProcessor) | |
156 | ) | |
105 | 157 | |
106 | 158 | for _, cp := range cps { |
107 | if index := getRIndex(names, cp.name); index > -1 { | |
108 | if !cp.replace && !cp.remove { | |
109 | fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) | |
110 | } | |
111 | } | |
112 | names = append(names, cp.name) | |
113 | } | |
114 | ||
115 | sortCallbackProcessor = func(c *callbackProcessor) { | |
116 | if getRIndex(sortedNames, c.name) > -1 { | |
117 | return | |
118 | } | |
119 | ||
120 | if len(c.before) > 0 { | |
121 | if index := getRIndex(sortedNames, c.before); index > -1 { | |
122 | sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) | |
123 | } else if index := getRIndex(names, c.before); index > -1 { | |
159 | // show warning message the callback name already exists | |
160 | if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { | |
161 | log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) | |
162 | } | |
163 | allNames = append(allNames, cp.name) | |
164 | } | |
165 | ||
166 | sortCallbackProcessor = func(c *CallbackProcessor) { | |
167 | if getRIndex(sortedNames, c.name) == -1 { // if not sorted | |
168 | if c.before != "" { // if defined before callback | |
169 | if index := getRIndex(sortedNames, c.before); index != -1 { | |
170 | // if before callback already sorted, append current callback just after it | |
171 | sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) | |
172 | } else if index := getRIndex(allNames, c.before); index != -1 { | |
173 | // if before callback exists but haven't sorted, append current callback to last | |
174 | sortedNames = append(sortedNames, c.name) | |
175 | sortCallbackProcessor(cps[index]) | |
176 | } | |
177 | } | |
178 | ||
179 | if c.after != "" { // if defined after callback | |
180 | if index := getRIndex(sortedNames, c.after); index != -1 { | |
181 | // if after callback already sorted, append current callback just before it | |
182 | sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) | |
183 | } else if index := getRIndex(allNames, c.after); index != -1 { | |
184 | // if after callback exists but haven't sorted | |
185 | cp := cps[index] | |
186 | // set after callback's before callback to current callback | |
187 | if cp.before == "" { | |
188 | cp.before = c.name | |
189 | } | |
190 | sortCallbackProcessor(cp) | |
191 | } | |
192 | } | |
193 | ||
194 | // if current callback haven't been sorted, append it to last | |
195 | if getRIndex(sortedNames, c.name) == -1 { | |
124 | 196 | sortedNames = append(sortedNames, c.name) |
125 | sortCallbackProcessor(cps[index]) | |
126 | } else { | |
127 | sortedNames = append(sortedNames, c.name) | |
128 | } | |
129 | } | |
130 | ||
131 | if len(c.after) > 0 { | |
132 | if index := getRIndex(sortedNames, c.after); index > -1 { | |
133 | sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) | |
134 | } else if index := getRIndex(names, c.after); index > -1 { | |
135 | cp := cps[index] | |
136 | if len(cp.before) == 0 { | |
137 | cp.before = c.name | |
138 | } | |
139 | sortCallbackProcessor(cp) | |
140 | } else { | |
141 | sortedNames = append(sortedNames, c.name) | |
142 | } | |
143 | } | |
144 | ||
145 | if getRIndex(sortedNames, c.name) == -1 { | |
146 | sortedNames = append(sortedNames, c.name) | |
197 | } | |
147 | 198 | } |
148 | 199 | } |
149 | 200 | |
151 | 202 | sortCallbackProcessor(cp) |
152 | 203 | } |
153 | 204 | |
154 | var funcs = []*func(scope *Scope){} | |
155 | var sortedFuncs = []*func(scope *Scope){} | |
205 | var sortedFuncs []*func(scope *Scope) | |
156 | 206 | for _, name := range sortedNames { |
157 | index := getRIndex(names, name) | |
158 | if !cps[index].remove { | |
207 | if index := getRIndex(allNames, name); !cps[index].remove { | |
159 | 208 | sortedFuncs = append(sortedFuncs, cps[index].processor) |
160 | 209 | } |
161 | 210 | } |
162 | 211 | |
163 | for _, cp := range cps { | |
164 | if sindex := getRIndex(sortedNames, cp.name); sindex == -1 { | |
165 | if !cp.remove { | |
166 | funcs = append(funcs, cp.processor) | |
167 | } | |
168 | } | |
169 | } | |
170 | ||
171 | return append(sortedFuncs, funcs...) | |
172 | } | |
173 | ||
174 | func (c *callback) sort() { | |
175 | var creates, updates, deletes, queries, rowQueries []*callbackProcessor | |
212 | return sortedFuncs | |
213 | } | |
214 | ||
215 | // reorder all registered processors, and reset CRUD callbacks | |
216 | func (c *Callback) reorder() { | |
217 | var creates, updates, deletes, queries, rowQueries []*CallbackProcessor | |
176 | 218 | |
177 | 219 | for _, processor := range c.processors { |
178 | switch processor.typ { | |
179 | case "create": | |
180 | creates = append(creates, processor) | |
181 | case "update": | |
182 | updates = append(updates, processor) | |
183 | case "delete": | |
184 | deletes = append(deletes, processor) | |
185 | case "query": | |
186 | queries = append(queries, processor) | |
187 | case "row_query": | |
188 | rowQueries = append(rowQueries, processor) | |
220 | if processor.name != "" { | |
221 | switch processor.kind { | |
222 | case "create": | |
223 | creates = append(creates, processor) | |
224 | case "update": | |
225 | updates = append(updates, processor) | |
226 | case "delete": | |
227 | deletes = append(deletes, processor) | |
228 | case "query": | |
229 | queries = append(queries, processor) | |
230 | case "row_query": | |
231 | rowQueries = append(rowQueries, processor) | |
232 | } | |
189 | 233 | } |
190 | 234 | } |
191 | 235 | |
195 | 239 | c.queries = sortProcessors(queries) |
196 | 240 | c.rowQueries = sortProcessors(rowQueries) |
197 | 241 | } |
198 | ||
199 | var DefaultCallback = &callback{processors: []*callbackProcessor{}} |
4 | 4 | "strings" |
5 | 5 | ) |
6 | 6 | |
7 | func BeforeCreate(scope *Scope) { | |
8 | scope.CallMethodWithErrorCheck("BeforeSave") | |
9 | scope.CallMethodWithErrorCheck("BeforeCreate") | |
7 | // Define callbacks for creating | |
8 | func init() { | |
9 | DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) | |
10 | DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) | |
11 | DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) | |
12 | DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback) | |
13 | DefaultCallback.Create().Register("gorm:create", createCallback) | |
14 | DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) | |
15 | DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) | |
16 | DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback) | |
17 | DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) | |
10 | 18 | } |
11 | 19 | |
12 | func UpdateTimeStampWhenCreate(scope *Scope) { | |
20 | // beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating | |
21 | func beforeCreateCallback(scope *Scope) { | |
13 | 22 | if !scope.HasError() { |
14 | now := NowFunc() | |
15 | scope.SetColumn("CreatedAt", now) | |
16 | scope.SetColumn("UpdatedAt", now) | |
23 | scope.CallMethod("BeforeSave") | |
24 | } | |
25 | if !scope.HasError() { | |
26 | scope.CallMethod("BeforeCreate") | |
17 | 27 | } |
18 | 28 | } |
19 | 29 | |
20 | func Create(scope *Scope) { | |
21 | defer scope.Trace(NowFunc()) | |
30 | // updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating | |
31 | func updateTimeStampForCreateCallback(scope *Scope) { | |
32 | if !scope.HasError() { | |
33 | now := NowFunc() | |
22 | 34 | |
35 | if createdAtField, ok := scope.FieldByName("CreatedAt"); ok { | |
36 | if createdAtField.IsBlank { | |
37 | createdAtField.Set(now) | |
38 | } | |
39 | } | |
40 | ||
41 | if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok { | |
42 | if updatedAtField.IsBlank { | |
43 | updatedAtField.Set(now) | |
44 | } | |
45 | } | |
46 | } | |
47 | } | |
48 | ||
49 | // createCallback the callback used to insert data into database | |
50 | func createCallback(scope *Scope) { | |
23 | 51 | if !scope.HasError() { |
24 | // set create sql | |
25 | var sqls, columns []string | |
26 | fields := scope.Fields() | |
27 | for _, field := range fields { | |
52 | defer scope.trace(NowFunc()) | |
53 | ||
54 | var ( | |
55 | columns, placeholders []string | |
56 | blankColumnsWithDefaultValue []string | |
57 | ) | |
58 | ||
59 | for _, field := range scope.Fields() { | |
28 | 60 | if scope.changeableField(field) { |
29 | 61 | if field.IsNormal { |
30 | if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) { | |
31 | if !field.IsBlank || !field.HasDefaultValue { | |
32 | columns = append(columns, scope.Quote(field.DBName)) | |
33 | sqls = append(sqls, scope.AddToVars(field.Field.Interface())) | |
34 | } else if field.HasDefaultValue { | |
35 | scope.InstanceSet("gorm:force_reload_after_create", true) | |
36 | } | |
62 | if field.IsBlank && field.HasDefaultValue { | |
63 | blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) | |
64 | scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) | |
65 | } else if !field.IsPrimaryKey || !field.IsBlank { | |
66 | columns = append(columns, scope.Quote(field.DBName)) | |
67 | placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) | |
37 | 68 | } |
38 | } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { | |
39 | for _, dbName := range relationship.ForeignDBNames { | |
40 | if relationField := fields[dbName]; !scope.changeableField(relationField) { | |
41 | columns = append(columns, scope.Quote(relationField.DBName)) | |
42 | sqls = append(sqls, scope.AddToVars(relationField.Field.Interface())) | |
69 | } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { | |
70 | for _, foreignKey := range field.Relationship.ForeignDBNames { | |
71 | if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { | |
72 | columns = append(columns, scope.Quote(foreignField.DBName)) | |
73 | placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) | |
43 | 74 | } |
44 | 75 | } |
45 | 76 | } |
46 | 77 | } |
47 | 78 | } |
48 | 79 | |
49 | returningKey := "*" | |
50 | primaryField := scope.PrimaryField() | |
51 | if primaryField != nil { | |
52 | returningKey = scope.Quote(primaryField.DBName) | |
80 | var ( | |
81 | returningColumn = "*" | |
82 | quotedTableName = scope.QuotedTableName() | |
83 | primaryField = scope.PrimaryField() | |
84 | extraOption string | |
85 | ) | |
86 | ||
87 | if str, ok := scope.Get("gorm:insert_option"); ok { | |
88 | extraOption = fmt.Sprint(str) | |
53 | 89 | } |
54 | 90 | |
91 | if primaryField != nil { | |
92 | returningColumn = scope.Quote(primaryField.DBName) | |
93 | } | |
94 | ||
95 | lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) | |
96 | ||
55 | 97 | if len(columns) == 0 { |
56 | scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v", | |
57 | scope.QuotedTableName(), | |
58 | scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey), | |
98 | scope.Raw(fmt.Sprintf( | |
99 | "INSERT INTO %v %v%v%v", | |
100 | quotedTableName, | |
101 | scope.Dialect().DefaultValueStr(), | |
102 | addExtraSpaceIfExist(extraOption), | |
103 | addExtraSpaceIfExist(lastInsertIDReturningSuffix), | |
59 | 104 | )) |
60 | 105 | } else { |
61 | 106 | scope.Raw(fmt.Sprintf( |
62 | "INSERT INTO %v (%v) VALUES (%v) %v", | |
107 | "INSERT INTO %v (%v) VALUES (%v)%v%v", | |
63 | 108 | scope.QuotedTableName(), |
64 | 109 | strings.Join(columns, ","), |
65 | strings.Join(sqls, ","), | |
66 | scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey), | |
110 | strings.Join(placeholders, ","), | |
111 | addExtraSpaceIfExist(extraOption), | |
112 | addExtraSpaceIfExist(lastInsertIDReturningSuffix), | |
67 | 113 | )) |
68 | 114 | } |
69 | 115 | |
70 | 116 | // execute create sql |
71 | if scope.Dialect().SupportLastInsertId() { | |
72 | if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { | |
73 | id, err := result.LastInsertId() | |
74 | if scope.Err(err) == nil { | |
75 | scope.db.RowsAffected, _ = result.RowsAffected() | |
76 | if primaryField != nil && primaryField.IsBlank { | |
77 | scope.Err(scope.SetColumn(primaryField, id)) | |
117 | if lastInsertIDReturningSuffix == "" || primaryField == nil { | |
118 | if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { | |
119 | // set rows affected count | |
120 | scope.db.RowsAffected, _ = result.RowsAffected() | |
121 | ||
122 | // set primary value to primary field | |
123 | if primaryField != nil && primaryField.IsBlank { | |
124 | if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { | |
125 | scope.Err(primaryField.Set(primaryValue)) | |
78 | 126 | } |
79 | 127 | } |
80 | 128 | } |
81 | 129 | } else { |
82 | if primaryField == nil { | |
83 | if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil { | |
84 | scope.db.RowsAffected, _ = results.RowsAffected() | |
85 | } else { | |
86 | scope.Err(err) | |
130 | if primaryField.Field.CanAddr() { | |
131 | if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { | |
132 | primaryField.IsBlank = false | |
133 | scope.db.RowsAffected = 1 | |
87 | 134 | } |
88 | 135 | } else { |
89 | if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil { | |
90 | scope.db.RowsAffected = 1 | |
91 | } else { | |
92 | scope.Err(err) | |
93 | } | |
136 | scope.Err(ErrUnaddressable) | |
94 | 137 | } |
95 | 138 | } |
96 | 139 | } |
97 | 140 | } |
98 | 141 | |
99 | func ForceReloadAfterCreate(scope *Scope) { | |
100 | if _, ok := scope.InstanceGet("gorm:force_reload_after_create"); ok { | |
101 | scope.DB().New().First(scope.Value) | |
142 | // forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object | |
143 | func forceReloadAfterCreateCallback(scope *Scope) { | |
144 | if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { | |
145 | db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string)) | |
146 | for _, field := range scope.Fields() { | |
147 | if field.IsPrimaryKey && !field.IsBlank { | |
148 | db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface()) | |
149 | } | |
150 | } | |
151 | db.Scan(scope.Value) | |
102 | 152 | } |
103 | 153 | } |
104 | 154 | |
105 | func AfterCreate(scope *Scope) { | |
106 | scope.CallMethodWithErrorCheck("AfterCreate") | |
107 | scope.CallMethodWithErrorCheck("AfterSave") | |
155 | // afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating | |
156 | func afterCreateCallback(scope *Scope) { | |
157 | if !scope.HasError() { | |
158 | scope.CallMethod("AfterCreate") | |
159 | } | |
160 | if !scope.HasError() { | |
161 | scope.CallMethod("AfterSave") | |
162 | } | |
108 | 163 | } |
109 | ||
110 | func init() { | |
111 | DefaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction) | |
112 | DefaultCallback.Create().Register("gorm:before_create", BeforeCreate) | |
113 | DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations) | |
114 | DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate) | |
115 | DefaultCallback.Create().Register("gorm:create", Create) | |
116 | DefaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate) | |
117 | DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations) | |
118 | DefaultCallback.Create().Register("gorm:after_create", AfterCreate) | |
119 | DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | |
120 | } |
0 | 0 | package gorm |
1 | 1 | |
2 | import "fmt" | |
2 | import ( | |
3 | "errors" | |
4 | "fmt" | |
5 | ) | |
3 | 6 | |
4 | func BeforeDelete(scope *Scope) { | |
5 | scope.CallMethodWithErrorCheck("BeforeDelete") | |
7 | // Define callbacks for deleting | |
8 | func init() { | |
9 | DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback) | |
10 | DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback) | |
11 | DefaultCallback.Delete().Register("gorm:delete", deleteCallback) | |
12 | DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback) | |
13 | DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) | |
6 | 14 | } |
7 | 15 | |
8 | func Delete(scope *Scope) { | |
16 | // beforeDeleteCallback will invoke `BeforeDelete` method before deleting | |
17 | func beforeDeleteCallback(scope *Scope) { | |
18 | if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { | |
19 | scope.Err(errors.New("Missing WHERE clause while deleting")) | |
20 | return | |
21 | } | |
9 | 22 | if !scope.HasError() { |
10 | if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") { | |
11 | scope.Raw( | |
12 | fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", | |
13 | scope.QuotedTableName(), | |
14 | scope.AddToVars(NowFunc()), | |
15 | scope.CombinedConditionSql(), | |
16 | )) | |
17 | } else { | |
18 | scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.QuotedTableName(), scope.CombinedConditionSql())) | |
19 | } | |
20 | ||
21 | scope.Exec() | |
23 | scope.CallMethod("BeforeDelete") | |
22 | 24 | } |
23 | 25 | } |
24 | 26 | |
25 | func AfterDelete(scope *Scope) { | |
26 | scope.CallMethodWithErrorCheck("AfterDelete") | |
27 | // deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete) | |
28 | func deleteCallback(scope *Scope) { | |
29 | if !scope.HasError() { | |
30 | var extraOption string | |
31 | if str, ok := scope.Get("gorm:delete_option"); ok { | |
32 | extraOption = fmt.Sprint(str) | |
33 | } | |
34 | ||
35 | deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt") | |
36 | ||
37 | if !scope.Search.Unscoped && hasDeletedAtField { | |
38 | scope.Raw(fmt.Sprintf( | |
39 | "UPDATE %v SET %v=%v%v%v", | |
40 | scope.QuotedTableName(), | |
41 | scope.Quote(deletedAtField.DBName), | |
42 | scope.AddToVars(NowFunc()), | |
43 | addExtraSpaceIfExist(scope.CombinedConditionSql()), | |
44 | addExtraSpaceIfExist(extraOption), | |
45 | )).Exec() | |
46 | } else { | |
47 | scope.Raw(fmt.Sprintf( | |
48 | "DELETE FROM %v%v%v", | |
49 | scope.QuotedTableName(), | |
50 | addExtraSpaceIfExist(scope.CombinedConditionSql()), | |
51 | addExtraSpaceIfExist(extraOption), | |
52 | )).Exec() | |
53 | } | |
54 | } | |
27 | 55 | } |
28 | 56 | |
29 | func init() { | |
30 | DefaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction) | |
31 | DefaultCallback.Delete().Register("gorm:before_delete", BeforeDelete) | |
32 | DefaultCallback.Delete().Register("gorm:delete", Delete) | |
33 | DefaultCallback.Delete().Register("gorm:after_delete", AfterDelete) | |
34 | DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | |
57 | // afterDeleteCallback will invoke `AfterDelete` method after deleting | |
58 | func afterDeleteCallback(scope *Scope) { | |
59 | if !scope.HasError() { | |
60 | scope.CallMethod("AfterDelete") | |
61 | } | |
35 | 62 | } |
5 | 5 | "reflect" |
6 | 6 | ) |
7 | 7 | |
8 | func Query(scope *Scope) { | |
9 | defer scope.Trace(NowFunc()) | |
8 | // Define callbacks for querying | |
9 | func init() { | |
10 | DefaultCallback.Query().Register("gorm:query", queryCallback) | |
11 | DefaultCallback.Query().Register("gorm:preload", preloadCallback) | |
12 | DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback) | |
13 | } | |
14 | ||
15 | // queryCallback used to query data from database | |
16 | func queryCallback(scope *Scope) { | |
17 | if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { | |
18 | return | |
19 | } | |
20 | ||
21 | defer scope.trace(NowFunc()) | |
10 | 22 | |
11 | 23 | var ( |
12 | isSlice bool | |
13 | isPtr bool | |
14 | anyRecordFound bool | |
15 | destType reflect.Type | |
24 | isSlice, isPtr bool | |
25 | resultType reflect.Type | |
26 | results = scope.IndirectValue() | |
16 | 27 | ) |
17 | 28 | |
18 | 29 | if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { |
19 | if primaryKey := scope.PrimaryKey(); primaryKey != "" { | |
20 | scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryKey), orderBy)) | |
30 | if primaryField := scope.PrimaryField(); primaryField != nil { | |
31 | scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy)) | |
21 | 32 | } |
22 | 33 | } |
23 | 34 | |
24 | var dest = scope.IndirectValue() | |
25 | 35 | if value, ok := scope.Get("gorm:query_destination"); ok { |
26 | dest = reflect.Indirect(reflect.ValueOf(value)) | |
36 | results = indirect(reflect.ValueOf(value)) | |
27 | 37 | } |
28 | 38 | |
29 | if kind := dest.Kind(); kind == reflect.Slice { | |
39 | if kind := results.Kind(); kind == reflect.Slice { | |
30 | 40 | isSlice = true |
31 | destType = dest.Type().Elem() | |
32 | dest.Set(reflect.MakeSlice(dest.Type(), 0, 0)) | |
41 | resultType = results.Type().Elem() | |
42 | results.Set(reflect.MakeSlice(results.Type(), 0, 0)) | |
33 | 43 | |
34 | if destType.Kind() == reflect.Ptr { | |
44 | if resultType.Kind() == reflect.Ptr { | |
35 | 45 | isPtr = true |
36 | destType = destType.Elem() | |
46 | resultType = resultType.Elem() | |
37 | 47 | } |
38 | 48 | } else if kind != reflect.Struct { |
39 | 49 | scope.Err(errors.New("unsupported destination, should be slice or struct")) |
40 | 50 | return |
41 | 51 | } |
42 | 52 | |
43 | scope.prepareQuerySql() | |
53 | scope.prepareQuerySQL() | |
44 | 54 | |
45 | 55 | if !scope.HasError() { |
46 | rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...) | |
47 | 56 | scope.db.RowsAffected = 0 |
57 | if str, ok := scope.Get("gorm:query_option"); ok { | |
58 | scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) | |
59 | } | |
48 | 60 | |
49 | if scope.Err(err) != nil { | |
50 | return | |
51 | } | |
52 | defer rows.Close() | |
61 | if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { | |
62 | defer rows.Close() | |
53 | 63 | |
54 | columns, _ := rows.Columns() | |
55 | for rows.Next() { | |
56 | scope.db.RowsAffected++ | |
64 | columns, _ := rows.Columns() | |
65 | for rows.Next() { | |
66 | scope.db.RowsAffected++ | |
57 | 67 | |
58 | anyRecordFound = true | |
59 | elem := dest | |
60 | if isSlice { | |
61 | elem = reflect.New(destType).Elem() | |
62 | } | |
68 | elem := results | |
69 | if isSlice { | |
70 | elem = reflect.New(resultType).Elem() | |
71 | } | |
63 | 72 | |
64 | var values = make([]interface{}, len(columns)) | |
73 | scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) | |
65 | 74 | |
66 | fields := scope.New(elem.Addr().Interface()).Fields() | |
67 | ||
68 | for index, column := range columns { | |
69 | if field, ok := fields[column]; ok { | |
70 | if field.Field.Kind() == reflect.Ptr { | |
71 | values[index] = field.Field.Addr().Interface() | |
75 | if isSlice { | |
76 | if isPtr { | |
77 | results.Set(reflect.Append(results, elem.Addr())) | |
72 | 78 | } else { |
73 | values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface() | |
74 | } | |
75 | } else { | |
76 | var value interface{} | |
77 | values[index] = &value | |
78 | } | |
79 | } | |
80 | ||
81 | scope.Err(rows.Scan(values...)) | |
82 | ||
83 | for index, column := range columns { | |
84 | value := values[index] | |
85 | if field, ok := fields[column]; ok { | |
86 | if field.Field.Kind() == reflect.Ptr { | |
87 | field.Field.Set(reflect.ValueOf(value).Elem()) | |
88 | } else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() { | |
89 | field.Field.Set(v) | |
79 | results.Set(reflect.Append(results, elem)) | |
90 | 80 | } |
91 | 81 | } |
92 | 82 | } |
93 | 83 | |
94 | if isSlice { | |
95 | if isPtr { | |
96 | dest.Set(reflect.Append(dest, elem.Addr())) | |
97 | } else { | |
98 | dest.Set(reflect.Append(dest, elem)) | |
99 | } | |
84 | if err := rows.Err(); err != nil { | |
85 | scope.Err(err) | |
86 | } else if scope.db.RowsAffected == 0 && !isSlice { | |
87 | scope.Err(ErrRecordNotFound) | |
100 | 88 | } |
101 | } | |
102 | ||
103 | if !anyRecordFound && !isSlice { | |
104 | scope.Err(RecordNotFound) | |
105 | 89 | } |
106 | 90 | } |
107 | 91 | } |
108 | 92 | |
109 | func AfterQuery(scope *Scope) { | |
110 | scope.CallMethodWithErrorCheck("AfterFind") | |
93 | // afterQueryCallback will invoke `AfterFind` method after querying | |
94 | func afterQueryCallback(scope *Scope) { | |
95 | if !scope.HasError() { | |
96 | scope.CallMethod("AfterFind") | |
97 | } | |
111 | 98 | } |
112 | ||
113 | func init() { | |
114 | DefaultCallback.Query().Register("gorm:query", Query) | |
115 | DefaultCallback.Query().Register("gorm:after_query", AfterQuery) | |
116 | DefaultCallback.Query().Register("gorm:preload", Preload) | |
117 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "fmt" | |
5 | "reflect" | |
6 | "strconv" | |
7 | "strings" | |
8 | ) | |
9 | ||
10 | // preloadCallback used to preload associations | |
11 | func preloadCallback(scope *Scope) { | |
12 | if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { | |
13 | return | |
14 | } | |
15 | ||
16 | if _, ok := scope.Get("gorm:auto_preload"); ok { | |
17 | autoPreload(scope) | |
18 | } | |
19 | ||
20 | if scope.Search.preload == nil || scope.HasError() { | |
21 | return | |
22 | } | |
23 | ||
24 | var ( | |
25 | preloadedMap = map[string]bool{} | |
26 | fields = scope.Fields() | |
27 | ) | |
28 | ||
29 | for _, preload := range scope.Search.preload { | |
30 | var ( | |
31 | preloadFields = strings.Split(preload.schema, ".") | |
32 | currentScope = scope | |
33 | currentFields = fields | |
34 | ) | |
35 | ||
36 | for idx, preloadField := range preloadFields { | |
37 | var currentPreloadConditions []interface{} | |
38 | ||
39 | if currentScope == nil { | |
40 | continue | |
41 | } | |
42 | ||
43 | // if not preloaded | |
44 | if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { | |
45 | ||
46 | // assign search conditions to last preload | |
47 | if idx == len(preloadFields)-1 { | |
48 | currentPreloadConditions = preload.conditions | |
49 | } | |
50 | ||
51 | for _, field := range currentFields { | |
52 | if field.Name != preloadField || field.Relationship == nil { | |
53 | continue | |
54 | } | |
55 | ||
56 | switch field.Relationship.Kind { | |
57 | case "has_one": | |
58 | currentScope.handleHasOnePreload(field, currentPreloadConditions) | |
59 | case "has_many": | |
60 | currentScope.handleHasManyPreload(field, currentPreloadConditions) | |
61 | case "belongs_to": | |
62 | currentScope.handleBelongsToPreload(field, currentPreloadConditions) | |
63 | case "many_to_many": | |
64 | currentScope.handleManyToManyPreload(field, currentPreloadConditions) | |
65 | default: | |
66 | scope.Err(errors.New("unsupported relation")) | |
67 | } | |
68 | ||
69 | preloadedMap[preloadKey] = true | |
70 | break | |
71 | } | |
72 | ||
73 | if !preloadedMap[preloadKey] { | |
74 | scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType)) | |
75 | return | |
76 | } | |
77 | } | |
78 | ||
79 | // preload next level | |
80 | if idx < len(preloadFields)-1 { | |
81 | currentScope = currentScope.getColumnAsScope(preloadField) | |
82 | if currentScope != nil { | |
83 | currentFields = currentScope.Fields() | |
84 | } | |
85 | } | |
86 | } | |
87 | } | |
88 | } | |
89 | ||
90 | func autoPreload(scope *Scope) { | |
91 | for _, field := range scope.Fields() { | |
92 | if field.Relationship == nil { | |
93 | continue | |
94 | } | |
95 | ||
96 | if val, ok := field.TagSettings["PRELOAD"]; ok { | |
97 | if preload, err := strconv.ParseBool(val); err != nil { | |
98 | scope.Err(errors.New("invalid preload option")) | |
99 | return | |
100 | } else if !preload { | |
101 | continue | |
102 | } | |
103 | } | |
104 | ||
105 | scope.Search.Preload(field.Name) | |
106 | } | |
107 | } | |
108 | ||
109 | func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { | |
110 | var ( | |
111 | preloadDB = scope.NewDB() | |
112 | preloadConditions []interface{} | |
113 | ) | |
114 | ||
115 | for _, condition := range conditions { | |
116 | if scopes, ok := condition.(func(*DB) *DB); ok { | |
117 | preloadDB = scopes(preloadDB) | |
118 | } else { | |
119 | preloadConditions = append(preloadConditions, condition) | |
120 | } | |
121 | } | |
122 | ||
123 | return preloadDB, preloadConditions | |
124 | } | |
125 | ||
126 | // handleHasOnePreload used to preload has one associations | |
127 | func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { | |
128 | relation := field.Relationship | |
129 | ||
130 | // get relations's primary keys | |
131 | primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) | |
132 | if len(primaryKeys) == 0 { | |
133 | return | |
134 | } | |
135 | ||
136 | // preload conditions | |
137 | preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) | |
138 | ||
139 | // find relations | |
140 | query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) | |
141 | values := toQueryValues(primaryKeys) | |
142 | if relation.PolymorphicType != "" { | |
143 | query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) | |
144 | values = append(values, relation.PolymorphicValue) | |
145 | } | |
146 | ||
147 | results := makeSlice(field.Struct.Type) | |
148 | scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) | |
149 | ||
150 | // assign find results | |
151 | var ( | |
152 | resultsValue = indirect(reflect.ValueOf(results)) | |
153 | indirectScopeValue = scope.IndirectValue() | |
154 | ) | |
155 | ||
156 | if indirectScopeValue.Kind() == reflect.Slice { | |
157 | for j := 0; j < indirectScopeValue.Len(); j++ { | |
158 | for i := 0; i < resultsValue.Len(); i++ { | |
159 | result := resultsValue.Index(i) | |
160 | foreignValues := getValueFromFields(result, relation.ForeignFieldNames) | |
161 | if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { | |
162 | indirectValue.FieldByName(field.Name).Set(result) | |
163 | break | |
164 | } | |
165 | } | |
166 | } | |
167 | } else { | |
168 | for i := 0; i < resultsValue.Len(); i++ { | |
169 | result := resultsValue.Index(i) | |
170 | scope.Err(field.Set(result)) | |
171 | } | |
172 | } | |
173 | } | |
174 | ||
175 | // handleHasManyPreload used to preload has many associations | |
176 | func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { | |
177 | relation := field.Relationship | |
178 | ||
179 | // get relations's primary keys | |
180 | primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) | |
181 | if len(primaryKeys) == 0 { | |
182 | return | |
183 | } | |
184 | ||
185 | // preload conditions | |
186 | preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) | |
187 | ||
188 | // find relations | |
189 | query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) | |
190 | values := toQueryValues(primaryKeys) | |
191 | if relation.PolymorphicType != "" { | |
192 | query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) | |
193 | values = append(values, relation.PolymorphicValue) | |
194 | } | |
195 | ||
196 | results := makeSlice(field.Struct.Type) | |
197 | scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) | |
198 | ||
199 | // assign find results | |
200 | var ( | |
201 | resultsValue = indirect(reflect.ValueOf(results)) | |
202 | indirectScopeValue = scope.IndirectValue() | |
203 | ) | |
204 | ||
205 | if indirectScopeValue.Kind() == reflect.Slice { | |
206 | preloadMap := make(map[string][]reflect.Value) | |
207 | for i := 0; i < resultsValue.Len(); i++ { | |
208 | result := resultsValue.Index(i) | |
209 | foreignValues := getValueFromFields(result, relation.ForeignFieldNames) | |
210 | preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result) | |
211 | } | |
212 | ||
213 | for j := 0; j < indirectScopeValue.Len(); j++ { | |
214 | object := indirect(indirectScopeValue.Index(j)) | |
215 | objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames) | |
216 | f := object.FieldByName(field.Name) | |
217 | if results, ok := preloadMap[toString(objectRealValue)]; ok { | |
218 | f.Set(reflect.Append(f, results...)) | |
219 | } else { | |
220 | f.Set(reflect.MakeSlice(f.Type(), 0, 0)) | |
221 | } | |
222 | } | |
223 | } else { | |
224 | scope.Err(field.Set(resultsValue)) | |
225 | } | |
226 | } | |
227 | ||
228 | // handleBelongsToPreload used to preload belongs to associations | |
229 | func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { | |
230 | relation := field.Relationship | |
231 | ||
232 | // preload conditions | |
233 | preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) | |
234 | ||
235 | // get relations's primary keys | |
236 | primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) | |
237 | if len(primaryKeys) == 0 { | |
238 | return | |
239 | } | |
240 | ||
241 | // find relations | |
242 | results := makeSlice(field.Struct.Type) | |
243 | scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) | |
244 | ||
245 | // assign find results | |
246 | var ( | |
247 | resultsValue = indirect(reflect.ValueOf(results)) | |
248 | indirectScopeValue = scope.IndirectValue() | |
249 | ) | |
250 | ||
251 | for i := 0; i < resultsValue.Len(); i++ { | |
252 | result := resultsValue.Index(i) | |
253 | if indirectScopeValue.Kind() == reflect.Slice { | |
254 | value := getValueFromFields(result, relation.AssociationForeignFieldNames) | |
255 | for j := 0; j < indirectScopeValue.Len(); j++ { | |
256 | object := indirect(indirectScopeValue.Index(j)) | |
257 | if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { | |
258 | object.FieldByName(field.Name).Set(result) | |
259 | } | |
260 | } | |
261 | } else { | |
262 | scope.Err(field.Set(result)) | |
263 | } | |
264 | } | |
265 | } | |
266 | ||
267 | // handleManyToManyPreload used to preload many to many associations | |
268 | func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { | |
269 | var ( | |
270 | relation = field.Relationship | |
271 | joinTableHandler = relation.JoinTableHandler | |
272 | fieldType = field.Struct.Type.Elem() | |
273 | foreignKeyValue interface{} | |
274 | foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type() | |
275 | linkHash = map[string][]reflect.Value{} | |
276 | isPtr bool | |
277 | ) | |
278 | ||
279 | if fieldType.Kind() == reflect.Ptr { | |
280 | isPtr = true | |
281 | fieldType = fieldType.Elem() | |
282 | } | |
283 | ||
284 | var sourceKeys = []string{} | |
285 | for _, key := range joinTableHandler.SourceForeignKeys() { | |
286 | sourceKeys = append(sourceKeys, key.DBName) | |
287 | } | |
288 | ||
289 | // preload conditions | |
290 | preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) | |
291 | ||
292 | // generate query with join table | |
293 | newScope := scope.New(reflect.New(fieldType).Interface()) | |
294 | preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value) | |
295 | ||
296 | if len(preloadDB.search.selects) == 0 { | |
297 | preloadDB = preloadDB.Select("*") | |
298 | } | |
299 | ||
300 | preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) | |
301 | ||
302 | // preload inline conditions | |
303 | if len(preloadConditions) > 0 { | |
304 | preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...) | |
305 | } | |
306 | ||
307 | rows, err := preloadDB.Rows() | |
308 | ||
309 | if scope.Err(err) != nil { | |
310 | return | |
311 | } | |
312 | defer rows.Close() | |
313 | ||
314 | columns, _ := rows.Columns() | |
315 | for rows.Next() { | |
316 | var ( | |
317 | elem = reflect.New(fieldType).Elem() | |
318 | fields = scope.New(elem.Addr().Interface()).Fields() | |
319 | ) | |
320 | ||
321 | // register foreign keys in join tables | |
322 | var joinTableFields []*Field | |
323 | for _, sourceKey := range sourceKeys { | |
324 | joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()}) | |
325 | } | |
326 | ||
327 | scope.scan(rows, columns, append(fields, joinTableFields...)) | |
328 | ||
329 | scope.New(elem.Addr().Interface()). | |
330 | InstanceSet("gorm:skip_query_callback", true). | |
331 | callCallbacks(scope.db.parent.callbacks.queries) | |
332 | ||
333 | var foreignKeys = make([]interface{}, len(sourceKeys)) | |
334 | // generate hashed forkey keys in join table | |
335 | for idx, joinTableField := range joinTableFields { | |
336 | if !joinTableField.Field.IsNil() { | |
337 | foreignKeys[idx] = joinTableField.Field.Elem().Interface() | |
338 | } | |
339 | } | |
340 | hashedSourceKeys := toString(foreignKeys) | |
341 | ||
342 | if isPtr { | |
343 | linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr()) | |
344 | } else { | |
345 | linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem) | |
346 | } | |
347 | } | |
348 | ||
349 | if err := rows.Err(); err != nil { | |
350 | scope.Err(err) | |
351 | } | |
352 | ||
353 | // assign find results | |
354 | var ( | |
355 | indirectScopeValue = scope.IndirectValue() | |
356 | fieldsSourceMap = map[string][]reflect.Value{} | |
357 | foreignFieldNames = []string{} | |
358 | ) | |
359 | ||
360 | for _, dbName := range relation.ForeignFieldNames { | |
361 | if field, ok := scope.FieldByName(dbName); ok { | |
362 | foreignFieldNames = append(foreignFieldNames, field.Name) | |
363 | } | |
364 | } | |
365 | ||
366 | if indirectScopeValue.Kind() == reflect.Slice { | |
367 | for j := 0; j < indirectScopeValue.Len(); j++ { | |
368 | object := indirect(indirectScopeValue.Index(j)) | |
369 | key := toString(getValueFromFields(object, foreignFieldNames)) | |
370 | fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name)) | |
371 | } | |
372 | } else if indirectScopeValue.IsValid() { | |
373 | key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) | |
374 | fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) | |
375 | } | |
376 | for source, link := range linkHash { | |
377 | for i, field := range fieldsSourceMap[source] { | |
378 | //If not 0 this means Value is a pointer and we already added preloaded models to it | |
379 | if fieldsSourceMap[source][i].Len() != 0 { | |
380 | continue | |
381 | } | |
382 | field.Set(reflect.Append(fieldsSourceMap[source][i], link...)) | |
383 | } | |
384 | ||
385 | } | |
386 | } |
0 | package gorm | |
1 | ||
2 | import "database/sql" | |
3 | ||
4 | // Define callbacks for row query | |
5 | func init() { | |
6 | DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback) | |
7 | } | |
8 | ||
9 | type RowQueryResult struct { | |
10 | Row *sql.Row | |
11 | } | |
12 | ||
13 | type RowsQueryResult struct { | |
14 | Rows *sql.Rows | |
15 | Error error | |
16 | } | |
17 | ||
18 | // queryCallback used to query data from database | |
19 | func rowQueryCallback(scope *Scope) { | |
20 | if result, ok := scope.InstanceGet("row_query_result"); ok { | |
21 | scope.prepareQuerySQL() | |
22 | ||
23 | if rowResult, ok := result.(*RowQueryResult); ok { | |
24 | rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) | |
25 | } else if rowsResult, ok := result.(*RowsQueryResult); ok { | |
26 | rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...) | |
27 | } | |
28 | } | |
29 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "reflect" | |
4 | "strings" | |
5 | ) | |
6 | ||
7 | func beginTransactionCallback(scope *Scope) { | |
8 | scope.Begin() | |
9 | } | |
10 | ||
11 | func commitOrRollbackTransactionCallback(scope *Scope) { | |
12 | scope.CommitOrRollback() | |
13 | } | |
14 | ||
15 | func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) { | |
16 | checkTruth := func(value interface{}) bool { | |
17 | if v, ok := value.(bool); ok && !v { | |
18 | return false | |
19 | } | |
20 | ||
21 | if v, ok := value.(string); ok { | |
22 | v = strings.ToLower(v) | |
23 | if v == "false" || v != "skip" { | |
24 | return false | |
25 | } | |
26 | } | |
27 | ||
28 | return true | |
29 | } | |
30 | ||
31 | if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { | |
32 | if r = field.Relationship; r != nil { | |
33 | autoUpdate, autoCreate, saveReference = true, true, true | |
34 | ||
35 | if value, ok := scope.Get("gorm:save_associations"); ok { | |
36 | autoUpdate = checkTruth(value) | |
37 | autoCreate = autoUpdate | |
38 | } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok { | |
39 | autoUpdate = checkTruth(value) | |
40 | autoCreate = autoUpdate | |
41 | } | |
42 | ||
43 | if value, ok := scope.Get("gorm:association_autoupdate"); ok { | |
44 | autoUpdate = checkTruth(value) | |
45 | } else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok { | |
46 | autoUpdate = checkTruth(value) | |
47 | } | |
48 | ||
49 | if value, ok := scope.Get("gorm:association_autocreate"); ok { | |
50 | autoCreate = checkTruth(value) | |
51 | } else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok { | |
52 | autoCreate = checkTruth(value) | |
53 | } | |
54 | ||
55 | if value, ok := scope.Get("gorm:association_save_reference"); ok { | |
56 | saveReference = checkTruth(value) | |
57 | } else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { | |
58 | saveReference = checkTruth(value) | |
59 | } | |
60 | } | |
61 | } | |
62 | ||
63 | return | |
64 | } | |
65 | ||
66 | func saveBeforeAssociationsCallback(scope *Scope) { | |
67 | for _, field := range scope.Fields() { | |
68 | autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) | |
69 | ||
70 | if relationship != nil && relationship.Kind == "belongs_to" { | |
71 | fieldValue := field.Field.Addr().Interface() | |
72 | newScope := scope.New(fieldValue) | |
73 | ||
74 | if newScope.PrimaryKeyZero() { | |
75 | if autoCreate { | |
76 | scope.Err(scope.NewDB().Save(fieldValue).Error) | |
77 | } | |
78 | } else if autoUpdate { | |
79 | scope.Err(scope.NewDB().Save(fieldValue).Error) | |
80 | } | |
81 | ||
82 | if saveReference { | |
83 | if len(relationship.ForeignFieldNames) != 0 { | |
84 | // set value's foreign key | |
85 | for idx, fieldName := range relationship.ForeignFieldNames { | |
86 | associationForeignName := relationship.AssociationForeignDBNames[idx] | |
87 | if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { | |
88 | scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) | |
89 | } | |
90 | } | |
91 | } | |
92 | } | |
93 | } | |
94 | } | |
95 | } | |
96 | ||
97 | func saveAfterAssociationsCallback(scope *Scope) { | |
98 | for _, field := range scope.Fields() { | |
99 | autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) | |
100 | ||
101 | if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { | |
102 | value := field.Field | |
103 | ||
104 | switch value.Kind() { | |
105 | case reflect.Slice: | |
106 | for i := 0; i < value.Len(); i++ { | |
107 | newDB := scope.NewDB() | |
108 | elem := value.Index(i).Addr().Interface() | |
109 | newScope := newDB.NewScope(elem) | |
110 | ||
111 | if saveReference { | |
112 | if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { | |
113 | for idx, fieldName := range relationship.ForeignFieldNames { | |
114 | associationForeignName := relationship.AssociationForeignDBNames[idx] | |
115 | if f, ok := scope.FieldByName(associationForeignName); ok { | |
116 | scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) | |
117 | } | |
118 | } | |
119 | } | |
120 | ||
121 | if relationship.PolymorphicType != "" { | |
122 | scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) | |
123 | } | |
124 | } | |
125 | ||
126 | if newScope.PrimaryKeyZero() { | |
127 | if autoCreate { | |
128 | scope.Err(newDB.Save(elem).Error) | |
129 | } | |
130 | } else if autoUpdate { | |
131 | scope.Err(newDB.Save(elem).Error) | |
132 | } | |
133 | ||
134 | if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference { | |
135 | if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { | |
136 | scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) | |
137 | } | |
138 | } | |
139 | } | |
140 | default: | |
141 | elem := value.Addr().Interface() | |
142 | newScope := scope.New(elem) | |
143 | ||
144 | if saveReference { | |
145 | if len(relationship.ForeignFieldNames) != 0 { | |
146 | for idx, fieldName := range relationship.ForeignFieldNames { | |
147 | associationForeignName := relationship.AssociationForeignDBNames[idx] | |
148 | if f, ok := scope.FieldByName(associationForeignName); ok { | |
149 | scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) | |
150 | } | |
151 | } | |
152 | } | |
153 | ||
154 | if relationship.PolymorphicType != "" { | |
155 | scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) | |
156 | } | |
157 | } | |
158 | ||
159 | if newScope.PrimaryKeyZero() { | |
160 | if autoCreate { | |
161 | scope.Err(scope.NewDB().Save(elem).Error) | |
162 | } | |
163 | } else if autoUpdate { | |
164 | scope.Err(scope.NewDB().Save(elem).Error) | |
165 | } | |
166 | } | |
167 | } | |
168 | } | |
169 | } |
0 | package gorm | |
1 | ||
2 | import "reflect" | |
3 | ||
4 | func BeginTransaction(scope *Scope) { | |
5 | scope.Begin() | |
6 | } | |
7 | ||
8 | func CommitOrRollbackTransaction(scope *Scope) { | |
9 | scope.CommitOrRollback() | |
10 | } | |
11 | ||
12 | func SaveBeforeAssociations(scope *Scope) { | |
13 | if !scope.shouldSaveAssociations() { | |
14 | return | |
15 | } | |
16 | for _, field := range scope.Fields() { | |
17 | if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { | |
18 | if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { | |
19 | value := field.Field | |
20 | scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error) | |
21 | if len(relationship.ForeignFieldNames) != 0 { | |
22 | for idx, fieldName := range relationship.ForeignFieldNames { | |
23 | associationForeignName := relationship.AssociationForeignDBNames[idx] | |
24 | if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok { | |
25 | scope.Err(scope.SetColumn(fieldName, f.Field.Interface())) | |
26 | } | |
27 | } | |
28 | } | |
29 | } | |
30 | } | |
31 | } | |
32 | } | |
33 | ||
34 | func SaveAfterAssociations(scope *Scope) { | |
35 | if !scope.shouldSaveAssociations() { | |
36 | return | |
37 | } | |
38 | for _, field := range scope.Fields() { | |
39 | if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { | |
40 | if relationship := field.Relationship; relationship != nil && | |
41 | (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { | |
42 | value := field.Field | |
43 | ||
44 | switch value.Kind() { | |
45 | case reflect.Slice: | |
46 | for i := 0; i < value.Len(); i++ { | |
47 | newDB := scope.NewDB() | |
48 | elem := value.Index(i).Addr().Interface() | |
49 | newScope := newDB.NewScope(elem) | |
50 | ||
51 | if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { | |
52 | for idx, fieldName := range relationship.ForeignFieldNames { | |
53 | associationForeignName := relationship.AssociationForeignDBNames[idx] | |
54 | if f, ok := scope.FieldByName(associationForeignName); ok { | |
55 | scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) | |
56 | } | |
57 | } | |
58 | } | |
59 | ||
60 | if relationship.PolymorphicType != "" { | |
61 | scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) | |
62 | } | |
63 | ||
64 | scope.Err(newDB.Save(elem).Error) | |
65 | ||
66 | if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { | |
67 | scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value)) | |
68 | } | |
69 | } | |
70 | default: | |
71 | elem := value.Addr().Interface() | |
72 | newScope := scope.New(elem) | |
73 | if len(relationship.ForeignFieldNames) != 0 { | |
74 | for idx, fieldName := range relationship.ForeignFieldNames { | |
75 | associationForeignName := relationship.AssociationForeignDBNames[idx] | |
76 | if f, ok := scope.FieldByName(associationForeignName); ok { | |
77 | scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) | |
78 | } | |
79 | } | |
80 | } | |
81 | ||
82 | if relationship.PolymorphicType != "" { | |
83 | scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) | |
84 | } | |
85 | scope.Err(scope.NewDB().Save(elem).Error) | |
86 | } | |
87 | } | |
88 | } | |
89 | } | |
90 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "reflect" | |
4 | "runtime" | |
5 | "strings" | |
6 | "testing" | |
7 | ) | |
8 | ||
9 | func equalFuncs(funcs []*func(s *Scope), fnames []string) bool { | |
10 | var names []string | |
11 | for _, f := range funcs { | |
12 | fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".") | |
13 | names = append(names, fnames[len(fnames)-1]) | |
14 | } | |
15 | return reflect.DeepEqual(names, fnames) | |
16 | } | |
17 | ||
18 | func create(s *Scope) {} | |
19 | func beforeCreate1(s *Scope) {} | |
20 | func beforeCreate2(s *Scope) {} | |
21 | func afterCreate1(s *Scope) {} | |
22 | func afterCreate2(s *Scope) {} | |
23 | ||
24 | func TestRegisterCallback(t *testing.T) { | |
25 | var callback = &Callback{} | |
26 | ||
27 | callback.Create().Register("before_create1", beforeCreate1) | |
28 | callback.Create().Register("before_create2", beforeCreate2) | |
29 | callback.Create().Register("create", create) | |
30 | callback.Create().Register("after_create1", afterCreate1) | |
31 | callback.Create().Register("after_create2", afterCreate2) | |
32 | ||
33 | if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { | |
34 | t.Errorf("register callback") | |
35 | } | |
36 | } | |
37 | ||
38 | func TestRegisterCallbackWithOrder(t *testing.T) { | |
39 | var callback1 = &Callback{} | |
40 | callback1.Create().Register("before_create1", beforeCreate1) | |
41 | callback1.Create().Register("create", create) | |
42 | callback1.Create().Register("after_create1", afterCreate1) | |
43 | callback1.Create().Before("after_create1").Register("after_create2", afterCreate2) | |
44 | if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { | |
45 | t.Errorf("register callback with order") | |
46 | } | |
47 | ||
48 | var callback2 = &Callback{} | |
49 | ||
50 | callback2.Update().Register("create", create) | |
51 | callback2.Update().Before("create").Register("before_create1", beforeCreate1) | |
52 | callback2.Update().After("after_create2").Register("after_create1", afterCreate1) | |
53 | callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2) | |
54 | callback2.Update().Register("after_create2", afterCreate2) | |
55 | ||
56 | if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { | |
57 | t.Errorf("register callback with order") | |
58 | } | |
59 | } | |
60 | ||
61 | func TestRegisterCallbackWithComplexOrder(t *testing.T) { | |
62 | var callback1 = &Callback{} | |
63 | ||
64 | callback1.Query().Before("after_create1").After("before_create1").Register("create", create) | |
65 | callback1.Query().Register("before_create1", beforeCreate1) | |
66 | callback1.Query().Register("after_create1", afterCreate1) | |
67 | ||
68 | if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) { | |
69 | t.Errorf("register callback with order") | |
70 | } | |
71 | ||
72 | var callback2 = &Callback{} | |
73 | ||
74 | callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) | |
75 | callback2.Delete().Before("create").Register("before_create1", beforeCreate1) | |
76 | callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) | |
77 | callback2.Delete().Register("after_create1", afterCreate1) | |
78 | callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) | |
79 | ||
80 | if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { | |
81 | t.Errorf("register callback with order") | |
82 | } | |
83 | } | |
84 | ||
85 | func replaceCreate(s *Scope) {} | |
86 | ||
87 | func TestReplaceCallback(t *testing.T) { | |
88 | var callback = &Callback{} | |
89 | ||
90 | callback.Create().Before("after_create1").After("before_create1").Register("create", create) | |
91 | callback.Create().Register("before_create1", beforeCreate1) | |
92 | callback.Create().Register("after_create1", afterCreate1) | |
93 | callback.Create().Replace("create", replaceCreate) | |
94 | ||
95 | if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) { | |
96 | t.Errorf("replace callback") | |
97 | } | |
98 | } | |
99 | ||
100 | func TestRemoveCallback(t *testing.T) { | |
101 | var callback = &Callback{} | |
102 | ||
103 | callback.Create().Before("after_create1").After("before_create1").Register("create", create) | |
104 | callback.Create().Register("before_create1", beforeCreate1) | |
105 | callback.Create().Register("after_create1", afterCreate1) | |
106 | callback.Create().Remove("create") | |
107 | ||
108 | if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) { | |
109 | t.Errorf("remove callback") | |
110 | } | |
111 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "reflect" | |
4 | "runtime" | |
5 | "strings" | |
6 | "testing" | |
7 | ) | |
8 | ||
9 | func equalFuncs(funcs []*func(s *Scope), fnames []string) bool { | |
10 | var names []string | |
11 | for _, f := range funcs { | |
12 | fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".") | |
13 | names = append(names, fnames[len(fnames)-1]) | |
14 | } | |
15 | return reflect.DeepEqual(names, fnames) | |
16 | } | |
17 | ||
18 | func create(s *Scope) {} | |
19 | func beforeCreate1(s *Scope) {} | |
20 | func beforeCreate2(s *Scope) {} | |
21 | func afterCreate1(s *Scope) {} | |
22 | func afterCreate2(s *Scope) {} | |
23 | ||
24 | func TestRegisterCallback(t *testing.T) { | |
25 | var callback = &callback{processors: []*callbackProcessor{}} | |
26 | ||
27 | callback.Create().Register("before_create1", beforeCreate1) | |
28 | callback.Create().Register("before_create2", beforeCreate2) | |
29 | callback.Create().Register("create", create) | |
30 | callback.Create().Register("after_create1", afterCreate1) | |
31 | callback.Create().Register("after_create2", afterCreate2) | |
32 | ||
33 | if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { | |
34 | t.Errorf("register callback") | |
35 | } | |
36 | } | |
37 | ||
38 | func TestRegisterCallbackWithOrder(t *testing.T) { | |
39 | var callback1 = &callback{processors: []*callbackProcessor{}} | |
40 | callback1.Create().Register("before_create1", beforeCreate1) | |
41 | callback1.Create().Register("create", create) | |
42 | callback1.Create().Register("after_create1", afterCreate1) | |
43 | callback1.Create().Before("after_create1").Register("after_create2", afterCreate2) | |
44 | if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { | |
45 | t.Errorf("register callback with order") | |
46 | } | |
47 | ||
48 | var callback2 = &callback{processors: []*callbackProcessor{}} | |
49 | ||
50 | callback2.Update().Register("create", create) | |
51 | callback2.Update().Before("create").Register("before_create1", beforeCreate1) | |
52 | callback2.Update().After("after_create2").Register("after_create1", afterCreate1) | |
53 | callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2) | |
54 | callback2.Update().Register("after_create2", afterCreate2) | |
55 | ||
56 | if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { | |
57 | t.Errorf("register callback with order") | |
58 | } | |
59 | } | |
60 | ||
61 | func TestRegisterCallbackWithComplexOrder(t *testing.T) { | |
62 | var callback1 = &callback{processors: []*callbackProcessor{}} | |
63 | ||
64 | callback1.Query().Before("after_create1").After("before_create1").Register("create", create) | |
65 | callback1.Query().Register("before_create1", beforeCreate1) | |
66 | callback1.Query().Register("after_create1", afterCreate1) | |
67 | ||
68 | if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) { | |
69 | t.Errorf("register callback with order") | |
70 | } | |
71 | ||
72 | var callback2 = &callback{processors: []*callbackProcessor{}} | |
73 | ||
74 | callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) | |
75 | callback2.Delete().Before("create").Register("before_create1", beforeCreate1) | |
76 | callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) | |
77 | callback2.Delete().Register("after_create1", afterCreate1) | |
78 | callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) | |
79 | ||
80 | if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { | |
81 | t.Errorf("register callback with order") | |
82 | } | |
83 | } | |
84 | ||
85 | func replaceCreate(s *Scope) {} | |
86 | ||
87 | func TestReplaceCallback(t *testing.T) { | |
88 | var callback = &callback{processors: []*callbackProcessor{}} | |
89 | ||
90 | callback.Create().Before("after_create1").After("before_create1").Register("create", create) | |
91 | callback.Create().Register("before_create1", beforeCreate1) | |
92 | callback.Create().Register("after_create1", afterCreate1) | |
93 | callback.Create().Replace("create", replaceCreate) | |
94 | ||
95 | if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) { | |
96 | t.Errorf("replace callback") | |
97 | } | |
98 | } | |
99 | ||
100 | func TestRemoveCallback(t *testing.T) { | |
101 | var callback = &callback{processors: []*callbackProcessor{}} | |
102 | ||
103 | callback.Create().Before("after_create1").After("before_create1").Register("create", create) | |
104 | callback.Create().Register("before_create1", beforeCreate1) | |
105 | callback.Create().Register("after_create1", afterCreate1) | |
106 | callback.Create().Remove("create") | |
107 | ||
108 | if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) { | |
109 | t.Errorf("remove callback") | |
110 | } | |
111 | } |
0 | 0 | package gorm |
1 | 1 | |
2 | 2 | import ( |
3 | "errors" | |
3 | 4 | "fmt" |
5 | "sort" | |
4 | 6 | "strings" |
5 | 7 | ) |
6 | 8 | |
7 | func AssignUpdateAttributes(scope *Scope) { | |
9 | // Define callbacks for updating | |
10 | func init() { | |
11 | DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback) | |
12 | DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) | |
13 | DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) | |
14 | DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) | |
15 | DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback) | |
16 | DefaultCallback.Update().Register("gorm:update", updateCallback) | |
17 | DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) | |
18 | DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) | |
19 | DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) | |
20 | } | |
21 | ||
22 | // assignUpdatingAttributesCallback assign updating attributes to model | |
23 | func assignUpdatingAttributesCallback(scope *Scope) { | |
8 | 24 | if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { |
9 | if maps := convertInterfaceToMap(attrs); len(maps) > 0 { | |
10 | protected, ok := scope.Get("gorm:ignore_protected_attrs") | |
11 | _, updateColumn := scope.Get("gorm:update_column") | |
12 | updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool)) | |
13 | ||
14 | if updateColumn { | |
15 | scope.InstanceSet("gorm:update_attrs", maps) | |
16 | } else if len(updateAttrs) > 0 { | |
17 | scope.InstanceSet("gorm:update_attrs", updateAttrs) | |
18 | } else if !hasUpdate { | |
19 | scope.SkipLeft() | |
20 | return | |
21 | } | |
25 | if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate { | |
26 | scope.InstanceSet("gorm:update_attrs", updateMaps) | |
27 | } else { | |
28 | scope.SkipLeft() | |
22 | 29 | } |
23 | 30 | } |
24 | 31 | } |
25 | 32 | |
26 | func BeforeUpdate(scope *Scope) { | |
33 | // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating | |
34 | func beforeUpdateCallback(scope *Scope) { | |
35 | if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { | |
36 | scope.Err(errors.New("Missing WHERE clause while updating")) | |
37 | return | |
38 | } | |
27 | 39 | if _, ok := scope.Get("gorm:update_column"); !ok { |
28 | scope.CallMethodWithErrorCheck("BeforeSave") | |
29 | scope.CallMethodWithErrorCheck("BeforeUpdate") | |
40 | if !scope.HasError() { | |
41 | scope.CallMethod("BeforeSave") | |
42 | } | |
43 | if !scope.HasError() { | |
44 | scope.CallMethod("BeforeUpdate") | |
45 | } | |
30 | 46 | } |
31 | 47 | } |
32 | 48 | |
33 | func UpdateTimeStampWhenUpdate(scope *Scope) { | |
49 | // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating | |
50 | func updateTimeStampForUpdateCallback(scope *Scope) { | |
34 | 51 | if _, ok := scope.Get("gorm:update_column"); !ok { |
35 | 52 | scope.SetColumn("UpdatedAt", NowFunc()) |
36 | 53 | } |
37 | 54 | } |
38 | 55 | |
39 | func Update(scope *Scope) { | |
56 | // updateCallback the callback used to update data to database | |
57 | func updateCallback(scope *Scope) { | |
40 | 58 | if !scope.HasError() { |
41 | 59 | var sqls []string |
42 | 60 | |
43 | 61 | if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { |
44 | for key, value := range updateAttrs.(map[string]interface{}) { | |
45 | if scope.changeableDBColumn(key) { | |
46 | sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value))) | |
47 | } | |
62 | // Sort the column names so that the generated SQL is the same every time. | |
63 | updateMap := updateAttrs.(map[string]interface{}) | |
64 | var columns []string | |
65 | for c := range updateMap { | |
66 | columns = append(columns, c) | |
67 | } | |
68 | sort.Strings(columns) | |
69 | ||
70 | for _, column := range columns { | |
71 | value := updateMap[column] | |
72 | sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) | |
48 | 73 | } |
49 | 74 | } else { |
50 | fields := scope.Fields() | |
51 | for _, field := range fields { | |
52 | if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal { | |
53 | sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) | |
54 | } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { | |
55 | for _, dbName := range relationship.ForeignDBNames { | |
56 | if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank { | |
57 | sql := fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface())) | |
58 | sqls = append(sqls, sql) | |
75 | for _, field := range scope.Fields() { | |
76 | if scope.changeableField(field) { | |
77 | if !field.IsPrimaryKey && field.IsNormal { | |
78 | sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) | |
79 | } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { | |
80 | for _, foreignKey := range relationship.ForeignDBNames { | |
81 | if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { | |
82 | sqls = append(sqls, | |
83 | fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface()))) | |
84 | } | |
59 | 85 | } |
60 | 86 | } |
61 | 87 | } |
62 | 88 | } |
63 | 89 | } |
64 | 90 | |
91 | var extraOption string | |
92 | if str, ok := scope.Get("gorm:update_option"); ok { | |
93 | extraOption = fmt.Sprint(str) | |
94 | } | |
95 | ||
65 | 96 | if len(sqls) > 0 { |
66 | 97 | scope.Raw(fmt.Sprintf( |
67 | "UPDATE %v SET %v %v", | |
98 | "UPDATE %v SET %v%v%v", | |
68 | 99 | scope.QuotedTableName(), |
69 | 100 | strings.Join(sqls, ", "), |
70 | scope.CombinedConditionSql(), | |
71 | )) | |
72 | scope.Exec() | |
101 | addExtraSpaceIfExist(scope.CombinedConditionSql()), | |
102 | addExtraSpaceIfExist(extraOption), | |
103 | )).Exec() | |
73 | 104 | } |
74 | 105 | } |
75 | 106 | } |
76 | 107 | |
77 | func AfterUpdate(scope *Scope) { | |
108 | // afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating | |
109 | func afterUpdateCallback(scope *Scope) { | |
78 | 110 | if _, ok := scope.Get("gorm:update_column"); !ok { |
79 | scope.CallMethodWithErrorCheck("AfterUpdate") | |
80 | scope.CallMethodWithErrorCheck("AfterSave") | |
111 | if !scope.HasError() { | |
112 | scope.CallMethod("AfterUpdate") | |
113 | } | |
114 | if !scope.HasError() { | |
115 | scope.CallMethod("AfterSave") | |
116 | } | |
81 | 117 | } |
82 | 118 | } |
83 | ||
84 | func init() { | |
85 | DefaultCallback.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes) | |
86 | DefaultCallback.Update().Register("gorm:begin_transaction", BeginTransaction) | |
87 | DefaultCallback.Update().Register("gorm:before_update", BeforeUpdate) | |
88 | DefaultCallback.Update().Register("gorm:save_before_associations", SaveBeforeAssociations) | |
89 | DefaultCallback.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate) | |
90 | DefaultCallback.Update().Register("gorm:update", Update) | |
91 | DefaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations) | |
92 | DefaultCallback.Update().Register("gorm:after_update", AfterUpdate) | |
93 | DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | |
94 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "reflect" | |
5 | "time" | |
6 | ) | |
7 | ||
8 | type commonDialect struct{} | |
9 | ||
10 | func (commonDialect) BinVar(i int) string { | |
11 | return "$$" // ? | |
12 | } | |
13 | ||
14 | func (commonDialect) SupportLastInsertId() bool { | |
15 | return true | |
16 | } | |
17 | ||
18 | func (commonDialect) HasTop() bool { | |
19 | return false | |
20 | } | |
21 | ||
22 | func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string { | |
23 | switch value.Kind() { | |
24 | case reflect.Bool: | |
25 | return "BOOLEAN" | |
26 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | |
27 | if autoIncrease { | |
28 | return "INTEGER AUTO_INCREMENT" | |
29 | } | |
30 | return "INTEGER" | |
31 | case reflect.Int64, reflect.Uint64: | |
32 | if autoIncrease { | |
33 | return "BIGINT AUTO_INCREMENT" | |
34 | } | |
35 | return "BIGINT" | |
36 | case reflect.Float32, reflect.Float64: | |
37 | return "FLOAT" | |
38 | case reflect.String: | |
39 | if size > 0 && size < 65532 { | |
40 | return fmt.Sprintf("VARCHAR(%d)", size) | |
41 | } | |
42 | return "VARCHAR(65532)" | |
43 | case reflect.Struct: | |
44 | if _, ok := value.Interface().(time.Time); ok { | |
45 | return "TIMESTAMP" | |
46 | } | |
47 | default: | |
48 | if _, ok := value.Interface().([]byte); ok { | |
49 | if size > 0 && size < 65532 { | |
50 | return fmt.Sprintf("BINARY(%d)", size) | |
51 | } | |
52 | return "BINARY(65532)" | |
53 | } | |
54 | } | |
55 | panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String())) | |
56 | } | |
57 | ||
58 | func (commonDialect) ReturningStr(tableName, key string) string { | |
59 | return "" | |
60 | } | |
61 | ||
62 | func (commonDialect) SelectFromDummyTable() string { | |
63 | return "" | |
64 | } | |
65 | ||
66 | func (commonDialect) Quote(key string) string { | |
67 | return fmt.Sprintf(`"%s"`, key) | |
68 | } | |
69 | ||
70 | func (c commonDialect) HasTable(scope *Scope, tableName string) bool { | |
71 | var ( | |
72 | count int | |
73 | databaseName = c.CurrentDatabase(scope) | |
74 | ) | |
75 | c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", databaseName, tableName) | |
76 | return count > 0 | |
77 | } | |
78 | ||
79 | func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { | |
80 | var ( | |
81 | count int | |
82 | databaseName = c.CurrentDatabase(scope) | |
83 | ) | |
84 | c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName) | |
85 | return count > 0 | |
86 | } | |
87 | ||
88 | func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { | |
89 | var ( | |
90 | count int | |
91 | databaseName = c.CurrentDatabase(scope) | |
92 | ) | |
93 | c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName) | |
94 | return count > 0 | |
95 | } | |
96 | ||
97 | func (commonDialect) RemoveIndex(scope *Scope, indexName string) { | |
98 | scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error) | |
99 | } | |
100 | ||
101 | // RawScanInt scans the first column of the first row into the `scan' int pointer. | |
102 | // This function captures raw query errors and propagates them to the original scope. | |
103 | func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) { | |
104 | scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr)) | |
105 | } | |
106 | ||
107 | // RawScanString scans the first column of the first row into the `scan' string pointer. | |
108 | // This function captures raw query errors and propagates them to the original scope. | |
109 | func (commonDialect) RawScanString(scope *Scope, scanPtr *string, query string, args ...interface{}) { | |
110 | scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr)) | |
111 | } | |
112 | ||
113 | func (commonDialect) CurrentDatabase(scope *Scope) (name string) { | |
114 | scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name)) | |
115 | return | |
116 | } |
0 | 0 | package gorm_test |
1 | 1 | |
2 | 2 | import ( |
3 | "os" | |
3 | 4 | "reflect" |
4 | 5 | "testing" |
5 | 6 | "time" |
7 | ||
8 | "github.com/jinzhu/now" | |
6 | 9 | ) |
7 | 10 | |
8 | 11 | func TestCreate(t *testing.T) { |
9 | 12 | float := 35.03554004971999 |
10 | user := User{Name: "CreateUser", Age: 18, Birthday: time.Now(), UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float} | |
13 | now := time.Now() | |
14 | user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float} | |
11 | 15 | |
12 | 16 | if !DB.NewRecord(user) || !DB.NewRecord(&user) { |
13 | 17 | t.Error("User should be new record before create") |
22 | 26 | } |
23 | 27 | |
24 | 28 | var newUser User |
25 | DB.First(&newUser, user.Id) | |
29 | if err := DB.First(&newUser, user.Id).Error; err != nil { | |
30 | t.Errorf("No error should happen, but got %v", err) | |
31 | } | |
26 | 32 | |
27 | 33 | if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) { |
28 | 34 | t.Errorf("User's PasswordHash should be saved ([]byte)") |
33 | 39 | } |
34 | 40 | |
35 | 41 | if newUser.UserNum != Num(111) { |
36 | t.Errorf("User's UserNum should be saved (custom type)") | |
42 | t.Errorf("User's UserNum should be saved (custom type), but got %v", newUser.UserNum) | |
37 | 43 | } |
38 | 44 | |
39 | 45 | if newUser.Latitude != float { |
50 | 56 | |
51 | 57 | DB.Model(user).Update("name", "create_user_new_name") |
52 | 58 | DB.First(&user, user.Id) |
53 | if user.CreatedAt != newUser.CreatedAt { | |
59 | if user.CreatedAt.Format(time.RFC3339Nano) != newUser.CreatedAt.Format(time.RFC3339Nano) { | |
54 | 60 | t.Errorf("CreatedAt should not be changed after update") |
55 | 61 | } |
56 | 62 | } |
57 | 63 | |
64 | func TestCreateEmptyStrut(t *testing.T) { | |
65 | type EmptyStruct struct { | |
66 | ID uint | |
67 | } | |
68 | DB.AutoMigrate(&EmptyStruct{}) | |
69 | ||
70 | if err := DB.Create(&EmptyStruct{}).Error; err != nil { | |
71 | t.Errorf("No error should happen when creating user, but got %v", err) | |
72 | } | |
73 | } | |
74 | ||
75 | func TestCreateWithExistingTimestamp(t *testing.T) { | |
76 | user := User{Name: "CreateUserExistingTimestamp"} | |
77 | ||
78 | timeA := now.MustParse("2016-01-01") | |
79 | user.CreatedAt = timeA | |
80 | user.UpdatedAt = timeA | |
81 | DB.Save(&user) | |
82 | ||
83 | if user.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { | |
84 | t.Errorf("CreatedAt should not be changed") | |
85 | } | |
86 | ||
87 | if user.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { | |
88 | t.Errorf("UpdatedAt should not be changed") | |
89 | } | |
90 | ||
91 | var newUser User | |
92 | DB.First(&newUser, user.Id) | |
93 | ||
94 | if newUser.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { | |
95 | t.Errorf("CreatedAt should not be changed") | |
96 | } | |
97 | ||
98 | if newUser.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { | |
99 | t.Errorf("UpdatedAt should not be changed") | |
100 | } | |
101 | } | |
102 | ||
103 | type AutoIncrementUser struct { | |
104 | User | |
105 | Sequence uint `gorm:"AUTO_INCREMENT"` | |
106 | } | |
107 | ||
108 | func TestCreateWithAutoIncrement(t *testing.T) { | |
109 | if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { | |
110 | t.Skip("Skipping this because only postgres properly support auto_increment on a non-primary_key column") | |
111 | } | |
112 | ||
113 | DB.AutoMigrate(&AutoIncrementUser{}) | |
114 | ||
115 | user1 := AutoIncrementUser{} | |
116 | user2 := AutoIncrementUser{} | |
117 | ||
118 | DB.Create(&user1) | |
119 | DB.Create(&user2) | |
120 | ||
121 | if user2.Sequence-user1.Sequence != 1 { | |
122 | t.Errorf("Auto increment should apply on Sequence") | |
123 | } | |
124 | } | |
125 | ||
58 | 126 | func TestCreateWithNoGORMPrimayKey(t *testing.T) { |
127 | if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" { | |
128 | t.Skip("Skipping this because MSSQL will return identity only if the table has an Id column") | |
129 | } | |
130 | ||
59 | 131 | jt := JoinTable{From: 1, To: 2} |
60 | 132 | err := DB.Create(&jt).Error |
61 | 133 | if err != nil { |
86 | 158 | |
87 | 159 | // We must fetch the value again, to have the default fields updated |
88 | 160 | // (We can't do this in the update statements, since sql default can be expressions |
89 | // And be different from the fields' type (eg. a time.Time fiels has a default value of "now()" | |
161 | // And be different from the fields' type (eg. a time.Time fields has a default value of "now()" | |
90 | 162 | DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an) |
91 | 163 | |
92 | 164 | if an.Name != "galeone" { |
104 | 176 | t.Errorf("Should be able to get anonymous scanner") |
105 | 177 | } |
106 | 178 | |
107 | if !user2.IsAdmin() { | |
179 | if !user2.Role.IsAdmin() { | |
108 | 180 | t.Errorf("Should be able to get anonymous scanner") |
109 | 181 | } |
110 | 182 | } |
153 | 225 | |
154 | 226 | if queryuser.BillingAddressID.Int64 != 0 || queryuser.ShippingAddressId == 0 || |
155 | 227 | queryuser.CreditCard.ID != 0 || len(queryuser.Emails) != 0 { |
156 | t.Errorf("Should not create omited relationships") | |
157 | } | |
158 | } | |
228 | t.Errorf("Should not create omitted relationships") | |
229 | } | |
230 | } |
2 | 2 | import ( |
3 | 3 | "testing" |
4 | 4 | "time" |
5 | ||
6 | "github.com/jinzhu/gorm" | |
5 | 7 | ) |
6 | 8 | |
7 | 9 | type CustomizeColumn struct { |
8 | ID int64 `gorm:"column:mapped_id; primary_key:yes"` | |
9 | Name string `gorm:"column:mapped_name"` | |
10 | Date time.Time `gorm:"column:mapped_time"` | |
10 | ID int64 `gorm:"column:mapped_id; primary_key:yes"` | |
11 | Name string `gorm:"column:mapped_name"` | |
12 | Date *time.Time `gorm:"column:mapped_time"` | |
11 | 13 | } |
12 | 14 | |
13 | 15 | // Make sure an ignored field does not interfere with another field's custom |
23 | 25 | DB.AutoMigrate(&CustomizeColumn{}) |
24 | 26 | |
25 | 27 | scope := DB.NewScope(&CustomizeColumn{}) |
26 | if !scope.Dialect().HasColumn(scope, scope.TableName(), col) { | |
28 | if !scope.Dialect().HasColumn(scope.TableName(), col) { | |
27 | 29 | t.Errorf("CustomizeColumn should have column %s", col) |
28 | 30 | } |
29 | 31 | |
33 | 35 | } |
34 | 36 | |
35 | 37 | expected := "foo" |
36 | cc := CustomizeColumn{ID: 666, Name: expected, Date: time.Now()} | |
38 | now := time.Now() | |
39 | cc := CustomizeColumn{ID: 666, Name: expected, Date: &now} | |
37 | 40 | |
38 | 41 | if count := DB.Create(&cc).RowsAffected; count != 1 { |
39 | 42 | t.Error("There should be one record be affected when create record") |
62 | 65 | t.Errorf("Should not raise error: %s", err) |
63 | 66 | } |
64 | 67 | } |
68 | ||
69 | type CustomizePerson struct { | |
70 | IdPerson string `gorm:"column:idPerson;primary_key:true"` | |
71 | Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"` | |
72 | } | |
73 | ||
74 | type CustomizeAccount struct { | |
75 | IdAccount string `gorm:"column:idAccount;primary_key:true"` | |
76 | Name string | |
77 | } | |
78 | ||
79 | func TestManyToManyWithCustomizedColumn(t *testing.T) { | |
80 | DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount") | |
81 | DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{}) | |
82 | ||
83 | account := CustomizeAccount{IdAccount: "account", Name: "id1"} | |
84 | person := CustomizePerson{ | |
85 | IdPerson: "person", | |
86 | Accounts: []CustomizeAccount{account}, | |
87 | } | |
88 | ||
89 | if err := DB.Create(&account).Error; err != nil { | |
90 | t.Errorf("no error should happen, but got %v", err) | |
91 | } | |
92 | ||
93 | if err := DB.Create(&person).Error; err != nil { | |
94 | t.Errorf("no error should happen, but got %v", err) | |
95 | } | |
96 | ||
97 | var person1 CustomizePerson | |
98 | scope := DB.NewScope(nil) | |
99 | if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil { | |
100 | t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err) | |
101 | } | |
102 | ||
103 | if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" { | |
104 | t.Errorf("should preload correct accounts") | |
105 | } | |
106 | } | |
107 | ||
108 | type CustomizeUser struct { | |
109 | gorm.Model | |
110 | Email string `sql:"column:email_address"` | |
111 | } | |
112 | ||
113 | type CustomizeInvitation struct { | |
114 | gorm.Model | |
115 | Address string `sql:"column:invitation"` | |
116 | Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"` | |
117 | } | |
118 | ||
119 | func TestOneToOneWithCustomizedColumn(t *testing.T) { | |
120 | DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{}) | |
121 | DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{}) | |
122 | ||
123 | user := CustomizeUser{ | |
124 | Email: "hello@example.com", | |
125 | } | |
126 | invitation := CustomizeInvitation{ | |
127 | Address: "hello@example.com", | |
128 | } | |
129 | ||
130 | DB.Create(&user) | |
131 | DB.Create(&invitation) | |
132 | ||
133 | var invitation2 CustomizeInvitation | |
134 | if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil { | |
135 | t.Errorf("no error should happen, but got %v", err) | |
136 | } | |
137 | ||
138 | if invitation2.Person.Email != user.Email { | |
139 | t.Errorf("Should preload one to one relation with customize foreign keys") | |
140 | } | |
141 | } | |
142 | ||
143 | type PromotionDiscount struct { | |
144 | gorm.Model | |
145 | Name string | |
146 | Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"` | |
147 | Rule *PromotionRule `gorm:"ForeignKey:discount_id"` | |
148 | Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"` | |
149 | } | |
150 | ||
151 | type PromotionBenefit struct { | |
152 | gorm.Model | |
153 | Name string | |
154 | PromotionID uint | |
155 | Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"` | |
156 | } | |
157 | ||
158 | type PromotionCoupon struct { | |
159 | gorm.Model | |
160 | Code string | |
161 | DiscountID uint | |
162 | Discount PromotionDiscount | |
163 | } | |
164 | ||
165 | type PromotionRule struct { | |
166 | gorm.Model | |
167 | Name string | |
168 | Begin *time.Time | |
169 | End *time.Time | |
170 | DiscountID uint | |
171 | Discount *PromotionDiscount | |
172 | } | |
173 | ||
174 | func TestOneToManyWithCustomizedColumn(t *testing.T) { | |
175 | DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{}) | |
176 | DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{}) | |
177 | ||
178 | discount := PromotionDiscount{ | |
179 | Name: "Happy New Year", | |
180 | Coupons: []*PromotionCoupon{ | |
181 | {Code: "newyear1"}, | |
182 | {Code: "newyear2"}, | |
183 | }, | |
184 | } | |
185 | ||
186 | if err := DB.Create(&discount).Error; err != nil { | |
187 | t.Errorf("no error should happen but got %v", err) | |
188 | } | |
189 | ||
190 | var discount1 PromotionDiscount | |
191 | if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil { | |
192 | t.Errorf("no error should happen but got %v", err) | |
193 | } | |
194 | ||
195 | if len(discount.Coupons) != 2 { | |
196 | t.Errorf("should find two coupons") | |
197 | } | |
198 | ||
199 | var coupon PromotionCoupon | |
200 | if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil { | |
201 | t.Errorf("no error should happen but got %v", err) | |
202 | } | |
203 | ||
204 | if coupon.Discount.Name != "Happy New Year" { | |
205 | t.Errorf("should preload discount from coupon") | |
206 | } | |
207 | } | |
208 | ||
209 | func TestHasOneWithPartialCustomizedColumn(t *testing.T) { | |
210 | DB.DropTable(&PromotionDiscount{}, &PromotionRule{}) | |
211 | DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{}) | |
212 | ||
213 | var begin = time.Now() | |
214 | var end = time.Now().Add(24 * time.Hour) | |
215 | discount := PromotionDiscount{ | |
216 | Name: "Happy New Year 2", | |
217 | Rule: &PromotionRule{ | |
218 | Name: "time_limited", | |
219 | Begin: &begin, | |
220 | End: &end, | |
221 | }, | |
222 | } | |
223 | ||
224 | if err := DB.Create(&discount).Error; err != nil { | |
225 | t.Errorf("no error should happen but got %v", err) | |
226 | } | |
227 | ||
228 | var discount1 PromotionDiscount | |
229 | if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil { | |
230 | t.Errorf("no error should happen but got %v", err) | |
231 | } | |
232 | ||
233 | if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) { | |
234 | t.Errorf("Should be able to preload Rule") | |
235 | } | |
236 | ||
237 | var rule PromotionRule | |
238 | if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil { | |
239 | t.Errorf("no error should happen but got %v", err) | |
240 | } | |
241 | ||
242 | if rule.Discount.Name != "Happy New Year 2" { | |
243 | t.Errorf("should preload discount from rule") | |
244 | } | |
245 | } | |
246 | ||
247 | func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { | |
248 | DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{}) | |
249 | DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{}) | |
250 | ||
251 | discount := PromotionDiscount{ | |
252 | Name: "Happy New Year 3", | |
253 | Benefits: []PromotionBenefit{ | |
254 | {Name: "free cod"}, | |
255 | {Name: "free shipping"}, | |
256 | }, | |
257 | } | |
258 | ||
259 | if err := DB.Create(&discount).Error; err != nil { | |
260 | t.Errorf("no error should happen but got %v", err) | |
261 | } | |
262 | ||
263 | var discount1 PromotionDiscount | |
264 | if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil { | |
265 | t.Errorf("no error should happen but got %v", err) | |
266 | } | |
267 | ||
268 | if len(discount.Benefits) != 2 { | |
269 | t.Errorf("should find two benefits") | |
270 | } | |
271 | ||
272 | var benefit PromotionBenefit | |
273 | if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil { | |
274 | t.Errorf("no error should happen but got %v", err) | |
275 | } | |
276 | ||
277 | if benefit.Discount.Name != "Happy New Year 3" { | |
278 | t.Errorf("should preload discount from coupon") | |
279 | } | |
280 | } | |
281 | ||
282 | type SelfReferencingUser struct { | |
283 | gorm.Model | |
284 | Name string | |
285 | Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"` | |
286 | } | |
287 | ||
288 | func TestSelfReferencingMany2ManyColumn(t *testing.T) { | |
289 | DB.DropTable(&SelfReferencingUser{}, "UserFriends") | |
290 | DB.AutoMigrate(&SelfReferencingUser{}) | |
291 | ||
292 | friend1 := SelfReferencingUser{Name: "friend1_m2m"} | |
293 | if err := DB.Create(&friend1).Error; err != nil { | |
294 | t.Errorf("no error should happen, but got %v", err) | |
295 | } | |
296 | ||
297 | friend2 := SelfReferencingUser{Name: "friend2_m2m"} | |
298 | if err := DB.Create(&friend2).Error; err != nil { | |
299 | t.Errorf("no error should happen, but got %v", err) | |
300 | } | |
301 | ||
302 | user := SelfReferencingUser{ | |
303 | Name: "self_m2m", | |
304 | Friends: []*SelfReferencingUser{&friend1, &friend2}, | |
305 | } | |
306 | ||
307 | if err := DB.Create(&user).Error; err != nil { | |
308 | t.Errorf("no error should happen, but got %v", err) | |
309 | } | |
310 | ||
311 | if DB.Model(&user).Association("Friends").Count() != 2 { | |
312 | t.Errorf("Should find created friends correctly") | |
313 | } | |
314 | ||
315 | var newUser = SelfReferencingUser{} | |
316 | ||
317 | if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil { | |
318 | t.Errorf("no error should happen, but got %v", err) | |
319 | } | |
320 | ||
321 | if len(newUser.Friends) != 2 { | |
322 | t.Errorf("Should preload created frineds for self reference m2m") | |
323 | } | |
324 | ||
325 | DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"}) | |
326 | if DB.Model(&user).Association("Friends").Count() != 3 { | |
327 | t.Errorf("Should find created friends correctly") | |
328 | } | |
329 | ||
330 | DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"}) | |
331 | if DB.Model(&user).Association("Friends").Count() != 1 { | |
332 | t.Errorf("Should find created friends correctly") | |
333 | } | |
334 | ||
335 | friend := SelfReferencingUser{} | |
336 | DB.Model(&newUser).Association("Friends").Find(&friend) | |
337 | if friend.Name != "friend4_m2m" { | |
338 | t.Errorf("Should find created friends correctly") | |
339 | } | |
340 | ||
341 | DB.Model(&newUser).Association("Friends").Delete(friend) | |
342 | if DB.Model(&user).Association("Friends").Count() != 0 { | |
343 | t.Errorf("All friends should be deleted") | |
344 | } | |
345 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "testing" | |
4 | ) | |
5 | ||
6 | func TestDdlErrors(t *testing.T) { | |
7 | var err error | |
8 | ||
9 | if err = DB.Close(); err != nil { | |
10 | t.Errorf("Closing DDL test db connection err=%s", err) | |
11 | } | |
12 | defer func() { | |
13 | // Reopen DB connection. | |
14 | if DB, err = OpenTestConnection(); err != nil { | |
15 | t.Fatalf("Failed re-opening db connection: %s", err) | |
16 | } | |
17 | }() | |
18 | ||
19 | DB.HasTable("foobarbaz") | |
20 | if DB.Error == nil { | |
21 | t.Errorf("Expected operation on closed db to produce an error, but err was nil") | |
22 | } | |
23 | } |
44 | 44 | type User struct { |
45 | 45 | Id int64 |
46 | 46 | Name string |
47 | DeletedAt time.Time | |
47 | DeletedAt *time.Time | |
48 | 48 | } |
49 | 49 | DB.AutoMigrate(&User{}) |
50 | 50 | |
65 | 65 | t.Errorf("Can't find permanently deleted record") |
66 | 66 | } |
67 | 67 | } |
68 | ||
69 | func TestSoftDeleteWithCustomizedDeletedAtColumnName(t *testing.T) { | |
70 | creditCard := CreditCard{Number: "411111111234567"} | |
71 | DB.Save(&creditCard) | |
72 | DB.Delete(&creditCard) | |
73 | ||
74 | if deletedAtField, ok := DB.NewScope(&CreditCard{}).FieldByName("DeletedAt"); !ok || deletedAtField.DBName != "deleted_time" { | |
75 | t.Errorf("CreditCard's DeletedAt's column name should be `deleted_time`") | |
76 | } | |
77 | ||
78 | if DB.First(&CreditCard{}, "number = ?", creditCard.Number).Error == nil { | |
79 | t.Errorf("Can't find a soft deleted record") | |
80 | } | |
81 | ||
82 | if err := DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).Error; err != nil { | |
83 | t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err) | |
84 | } | |
85 | ||
86 | DB.Unscoped().Delete(&creditCard) | |
87 | if !DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).RecordNotFound() { | |
88 | t.Errorf("Can't find permanently deleted record") | |
89 | } | |
90 | } |
0 | 0 | package gorm |
1 | 1 | |
2 | 2 | import ( |
3 | "database/sql" | |
3 | 4 | "fmt" |
4 | 5 | "reflect" |
6 | "strconv" | |
7 | "strings" | |
5 | 8 | ) |
6 | 9 | |
10 | // Dialect interface contains behaviors that differ across SQL database | |
7 | 11 | type Dialect interface { |
8 | BinVar(i int) string | |
9 | SupportLastInsertId() bool | |
10 | HasTop() bool | |
11 | SqlTag(value reflect.Value, size int, autoIncrease bool) string | |
12 | ReturningStr(tableName, key string) string | |
12 | // GetName get dialect's name | |
13 | GetName() string | |
14 | ||
15 | // SetDB set db for dialect | |
16 | SetDB(db SQLCommon) | |
17 | ||
18 | // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 | |
19 | BindVar(i int) string | |
20 | // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name | |
21 | Quote(key string) string | |
22 | // DataTypeOf return data's sql type | |
23 | DataTypeOf(field *StructField) string | |
24 | ||
25 | // HasIndex check has index or not | |
26 | HasIndex(tableName string, indexName string) bool | |
27 | // HasForeignKey check has foreign key or not | |
28 | HasForeignKey(tableName string, foreignKeyName string) bool | |
29 | // RemoveIndex remove index | |
30 | RemoveIndex(tableName string, indexName string) error | |
31 | // HasTable check has table or not | |
32 | HasTable(tableName string) bool | |
33 | // HasColumn check has column or not | |
34 | HasColumn(tableName string, columnName string) bool | |
35 | // ModifyColumn modify column's type | |
36 | ModifyColumn(tableName string, columnName string, typ string) error | |
37 | ||
38 | // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case | |
39 | LimitAndOffsetSQL(limit, offset interface{}) string | |
40 | // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` | |
13 | 41 | SelectFromDummyTable() string |
14 | Quote(key string) string | |
15 | HasTable(scope *Scope, tableName string) bool | |
16 | HasColumn(scope *Scope, tableName string, columnName string) bool | |
17 | HasIndex(scope *Scope, tableName string, indexName string) bool | |
18 | RemoveIndex(scope *Scope, indexName string) | |
19 | CurrentDatabase(scope *Scope) string | |
42 | // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` | |
43 | LastInsertIDReturningSuffix(tableName, columnName string) string | |
44 | // DefaultValueStr | |
45 | DefaultValueStr() string | |
46 | ||
47 | // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference | |
48 | BuildKeyName(kind, tableName string, fields ...string) string | |
49 | ||
50 | // CurrentDatabase return current database name | |
51 | CurrentDatabase() string | |
20 | 52 | } |
21 | 53 | |
22 | func NewDialect(driver string) Dialect { | |
23 | var d Dialect | |
24 | switch driver { | |
25 | case "postgres": | |
26 | d = &postgres{} | |
27 | case "foundation": | |
28 | d = &foundation{} | |
29 | case "mysql": | |
30 | d = &mysql{} | |
31 | case "sqlite3": | |
32 | d = &sqlite3{} | |
33 | case "mssql": | |
34 | d = &mssql{} | |
35 | default: | |
36 | fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver) | |
37 | d = &commonDialect{} | |
54 | var dialectsMap = map[string]Dialect{} | |
55 | ||
56 | func newDialect(name string, db SQLCommon) Dialect { | |
57 | if value, ok := dialectsMap[name]; ok { | |
58 | dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) | |
59 | dialect.SetDB(db) | |
60 | return dialect | |
38 | 61 | } |
39 | return d | |
62 | ||
63 | fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name) | |
64 | commontDialect := &commonDialect{} | |
65 | commontDialect.SetDB(db) | |
66 | return commontDialect | |
40 | 67 | } |
68 | ||
69 | // RegisterDialect register new dialect | |
70 | func RegisterDialect(name string, dialect Dialect) { | |
71 | dialectsMap[name] = dialect | |
72 | } | |
73 | ||
74 | // ParseFieldStructForDialect get field's sql data type | |
75 | var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { | |
76 | // Get redirected field type | |
77 | var ( | |
78 | reflectType = field.Struct.Type | |
79 | dataType = field.TagSettings["TYPE"] | |
80 | ) | |
81 | ||
82 | for reflectType.Kind() == reflect.Ptr { | |
83 | reflectType = reflectType.Elem() | |
84 | } | |
85 | ||
86 | // Get redirected field value | |
87 | fieldValue = reflect.Indirect(reflect.New(reflectType)) | |
88 | ||
89 | if gormDataType, ok := fieldValue.Interface().(interface { | |
90 | GormDataType(Dialect) string | |
91 | }); ok { | |
92 | dataType = gormDataType.GormDataType(dialect) | |
93 | } | |
94 | ||
95 | // Get scanner's real value | |
96 | if dataType == "" { | |
97 | var getScannerValue func(reflect.Value) | |
98 | getScannerValue = func(value reflect.Value) { | |
99 | fieldValue = value | |
100 | if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { | |
101 | getScannerValue(fieldValue.Field(0)) | |
102 | } | |
103 | } | |
104 | getScannerValue(fieldValue) | |
105 | } | |
106 | ||
107 | // Default Size | |
108 | if num, ok := field.TagSettings["SIZE"]; ok { | |
109 | size, _ = strconv.Atoi(num) | |
110 | } else { | |
111 | size = 255 | |
112 | } | |
113 | ||
114 | // Default type from tag setting | |
115 | additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] | |
116 | if value, ok := field.TagSettings["DEFAULT"]; ok { | |
117 | additionalType = additionalType + " DEFAULT " + value | |
118 | } | |
119 | ||
120 | return fieldValue, dataType, size, strings.TrimSpace(additionalType) | |
121 | } | |
122 | ||
123 | func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { | |
124 | if strings.Contains(tableName, ".") { | |
125 | splitStrings := strings.SplitN(tableName, ".", 2) | |
126 | return splitStrings[0], splitStrings[1] | |
127 | } | |
128 | return dialect.CurrentDatabase(), tableName | |
129 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "reflect" | |
5 | "regexp" | |
6 | "strconv" | |
7 | "strings" | |
8 | "time" | |
9 | ) | |
10 | ||
11 | // DefaultForeignKeyNamer contains the default foreign key name generator method | |
12 | type DefaultForeignKeyNamer struct { | |
13 | } | |
14 | ||
15 | type commonDialect struct { | |
16 | db SQLCommon | |
17 | DefaultForeignKeyNamer | |
18 | } | |
19 | ||
20 | func init() { | |
21 | RegisterDialect("common", &commonDialect{}) | |
22 | } | |
23 | ||
24 | func (commonDialect) GetName() string { | |
25 | return "common" | |
26 | } | |
27 | ||
28 | func (s *commonDialect) SetDB(db SQLCommon) { | |
29 | s.db = db | |
30 | } | |
31 | ||
32 | func (commonDialect) BindVar(i int) string { | |
33 | return "$$$" // ? | |
34 | } | |
35 | ||
36 | func (commonDialect) Quote(key string) string { | |
37 | return fmt.Sprintf(`"%s"`, key) | |
38 | } | |
39 | ||
40 | func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { | |
41 | if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { | |
42 | return strings.ToLower(value) != "false" | |
43 | } | |
44 | return field.IsPrimaryKey | |
45 | } | |
46 | ||
47 | func (s *commonDialect) DataTypeOf(field *StructField) string { | |
48 | var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) | |
49 | ||
50 | if sqlType == "" { | |
51 | switch dataValue.Kind() { | |
52 | case reflect.Bool: | |
53 | sqlType = "BOOLEAN" | |
54 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | |
55 | if s.fieldCanAutoIncrement(field) { | |
56 | sqlType = "INTEGER AUTO_INCREMENT" | |
57 | } else { | |
58 | sqlType = "INTEGER" | |
59 | } | |
60 | case reflect.Int64, reflect.Uint64: | |
61 | if s.fieldCanAutoIncrement(field) { | |
62 | sqlType = "BIGINT AUTO_INCREMENT" | |
63 | } else { | |
64 | sqlType = "BIGINT" | |
65 | } | |
66 | case reflect.Float32, reflect.Float64: | |
67 | sqlType = "FLOAT" | |
68 | case reflect.String: | |
69 | if size > 0 && size < 65532 { | |
70 | sqlType = fmt.Sprintf("VARCHAR(%d)", size) | |
71 | } else { | |
72 | sqlType = "VARCHAR(65532)" | |
73 | } | |
74 | case reflect.Struct: | |
75 | if _, ok := dataValue.Interface().(time.Time); ok { | |
76 | sqlType = "TIMESTAMP" | |
77 | } | |
78 | default: | |
79 | if _, ok := dataValue.Interface().([]byte); ok { | |
80 | if size > 0 && size < 65532 { | |
81 | sqlType = fmt.Sprintf("BINARY(%d)", size) | |
82 | } else { | |
83 | sqlType = "BINARY(65532)" | |
84 | } | |
85 | } | |
86 | } | |
87 | } | |
88 | ||
89 | if sqlType == "" { | |
90 | panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String())) | |
91 | } | |
92 | ||
93 | if strings.TrimSpace(additionalType) == "" { | |
94 | return sqlType | |
95 | } | |
96 | return fmt.Sprintf("%v %v", sqlType, additionalType) | |
97 | } | |
98 | ||
99 | func (s commonDialect) HasIndex(tableName string, indexName string) bool { | |
100 | var count int | |
101 | currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) | |
102 | s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) | |
103 | return count > 0 | |
104 | } | |
105 | ||
106 | func (s commonDialect) RemoveIndex(tableName string, indexName string) error { | |
107 | _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) | |
108 | return err | |
109 | } | |
110 | ||
111 | func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { | |
112 | return false | |
113 | } | |
114 | ||
115 | func (s commonDialect) HasTable(tableName string) bool { | |
116 | var count int | |
117 | currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) | |
118 | s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) | |
119 | return count > 0 | |
120 | } | |
121 | ||
122 | func (s commonDialect) HasColumn(tableName string, columnName string) bool { | |
123 | var count int | |
124 | currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) | |
125 | s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) | |
126 | return count > 0 | |
127 | } | |
128 | ||
129 | func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { | |
130 | _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) | |
131 | return err | |
132 | } | |
133 | ||
134 | func (s commonDialect) CurrentDatabase() (name string) { | |
135 | s.db.QueryRow("SELECT DATABASE()").Scan(&name) | |
136 | return | |
137 | } | |
138 | ||
139 | func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { | |
140 | if limit != nil { | |
141 | if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { | |
142 | sql += fmt.Sprintf(" LIMIT %d", parsedLimit) | |
143 | } | |
144 | } | |
145 | if offset != nil { | |
146 | if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { | |
147 | sql += fmt.Sprintf(" OFFSET %d", parsedOffset) | |
148 | } | |
149 | } | |
150 | return | |
151 | } | |
152 | ||
153 | func (commonDialect) SelectFromDummyTable() string { | |
154 | return "" | |
155 | } | |
156 | ||
157 | func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { | |
158 | return "" | |
159 | } | |
160 | ||
161 | func (commonDialect) DefaultValueStr() string { | |
162 | return "DEFAULT VALUES" | |
163 | } | |
164 | ||
165 | // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference | |
166 | func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { | |
167 | keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) | |
168 | keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_") | |
169 | return keyName | |
170 | } | |
171 | ||
172 | // IsByteArrayOrSlice returns true of the reflected value is an array or slice | |
173 | func IsByteArrayOrSlice(value reflect.Value) bool { | |
174 | return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) | |
175 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "crypto/sha1" | |
4 | "fmt" | |
5 | "reflect" | |
6 | "regexp" | |
7 | "strconv" | |
8 | "strings" | |
9 | "time" | |
10 | "unicode/utf8" | |
11 | ) | |
12 | ||
13 | type mysql struct { | |
14 | commonDialect | |
15 | } | |
16 | ||
17 | func init() { | |
18 | RegisterDialect("mysql", &mysql{}) | |
19 | } | |
20 | ||
21 | func (mysql) GetName() string { | |
22 | return "mysql" | |
23 | } | |
24 | ||
25 | func (mysql) Quote(key string) string { | |
26 | return fmt.Sprintf("`%s`", key) | |
27 | } | |
28 | ||
29 | // Get Data Type for MySQL Dialect | |
30 | func (s *mysql) DataTypeOf(field *StructField) string { | |
31 | var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) | |
32 | ||
33 | // MySQL allows only one auto increment column per table, and it must | |
34 | // be a KEY column. | |
35 | if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { | |
36 | if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey { | |
37 | delete(field.TagSettings, "AUTO_INCREMENT") | |
38 | } | |
39 | } | |
40 | ||
41 | if sqlType == "" { | |
42 | switch dataValue.Kind() { | |
43 | case reflect.Bool: | |
44 | sqlType = "boolean" | |
45 | case reflect.Int8: | |
46 | if s.fieldCanAutoIncrement(field) { | |
47 | field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" | |
48 | sqlType = "tinyint AUTO_INCREMENT" | |
49 | } else { | |
50 | sqlType = "tinyint" | |
51 | } | |
52 | case reflect.Int, reflect.Int16, reflect.Int32: | |
53 | if s.fieldCanAutoIncrement(field) { | |
54 | field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" | |
55 | sqlType = "int AUTO_INCREMENT" | |
56 | } else { | |
57 | sqlType = "int" | |
58 | } | |
59 | case reflect.Uint8: | |
60 | if s.fieldCanAutoIncrement(field) { | |
61 | field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" | |
62 | sqlType = "tinyint unsigned AUTO_INCREMENT" | |
63 | } else { | |
64 | sqlType = "tinyint unsigned" | |
65 | } | |
66 | case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | |
67 | if s.fieldCanAutoIncrement(field) { | |
68 | field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" | |
69 | sqlType = "int unsigned AUTO_INCREMENT" | |
70 | } else { | |
71 | sqlType = "int unsigned" | |
72 | } | |
73 | case reflect.Int64: | |
74 | if s.fieldCanAutoIncrement(field) { | |
75 | field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" | |
76 | sqlType = "bigint AUTO_INCREMENT" | |
77 | } else { | |
78 | sqlType = "bigint" | |
79 | } | |
80 | case reflect.Uint64: | |
81 | if s.fieldCanAutoIncrement(field) { | |
82 | field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" | |
83 | sqlType = "bigint unsigned AUTO_INCREMENT" | |
84 | } else { | |
85 | sqlType = "bigint unsigned" | |
86 | } | |
87 | case reflect.Float32, reflect.Float64: | |
88 | sqlType = "double" | |
89 | case reflect.String: | |
90 | if size > 0 && size < 65532 { | |
91 | sqlType = fmt.Sprintf("varchar(%d)", size) | |
92 | } else { | |
93 | sqlType = "longtext" | |
94 | } | |
95 | case reflect.Struct: | |
96 | if _, ok := dataValue.Interface().(time.Time); ok { | |
97 | precision := "" | |
98 | if p, ok := field.TagSettings["PRECISION"]; ok { | |
99 | precision = fmt.Sprintf("(%s)", p) | |
100 | } | |
101 | ||
102 | if _, ok := field.TagSettings["NOT NULL"]; ok { | |
103 | sqlType = fmt.Sprintf("timestamp%v", precision) | |
104 | } else { | |
105 | sqlType = fmt.Sprintf("timestamp%v NULL", precision) | |
106 | } | |
107 | } | |
108 | default: | |
109 | if IsByteArrayOrSlice(dataValue) { | |
110 | if size > 0 && size < 65532 { | |
111 | sqlType = fmt.Sprintf("varbinary(%d)", size) | |
112 | } else { | |
113 | sqlType = "longblob" | |
114 | } | |
115 | } | |
116 | } | |
117 | } | |
118 | ||
119 | if sqlType == "" { | |
120 | panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String())) | |
121 | } | |
122 | ||
123 | if strings.TrimSpace(additionalType) == "" { | |
124 | return sqlType | |
125 | } | |
126 | return fmt.Sprintf("%v %v", sqlType, additionalType) | |
127 | } | |
128 | ||
129 | func (s mysql) RemoveIndex(tableName string, indexName string) error { | |
130 | _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) | |
131 | return err | |
132 | } | |
133 | ||
134 | func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error { | |
135 | _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) | |
136 | return err | |
137 | } | |
138 | ||
139 | func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { | |
140 | if limit != nil { | |
141 | if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { | |
142 | sql += fmt.Sprintf(" LIMIT %d", parsedLimit) | |
143 | ||
144 | if offset != nil { | |
145 | if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { | |
146 | sql += fmt.Sprintf(" OFFSET %d", parsedOffset) | |
147 | } | |
148 | } | |
149 | } | |
150 | } | |
151 | return | |
152 | } | |
153 | ||
154 | func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { | |
155 | var count int | |
156 | currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) | |
157 | s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) | |
158 | return count > 0 | |
159 | } | |
160 | ||
161 | func (s mysql) CurrentDatabase() (name string) { | |
162 | s.db.QueryRow("SELECT DATABASE()").Scan(&name) | |
163 | return | |
164 | } | |
165 | ||
166 | func (mysql) SelectFromDummyTable() string { | |
167 | return "FROM DUAL" | |
168 | } | |
169 | ||
170 | func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { | |
171 | keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...) | |
172 | if utf8.RuneCountInString(keyName) <= 64 { | |
173 | return keyName | |
174 | } | |
175 | h := sha1.New() | |
176 | h.Write([]byte(keyName)) | |
177 | bs := h.Sum(nil) | |
178 | ||
179 | // sha1 is 40 characters, keep first 24 characters of destination | |
180 | destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_")) | |
181 | if len(destRunes) > 24 { | |
182 | destRunes = destRunes[:24] | |
183 | } | |
184 | ||
185 | return fmt.Sprintf("%s%x", string(destRunes), bs) | |
186 | } | |
187 | ||
188 | func (mysql) DefaultValueStr() string { | |
189 | return "VALUES()" | |
190 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "encoding/json" | |
4 | "fmt" | |
5 | "reflect" | |
6 | "strings" | |
7 | "time" | |
8 | ) | |
9 | ||
10 | type postgres struct { | |
11 | commonDialect | |
12 | } | |
13 | ||
14 | func init() { | |
15 | RegisterDialect("postgres", &postgres{}) | |
16 | RegisterDialect("cloudsqlpostgres", &postgres{}) | |
17 | } | |
18 | ||
19 | func (postgres) GetName() string { | |
20 | return "postgres" | |
21 | } | |
22 | ||
23 | func (postgres) BindVar(i int) string { | |
24 | return fmt.Sprintf("$%v", i) | |
25 | } | |
26 | ||
27 | func (s *postgres) DataTypeOf(field *StructField) string { | |
28 | var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) | |
29 | ||
30 | if sqlType == "" { | |
31 | switch dataValue.Kind() { | |
32 | case reflect.Bool: | |
33 | sqlType = "boolean" | |
34 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: | |
35 | if s.fieldCanAutoIncrement(field) { | |
36 | field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" | |
37 | sqlType = "serial" | |
38 | } else { | |
39 | sqlType = "integer" | |
40 | } | |
41 | case reflect.Int64, reflect.Uint32, reflect.Uint64: | |
42 | if s.fieldCanAutoIncrement(field) { | |
43 | field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" | |
44 | sqlType = "bigserial" | |
45 | } else { | |
46 | sqlType = "bigint" | |
47 | } | |
48 | case reflect.Float32, reflect.Float64: | |
49 | sqlType = "numeric" | |
50 | case reflect.String: | |
51 | if _, ok := field.TagSettings["SIZE"]; !ok { | |
52 | size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different | |
53 | } | |
54 | ||
55 | if size > 0 && size < 65532 { | |
56 | sqlType = fmt.Sprintf("varchar(%d)", size) | |
57 | } else { | |
58 | sqlType = "text" | |
59 | } | |
60 | case reflect.Struct: | |
61 | if _, ok := dataValue.Interface().(time.Time); ok { | |
62 | sqlType = "timestamp with time zone" | |
63 | } | |
64 | case reflect.Map: | |
65 | if dataValue.Type().Name() == "Hstore" { | |
66 | sqlType = "hstore" | |
67 | } | |
68 | default: | |
69 | if IsByteArrayOrSlice(dataValue) { | |
70 | sqlType = "bytea" | |
71 | ||
72 | if isUUID(dataValue) { | |
73 | sqlType = "uuid" | |
74 | } | |
75 | ||
76 | if isJSON(dataValue) { | |
77 | sqlType = "jsonb" | |
78 | } | |
79 | } | |
80 | } | |
81 | } | |
82 | ||
83 | if sqlType == "" { | |
84 | panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String())) | |
85 | } | |
86 | ||
87 | if strings.TrimSpace(additionalType) == "" { | |
88 | return sqlType | |
89 | } | |
90 | return fmt.Sprintf("%v %v", sqlType, additionalType) | |
91 | } | |
92 | ||
93 | func (s postgres) HasIndex(tableName string, indexName string) bool { | |
94 | var count int | |
95 | s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count) | |
96 | return count > 0 | |
97 | } | |
98 | ||
99 | func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { | |
100 | var count int | |
101 | s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) | |
102 | return count > 0 | |
103 | } | |
104 | ||
105 | func (s postgres) HasTable(tableName string) bool { | |
106 | var count int | |
107 | s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count) | |
108 | return count > 0 | |
109 | } | |
110 | ||
111 | func (s postgres) HasColumn(tableName string, columnName string) bool { | |
112 | var count int | |
113 | s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count) | |
114 | return count > 0 | |
115 | } | |
116 | ||
117 | func (s postgres) CurrentDatabase() (name string) { | |
118 | s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) | |
119 | return | |
120 | } | |
121 | ||
122 | func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { | |
123 | return fmt.Sprintf("RETURNING %v.%v", tableName, key) | |
124 | } | |
125 | ||
126 | func (postgres) SupportLastInsertID() bool { | |
127 | return false | |
128 | } | |
129 | ||
130 | func isUUID(value reflect.Value) bool { | |
131 | if value.Kind() != reflect.Array || value.Type().Len() != 16 { | |
132 | return false | |
133 | } | |
134 | typename := value.Type().Name() | |
135 | lower := strings.ToLower(typename) | |
136 | return "uuid" == lower || "guid" == lower | |
137 | } | |
138 | ||
139 | func isJSON(value reflect.Value) bool { | |
140 | _, ok := value.Interface().(json.RawMessage) | |
141 | return ok | |
142 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "reflect" | |
5 | "strings" | |
6 | "time" | |
7 | ) | |
8 | ||
9 | type sqlite3 struct { | |
10 | commonDialect | |
11 | } | |
12 | ||
13 | func init() { | |
14 | RegisterDialect("sqlite3", &sqlite3{}) | |
15 | } | |
16 | ||
17 | func (sqlite3) GetName() string { | |
18 | return "sqlite3" | |
19 | } | |
20 | ||
21 | // Get Data Type for Sqlite Dialect | |
22 | func (s *sqlite3) DataTypeOf(field *StructField) string { | |
23 | var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) | |
24 | ||
25 | if sqlType == "" { | |
26 | switch dataValue.Kind() { | |
27 | case reflect.Bool: | |
28 | sqlType = "bool" | |
29 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | |
30 | if s.fieldCanAutoIncrement(field) { | |
31 | field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" | |
32 | sqlType = "integer primary key autoincrement" | |
33 | } else { | |
34 | sqlType = "integer" | |
35 | } | |
36 | case reflect.Int64, reflect.Uint64: | |
37 | if s.fieldCanAutoIncrement(field) { | |
38 | field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" | |
39 | sqlType = "integer primary key autoincrement" | |
40 | } else { | |
41 | sqlType = "bigint" | |
42 | } | |
43 | case reflect.Float32, reflect.Float64: | |
44 | sqlType = "real" | |
45 | case reflect.String: | |
46 | if size > 0 && size < 65532 { | |
47 | sqlType = fmt.Sprintf("varchar(%d)", size) | |
48 | } else { | |
49 | sqlType = "text" | |
50 | } | |
51 | case reflect.Struct: | |
52 | if _, ok := dataValue.Interface().(time.Time); ok { | |
53 | sqlType = "datetime" | |
54 | } | |
55 | default: | |
56 | if IsByteArrayOrSlice(dataValue) { | |
57 | sqlType = "blob" | |
58 | } | |
59 | } | |
60 | } | |
61 | ||
62 | if sqlType == "" { | |
63 | panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String())) | |
64 | } | |
65 | ||
66 | if strings.TrimSpace(additionalType) == "" { | |
67 | return sqlType | |
68 | } | |
69 | return fmt.Sprintf("%v %v", sqlType, additionalType) | |
70 | } | |
71 | ||
72 | func (s sqlite3) HasIndex(tableName string, indexName string) bool { | |
73 | var count int | |
74 | s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) | |
75 | return count > 0 | |
76 | } | |
77 | ||
78 | func (s sqlite3) HasTable(tableName string) bool { | |
79 | var count int | |
80 | s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) | |
81 | return count > 0 | |
82 | } | |
83 | ||
84 | func (s sqlite3) HasColumn(tableName string, columnName string) bool { | |
85 | var count int | |
86 | s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) | |
87 | return count > 0 | |
88 | } | |
89 | ||
90 | func (s sqlite3) CurrentDatabase() (name string) { | |
91 | var ( | |
92 | ifaces = make([]interface{}, 3) | |
93 | pointers = make([]*string, 3) | |
94 | i int | |
95 | ) | |
96 | for i = 0; i < 3; i++ { | |
97 | ifaces[i] = &pointers[i] | |
98 | } | |
99 | if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil { | |
100 | return | |
101 | } | |
102 | if pointers[1] != nil { | |
103 | name = *pointers[1] | |
104 | } | |
105 | return | |
106 | } |
0 | package mssql | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "reflect" | |
5 | "strconv" | |
6 | "strings" | |
7 | "time" | |
8 | ||
9 | _ "github.com/denisenkom/go-mssqldb" | |
10 | "github.com/jinzhu/gorm" | |
11 | ) | |
12 | ||
13 | func setIdentityInsert(scope *gorm.Scope) { | |
14 | if scope.Dialect().GetName() == "mssql" { | |
15 | for _, field := range scope.PrimaryFields() { | |
16 | if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank { | |
17 | scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) | |
18 | scope.InstanceSet("mssql:identity_insert_on", true) | |
19 | } | |
20 | } | |
21 | } | |
22 | } | |
23 | ||
24 | func turnOffIdentityInsert(scope *gorm.Scope) { | |
25 | if scope.Dialect().GetName() == "mssql" { | |
26 | if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok { | |
27 | scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName())) | |
28 | } | |
29 | } | |
30 | } | |
31 | ||
32 | func init() { | |
33 | gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert) | |
34 | gorm.DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert) | |
35 | gorm.RegisterDialect("mssql", &mssql{}) | |
36 | } | |
37 | ||
38 | type mssql struct { | |
39 | db gorm.SQLCommon | |
40 | gorm.DefaultForeignKeyNamer | |
41 | } | |
42 | ||
43 | func (mssql) GetName() string { | |
44 | return "mssql" | |
45 | } | |
46 | ||
47 | func (s *mssql) SetDB(db gorm.SQLCommon) { | |
48 | s.db = db | |
49 | } | |
50 | ||
51 | func (mssql) BindVar(i int) string { | |
52 | return "$$$" // ? | |
53 | } | |
54 | ||
55 | func (mssql) Quote(key string) string { | |
56 | return fmt.Sprintf(`[%s]`, key) | |
57 | } | |
58 | ||
59 | func (s *mssql) DataTypeOf(field *gorm.StructField) string { | |
60 | var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s) | |
61 | ||
62 | if sqlType == "" { | |
63 | switch dataValue.Kind() { | |
64 | case reflect.Bool: | |
65 | sqlType = "bit" | |
66 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | |
67 | if s.fieldCanAutoIncrement(field) { | |
68 | field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" | |
69 | sqlType = "int IDENTITY(1,1)" | |
70 | } else { | |
71 | sqlType = "int" | |
72 | } | |
73 | case reflect.Int64, reflect.Uint64: | |
74 | if s.fieldCanAutoIncrement(field) { | |
75 | field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" | |
76 | sqlType = "bigint IDENTITY(1,1)" | |
77 | } else { | |
78 | sqlType = "bigint" | |
79 | } | |
80 | case reflect.Float32, reflect.Float64: | |
81 | sqlType = "float" | |
82 | case reflect.String: | |
83 | if size > 0 && size < 8000 { | |
84 | sqlType = fmt.Sprintf("nvarchar(%d)", size) | |
85 | } else { | |
86 | sqlType = "nvarchar(max)" | |
87 | } | |
88 | case reflect.Struct: | |
89 | if _, ok := dataValue.Interface().(time.Time); ok { | |
90 | sqlType = "datetimeoffset" | |
91 | } | |
92 | default: | |
93 | if gorm.IsByteArrayOrSlice(dataValue) { | |
94 | if size > 0 && size < 8000 { | |
95 | sqlType = fmt.Sprintf("varbinary(%d)", size) | |
96 | } else { | |
97 | sqlType = "varbinary(max)" | |
98 | } | |
99 | } | |
100 | } | |
101 | } | |
102 | ||
103 | if sqlType == "" { | |
104 | panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String())) | |
105 | } | |
106 | ||
107 | if strings.TrimSpace(additionalType) == "" { | |
108 | return sqlType | |
109 | } | |
110 | return fmt.Sprintf("%v %v", sqlType, additionalType) | |
111 | } | |
112 | ||
113 | func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { | |
114 | if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { | |
115 | return value != "FALSE" | |
116 | } | |
117 | return field.IsPrimaryKey | |
118 | } | |
119 | ||
120 | func (s mssql) HasIndex(tableName string, indexName string) bool { | |
121 | var count int | |
122 | s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) | |
123 | return count > 0 | |
124 | } | |
125 | ||
126 | func (s mssql) RemoveIndex(tableName string, indexName string) error { | |
127 | _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) | |
128 | return err | |
129 | } | |
130 | ||
131 | func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { | |
132 | return false | |
133 | } | |
134 | ||
135 | func (s mssql) HasTable(tableName string) bool { | |
136 | var count int | |
137 | currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) | |
138 | s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) | |
139 | return count > 0 | |
140 | } | |
141 | ||
142 | func (s mssql) HasColumn(tableName string, columnName string) bool { | |
143 | var count int | |
144 | currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) | |
145 | s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) | |
146 | return count > 0 | |
147 | } | |
148 | ||
149 | func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { | |
150 | _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) | |
151 | return err | |
152 | } | |
153 | ||
154 | func (s mssql) CurrentDatabase() (name string) { | |
155 | s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) | |
156 | return | |
157 | } | |
158 | ||
159 | func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { | |
160 | if offset != nil { | |
161 | if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { | |
162 | sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset) | |
163 | } | |
164 | } | |
165 | if limit != nil { | |
166 | if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { | |
167 | if sql == "" { | |
168 | // add default zero offset | |
169 | sql += " OFFSET 0 ROWS" | |
170 | } | |
171 | sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit) | |
172 | } | |
173 | } | |
174 | return | |
175 | } | |
176 | ||
177 | func (mssql) SelectFromDummyTable() string { | |
178 | return "" | |
179 | } | |
180 | ||
181 | func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { | |
182 | return "" | |
183 | } | |
184 | ||
185 | func (mssql) DefaultValueStr() string { | |
186 | return "DEFAULT VALUES" | |
187 | } | |
188 | ||
189 | func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { | |
190 | if strings.Contains(tableName, ".") { | |
191 | splitStrings := strings.SplitN(tableName, ".", 2) | |
192 | return splitStrings[0], splitStrings[1] | |
193 | } | |
194 | return dialect.CurrentDatabase(), tableName | |
195 | } |
0 | package postgres | |
1 | ||
2 | import ( | |
3 | "database/sql" | |
4 | "database/sql/driver" | |
5 | ||
6 | _ "github.com/lib/pq" | |
7 | "github.com/lib/pq/hstore" | |
8 | "encoding/json" | |
9 | "errors" | |
10 | "fmt" | |
11 | ) | |
12 | ||
13 | type Hstore map[string]*string | |
14 | ||
15 | // Value get value of Hstore | |
16 | func (h Hstore) Value() (driver.Value, error) { | |
17 | hstore := hstore.Hstore{Map: map[string]sql.NullString{}} | |
18 | if len(h) == 0 { | |
19 | return nil, nil | |
20 | } | |
21 | ||
22 | for key, value := range h { | |
23 | var s sql.NullString | |
24 | if value != nil { | |
25 | s.String = *value | |
26 | s.Valid = true | |
27 | } | |
28 | hstore.Map[key] = s | |
29 | } | |
30 | return hstore.Value() | |
31 | } | |
32 | ||
33 | // Scan scan value into Hstore | |
34 | func (h *Hstore) Scan(value interface{}) error { | |
35 | hstore := hstore.Hstore{} | |
36 | ||
37 | if err := hstore.Scan(value); err != nil { | |
38 | return err | |
39 | } | |
40 | ||
41 | if len(hstore.Map) == 0 { | |
42 | return nil | |
43 | } | |
44 | ||
45 | *h = Hstore{} | |
46 | for k := range hstore.Map { | |
47 | if hstore.Map[k].Valid { | |
48 | s := hstore.Map[k].String | |
49 | (*h)[k] = &s | |
50 | } else { | |
51 | (*h)[k] = nil | |
52 | } | |
53 | } | |
54 | ||
55 | return nil | |
56 | } | |
57 | ||
58 | // Jsonb Postgresql's JSONB data type | |
59 | type Jsonb struct { | |
60 | json.RawMessage | |
61 | } | |
62 | ||
63 | // Value get value of Jsonb | |
64 | func (j Jsonb) Value() (driver.Value, error) { | |
65 | if len(j.RawMessage) == 0 { | |
66 | return nil, nil | |
67 | } | |
68 | return j.MarshalJSON() | |
69 | } | |
70 | ||
71 | // Scan scan value into Jsonb | |
72 | func (j *Jsonb) Scan(value interface{}) error { | |
73 | bytes, ok := value.([]byte) | |
74 | if !ok { | |
75 | return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value)) | |
76 | } | |
77 | ||
78 | return json.Unmarshal(bytes, j) | |
79 | } |
0 | # Gorm Development | |
1 | ||
2 | ## Architecture | |
3 | ||
4 | The most notable component of Gorm is`gorm.DB`, which hold database connection. It could be initialized like this: | |
5 | ||
6 | db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") | |
7 | ||
8 | Gorm has chainable API, `gorm.DB` is the bridge of chains, it save related information and pass it to the next chain. | |
9 | ||
10 | Lets use below code to explain how it works: | |
11 | ||
12 | db.Where("name = ?", "jinzhu").Find(&users) | |
13 | ||
14 | // equivalent code | |
15 | newdb := db.Where("name =?", "jinzhu") | |
16 | newdb.Find(&user) | |
17 | ||
18 | `newdb` is `db`'s clone, in addition, it contains search conditions from the `Where` method. | |
19 | `Find` is a query method, it creates a `Scope` instance, and pass it as argument to query callbacks. | |
20 | ||
21 | There are four kinds of callbacks corresponds to sql's CURD: create callbacks, update callbacks, query callbacks, delete callbacks. | |
22 | ||
23 | ## Callbacks | |
24 | ||
25 | ### Register a new callback | |
26 | ||
27 | func updateCreated(scope *Scope) { | |
28 | if scope.HasColumn("Created") { | |
29 | scope.SetColumn("Created", NowFunc()) | |
30 | } | |
31 | } | |
32 | ||
33 | db.Callback().Create().Register("update_created_at", updateCreated) | |
34 | // register a callback for Create process | |
35 | ||
36 | ### Delete an existing callback | |
37 | ||
38 | db.Callback().Create().Remove("gorm:create") | |
39 | // delete callback `gorm:create` from Create callbacks | |
40 | ||
41 | ### Replace an existing callback | |
42 | ||
43 | db.Callback().Create().Replace("gorm:create", newCreateFunction) | |
44 | // replace callback `gorm:create` with new function `newCreateFunction` for Create process | |
45 | ||
46 | ### Register callback orders | |
47 | ||
48 | db.Callback().Create().Before("gorm:create").Register("update_created_at", updateCreated) | |
49 | db.Callback().Create().After("gorm:create").Register("update_created_at", updateCreated) | |
50 | db.Callback().Query().After("gorm:query").Register("my_plugin:after_query", afterQuery) | |
51 | db.Callback().Delete().After("gorm:delete").Register("my_plugin:after_delete", afterDelete) | |
52 | db.Callback().Update().Before("gorm:update").Register("my_plugin:before_update", beforeUpdate) | |
53 | db.Callback().Create().Before("gorm:create").After("gorm:before_create").Register("my_plugin:before_create", beforeCreate) | |
54 | ||
55 | ### Callback API | |
56 | ||
57 | Gorm is powered by callbacks, so you could refer below links to learn how to write callbacks | |
58 | ||
59 | [Create callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go) | |
60 | ||
61 | [Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go) | |
62 | ||
63 | [Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go) | |
64 | ||
65 | [Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go) | |
66 | ||
67 | View [https://github.com/jinzhu/gorm/blob/master/scope.go](https://github.com/jinzhu/gorm/blob/master/scope.go) for all available API |
0 | version: '3' | |
1 | ||
2 | services: | |
3 | mysql: | |
4 | image: 'mysql:latest' | |
5 | ports: | |
6 | - 9910:3306 | |
7 | environment: | |
8 | - MYSQL_DATABASE=gorm | |
9 | - MYSQL_USER=gorm | |
10 | - MYSQL_PASSWORD=gorm | |
11 | - MYSQL_RANDOM_ROOT_PASSWORD="yes" | |
12 | postgres: | |
13 | image: 'postgres:latest' | |
14 | ports: | |
15 | - 9920:5432 | |
16 | environment: | |
17 | - POSTGRES_USER=gorm | |
18 | - POSTGRES_DB=gorm | |
19 | - POSTGRES_PASSWORD=gorm | |
20 | mssql: | |
21 | image: 'mcmoe/mssqldocker:latest' | |
22 | ports: | |
23 | - 9930:1433 | |
24 | environment: | |
25 | - ACCEPT_EULA=Y | |
26 | - SA_PASSWORD=LoremIpsum86 | |
27 | - MSSQL_DB=gorm | |
28 | - MSSQL_USER=gorm | |
29 | - MSSQL_PASSWORD=LoremIpsum86 |
7 | 7 | URL string |
8 | 8 | } |
9 | 9 | |
10 | type Author struct { | |
11 | ID string | |
12 | Name string | |
13 | Email string | |
14 | } | |
15 | ||
10 | 16 | type HNPost struct { |
11 | 17 | BasePost |
18 | Author `gorm:"embedded_prefix:user_"` // Embedded struct | |
12 | 19 | Upvotes int32 |
13 | 20 | } |
14 | 21 | |
15 | 22 | type EngadgetPost struct { |
16 | 23 | BasePost BasePost `gorm:"embedded"` |
24 | Author Author `gorm:"embedded;embedded_prefix:author_"` // Embedded struct | |
17 | 25 | ImageUrl string |
26 | } | |
27 | ||
28 | func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) { | |
29 | dialect := DB.NewScope(&EngadgetPost{}).Dialect() | |
30 | engadgetPostScope := DB.NewScope(&EngadgetPost{}) | |
31 | if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") { | |
32 | t.Errorf("should has prefix for embedded columns") | |
33 | } | |
34 | ||
35 | if len(engadgetPostScope.PrimaryFields()) != 1 { | |
36 | t.Errorf("should have only one primary field with embedded struct, but got %v", len(engadgetPostScope.PrimaryFields())) | |
37 | } | |
38 | ||
39 | hnScope := DB.NewScope(&HNPost{}) | |
40 | if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") { | |
41 | t.Errorf("should has prefix for embedded columns") | |
42 | } | |
18 | 43 | } |
19 | 44 | |
20 | 45 | func TestSaveAndQueryEmbeddedStruct(t *testing.T) { |
45 | 70 | } |
46 | 71 | } |
47 | 72 | } |
73 | ||
74 | func TestEmbeddedPointerTypeStruct(t *testing.T) { | |
75 | type HNPost struct { | |
76 | *BasePost | |
77 | Upvotes int32 | |
78 | } | |
79 | ||
80 | DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}}) | |
81 | ||
82 | var hnPost HNPost | |
83 | if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil { | |
84 | t.Errorf("No error should happen when find embedded pointer type, but got %v", err) | |
85 | } | |
86 | ||
87 | if hnPost.Title != "embedded_pointer_type" { | |
88 | t.Errorf("Should find correct value for embedded pointer type") | |
89 | } | |
90 | } |
5 | 5 | ) |
6 | 6 | |
7 | 7 | var ( |
8 | RecordNotFound = errors.New("record not found") | |
9 | InvalidSql = errors.New("invalid sql") | |
10 | NoNewAttrs = errors.New("no new attributes") | |
11 | NoValidTransaction = errors.New("no valid transaction") | |
12 | CantStartTransaction = errors.New("can't start transaction") | |
8 | // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct | |
9 | ErrRecordNotFound = errors.New("record not found") | |
10 | // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL | |
11 | ErrInvalidSQL = errors.New("invalid SQL") | |
12 | // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` | |
13 | ErrInvalidTransaction = errors.New("no valid transaction") | |
14 | // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` | |
15 | ErrCantStartTransaction = errors.New("can't start transaction") | |
16 | // ErrUnaddressable unaddressable value | |
17 | ErrUnaddressable = errors.New("using unaddressable value") | |
13 | 18 | ) |
14 | 19 | |
15 | type errorsInterface interface { | |
16 | GetErrors() []error | |
20 | // Errors contains all happened errors | |
21 | type Errors []error | |
22 | ||
23 | // IsRecordNotFoundError returns current error has record not found error or not | |
24 | func IsRecordNotFoundError(err error) bool { | |
25 | if errs, ok := err.(Errors); ok { | |
26 | for _, err := range errs { | |
27 | if err == ErrRecordNotFound { | |
28 | return true | |
29 | } | |
30 | } | |
31 | } | |
32 | return err == ErrRecordNotFound | |
17 | 33 | } |
18 | 34 | |
19 | type Errors struct { | |
20 | errors []error | |
35 | // GetErrors gets all happened errors | |
36 | func (errs Errors) GetErrors() []error { | |
37 | return errs | |
21 | 38 | } |
22 | 39 | |
23 | func (errs Errors) GetErrors() []error { | |
24 | return errs.errors | |
40 | // Add adds an error | |
41 | func (errs Errors) Add(newErrors ...error) Errors { | |
42 | for _, err := range newErrors { | |
43 | if err == nil { | |
44 | continue | |
45 | } | |
46 | ||
47 | if errors, ok := err.(Errors); ok { | |
48 | errs = errs.Add(errors...) | |
49 | } else { | |
50 | ok = true | |
51 | for _, e := range errs { | |
52 | if err == e { | |
53 | ok = false | |
54 | } | |
55 | } | |
56 | if ok { | |
57 | errs = append(errs, err) | |
58 | } | |
59 | } | |
60 | } | |
61 | return errs | |
25 | 62 | } |
26 | 63 | |
27 | func (errs *Errors) Add(err error) { | |
28 | if errors, ok := err.(errorsInterface); ok { | |
29 | for _, err := range errors.GetErrors() { | |
30 | errs.Add(err) | |
31 | } | |
32 | } else { | |
33 | for _, e := range errs.errors { | |
34 | if err == e { | |
35 | return | |
36 | } | |
37 | } | |
38 | errs.errors = append(errs.errors, err) | |
39 | } | |
40 | } | |
41 | ||
64 | // Error format happened errors | |
42 | 65 | func (errs Errors) Error() string { |
43 | 66 | var errors = []string{} |
44 | for _, e := range errs.errors { | |
67 | for _, e := range errs { | |
45 | 68 | errors = append(errors, e.Error()) |
46 | 69 | } |
47 | 70 | return strings.Join(errors, "; ") |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "testing" | |
5 | ||
6 | "github.com/jinzhu/gorm" | |
7 | ) | |
8 | ||
9 | func TestErrorsCanBeUsedOutsideGorm(t *testing.T) { | |
10 | errs := []error{errors.New("First"), errors.New("Second")} | |
11 | ||
12 | gErrs := gorm.Errors(errs) | |
13 | gErrs = gErrs.Add(errors.New("Third")) | |
14 | gErrs = gErrs.Add(gErrs) | |
15 | ||
16 | if gErrs.Error() != "First; Second; Third" { | |
17 | t.Fatalf("Gave wrong error, got %s", gErrs.Error()) | |
18 | } | |
19 | } |
2 | 2 | import ( |
3 | 3 | "database/sql" |
4 | 4 | "errors" |
5 | "fmt" | |
5 | 6 | "reflect" |
6 | 7 | ) |
7 | 8 | |
9 | // Field model field definition | |
8 | 10 | type Field struct { |
9 | 11 | *StructField |
10 | 12 | IsBlank bool |
11 | 13 | Field reflect.Value |
12 | 14 | } |
13 | 15 | |
14 | func (field *Field) Set(value interface{}) error { | |
16 | // Set set a value to the field | |
17 | func (field *Field) Set(value interface{}) (err error) { | |
15 | 18 | if !field.Field.IsValid() { |
16 | 19 | return errors.New("field value not valid") |
17 | 20 | } |
18 | 21 | |
19 | 22 | if !field.Field.CanAddr() { |
20 | return errors.New("unaddressable value") | |
23 | return ErrUnaddressable | |
21 | 24 | } |
22 | 25 | |
23 | if rvalue, ok := value.(reflect.Value); ok { | |
24 | value = rvalue.Interface() | |
26 | reflectValue, ok := value.(reflect.Value) | |
27 | if !ok { | |
28 | reflectValue = reflect.ValueOf(value) | |
25 | 29 | } |
26 | 30 | |
27 | if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok { | |
28 | if v, ok := value.(reflect.Value); ok { | |
29 | if err := scanner.Scan(v.Interface()); err != nil { | |
30 | return err | |
31 | fieldValue := field.Field | |
32 | if reflectValue.IsValid() { | |
33 | if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { | |
34 | fieldValue.Set(reflectValue.Convert(fieldValue.Type())) | |
35 | } else { | |
36 | if fieldValue.Kind() == reflect.Ptr { | |
37 | if fieldValue.IsNil() { | |
38 | fieldValue.Set(reflect.New(field.Struct.Type.Elem())) | |
39 | } | |
40 | fieldValue = fieldValue.Elem() | |
31 | 41 | } |
32 | } else { | |
33 | if err := scanner.Scan(value); err != nil { | |
34 | return err | |
42 | ||
43 | if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { | |
44 | fieldValue.Set(reflectValue.Convert(fieldValue.Type())) | |
45 | } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { | |
46 | err = scanner.Scan(reflectValue.Interface()) | |
47 | } else { | |
48 | err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) | |
35 | 49 | } |
36 | 50 | } |
37 | 51 | } else { |
38 | reflectValue, ok := value.(reflect.Value) | |
39 | if !ok { | |
40 | reflectValue = reflect.ValueOf(value) | |
41 | } | |
42 | ||
43 | if reflectValue.Type().ConvertibleTo(field.Field.Type()) { | |
44 | field.Field.Set(reflectValue.Convert(field.Field.Type())) | |
45 | } else { | |
46 | return errors.New("could not convert argument") | |
47 | } | |
52 | field.Field.Set(reflect.Zero(field.Field.Type())) | |
48 | 53 | } |
49 | 54 | |
50 | 55 | field.IsBlank = isBlank(field.Field) |
51 | return nil | |
56 | return err | |
52 | 57 | } |
53 | ||
54 | // Fields get value's fields | |
55 | func (scope *Scope) Fields() map[string]*Field { | |
56 | if scope.fields == nil { | |
57 | fields := map[string]*Field{} | |
58 | modelStruct := scope.GetModelStruct() | |
59 | ||
60 | indirectValue := scope.IndirectValue() | |
61 | isStruct := indirectValue.Kind() == reflect.Struct | |
62 | for _, structField := range modelStruct.StructFields { | |
63 | if isStruct { | |
64 | fields[structField.DBName] = getField(indirectValue, structField) | |
65 | } else { | |
66 | fields[structField.DBName] = &Field{StructField: structField, IsBlank: true} | |
67 | } | |
68 | } | |
69 | ||
70 | if modelStruct.cached { | |
71 | scope.fields = fields | |
72 | } | |
73 | return fields | |
74 | } | |
75 | return scope.fields | |
76 | } | |
77 | ||
78 | func getField(indirectValue reflect.Value, structField *StructField) *Field { | |
79 | field := &Field{StructField: structField} | |
80 | for _, name := range structField.Names { | |
81 | indirectValue = reflect.Indirect(indirectValue).FieldByName(name) | |
82 | } | |
83 | field.Field = indirectValue | |
84 | field.IsBlank = isBlank(indirectValue) | |
85 | return field | |
86 | } |
10 | 10 | Name string |
11 | 11 | Children []CalculateFieldChild |
12 | 12 | Category CalculateFieldCategory |
13 | EmbeddedField | |
14 | } | |
15 | ||
16 | type EmbeddedField struct { | |
17 | EmbeddedName string `sql:"NOT NULL;DEFAULT:'hello'"` | |
13 | 18 | } |
14 | 19 | |
15 | 20 | type CalculateFieldChild struct { |
26 | 31 | |
27 | 32 | func TestCalculateField(t *testing.T) { |
28 | 33 | var field CalculateField |
29 | fields := DB.NewScope(&field).Fields() | |
30 | if fields["children"].Relationship == nil || fields["category"].Relationship == nil { | |
34 | var scope = DB.NewScope(&field) | |
35 | if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil { | |
31 | 36 | t.Errorf("Should calculate fields correctly for the first time") |
32 | 37 | } |
38 | ||
39 | if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil { | |
40 | t.Errorf("Should calculate fields correctly for the first time") | |
41 | } | |
42 | ||
43 | if field, ok := scope.FieldByName("embedded_name"); !ok { | |
44 | t.Errorf("should find embedded field") | |
45 | } else if _, ok := field.TagSettings["NOT NULL"]; !ok { | |
46 | t.Errorf("should find embedded field's tag settings") | |
47 | } | |
33 | 48 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "reflect" | |
5 | "time" | |
6 | ) | |
7 | ||
8 | type foundation struct { | |
9 | commonDialect | |
10 | } | |
11 | ||
12 | func (foundation) BinVar(i int) string { | |
13 | return fmt.Sprintf("$%v", i) | |
14 | } | |
15 | ||
16 | func (foundation) SupportLastInsertId() bool { | |
17 | return false | |
18 | } | |
19 | ||
20 | func (foundation) SqlTag(value reflect.Value, size int, autoIncrease bool) string { | |
21 | switch value.Kind() { | |
22 | case reflect.Bool: | |
23 | return "boolean" | |
24 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | |
25 | if autoIncrease { | |
26 | return "serial" | |
27 | } | |
28 | return "int" | |
29 | case reflect.Int64, reflect.Uint64: | |
30 | if autoIncrease { | |
31 | return "bigserial" | |
32 | } | |
33 | return "bigint" | |
34 | case reflect.Float32, reflect.Float64: | |
35 | return "double" | |
36 | case reflect.String: | |
37 | if size > 0 && size < 65532 { | |
38 | return fmt.Sprintf("varchar(%d)", size) | |
39 | } | |
40 | return "clob" | |
41 | case reflect.Struct: | |
42 | if _, ok := value.Interface().(time.Time); ok { | |
43 | return "datetime" | |
44 | } | |
45 | default: | |
46 | if _, ok := value.Interface().([]byte); ok { | |
47 | return "blob" | |
48 | } | |
49 | } | |
50 | panic(fmt.Sprintf("invalid sql type %s (%s) for foundation", value.Type().Name(), value.Kind().String())) | |
51 | } | |
52 | ||
53 | func (s foundation) ReturningStr(tableName, key string) string { | |
54 | return fmt.Sprintf("RETURNING %v.%v", tableName, key) | |
55 | } | |
56 | ||
57 | func (s foundation) HasTable(scope *Scope, tableName string) bool { | |
58 | var count int | |
59 | s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName) | |
60 | return count > 0 | |
61 | } | |
62 | ||
63 | func (s foundation) HasColumn(scope *Scope, tableName string, columnName string) bool { | |
64 | var count int | |
65 | s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName) | |
66 | return count > 0 | |
67 | } | |
68 | ||
69 | func (s foundation) RemoveIndex(scope *Scope, indexName string) { | |
70 | scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", s.Quote(indexName))) | |
71 | } | |
72 | ||
73 | func (s foundation) HasIndex(scope *Scope, tableName string, indexName string) bool { | |
74 | var count int | |
75 | s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName) | |
76 | return count > 0 | |
77 | } | |
78 | ||
79 | func (s foundation) CurrentDatabase(scope *Scope) (name string) { | |
80 | s.RawScanString(scope, &name, "SELECT CURRENT_SCHEMA") | |
81 | return | |
82 | } |
1 | 1 | |
2 | 2 | import "database/sql" |
3 | 3 | |
4 | type sqlCommon interface { | |
4 | // SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. | |
5 | type SQLCommon interface { | |
5 | 6 | Exec(query string, args ...interface{}) (sql.Result, error) |
6 | 7 | Prepare(query string) (*sql.Stmt, error) |
7 | 8 | Query(query string, args ...interface{}) (*sql.Rows, error) |
6 | 6 | "strings" |
7 | 7 | ) |
8 | 8 | |
9 | // JoinTableHandlerInterface is an interface for how to handle many2many relations | |
9 | 10 | type JoinTableHandlerInterface interface { |
11 | // initialize join table handler | |
10 | 12 | Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) |
13 | // Table return join table's table name | |
11 | 14 | Table(db *DB) string |
15 | // Add create relationship in join table for source and destination | |
12 | 16 | Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error |
17 | // Delete delete relationship in join table for sources | |
13 | 18 | Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error |
19 | // JoinWith query with `Join` conditions | |
14 | 20 | JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB |
21 | // SourceForeignKeys return source foreign keys | |
15 | 22 | SourceForeignKeys() []JoinTableForeignKey |
23 | // DestinationForeignKeys return destination foreign keys | |
16 | 24 | DestinationForeignKeys() []JoinTableForeignKey |
17 | 25 | } |
18 | 26 | |
27 | // JoinTableForeignKey join table foreign key struct | |
19 | 28 | type JoinTableForeignKey struct { |
20 | 29 | DBName string |
21 | 30 | AssociationDBName string |
22 | 31 | } |
23 | 32 | |
33 | // JoinTableSource is a struct that contains model type and foreign keys | |
24 | 34 | type JoinTableSource struct { |
25 | 35 | ModelType reflect.Type |
26 | 36 | ForeignKeys []JoinTableForeignKey |
27 | 37 | } |
28 | 38 | |
39 | // JoinTableHandler default join table handler | |
29 | 40 | type JoinTableHandler struct { |
30 | 41 | TableName string `sql:"-"` |
31 | 42 | Source JoinTableSource `sql:"-"` |
32 | 43 | Destination JoinTableSource `sql:"-"` |
33 | 44 | } |
34 | 45 | |
46 | // SourceForeignKeys return source foreign keys | |
35 | 47 | func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { |
36 | 48 | return s.Source.ForeignKeys |
37 | 49 | } |
38 | 50 | |
51 | // DestinationForeignKeys return destination foreign keys | |
39 | 52 | func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { |
40 | 53 | return s.Destination.ForeignKeys |
41 | 54 | } |
42 | 55 | |
56 | // Setup initialize a default join table handler | |
43 | 57 | func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { |
44 | 58 | s.TableName = tableName |
45 | 59 | |
46 | 60 | s.Source = JoinTableSource{ModelType: source} |
61 | s.Source.ForeignKeys = []JoinTableForeignKey{} | |
47 | 62 | for idx, dbName := range relationship.ForeignFieldNames { |
48 | 63 | s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ |
49 | 64 | DBName: relationship.ForeignDBNames[idx], |
52 | 67 | } |
53 | 68 | |
54 | 69 | s.Destination = JoinTableSource{ModelType: destination} |
70 | s.Destination.ForeignKeys = []JoinTableForeignKey{} | |
55 | 71 | for idx, dbName := range relationship.AssociationForeignFieldNames { |
56 | 72 | s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ |
57 | 73 | DBName: relationship.AssociationForeignDBNames[idx], |
60 | 76 | } |
61 | 77 | } |
62 | 78 | |
79 | // Table return join table's table name | |
63 | 80 | func (s JoinTableHandler) Table(db *DB) string { |
64 | return s.TableName | |
65 | } | |
66 | ||
67 | func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} { | |
68 | values := map[string]interface{}{} | |
69 | ||
81 | return DefaultTableNameHandler(db, s.TableName) | |
82 | } | |
83 | ||
84 | func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) { | |
70 | 85 | for _, source := range sources { |
71 | 86 | scope := db.NewScope(source) |
72 | 87 | modelType := scope.GetModelStruct().ModelType |
73 | 88 | |
74 | if s.Source.ModelType == modelType { | |
75 | for _, foreignKey := range s.Source.ForeignKeys { | |
76 | values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface() | |
89 | for _, joinTableSource := range joinTableSources { | |
90 | if joinTableSource.ModelType == modelType { | |
91 | for _, foreignKey := range joinTableSource.ForeignKeys { | |
92 | if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { | |
93 | conditionMap[foreignKey.DBName] = field.Field.Interface() | |
94 | } | |
95 | } | |
96 | break | |
77 | 97 | } |
78 | } else if s.Destination.ModelType == modelType { | |
79 | for _, foreignKey := range s.Destination.ForeignKeys { | |
80 | values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface() | |
81 | } | |
82 | } | |
83 | } | |
84 | return values | |
85 | } | |
86 | ||
87 | func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error { | |
88 | scope := db.NewScope("") | |
89 | searchMap := s.GetSearchMap(db, source1, source2) | |
98 | } | |
99 | } | |
100 | } | |
101 | ||
102 | // Add create relationship in join table for source and destination | |
103 | func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { | |
104 | var ( | |
105 | scope = db.NewScope("") | |
106 | conditionMap = map[string]interface{}{} | |
107 | ) | |
108 | ||
109 | // Update condition map for source | |
110 | s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source) | |
111 | ||
112 | // Update condition map for destination | |
113 | s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination) | |
90 | 114 | |
91 | 115 | var assignColumns, binVars, conditions []string |
92 | 116 | var values []interface{} |
93 | for key, value := range searchMap { | |
94 | assignColumns = append(assignColumns, key) | |
117 | for key, value := range conditionMap { | |
118 | assignColumns = append(assignColumns, scope.Quote(key)) | |
95 | 119 | binVars = append(binVars, `?`) |
96 | 120 | conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) |
97 | 121 | values = append(values, value) |
101 | 125 | values = append(values, value) |
102 | 126 | } |
103 | 127 | |
104 | quotedTable := handler.Table(db) | |
128 | quotedTable := scope.Quote(handler.Table(db)) | |
105 | 129 | sql := fmt.Sprintf( |
106 | 130 | "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", |
107 | 131 | quotedTable, |
115 | 139 | return db.Exec(sql, values...).Error |
116 | 140 | } |
117 | 141 | |
142 | // Delete delete relationship in join table for sources | |
118 | 143 | func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { |
119 | var conditions []string | |
120 | var values []interface{} | |
121 | ||
122 | for key, value := range s.GetSearchMap(db, sources...) { | |
123 | conditions = append(conditions, fmt.Sprintf("%v = ?", key)) | |
144 | var ( | |
145 | scope = db.NewScope(nil) | |
146 | conditions []string | |
147 | values []interface{} | |
148 | conditionMap = map[string]interface{}{} | |
149 | ) | |
150 | ||
151 | s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...) | |
152 | ||
153 | for key, value := range conditionMap { | |
154 | conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) | |
124 | 155 | values = append(values, value) |
125 | 156 | } |
126 | 157 | |
127 | 158 | return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error |
128 | 159 | } |
129 | 160 | |
161 | // JoinWith query with `Join` conditions | |
130 | 162 | func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { |
131 | quotedTable := handler.Table(db) | |
132 | ||
133 | scope := db.NewScope(source) | |
134 | modelType := scope.GetModelStruct().ModelType | |
135 | var joinConditions []string | |
136 | var queryConditions []string | |
137 | var values []interface{} | |
138 | if s.Source.ModelType == modelType { | |
163 | var ( | |
164 | scope = db.NewScope(source) | |
165 | tableName = handler.Table(db) | |
166 | quotedTableName = scope.Quote(tableName) | |
167 | joinConditions []string | |
168 | values []interface{} | |
169 | ) | |
170 | ||
171 | if s.Source.ModelType == scope.GetModelStruct().ModelType { | |
139 | 172 | destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() |
140 | 173 | for _, foreignKey := range s.Destination.ForeignKeys { |
141 | joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) | |
174 | joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) | |
142 | 175 | } |
143 | 176 | |
144 | 177 | var foreignDBNames []string |
146 | 179 | |
147 | 180 | for _, foreignKey := range s.Source.ForeignKeys { |
148 | 181 | foreignDBNames = append(foreignDBNames, foreignKey.DBName) |
149 | foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name) | |
150 | } | |
151 | ||
152 | foreignFieldValues := scope.getColumnAsArray(foreignFieldNames) | |
153 | ||
154 | condString := fmt.Sprintf("%v in (%v)", toQueryCondition(scope, foreignDBNames), toQueryMarks(foreignFieldValues)) | |
155 | ||
156 | keys := scope.getColumnAsArray(foreignFieldNames) | |
157 | values = append(values, toQueryValues(keys)) | |
158 | ||
159 | queryConditions = append(queryConditions, condString) | |
160 | ||
161 | return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))). | |
182 | if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { | |
183 | foreignFieldNames = append(foreignFieldNames, field.Name) | |
184 | } | |
185 | } | |
186 | ||
187 | foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) | |
188 | ||
189 | var condString string | |
190 | if len(foreignFieldValues) > 0 { | |
191 | var quotedForeignDBNames []string | |
192 | for _, dbName := range foreignDBNames { | |
193 | quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName) | |
194 | } | |
195 | ||
196 | condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) | |
197 | ||
198 | keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) | |
199 | values = append(values, toQueryValues(keys)) | |
200 | } else { | |
201 | condString = fmt.Sprintf("1 <> 1") | |
202 | } | |
203 | ||
204 | return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). | |
162 | 205 | Where(condString, toQueryValues(foreignFieldValues)...) |
163 | } else { | |
164 | db.Error = errors.New("wrong source type for join table handler") | |
165 | return db | |
166 | } | |
167 | } | |
206 | } | |
207 | ||
208 | db.Error = errors.New("wrong source type for join table handler") | |
209 | return db | |
210 | } |
1 | 1 | |
2 | 2 | import ( |
3 | 3 | "fmt" |
4 | "strconv" | |
4 | 5 | "testing" |
5 | 6 | "time" |
6 | 7 | |
17 | 18 | gorm.JoinTableHandler |
18 | 19 | PersonID int |
19 | 20 | AddressID int |
20 | DeletedAt time.Time | |
21 | DeletedAt *time.Time | |
21 | 22 | CreatedAt time.Time |
22 | 23 | } |
23 | 24 | |
24 | 25 | func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { |
25 | return db.Where(map[string]interface{}{ | |
26 | "person_id": db.NewScope(foreignValue).PrimaryKeyValue(), | |
27 | "address_id": db.NewScope(associationValue).PrimaryKeyValue(), | |
28 | }).Assign(map[string]interface{}{ | |
29 | "person_id": foreignValue, | |
30 | "address_id": associationValue, | |
26 | foreignPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(foreignValue).PrimaryKeyValue())) | |
27 | associationPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(associationValue).PrimaryKeyValue())) | |
28 | if result := db.Unscoped().Model(&PersonAddress{}).Where(map[string]interface{}{ | |
29 | "person_id": foreignPrimaryKey, | |
30 | "address_id": associationPrimaryKey, | |
31 | }).Update(map[string]interface{}{ | |
32 | "person_id": foreignPrimaryKey, | |
33 | "address_id": associationPrimaryKey, | |
31 | 34 | "deleted_at": gorm.Expr("NULL"), |
32 | }).FirstOrCreate(&PersonAddress{}).Error | |
35 | }).RowsAffected; result == 0 { | |
36 | return db.Create(&PersonAddress{ | |
37 | PersonID: foreignPrimaryKey, | |
38 | AddressID: associationPrimaryKey, | |
39 | }).Error | |
40 | } | |
41 | ||
42 | return nil | |
33 | 43 | } |
34 | 44 | |
35 | 45 | func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error { |
38 | 48 | |
39 | 49 | func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { |
40 | 50 | table := pa.Table(db) |
41 | return db.Table(table).Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) | |
51 | return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) | |
42 | 52 | } |
43 | 53 | |
44 | 54 | func TestJoinTable(t *testing.T) { |
69 | 79 | t.Errorf("Should deleted all addresses") |
70 | 80 | } |
71 | 81 | } |
82 | ||
83 | func TestEmbeddedMany2ManyRelationship(t *testing.T) { | |
84 | type EmbeddedPerson struct { | |
85 | ID int | |
86 | Name string | |
87 | Addresses []*Address `gorm:"many2many:person_addresses;"` | |
88 | } | |
89 | ||
90 | type NewPerson struct { | |
91 | EmbeddedPerson | |
92 | ExternalID uint | |
93 | } | |
94 | DB.Exec("drop table person_addresses;") | |
95 | DB.AutoMigrate(&NewPerson{}) | |
96 | ||
97 | address1 := &Address{Address1: "address 1"} | |
98 | address2 := &Address{Address1: "address 2"} | |
99 | person := &NewPerson{ExternalID: 100, EmbeddedPerson: EmbeddedPerson{Name: "person", Addresses: []*Address{address1, address2}}} | |
100 | if err := DB.Save(person).Error; err != nil { | |
101 | t.Errorf("no error should return when save embedded many2many relationship, but got %v", err) | |
102 | } | |
103 | ||
104 | if err := DB.Model(person).Association("Addresses").Delete(address1).Error; err != nil { | |
105 | t.Errorf("no error should return when delete embedded many2many relationship, but got %v", err) | |
106 | } | |
107 | ||
108 | association := DB.Model(person).Association("Addresses") | |
109 | if count := association.Count(); count != 1 || association.Error != nil { | |
110 | t.Errorf("Should found one address, but got %v, error is %v", count, association.Error) | |
111 | } | |
112 | ||
113 | if association.Clear(); association.Count() != 0 { | |
114 | t.Errorf("Should deleted all addresses") | |
115 | } | |
116 | } |
6 | 6 | "os" |
7 | 7 | "reflect" |
8 | 8 | "regexp" |
9 | "strconv" | |
9 | 10 | "time" |
11 | "unicode" | |
10 | 12 | ) |
13 | ||
14 | var ( | |
15 | defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} | |
16 | sqlRegexp = regexp.MustCompile(`\?`) | |
17 | numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`) | |
18 | ) | |
19 | ||
20 | func isPrintable(s string) bool { | |
21 | for _, r := range s { | |
22 | if !unicode.IsPrint(r) { | |
23 | return false | |
24 | } | |
25 | } | |
26 | return true | |
27 | } | |
28 | ||
29 | var LogFormatter = func(values ...interface{}) (messages []interface{}) { | |
30 | if len(values) > 1 { | |
31 | var ( | |
32 | sql string | |
33 | formattedValues []string | |
34 | level = values[0] | |
35 | currentTime = "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m" | |
36 | source = fmt.Sprintf("\033[35m(%v)\033[0m", values[1]) | |
37 | ) | |
38 | ||
39 | messages = []interface{}{source, currentTime} | |
40 | ||
41 | if level == "sql" { | |
42 | // duration | |
43 | messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) | |
44 | // sql | |
45 | ||
46 | for _, value := range values[4].([]interface{}) { | |
47 | indirectValue := reflect.Indirect(reflect.ValueOf(value)) | |
48 | if indirectValue.IsValid() { | |
49 | value = indirectValue.Interface() | |
50 | if t, ok := value.(time.Time); ok { | |
51 | formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05"))) | |
52 | } else if b, ok := value.([]byte); ok { | |
53 | if str := string(b); isPrintable(str) { | |
54 | formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str)) | |
55 | } else { | |
56 | formattedValues = append(formattedValues, "'<binary>'") | |
57 | } | |
58 | } else if r, ok := value.(driver.Valuer); ok { | |
59 | if value, err := r.Value(); err == nil && value != nil { | |
60 | formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) | |
61 | } else { | |
62 | formattedValues = append(formattedValues, "NULL") | |
63 | } | |
64 | } else { | |
65 | formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) | |
66 | } | |
67 | } else { | |
68 | formattedValues = append(formattedValues, "NULL") | |
69 | } | |
70 | } | |
71 | ||
72 | // differentiate between $n placeholders or else treat like ? | |
73 | if numericPlaceHolderRegexp.MatchString(values[3].(string)) { | |
74 | sql = values[3].(string) | |
75 | for index, value := range formattedValues { | |
76 | placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1) | |
77 | sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1") | |
78 | } | |
79 | } else { | |
80 | formattedValuesLength := len(formattedValues) | |
81 | for index, value := range sqlRegexp.Split(values[3].(string), -1) { | |
82 | sql += value | |
83 | if index < formattedValuesLength { | |
84 | sql += formattedValues[index] | |
85 | } | |
86 | } | |
87 | } | |
88 | ||
89 | messages = append(messages, sql) | |
90 | messages = append(messages, fmt.Sprintf(" \n\033[36;31m[%v]\033[0m ", strconv.FormatInt(values[5].(int64), 10)+" rows affected or returned ")) | |
91 | } else { | |
92 | messages = append(messages, "\033[31;1m") | |
93 | messages = append(messages, values[2:]...) | |
94 | messages = append(messages, "\033[0m") | |
95 | } | |
96 | } | |
97 | ||
98 | return | |
99 | } | |
11 | 100 | |
12 | 101 | type logger interface { |
13 | 102 | Print(v ...interface{}) |
14 | 103 | } |
15 | 104 | |
105 | // LogWriter log writer interface | |
16 | 106 | type LogWriter interface { |
17 | 107 | Println(v ...interface{}) |
18 | 108 | } |
19 | 109 | |
110 | // Logger default logger | |
20 | 111 | type Logger struct { |
21 | 112 | LogWriter |
22 | 113 | } |
23 | 114 | |
24 | var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} | |
25 | ||
26 | // Format log | |
27 | var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`) | |
28 | ||
115 | // Print format & print log | |
29 | 116 | func (logger Logger) Print(values ...interface{}) { |
30 | if len(values) > 1 { | |
31 | level := values[0] | |
32 | currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m" | |
33 | source := fmt.Sprintf("\033[35m(%v)\033[0m", values[1]) | |
34 | messages := []interface{}{source, currentTime} | |
35 | ||
36 | if level == "sql" { | |
37 | // duration | |
38 | messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) | |
39 | // sql | |
40 | var formatedValues []interface{} | |
41 | for _, value := range values[4].([]interface{}) { | |
42 | indirectValue := reflect.Indirect(reflect.ValueOf(value)) | |
43 | if indirectValue.IsValid() { | |
44 | value = indirectValue.Interface() | |
45 | if t, ok := value.(time.Time); ok { | |
46 | formatedValues = append(formatedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339))) | |
47 | } else if b, ok := value.([]byte); ok { | |
48 | formatedValues = append(formatedValues, fmt.Sprintf("'%v'", string(b))) | |
49 | } else if r, ok := value.(driver.Valuer); ok { | |
50 | if value, err := r.Value(); err == nil && value != nil { | |
51 | formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value)) | |
52 | } else { | |
53 | formatedValues = append(formatedValues, "NULL") | |
54 | } | |
55 | } else { | |
56 | formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value)) | |
57 | } | |
58 | } else { | |
59 | formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value)) | |
60 | } | |
61 | } | |
62 | messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(values[3].(string), "%v"), formatedValues...)) | |
63 | } else { | |
64 | messages = append(messages, "\033[31;1m") | |
65 | messages = append(messages, values[2:]...) | |
66 | messages = append(messages, "\033[0m") | |
67 | } | |
68 | logger.Println(messages...) | |
69 | } | |
117 | logger.Println(LogFormatter(values...)...) | |
70 | 118 | } |
8 | 8 | "time" |
9 | 9 | ) |
10 | 10 | |
11 | // NowFunc returns current time, this function is exported in order to be able | |
12 | // to give the flexibility to the developer to customize it according to their | |
13 | // needs | |
14 | // | |
15 | // e.g: return time.Now().UTC() | |
16 | // | |
17 | var NowFunc = func() time.Time { | |
18 | return time.Now() | |
19 | } | |
20 | ||
11 | // DB contains information for current db connection | |
21 | 12 | type DB struct { |
22 | Value interface{} | |
23 | Error error | |
24 | RowsAffected int64 | |
25 | callback *callback | |
26 | db sqlCommon | |
27 | parent *DB | |
28 | search *search | |
13 | Value interface{} | |
14 | Error error | |
15 | RowsAffected int64 | |
16 | ||
17 | // single db | |
18 | db SQLCommon | |
19 | blockGlobalUpdate bool | |
29 | 20 | logMode int |
30 | 21 | logger logger |
31 | dialect Dialect | |
32 | singularTable bool | |
33 | source string | |
22 | search *search | |
34 | 23 | values map[string]interface{} |
35 | joinTableHandlers map[string]JoinTableHandler | |
36 | } | |
37 | ||
38 | func Open(dialect string, args ...interface{}) (DB, error) { | |
39 | var db DB | |
40 | var err error | |
41 | ||
24 | ||
25 | // global db | |
26 | parent *DB | |
27 | callbacks *Callback | |
28 | dialect Dialect | |
29 | singularTable bool | |
30 | } | |
31 | ||
32 | // Open initialize a new db connection, need to import driver first, e.g: | |
33 | // | |
34 | // import _ "github.com/go-sql-driver/mysql" | |
35 | // func main() { | |
36 | // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") | |
37 | // } | |
38 | // GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with | |
39 | // import _ "github.com/jinzhu/gorm/dialects/mysql" | |
40 | // // import _ "github.com/jinzhu/gorm/dialects/postgres" | |
41 | // // import _ "github.com/jinzhu/gorm/dialects/sqlite" | |
42 | // // import _ "github.com/jinzhu/gorm/dialects/mssql" | |
43 | func Open(dialect string, args ...interface{}) (db *DB, err error) { | |
42 | 44 | if len(args) == 0 { |
43 | 45 | err = errors.New("invalid database source") |
44 | } else { | |
45 | var source string | |
46 | var dbSql sqlCommon | |
47 | ||
48 | switch value := args[0].(type) { | |
49 | case string: | |
50 | var driver = dialect | |
51 | if len(args) == 1 { | |
52 | source = value | |
53 | } else if len(args) >= 2 { | |
54 | driver = value | |
55 | source = args[1].(string) | |
56 | } | |
57 | if driver == "foundation" { | |
58 | driver = "postgres" // FoundationDB speaks a postgres-compatible protocol. | |
59 | } | |
60 | dbSql, err = sql.Open(driver, source) | |
61 | case sqlCommon: | |
62 | source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() | |
63 | dbSql = value | |
64 | } | |
65 | ||
66 | db = DB{ | |
67 | dialect: NewDialect(dialect), | |
68 | logger: defaultLogger, | |
69 | callback: DefaultCallback, | |
70 | source: source, | |
71 | values: map[string]interface{}{}, | |
72 | db: dbSql, | |
73 | } | |
74 | db.parent = &db | |
75 | ||
76 | if err == nil { | |
77 | err = db.DB().Ping() // Send a ping to make sure the database connection is alive. | |
78 | } | |
79 | } | |
80 | ||
81 | return db, err | |
82 | } | |
83 | ||
84 | func (s *DB) Close() error { | |
85 | return s.parent.db.(*sql.DB).Close() | |
86 | } | |
87 | ||
88 | func (s *DB) DB() *sql.DB { | |
89 | return s.db.(*sql.DB) | |
90 | } | |
91 | ||
46 | return nil, err | |
47 | } | |
48 | var source string | |
49 | var dbSQL SQLCommon | |
50 | ||
51 | switch value := args[0].(type) { | |
52 | case string: | |
53 | var driver = dialect | |
54 | if len(args) == 1 { | |
55 | source = value | |
56 | } else if len(args) >= 2 { | |
57 | driver = value | |
58 | source = args[1].(string) | |
59 | } | |
60 | dbSQL, err = sql.Open(driver, source) | |
61 | case SQLCommon: | |
62 | dbSQL = value | |
63 | } | |
64 | ||
65 | db = &DB{ | |
66 | db: dbSQL, | |
67 | logger: defaultLogger, | |
68 | values: map[string]interface{}{}, | |
69 | callbacks: DefaultCallback, | |
70 | dialect: newDialect(dialect, dbSQL), | |
71 | } | |
72 | db.parent = db | |
73 | if err != nil { | |
74 | return | |
75 | } | |
76 | // Send a ping to make sure the database connection is alive. | |
77 | if d, ok := dbSQL.(*sql.DB); ok { | |
78 | if err = d.Ping(); err != nil { | |
79 | d.Close() | |
80 | } | |
81 | } | |
82 | return | |
83 | } | |
84 | ||
85 | // New clone a new db connection without search conditions | |
92 | 86 | func (s *DB) New() *DB { |
93 | 87 | clone := s.clone() |
94 | 88 | clone.search = nil |
96 | 90 | return clone |
97 | 91 | } |
98 | 92 | |
99 | // NewScope create scope for callbacks, including DB's search information | |
100 | func (db *DB) NewScope(value interface{}) *Scope { | |
101 | dbClone := db.clone() | |
102 | dbClone.Value = value | |
103 | return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} | |
104 | } | |
105 | ||
106 | // CommonDB Return the underlying sql.DB or sql.Tx instance. | |
107 | // Use of this method is discouraged. It's mainly intended to allow | |
108 | // coexistence with legacy non-GORM code. | |
109 | func (s *DB) CommonDB() sqlCommon { | |
93 | type closer interface { | |
94 | Close() error | |
95 | } | |
96 | ||
97 | // Close close current db connection. If database connection is not an io.Closer, returns an error. | |
98 | func (s *DB) Close() error { | |
99 | if db, ok := s.parent.db.(closer); ok { | |
100 | return db.Close() | |
101 | } | |
102 | return errors.New("can't close current db") | |
103 | } | |
104 | ||
105 | // DB get `*sql.DB` from current connection | |
106 | // If the underlying database connection is not a *sql.DB, returns nil | |
107 | func (s *DB) DB() *sql.DB { | |
108 | db, _ := s.db.(*sql.DB) | |
109 | return db | |
110 | } | |
111 | ||
112 | // CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. | |
113 | func (s *DB) CommonDB() SQLCommon { | |
110 | 114 | return s.db |
111 | 115 | } |
112 | 116 | |
113 | func (s *DB) Callback() *callback { | |
114 | s.parent.callback = s.parent.callback.clone() | |
115 | return s.parent.callback | |
116 | } | |
117 | ||
118 | func (s *DB) SetLogger(l logger) { | |
119 | s.logger = l | |
120 | } | |
121 | ||
117 | // Dialect get dialect | |
118 | func (s *DB) Dialect() Dialect { | |
119 | return s.parent.dialect | |
120 | } | |
121 | ||
122 | // Callback return `Callbacks` container, you could add/change/delete callbacks with it | |
123 | // db.Callback().Create().Register("update_created_at", updateCreated) | |
124 | // Refer https://jinzhu.github.io/gorm/development.html#callbacks | |
125 | func (s *DB) Callback() *Callback { | |
126 | s.parent.callbacks = s.parent.callbacks.clone() | |
127 | return s.parent.callbacks | |
128 | } | |
129 | ||
130 | // SetLogger replace default logger | |
131 | func (s *DB) SetLogger(log logger) { | |
132 | s.logger = log | |
133 | } | |
134 | ||
135 | // LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs | |
122 | 136 | func (s *DB) LogMode(enable bool) *DB { |
123 | 137 | if enable { |
124 | 138 | s.logMode = 2 |
128 | 142 | return s |
129 | 143 | } |
130 | 144 | |
145 | // BlockGlobalUpdate if true, generates an error on update/delete without where clause. | |
146 | // This is to prevent eventual error with empty objects updates/deletions | |
147 | func (s *DB) BlockGlobalUpdate(enable bool) *DB { | |
148 | s.blockGlobalUpdate = enable | |
149 | return s | |
150 | } | |
151 | ||
152 | // HasBlockGlobalUpdate return state of block | |
153 | func (s *DB) HasBlockGlobalUpdate() bool { | |
154 | return s.blockGlobalUpdate | |
155 | } | |
156 | ||
157 | // SingularTable use singular table by default | |
131 | 158 | func (s *DB) SingularTable(enable bool) { |
132 | 159 | modelStructsMap = newModelStructsMap() |
133 | 160 | s.parent.singularTable = enable |
134 | 161 | } |
135 | 162 | |
163 | // NewScope create a scope for current operation | |
164 | func (s *DB) NewScope(value interface{}) *Scope { | |
165 | dbClone := s.clone() | |
166 | dbClone.Value = value | |
167 | return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} | |
168 | } | |
169 | ||
170 | // QueryExpr returns the query as expr object | |
171 | func (s *DB) QueryExpr() *expr { | |
172 | scope := s.NewScope(s.Value) | |
173 | scope.InstanceSet("skip_bindvar", true) | |
174 | scope.prepareQuerySQL() | |
175 | ||
176 | return Expr(scope.SQL, scope.SQLVars...) | |
177 | } | |
178 | ||
179 | // Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query | |
136 | 180 | func (s *DB) Where(query interface{}, args ...interface{}) *DB { |
137 | 181 | return s.clone().search.Where(query, args...).db |
138 | 182 | } |
139 | 183 | |
184 | // Or filter records that match before conditions or this one, similar to `Where` | |
140 | 185 | func (s *DB) Or(query interface{}, args ...interface{}) *DB { |
141 | 186 | return s.clone().search.Or(query, args...).db |
142 | 187 | } |
143 | 188 | |
189 | // Not filter records that don't match current conditions, similar to `Where` | |
144 | 190 | func (s *DB) Not(query interface{}, args ...interface{}) *DB { |
145 | 191 | return s.clone().search.Not(query, args...).db |
146 | 192 | } |
147 | 193 | |
148 | func (s *DB) Limit(value interface{}) *DB { | |
149 | return s.clone().search.Limit(value).db | |
150 | } | |
151 | ||
152 | func (s *DB) Offset(value interface{}) *DB { | |
153 | return s.clone().search.Offset(value).db | |
154 | } | |
155 | ||
156 | func (s *DB) Order(value string, reorder ...bool) *DB { | |
194 | // Limit specify the number of records to be retrieved | |
195 | func (s *DB) Limit(limit interface{}) *DB { | |
196 | return s.clone().search.Limit(limit).db | |
197 | } | |
198 | ||
199 | // Offset specify the number of records to skip before starting to return the records | |
200 | func (s *DB) Offset(offset interface{}) *DB { | |
201 | return s.clone().search.Offset(offset).db | |
202 | } | |
203 | ||
204 | // Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions | |
205 | // db.Order("name DESC") | |
206 | // db.Order("name DESC", true) // reorder | |
207 | // db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression | |
208 | func (s *DB) Order(value interface{}, reorder ...bool) *DB { | |
157 | 209 | return s.clone().search.Order(value, reorder...).db |
158 | 210 | } |
159 | 211 | |
212 | // Select specify fields that you want to retrieve from database when querying, by default, will select all fields; | |
213 | // When creating/updating, specify fields that you want to save to database | |
160 | 214 | func (s *DB) Select(query interface{}, args ...interface{}) *DB { |
161 | 215 | return s.clone().search.Select(query, args...).db |
162 | 216 | } |
163 | 217 | |
218 | // Omit specify fields that you want to ignore when saving to database for creating, updating | |
164 | 219 | func (s *DB) Omit(columns ...string) *DB { |
165 | 220 | return s.clone().search.Omit(columns...).db |
166 | 221 | } |
167 | 222 | |
223 | // Group specify the group method on the find | |
168 | 224 | func (s *DB) Group(query string) *DB { |
169 | 225 | return s.clone().search.Group(query).db |
170 | 226 | } |
171 | 227 | |
172 | func (s *DB) Having(query string, values ...interface{}) *DB { | |
228 | // Having specify HAVING conditions for GROUP BY | |
229 | func (s *DB) Having(query interface{}, values ...interface{}) *DB { | |
173 | 230 | return s.clone().search.Having(query, values...).db |
174 | 231 | } |
175 | 232 | |
176 | func (s *DB) Joins(query string) *DB { | |
177 | return s.clone().search.Joins(query).db | |
178 | } | |
179 | ||
233 | // Joins specify Joins conditions | |
234 | // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) | |
235 | func (s *DB) Joins(query string, args ...interface{}) *DB { | |
236 | return s.clone().search.Joins(query, args...).db | |
237 | } | |
238 | ||
239 | // Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically | |
240 | // func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { | |
241 | // return db.Where("amount > ?", 1000) | |
242 | // } | |
243 | // | |
244 | // func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { | |
245 | // return func (db *gorm.DB) *gorm.DB { | |
246 | // return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) | |
247 | // } | |
248 | // } | |
249 | // | |
250 | // db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) | |
251 | // Refer https://jinzhu.github.io/gorm/crud.html#scopes | |
180 | 252 | func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { |
181 | 253 | for _, f := range funcs { |
182 | 254 | s = f(s) |
184 | 256 | return s |
185 | 257 | } |
186 | 258 | |
259 | // Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/crud.html#soft-delete | |
187 | 260 | func (s *DB) Unscoped() *DB { |
188 | 261 | return s.clone().search.unscoped().db |
189 | 262 | } |
190 | 263 | |
264 | // Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate | |
191 | 265 | func (s *DB) Attrs(attrs ...interface{}) *DB { |
192 | 266 | return s.clone().search.Attrs(attrs...).db |
193 | 267 | } |
194 | 268 | |
269 | // Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate | |
195 | 270 | func (s *DB) Assign(attrs ...interface{}) *DB { |
196 | 271 | return s.clone().search.Assign(attrs...).db |
197 | 272 | } |
198 | 273 | |
274 | // First find first record that match given conditions, order by primary key | |
199 | 275 | func (s *DB) First(out interface{}, where ...interface{}) *DB { |
200 | newScope := s.clone().NewScope(out) | |
276 | newScope := s.NewScope(out) | |
201 | 277 | newScope.Search.Limit(1) |
202 | 278 | return newScope.Set("gorm:order_by_primary_key", "ASC"). |
203 | inlineCondition(where...).callCallbacks(s.parent.callback.queries).db | |
204 | } | |
205 | ||
279 | inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db | |
280 | } | |
281 | ||
282 | // Take return a record that match given conditions, the order will depend on the database implementation | |
283 | func (s *DB) Take(out interface{}, where ...interface{}) *DB { | |
284 | newScope := s.NewScope(out) | |
285 | newScope.Search.Limit(1) | |
286 | return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db | |
287 | } | |
288 | ||
289 | // Last find last record that match given conditions, order by primary key | |
206 | 290 | func (s *DB) Last(out interface{}, where ...interface{}) *DB { |
207 | newScope := s.clone().NewScope(out) | |
291 | newScope := s.NewScope(out) | |
208 | 292 | newScope.Search.Limit(1) |
209 | 293 | return newScope.Set("gorm:order_by_primary_key", "DESC"). |
210 | inlineCondition(where...).callCallbacks(s.parent.callback.queries).db | |
211 | } | |
212 | ||
294 | inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db | |
295 | } | |
296 | ||
297 | // Find find records that match given conditions | |
213 | 298 | func (s *DB) Find(out interface{}, where ...interface{}) *DB { |
214 | return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db | |
215 | } | |
216 | ||
299 | return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db | |
300 | } | |
301 | ||
302 | // Scan scan value to a struct | |
217 | 303 | func (s *DB) Scan(dest interface{}) *DB { |
218 | return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db | |
219 | } | |
220 | ||
304 | return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db | |
305 | } | |
306 | ||
307 | // Row return `*sql.Row` with given conditions | |
221 | 308 | func (s *DB) Row() *sql.Row { |
222 | 309 | return s.NewScope(s.Value).row() |
223 | 310 | } |
224 | 311 | |
312 | // Rows return `*sql.Rows` with given conditions | |
225 | 313 | func (s *DB) Rows() (*sql.Rows, error) { |
226 | 314 | return s.NewScope(s.Value).rows() |
227 | 315 | } |
228 | 316 | |
317 | // ScanRows scan `*sql.Rows` to give struct | |
318 | func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { | |
319 | var ( | |
320 | scope = s.NewScope(result) | |
321 | clone = scope.db | |
322 | columns, err = rows.Columns() | |
323 | ) | |
324 | ||
325 | if clone.AddError(err) == nil { | |
326 | scope.scan(rows, columns, scope.Fields()) | |
327 | } | |
328 | ||
329 | return clone.Error | |
330 | } | |
331 | ||
332 | // Pluck used to query single column from a model as a map | |
333 | // var ages []int64 | |
334 | // db.Find(&users).Pluck("age", &ages) | |
229 | 335 | func (s *DB) Pluck(column string, value interface{}) *DB { |
230 | 336 | return s.NewScope(s.Value).pluck(column, value).db |
231 | 337 | } |
232 | 338 | |
339 | // Count get how many records for a model | |
233 | 340 | func (s *DB) Count(value interface{}) *DB { |
234 | 341 | return s.NewScope(s.Value).count(value).db |
235 | 342 | } |
236 | 343 | |
344 | // Related get related associations | |
237 | 345 | func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { |
238 | return s.clone().NewScope(s.Value).related(value, foreignKeys...).db | |
239 | } | |
240 | ||
346 | return s.NewScope(s.Value).related(value, foreignKeys...).db | |
347 | } | |
348 | ||
349 | // FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) | |
350 | // https://jinzhu.github.io/gorm/crud.html#firstorinit | |
241 | 351 | func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { |
242 | 352 | c := s.clone() |
243 | 353 | if result := c.First(out, where...); result.Error != nil { |
246 | 356 | } |
247 | 357 | c.NewScope(out).inlineCondition(where...).initialize() |
248 | 358 | } else { |
249 | c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs), false) | |
359 | c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs) | |
250 | 360 | } |
251 | 361 | return c |
252 | 362 | } |
253 | 363 | |
364 | // FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions) | |
365 | // https://jinzhu.github.io/gorm/crud.html#firstorcreate | |
254 | 366 | func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { |
255 | 367 | c := s.clone() |
256 | if result := c.First(out, where...); result.Error != nil { | |
368 | if result := s.First(out, where...); result.Error != nil { | |
257 | 369 | if !result.RecordNotFound() { |
258 | 370 | return result |
259 | 371 | } |
260 | c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callback.creates).db.Error) | |
372 | return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db | |
261 | 373 | } else if len(c.search.assignAttrs) > 0 { |
262 | c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callback.updates).db.Error) | |
374 | return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db | |
263 | 375 | } |
264 | 376 | return c |
265 | 377 | } |
266 | 378 | |
379 | // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update | |
267 | 380 | func (s *DB) Update(attrs ...interface{}) *DB { |
268 | 381 | return s.Updates(toSearchableMap(attrs...), true) |
269 | 382 | } |
270 | 383 | |
384 | // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update | |
271 | 385 | func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { |
272 | return s.clone().NewScope(s.Value). | |
386 | return s.NewScope(s.Value). | |
273 | 387 | Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). |
274 | 388 | InstanceSet("gorm:update_interface", values). |
275 | callCallbacks(s.parent.callback.updates).db | |
276 | } | |
277 | ||
389 | callCallbacks(s.parent.callbacks.updates).db | |
390 | } | |
391 | ||
392 | // UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update | |
278 | 393 | func (s *DB) UpdateColumn(attrs ...interface{}) *DB { |
279 | 394 | return s.UpdateColumns(toSearchableMap(attrs...)) |
280 | 395 | } |
281 | 396 | |
397 | // UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update | |
282 | 398 | func (s *DB) UpdateColumns(values interface{}) *DB { |
283 | return s.clone().NewScope(s.Value). | |
399 | return s.NewScope(s.Value). | |
284 | 400 | Set("gorm:update_column", true). |
285 | 401 | Set("gorm:save_associations", false). |
286 | 402 | InstanceSet("gorm:update_interface", values). |
287 | callCallbacks(s.parent.callback.updates).db | |
288 | } | |
289 | ||
403 | callCallbacks(s.parent.callbacks.updates).db | |
404 | } | |
405 | ||
406 | // Save update value in database, if the value doesn't have primary key, will insert it | |
290 | 407 | func (s *DB) Save(value interface{}) *DB { |
291 | scope := s.clone().NewScope(value) | |
292 | if scope.PrimaryKeyZero() { | |
293 | return scope.callCallbacks(s.parent.callback.creates).db | |
294 | } | |
295 | return scope.callCallbacks(s.parent.callback.updates).db | |
296 | } | |
297 | ||
408 | scope := s.NewScope(value) | |
409 | if !scope.PrimaryKeyZero() { | |
410 | newDB := scope.callCallbacks(s.parent.callbacks.updates).db | |
411 | if newDB.Error == nil && newDB.RowsAffected == 0 { | |
412 | return s.New().FirstOrCreate(value) | |
413 | } | |
414 | return newDB | |
415 | } | |
416 | return scope.callCallbacks(s.parent.callbacks.creates).db | |
417 | } | |
418 | ||
419 | // Create insert the value into database | |
298 | 420 | func (s *DB) Create(value interface{}) *DB { |
299 | scope := s.clone().NewScope(value) | |
300 | return scope.callCallbacks(s.parent.callback.creates).db | |
301 | } | |
302 | ||
421 | scope := s.NewScope(value) | |
422 | return scope.callCallbacks(s.parent.callbacks.creates).db | |
423 | } | |
424 | ||
425 | // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition | |
303 | 426 | func (s *DB) Delete(value interface{}, where ...interface{}) *DB { |
304 | return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callback.deletes).db | |
305 | } | |
306 | ||
427 | return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db | |
428 | } | |
429 | ||
430 | // Raw use raw sql as conditions, won't run it unless invoked by other methods | |
431 | // db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) | |
307 | 432 | func (s *DB) Raw(sql string, values ...interface{}) *DB { |
308 | 433 | return s.clone().search.Raw(true).Where(sql, values...).db |
309 | 434 | } |
310 | 435 | |
436 | // Exec execute raw sql | |
311 | 437 | func (s *DB) Exec(sql string, values ...interface{}) *DB { |
312 | scope := s.clone().NewScope(nil) | |
313 | generatedSql := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) | |
314 | generatedSql = strings.TrimSuffix(strings.TrimPrefix(generatedSql, "("), ")") | |
315 | scope.Raw(generatedSql) | |
438 | scope := s.NewScope(nil) | |
439 | generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true) | |
440 | generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") | |
441 | scope.Raw(generatedSQL) | |
316 | 442 | return scope.Exec().db |
317 | 443 | } |
318 | 444 | |
445 | // Model specify the model you would like to run db operations | |
446 | // // update all users's name to `hello` | |
447 | // db.Model(&User{}).Update("name", "hello") | |
448 | // // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` | |
449 | // db.Model(&user).Update("name", "hello") | |
319 | 450 | func (s *DB) Model(value interface{}) *DB { |
320 | 451 | c := s.clone() |
321 | 452 | c.Value = value |
322 | 453 | return c |
323 | 454 | } |
324 | 455 | |
456 | // Table specify the table you would like to run db operations | |
325 | 457 | func (s *DB) Table(name string) *DB { |
326 | 458 | clone := s.clone() |
327 | 459 | clone.search.Table(name) |
329 | 461 | return clone |
330 | 462 | } |
331 | 463 | |
464 | // Debug start debug mode | |
332 | 465 | func (s *DB) Debug() *DB { |
333 | 466 | return s.clone().LogMode(true) |
334 | 467 | } |
335 | 468 | |
469 | // Begin begin a transaction | |
336 | 470 | func (s *DB) Begin() *DB { |
337 | 471 | c := s.clone() |
338 | if db, ok := c.db.(sqlDb); ok { | |
472 | if db, ok := c.db.(sqlDb); ok && db != nil { | |
339 | 473 | tx, err := db.Begin() |
340 | c.db = interface{}(tx).(sqlCommon) | |
474 | c.db = interface{}(tx).(SQLCommon) | |
341 | 475 | c.AddError(err) |
342 | 476 | } else { |
343 | c.AddError(CantStartTransaction) | |
477 | c.AddError(ErrCantStartTransaction) | |
344 | 478 | } |
345 | 479 | return c |
346 | 480 | } |
347 | 481 | |
482 | // Commit commit a transaction | |
348 | 483 | func (s *DB) Commit() *DB { |
349 | if db, ok := s.db.(sqlTx); ok { | |
484 | if db, ok := s.db.(sqlTx); ok && db != nil { | |
350 | 485 | s.AddError(db.Commit()) |
351 | 486 | } else { |
352 | s.AddError(NoValidTransaction) | |
487 | s.AddError(ErrInvalidTransaction) | |
353 | 488 | } |
354 | 489 | return s |
355 | 490 | } |
356 | 491 | |
492 | // Rollback rollback a transaction | |
357 | 493 | func (s *DB) Rollback() *DB { |
358 | if db, ok := s.db.(sqlTx); ok { | |
494 | if db, ok := s.db.(sqlTx); ok && db != nil { | |
359 | 495 | s.AddError(db.Rollback()) |
360 | 496 | } else { |
361 | s.AddError(NoValidTransaction) | |
497 | s.AddError(ErrInvalidTransaction) | |
362 | 498 | } |
363 | 499 | return s |
364 | 500 | } |
365 | 501 | |
502 | // NewRecord check if value's primary key is blank | |
366 | 503 | func (s *DB) NewRecord(value interface{}) bool { |
367 | return s.clone().NewScope(value).PrimaryKeyZero() | |
368 | } | |
369 | ||
504 | return s.NewScope(value).PrimaryKeyZero() | |
505 | } | |
506 | ||
507 | // RecordNotFound check if returning ErrRecordNotFound error | |
370 | 508 | func (s *DB) RecordNotFound() bool { |
371 | return s.Error == RecordNotFound | |
372 | } | |
373 | ||
374 | // Migrations | |
375 | func (s *DB) CreateTable(values ...interface{}) *DB { | |
376 | db := s.clone() | |
377 | for _, value := range values { | |
378 | db = db.NewScope(value).createTable().db | |
509 | for _, err := range s.GetErrors() { | |
510 | if err == ErrRecordNotFound { | |
511 | return true | |
512 | } | |
513 | } | |
514 | return false | |
515 | } | |
516 | ||
517 | // CreateTable create table for models | |
518 | func (s *DB) CreateTable(models ...interface{}) *DB { | |
519 | db := s.Unscoped() | |
520 | for _, model := range models { | |
521 | db = db.NewScope(model).createTable().db | |
379 | 522 | } |
380 | 523 | return db |
381 | 524 | } |
382 | 525 | |
526 | // DropTable drop table for models | |
383 | 527 | func (s *DB) DropTable(values ...interface{}) *DB { |
384 | 528 | db := s.clone() |
385 | 529 | for _, value := range values { |
530 | if tableName, ok := value.(string); ok { | |
531 | db = db.Table(tableName) | |
532 | } | |
533 | ||
386 | 534 | db = db.NewScope(value).dropTable().db |
387 | 535 | } |
388 | 536 | return db |
389 | 537 | } |
390 | 538 | |
539 | // DropTableIfExists drop table if it is exist | |
391 | 540 | func (s *DB) DropTableIfExists(values ...interface{}) *DB { |
392 | 541 | db := s.clone() |
393 | 542 | for _, value := range values { |
394 | db = db.NewScope(value).dropTableIfExists().db | |
543 | if s.HasTable(value) { | |
544 | db.AddError(s.DropTable(value).Error) | |
545 | } | |
395 | 546 | } |
396 | 547 | return db |
397 | 548 | } |
398 | 549 | |
550 | // HasTable check has table or not | |
399 | 551 | func (s *DB) HasTable(value interface{}) bool { |
400 | scope := s.clone().NewScope(value) | |
401 | tableName := scope.TableName() | |
402 | has := scope.Dialect().HasTable(scope, tableName) | |
552 | var ( | |
553 | scope = s.NewScope(value) | |
554 | tableName string | |
555 | ) | |
556 | ||
557 | if name, ok := value.(string); ok { | |
558 | tableName = name | |
559 | } else { | |
560 | tableName = scope.TableName() | |
561 | } | |
562 | ||
563 | has := scope.Dialect().HasTable(tableName) | |
403 | 564 | s.AddError(scope.db.Error) |
404 | 565 | return has |
405 | 566 | } |
406 | 567 | |
568 | // AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data | |
407 | 569 | func (s *DB) AutoMigrate(values ...interface{}) *DB { |
408 | db := s.clone() | |
570 | db := s.Unscoped() | |
409 | 571 | for _, value := range values { |
410 | db = db.NewScope(value).NeedPtr().autoMigrate().db | |
572 | db = db.NewScope(value).autoMigrate().db | |
411 | 573 | } |
412 | 574 | return db |
413 | 575 | } |
414 | 576 | |
577 | // ModifyColumn modify column to type | |
415 | 578 | func (s *DB) ModifyColumn(column string, typ string) *DB { |
416 | scope := s.clone().NewScope(s.Value) | |
579 | scope := s.NewScope(s.Value) | |
417 | 580 | scope.modifyColumn(column, typ) |
418 | 581 | return scope.db |
419 | 582 | } |
420 | 583 | |
584 | // DropColumn drop a column | |
421 | 585 | func (s *DB) DropColumn(column string) *DB { |
422 | scope := s.clone().NewScope(s.Value) | |
586 | scope := s.NewScope(s.Value) | |
423 | 587 | scope.dropColumn(column) |
424 | 588 | return scope.db |
425 | 589 | } |
426 | 590 | |
427 | func (s *DB) AddIndex(indexName string, column ...string) *DB { | |
428 | scope := s.clone().NewScope(s.Value) | |
429 | scope.addIndex(false, indexName, column...) | |
591 | // AddIndex add index for columns with given name | |
592 | func (s *DB) AddIndex(indexName string, columns ...string) *DB { | |
593 | scope := s.Unscoped().NewScope(s.Value) | |
594 | scope.addIndex(false, indexName, columns...) | |
430 | 595 | return scope.db |
431 | 596 | } |
432 | 597 | |
433 | func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB { | |
434 | scope := s.clone().NewScope(s.Value) | |
435 | scope.addIndex(true, indexName, column...) | |
598 | // AddUniqueIndex add unique index for columns with given name | |
599 | func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { | |
600 | scope := s.Unscoped().NewScope(s.Value) | |
601 | scope.addIndex(true, indexName, columns...) | |
436 | 602 | return scope.db |
437 | 603 | } |
438 | 604 | |
605 | // RemoveIndex remove index with name | |
439 | 606 | func (s *DB) RemoveIndex(indexName string) *DB { |
440 | scope := s.clone().NewScope(s.Value) | |
607 | scope := s.NewScope(s.Value) | |
441 | 608 | scope.removeIndex(indexName) |
442 | 609 | return scope.db |
443 | 610 | } |
444 | 611 | |
445 | func (s *DB) CurrentDatabase() string { | |
446 | var ( | |
447 | scope = s.clone().NewScope(s.Value) | |
448 | name = s.dialect.CurrentDatabase(scope) | |
449 | ) | |
450 | return name | |
451 | } | |
452 | ||
453 | /* | |
454 | Add foreign key to the given scope | |
455 | ||
456 | Example: | |
457 | db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") | |
458 | */ | |
612 | // AddForeignKey Add foreign key to the given scope, e.g: | |
613 | // db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") | |
459 | 614 | func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { |
460 | scope := s.clone().NewScope(s.Value) | |
615 | scope := s.NewScope(s.Value) | |
461 | 616 | scope.addForeignKey(field, dest, onDelete, onUpdate) |
462 | 617 | return scope.db |
463 | 618 | } |
464 | 619 | |
620 | // RemoveForeignKey Remove foreign key from the given scope, e.g: | |
621 | // db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") | |
622 | func (s *DB) RemoveForeignKey(field string, dest string) *DB { | |
623 | scope := s.clone().NewScope(s.Value) | |
624 | scope.removeForeignKey(field, dest) | |
625 | return scope.db | |
626 | } | |
627 | ||
628 | // Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode | |
465 | 629 | func (s *DB) Association(column string) *Association { |
466 | 630 | var err error |
467 | scope := s.clone().NewScope(s.Value) | |
631 | var scope = s.Set("gorm:association:source", s.Value).NewScope(s.Value) | |
468 | 632 | |
469 | 633 | if primaryField := scope.PrimaryField(); primaryField.IsBlank { |
470 | 634 | err = errors.New("primary key can't be nil") |
473 | 637 | if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { |
474 | 638 | err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) |
475 | 639 | } else { |
476 | return &Association{Scope: scope, Column: column, Field: field} | |
640 | return &Association{scope: scope, column: column, field: field} | |
477 | 641 | } |
478 | 642 | } else { |
479 | 643 | err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) |
483 | 647 | return &Association{Error: err} |
484 | 648 | } |
485 | 649 | |
650 | // Preload preload associations with given conditions | |
651 | // db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) | |
486 | 652 | func (s *DB) Preload(column string, conditions ...interface{}) *DB { |
487 | 653 | return s.clone().search.Preload(column, conditions...).db |
488 | 654 | } |
489 | 655 | |
490 | // Set set value by name | |
656 | // Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting | |
491 | 657 | func (s *DB) Set(name string, value interface{}) *DB { |
492 | 658 | return s.clone().InstantSet(name, value) |
493 | 659 | } |
494 | 660 | |
661 | // InstantSet instant set setting, will affect current db | |
495 | 662 | func (s *DB) InstantSet(name string, value interface{}) *DB { |
496 | 663 | s.values[name] = value |
497 | 664 | return s |
498 | 665 | } |
499 | 666 | |
500 | // Get get value by name | |
667 | // Get get setting by name | |
501 | 668 | func (s *DB) Get(name string) (value interface{}, ok bool) { |
502 | 669 | value, ok = s.values[name] |
503 | 670 | return |
504 | 671 | } |
505 | 672 | |
673 | // SetJoinTableHandler set a model's join table handler for a relation | |
506 | 674 | func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { |
507 | 675 | scope := s.NewScope(source) |
508 | 676 | for _, field := range scope.GetModelStruct().StructFields { |
509 | 677 | if field.Name == column || field.DBName == column { |
510 | if many2many := parseTagSetting(field.Tag.Get("gorm"))["MANY2MANY"]; many2many != "" { | |
678 | if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { | |
511 | 679 | source := (&Scope{Value: source}).GetModelStruct().ModelType |
512 | 680 | destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType |
513 | 681 | handler.Setup(field.Relationship, many2many, source, destination) |
514 | 682 | field.Relationship.JoinTableHandler = handler |
515 | if table := handler.Table(s); scope.Dialect().HasTable(scope, table) { | |
683 | if table := handler.Table(s); scope.Dialect().HasTable(table) { | |
516 | 684 | s.Table(table).AutoMigrate(handler) |
517 | 685 | } |
518 | 686 | } |
520 | 688 | } |
521 | 689 | } |
522 | 690 | |
691 | // AddError add error to the db | |
523 | 692 | func (s *DB) AddError(err error) error { |
524 | 693 | if err != nil { |
525 | if err != RecordNotFound { | |
694 | if err != ErrRecordNotFound { | |
526 | 695 | if s.logMode == 0 { |
527 | 696 | go s.print(fileWithLineNum(), err) |
528 | 697 | } else { |
529 | 698 | s.log(err) |
530 | 699 | } |
531 | 700 | |
532 | errors := Errors{errors: s.GetErrors()} | |
533 | errors.Add(err) | |
534 | if len(errors.GetErrors()) > 1 { | |
701 | errors := Errors(s.GetErrors()) | |
702 | errors = errors.Add(err) | |
703 | if len(errors) > 1 { | |
535 | 704 | err = errors |
536 | 705 | } |
537 | 706 | } |
541 | 710 | return err |
542 | 711 | } |
543 | 712 | |
544 | func (s *DB) GetErrors() (errors []error) { | |
545 | if errs, ok := s.Error.(errorsInterface); ok { | |
546 | return errs.GetErrors() | |
713 | // GetErrors get happened errors from the db | |
714 | func (s *DB) GetErrors() []error { | |
715 | if errs, ok := s.Error.(Errors); ok { | |
716 | return errs | |
547 | 717 | } else if s.Error != nil { |
548 | 718 | return []error{s.Error} |
549 | 719 | } |
550 | return | |
551 | } | |
720 | return []error{} | |
721 | } | |
722 | ||
723 | //////////////////////////////////////////////////////////////////////////////// | |
724 | // Private Methods For DB | |
725 | //////////////////////////////////////////////////////////////////////////////// | |
726 | ||
727 | func (s *DB) clone() *DB { | |
728 | db := &DB{ | |
729 | db: s.db, | |
730 | parent: s.parent, | |
731 | logger: s.logger, | |
732 | logMode: s.logMode, | |
733 | values: map[string]interface{}{}, | |
734 | Value: s.Value, | |
735 | Error: s.Error, | |
736 | blockGlobalUpdate: s.blockGlobalUpdate, | |
737 | } | |
738 | ||
739 | for key, value := range s.values { | |
740 | db.values[key] = value | |
741 | } | |
742 | ||
743 | if s.search == nil { | |
744 | db.search = &search{limit: -1, offset: -1} | |
745 | } else { | |
746 | db.search = s.search.clone() | |
747 | } | |
748 | ||
749 | db.search.db = db | |
750 | return db | |
751 | } | |
752 | ||
753 | func (s *DB) print(v ...interface{}) { | |
754 | s.logger.Print(v...) | |
755 | } | |
756 | ||
757 | func (s *DB) log(v ...interface{}) { | |
758 | if s != nil && s.logMode == 2 { | |
759 | s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) | |
760 | } | |
761 | } | |
762 | ||
763 | func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { | |
764 | if s.logMode == 2 { | |
765 | s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) | |
766 | } | |
767 | } |
0 | package gorm | |
1 | ||
2 | import "time" | |
3 | ||
4 | func (s *DB) clone() *DB { | |
5 | db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} | |
6 | ||
7 | for key, value := range s.values { | |
8 | db.values[key] = value | |
9 | } | |
10 | ||
11 | if s.search == nil { | |
12 | db.search = &search{} | |
13 | } else { | |
14 | db.search = s.search.clone() | |
15 | } | |
16 | ||
17 | db.search.db = &db | |
18 | return &db | |
19 | } | |
20 | ||
21 | func (s *DB) print(v ...interface{}) { | |
22 | s.logger.(logger).Print(v...) | |
23 | } | |
24 | ||
25 | func (s *DB) log(v ...interface{}) { | |
26 | if s != nil && s.logMode == 2 { | |
27 | s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) | |
28 | } | |
29 | } | |
30 | ||
31 | func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { | |
32 | if s.logMode == 2 { | |
33 | s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars) | |
34 | } | |
35 | } |
3 | 3 | "database/sql" |
4 | 4 | "database/sql/driver" |
5 | 5 | "fmt" |
6 | "os" | |
7 | "path/filepath" | |
8 | "reflect" | |
6 | 9 | "strconv" |
7 | ||
8 | _ "github.com/denisenkom/go-mssqldb" | |
9 | testdb "github.com/erikstmartin/go-testdb" | |
10 | _ "github.com/go-sql-driver/mysql" | |
11 | "github.com/jinzhu/gorm" | |
12 | "github.com/jinzhu/now" | |
13 | _ "github.com/lib/pq" | |
14 | _ "github.com/mattn/go-sqlite3" | |
15 | ||
16 | "os" | |
17 | 10 | "testing" |
18 | 11 | "time" |
12 | ||
13 | "github.com/erikstmartin/go-testdb" | |
14 | "github.com/jinzhu/gorm" | |
15 | _ "github.com/jinzhu/gorm/dialects/mssql" | |
16 | _ "github.com/jinzhu/gorm/dialects/mysql" | |
17 | "github.com/jinzhu/gorm/dialects/postgres" | |
18 | _ "github.com/jinzhu/gorm/dialects/sqlite" | |
19 | "github.com/jinzhu/now" | |
19 | 20 | ) |
20 | 21 | |
21 | 22 | var ( |
22 | DB gorm.DB | |
23 | DB *gorm.DB | |
23 | 24 | t1, t2, t3, t4, t5 time.Time |
24 | 25 | ) |
25 | 26 | |
30 | 31 | panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err)) |
31 | 32 | } |
32 | 33 | |
33 | // DB.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) | |
34 | // DB.SetLogger(log.New(os.Stdout, "\r\n", 0)) | |
35 | // DB.LogMode(true) | |
36 | DB.LogMode(false) | |
37 | ||
38 | DB.DB().SetMaxIdleConns(10) | |
39 | ||
40 | 34 | runMigration() |
41 | 35 | } |
42 | 36 | |
43 | func OpenTestConnection() (db gorm.DB, err error) { | |
37 | func OpenTestConnection() (db *gorm.DB, err error) { | |
38 | dbDSN := os.Getenv("GORM_DSN") | |
44 | 39 | switch os.Getenv("GORM_DIALECT") { |
45 | 40 | case "mysql": |
46 | // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; | |
47 | // CREATE DATABASE gorm; | |
48 | // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; | |
49 | 41 | fmt.Println("testing mysql...") |
50 | db, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") | |
42 | if dbDSN == "" { | |
43 | dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" | |
44 | } | |
45 | db, err = gorm.Open("mysql", dbDSN) | |
51 | 46 | case "postgres": |
52 | 47 | fmt.Println("testing postgres...") |
53 | db, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable") | |
54 | case "foundation": | |
55 | fmt.Println("testing foundation...") | |
56 | db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") | |
48 | if dbDSN == "" { | |
49 | dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" | |
50 | } | |
51 | db, err = gorm.Open("postgres", dbDSN) | |
57 | 52 | case "mssql": |
53 | // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; | |
54 | // CREATE DATABASE gorm; | |
55 | // USE gorm; | |
56 | // CREATE USER gorm FROM LOGIN gorm; | |
57 | // sp_changedbowner 'gorm'; | |
58 | 58 | fmt.Println("testing mssql...") |
59 | db, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433") | |
59 | if dbDSN == "" { | |
60 | dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" | |
61 | } | |
62 | db, err = gorm.Open("mssql", dbDSN) | |
60 | 63 | default: |
61 | 64 | fmt.Println("testing sqlite3...") |
62 | db, err = gorm.Open("sqlite3", "/tmp/gorm.db") | |
63 | } | |
65 | db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) | |
66 | } | |
67 | ||
68 | // db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) | |
69 | // db.SetLogger(log.New(os.Stdout, "\r\n", 0)) | |
70 | if debug := os.Getenv("DEBUG"); debug == "true" { | |
71 | db.LogMode(true) | |
72 | } else if debug == "false" { | |
73 | db.LogMode(false) | |
74 | } | |
75 | ||
76 | db.DB().SetMaxIdleConns(10) | |
77 | ||
64 | 78 | return |
65 | 79 | } |
66 | 80 | |
69 | 83 | ID string `gorm:"primary_key"` |
70 | 84 | Name string |
71 | 85 | } |
86 | DB.DropTable(&UUIDStruct{}) | |
72 | 87 | DB.AutoMigrate(&UUIDStruct{}) |
73 | 88 | |
74 | 89 | data := UUIDStruct{ID: "uuid", Name: "hello"} |
75 | if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" { | |
90 | if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.Name != "hello" { | |
91 | t.Errorf("string primary key should not be populated") | |
92 | } | |
93 | ||
94 | data = UUIDStruct{ID: "uuid", Name: "hello world"} | |
95 | if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.Name != "hello world" { | |
76 | 96 | t.Errorf("string primary key should not be populated") |
77 | 97 | } |
78 | 98 | } |
113 | 133 | DB.Create(getPreparedUser("pluck_user3", "pluck_user")) |
114 | 134 | |
115 | 135 | if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil { |
116 | t.Errorf("No errors should happen if set table for pluck", err.Error()) | |
136 | t.Error("No errors should happen if set table for pluck", err) | |
117 | 137 | } |
118 | 138 | |
119 | 139 | var users []User |
163 | 183 | Stuff string |
164 | 184 | } |
165 | 185 | DB.DropTable(&Foo{}) |
186 | ||
187 | // Table should not exist at this point, HasTable should return false | |
188 | if ok := DB.HasTable("foos"); ok { | |
189 | t.Errorf("Table should not exist, but does") | |
190 | } | |
166 | 191 | if ok := DB.HasTable(&Foo{}); ok { |
167 | 192 | t.Errorf("Table should not exist, but does") |
168 | 193 | } |
194 | ||
195 | // We create the table | |
169 | 196 | if err := DB.CreateTable(&Foo{}).Error; err != nil { |
170 | 197 | t.Errorf("Table should be created") |
198 | } | |
199 | ||
200 | // And now it should exits, and HasTable should return true | |
201 | if ok := DB.HasTable("foos"); !ok { | |
202 | t.Errorf("Table should exist, but HasTable informs it does not") | |
171 | 203 | } |
172 | 204 | if ok := DB.HasTable(&Foo{}); !ok { |
173 | 205 | t.Errorf("Table should exist, but HasTable informs it does not") |
227 | 259 | DB.SingularTable(false) |
228 | 260 | } |
229 | 261 | |
230 | func TestSqlNullValue(t *testing.T) { | |
262 | func TestNullValues(t *testing.T) { | |
231 | 263 | DB.DropTable(&NullValue{}) |
232 | 264 | DB.AutoMigrate(&NullValue{}) |
233 | 265 | |
234 | if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello", Valid: true}, | |
266 | if err := DB.Save(&NullValue{ | |
267 | Name: sql.NullString{String: "hello", Valid: true}, | |
268 | Gender: &sql.NullString{String: "M", Valid: true}, | |
235 | 269 | Age: sql.NullInt64{Int64: 18, Valid: true}, |
236 | 270 | Male: sql.NullBool{Bool: true, Valid: true}, |
237 | 271 | Height: sql.NullFloat64{Float64: 100.11, Valid: true}, |
243 | 277 | var nv NullValue |
244 | 278 | DB.First(&nv, "name = ?", "hello") |
245 | 279 | |
246 | if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true { | |
280 | if nv.Name.String != "hello" || nv.Gender.String != "M" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true { | |
247 | 281 | t.Errorf("Should be able to fetch null value") |
248 | 282 | } |
249 | 283 | |
250 | if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello-2", Valid: true}, | |
284 | if err := DB.Save(&NullValue{ | |
285 | Name: sql.NullString{String: "hello-2", Valid: true}, | |
286 | Gender: &sql.NullString{String: "F", Valid: true}, | |
251 | 287 | Age: sql.NullInt64{Int64: 18, Valid: false}, |
252 | 288 | Male: sql.NullBool{Bool: true, Valid: true}, |
253 | 289 | Height: sql.NullFloat64{Float64: 100.11, Valid: true}, |
258 | 294 | |
259 | 295 | var nv2 NullValue |
260 | 296 | DB.First(&nv2, "name = ?", "hello-2") |
261 | if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { | |
297 | if nv2.Name.String != "hello-2" || nv2.Gender.String != "F" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { | |
262 | 298 | t.Errorf("Should be able to fetch null value") |
263 | 299 | } |
264 | 300 | |
265 | if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello-3", Valid: false}, | |
301 | if err := DB.Save(&NullValue{ | |
302 | Name: sql.NullString{String: "hello-3", Valid: false}, | |
303 | Gender: &sql.NullString{String: "M", Valid: true}, | |
266 | 304 | Age: sql.NullInt64{Int64: 18, Valid: false}, |
267 | 305 | Male: sql.NullBool{Bool: true, Valid: true}, |
268 | 306 | Height: sql.NullFloat64{Float64: 100.11, Valid: true}, |
272 | 310 | } |
273 | 311 | } |
274 | 312 | |
313 | func TestNullValuesWithFirstOrCreate(t *testing.T) { | |
314 | var nv1 = NullValue{ | |
315 | Name: sql.NullString{String: "first_or_create", Valid: true}, | |
316 | Gender: &sql.NullString{String: "M", Valid: true}, | |
317 | } | |
318 | ||
319 | var nv2 NullValue | |
320 | result := DB.Where(nv1).FirstOrCreate(&nv2) | |
321 | ||
322 | if result.RowsAffected != 1 { | |
323 | t.Errorf("RowsAffected should be 1 after create some record") | |
324 | } | |
325 | ||
326 | if result.Error != nil { | |
327 | t.Errorf("Should not raise any error, but got %v", result.Error) | |
328 | } | |
329 | ||
330 | if nv2.Name.String != "first_or_create" || nv2.Gender.String != "M" { | |
331 | t.Errorf("first or create with nullvalues") | |
332 | } | |
333 | ||
334 | if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil { | |
335 | t.Errorf("Should not raise any error, but got %v", err) | |
336 | } | |
337 | ||
338 | if nv2.Age.Int64 != 18 { | |
339 | t.Errorf("should update age to 18") | |
340 | } | |
341 | } | |
342 | ||
275 | 343 | func TestTransaction(t *testing.T) { |
276 | 344 | tx := DB.Begin() |
277 | 345 | u := User{Name: "transcation"} |
311 | 379 | } |
312 | 380 | |
313 | 381 | func TestRow(t *testing.T) { |
314 | user1 := User{Name: "RowUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} | |
315 | user2 := User{Name: "RowUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} | |
316 | user3 := User{Name: "RowUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} | |
382 | user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")} | |
383 | user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")} | |
384 | user3 := User{Name: "RowUser3", Age: 20, Birthday: parseTime("2020-1-1")} | |
317 | 385 | DB.Save(&user1).Save(&user2).Save(&user3) |
318 | 386 | |
319 | 387 | row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row() |
325 | 393 | } |
326 | 394 | |
327 | 395 | func TestRows(t *testing.T) { |
328 | user1 := User{Name: "RowsUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} | |
329 | user2 := User{Name: "RowsUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} | |
330 | user3 := User{Name: "RowsUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} | |
396 | user1 := User{Name: "RowsUser1", Age: 1, Birthday: parseTime("2000-1-1")} | |
397 | user2 := User{Name: "RowsUser2", Age: 10, Birthday: parseTime("2010-1-1")} | |
398 | user3 := User{Name: "RowsUser3", Age: 20, Birthday: parseTime("2020-1-1")} | |
331 | 399 | DB.Save(&user1).Save(&user2).Save(&user3) |
332 | 400 | |
333 | 401 | rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() |
334 | 402 | if err != nil { |
335 | t.Errorf("Not error should happen, but got") | |
403 | t.Errorf("Not error should happen, got %v", err) | |
336 | 404 | } |
337 | 405 | |
338 | 406 | count := 0 |
342 | 410 | rows.Scan(&name, &age) |
343 | 411 | count++ |
344 | 412 | } |
413 | ||
345 | 414 | if count != 2 { |
346 | t.Errorf("Should found two records with name 3") | |
415 | t.Errorf("Should found two records") | |
416 | } | |
417 | } | |
418 | ||
419 | func TestScanRows(t *testing.T) { | |
420 | user1 := User{Name: "ScanRowsUser1", Age: 1, Birthday: parseTime("2000-1-1")} | |
421 | user2 := User{Name: "ScanRowsUser2", Age: 10, Birthday: parseTime("2010-1-1")} | |
422 | user3 := User{Name: "ScanRowsUser3", Age: 20, Birthday: parseTime("2020-1-1")} | |
423 | DB.Save(&user1).Save(&user2).Save(&user3) | |
424 | ||
425 | rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() | |
426 | if err != nil { | |
427 | t.Errorf("Not error should happen, got %v", err) | |
428 | } | |
429 | ||
430 | type Result struct { | |
431 | Name string | |
432 | Age int | |
433 | } | |
434 | ||
435 | var results []Result | |
436 | for rows.Next() { | |
437 | var result Result | |
438 | if err := DB.ScanRows(rows, &result); err != nil { | |
439 | t.Errorf("should get no error, but got %v", err) | |
440 | } | |
441 | results = append(results, result) | |
442 | } | |
443 | ||
444 | if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { | |
445 | t.Errorf("Should find expected results") | |
347 | 446 | } |
348 | 447 | } |
349 | 448 | |
350 | 449 | func TestScan(t *testing.T) { |
351 | user1 := User{Name: "ScanUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} | |
352 | user2 := User{Name: "ScanUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} | |
353 | user3 := User{Name: "ScanUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} | |
450 | user1 := User{Name: "ScanUser1", Age: 1, Birthday: parseTime("2000-1-1")} | |
451 | user2 := User{Name: "ScanUser2", Age: 10, Birthday: parseTime("2010-1-1")} | |
452 | user3 := User{Name: "ScanUser3", Age: 20, Birthday: parseTime("2020-1-1")} | |
354 | 453 | DB.Save(&user1).Save(&user2).Save(&user3) |
355 | 454 | |
356 | 455 | type result struct { |
364 | 463 | t.Errorf("Scan into struct should work") |
365 | 464 | } |
366 | 465 | |
367 | var doubleAgeRes result | |
368 | DB.Table("users").Select("age + age as age").Where("name = ?", user3.Name).Scan(&doubleAgeRes) | |
466 | var doubleAgeRes = &result{} | |
467 | if err := DB.Table("users").Select("age + age as age").Where("name = ?", user3.Name).Scan(&doubleAgeRes).Error; err != nil { | |
468 | t.Errorf("Scan to pointer of pointer") | |
469 | } | |
369 | 470 | if doubleAgeRes.Age != res.Age*2 { |
370 | 471 | t.Errorf("Scan double age as age") |
371 | 472 | } |
378 | 479 | } |
379 | 480 | |
380 | 481 | func TestRaw(t *testing.T) { |
381 | user1 := User{Name: "ExecRawSqlUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} | |
382 | user2 := User{Name: "ExecRawSqlUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} | |
383 | user3 := User{Name: "ExecRawSqlUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} | |
482 | user1 := User{Name: "ExecRawSqlUser1", Age: 1, Birthday: parseTime("2000-1-1")} | |
483 | user2 := User{Name: "ExecRawSqlUser2", Age: 10, Birthday: parseTime("2010-1-1")} | |
484 | user3 := User{Name: "ExecRawSqlUser3", Age: 20, Birthday: parseTime("2020-1-1")} | |
384 | 485 | DB.Save(&user1).Save(&user2).Save(&user3) |
385 | 486 | |
386 | 487 | type result struct { |
404 | 505 | } |
405 | 506 | |
406 | 507 | DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name}) |
407 | if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.RecordNotFound { | |
508 | if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound { | |
408 | 509 | t.Error("Raw sql to update records") |
409 | 510 | } |
410 | 511 | } |
425 | 526 | |
426 | 527 | func TestJoins(t *testing.T) { |
427 | 528 | var user = User{ |
428 | Name: "joins", | |
429 | Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, | |
430 | } | |
431 | DB.Save(&user) | |
432 | ||
433 | var result User | |
434 | DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").First(&result) | |
435 | if result.Name != "joins" || result.Id != user.Id { | |
436 | t.Errorf("Should find all two emails with Join") | |
529 | Name: "joins", | |
530 | CreditCard: CreditCard{Number: "411111111111"}, | |
531 | Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, | |
532 | } | |
533 | DB.Save(&user) | |
534 | ||
535 | var users1 []User | |
536 | DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Find(&users1) | |
537 | if len(users1) != 2 { | |
538 | t.Errorf("should find two users using left join") | |
539 | } | |
540 | ||
541 | var users2 []User | |
542 | DB.Joins("left join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Where("name = ?", "joins").First(&users2) | |
543 | if len(users2) != 1 { | |
544 | t.Errorf("should find one users using left join with conditions") | |
545 | } | |
546 | ||
547 | var users3 []User | |
548 | DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where("name = ?", "joins").First(&users3) | |
549 | if len(users3) != 1 { | |
550 | t.Errorf("should find one users using multiple left join conditions") | |
551 | } | |
552 | ||
553 | var users4 []User | |
554 | DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "422222222222").Where("name = ?", "joins").First(&users4) | |
555 | if len(users4) != 0 { | |
556 | t.Errorf("should find no user when searching with unexisting credit card") | |
557 | } | |
558 | ||
559 | var users5 []User | |
560 | db5 := DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where(User{Id: 1}).Where(Email{Id: 1}).Not(Email{Id: 10}).First(&users5) | |
561 | if db5.Error != nil { | |
562 | t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) | |
437 | 563 | } |
438 | 564 | } |
439 | 565 | |
450 | 576 | DB.Save(&user) |
451 | 577 | |
452 | 578 | var results []result |
453 | DB.Table("users").Select("name, email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results) | |
579 | DB.Table("users").Select("name, emails.email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results) | |
454 | 580 | if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" { |
455 | 581 | t.Errorf("Should find all two emails with Join select") |
456 | 582 | } |
478 | 604 | } |
479 | 605 | } |
480 | 606 | |
607 | func TestQueryBuilderSubselectInWhere(t *testing.T) { | |
608 | user := User{Name: "query_expr_select_ruser1", Email: "root@user1.com", Age: 32} | |
609 | DB.Save(&user) | |
610 | user = User{Name: "query_expr_select_ruser2", Email: "nobody@user2.com", Age: 16} | |
611 | DB.Save(&user) | |
612 | user = User{Name: "query_expr_select_ruser3", Email: "root@user3.com", Age: 64} | |
613 | DB.Save(&user) | |
614 | user = User{Name: "query_expr_select_ruser4", Email: "somebody@user3.com", Age: 128} | |
615 | DB.Save(&user) | |
616 | ||
617 | var users []User | |
618 | DB.Select("*").Where("name IN (?)", DB. | |
619 | Select("name").Table("users").Where("name LIKE ?", "query_expr_select%").QueryExpr()).Find(&users) | |
620 | ||
621 | if len(users) != 4 { | |
622 | t.Errorf("Four users should be found, instead found %d", len(users)) | |
623 | } | |
624 | ||
625 | DB.Select("*").Where("name LIKE ?", "query_expr_select%").Where("age >= (?)", DB. | |
626 | Select("AVG(age)").Table("users").Where("name LIKE ?", "query_expr_select%").QueryExpr()).Find(&users) | |
627 | ||
628 | if len(users) != 2 { | |
629 | t.Errorf("Two users should be found, instead found %d", len(users)) | |
630 | } | |
631 | } | |
632 | ||
633 | func TestQueryBuilderRawQueryWithSubquery(t *testing.T) { | |
634 | user := User{Name: "subquery_test_user1", Age: 10} | |
635 | DB.Save(&user) | |
636 | user = User{Name: "subquery_test_user2", Age: 11} | |
637 | DB.Save(&user) | |
638 | user = User{Name: "subquery_test_user3", Age: 12} | |
639 | DB.Save(&user) | |
640 | ||
641 | var count int | |
642 | err := DB.Raw("select count(*) from (?) tmp", | |
643 | DB.Table("users"). | |
644 | Select("name"). | |
645 | Where("age >= ? and name in (?)", 10, []string{"subquery_test_user1", "subquery_test_user2"}). | |
646 | Group("name"). | |
647 | QueryExpr(), | |
648 | ).Count(&count).Error | |
649 | ||
650 | if err != nil { | |
651 | t.Errorf("Expected to get no errors, but got %v", err) | |
652 | } | |
653 | if count != 2 { | |
654 | t.Errorf("Row count must be 2, instead got %d", count) | |
655 | } | |
656 | ||
657 | err = DB.Raw("select count(*) from (?) tmp", | |
658 | DB.Table("users"). | |
659 | Select("name"). | |
660 | Where("name LIKE ?", "subquery_test%"). | |
661 | Not("age <= ?", 10).Not("name in (?)", []string{"subquery_test_user1", "subquery_test_user2"}). | |
662 | Group("name"). | |
663 | QueryExpr(), | |
664 | ).Count(&count).Error | |
665 | ||
666 | if err != nil { | |
667 | t.Errorf("Expected to get no errors, but got %v", err) | |
668 | } | |
669 | if count != 1 { | |
670 | t.Errorf("Row count must be 1, instead got %d", count) | |
671 | } | |
672 | } | |
673 | ||
674 | func TestQueryBuilderSubselectInHaving(t *testing.T) { | |
675 | user := User{Name: "query_expr_having_ruser1", Email: "root@user1.com", Age: 64} | |
676 | DB.Save(&user) | |
677 | user = User{Name: "query_expr_having_ruser2", Email: "root@user2.com", Age: 128} | |
678 | DB.Save(&user) | |
679 | user = User{Name: "query_expr_having_ruser3", Email: "root@user1.com", Age: 64} | |
680 | DB.Save(&user) | |
681 | user = User{Name: "query_expr_having_ruser4", Email: "root@user2.com", Age: 128} | |
682 | DB.Save(&user) | |
683 | ||
684 | var users []User | |
685 | DB.Select("AVG(age) as avgage").Where("name LIKE ?", "query_expr_having_%").Group("email").Having("AVG(age) > (?)", DB. | |
686 | Select("AVG(age)").Where("name LIKE ?", "query_expr_having_%").Table("users").QueryExpr()).Find(&users) | |
687 | ||
688 | if len(users) != 1 { | |
689 | t.Errorf("Two user group should be found, instead found %d", len(users)) | |
690 | } | |
691 | } | |
692 | ||
481 | 693 | func DialectHasTzSupport() bool { |
482 | 694 | // NB: mssql and FoundationDB do not support time zones. |
483 | if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" { | |
695 | if dialect := os.Getenv("GORM_DIALECT"); dialect == "foundation" { | |
484 | 696 | return false |
485 | 697 | } |
486 | 698 | return true |
495 | 707 | |
496 | 708 | for index, vtime := range times { |
497 | 709 | name := "time_with_zone_" + strconv.Itoa(index) |
498 | user := User{Name: name, Birthday: vtime} | |
710 | user := User{Name: name, Birthday: &vtime} | |
499 | 711 | |
500 | 712 | if !DialectHasTzSupport() { |
501 | 713 | // If our driver dialect doesn't support TZ's, just use UTC for everything here. |
502 | user.Birthday = vtime.UTC() | |
714 | utcBirthday := user.Birthday.UTC() | |
715 | user.Birthday = &utcBirthday | |
503 | 716 | } |
504 | 717 | |
505 | 718 | DB.Save(&user) |
513 | 726 | DB.First(&findUser, "name = ?", name) |
514 | 727 | foundBirthday = findUser.Birthday.UTC().Format(format) |
515 | 728 | if foundBirthday != expectedBirthday { |
516 | t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v or %+v", name, expectedBirthday, foundBirthday) | |
729 | t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) | |
517 | 730 | } |
518 | 731 | |
519 | 732 | if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() { |
529 | 742 | func TestHstore(t *testing.T) { |
530 | 743 | type Details struct { |
531 | 744 | Id int64 |
532 | Bulk gorm.Hstore | |
745 | Bulk postgres.Hstore | |
533 | 746 | } |
534 | 747 | |
535 | 748 | if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { |
615 | 828 | } |
616 | 829 | |
617 | 830 | var user User |
618 | if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound { | |
831 | if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound { | |
619 | 832 | t.Errorf("Should have found existing record") |
833 | } | |
834 | } | |
835 | ||
836 | func TestDdlErrors(t *testing.T) { | |
837 | var err error | |
838 | ||
839 | if err = DB.Close(); err != nil { | |
840 | t.Errorf("Closing DDL test db connection err=%s", err) | |
841 | } | |
842 | defer func() { | |
843 | // Reopen DB connection. | |
844 | if DB, err = OpenTestConnection(); err != nil { | |
845 | t.Fatalf("Failed re-opening db connection: %s", err) | |
846 | } | |
847 | }() | |
848 | ||
849 | if err := DB.Find(&User{}).Error; err == nil { | |
850 | t.Errorf("Expected operation on closed db to produce an error, but err was nil") | |
851 | } | |
852 | } | |
853 | ||
854 | func TestOpenWithOneParameter(t *testing.T) { | |
855 | db, err := gorm.Open("dialect") | |
856 | if db != nil { | |
857 | t.Error("Open with one parameter returned non nil for db") | |
858 | } | |
859 | if err == nil { | |
860 | t.Error("Open with one parameter returned err as nil") | |
861 | } | |
862 | } | |
863 | ||
864 | func TestBlockGlobalUpdate(t *testing.T) { | |
865 | db := DB.New() | |
866 | db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"}) | |
867 | ||
868 | err := db.Model(&Toy{}).Update("OwnerType", "Human").Error | |
869 | if err != nil { | |
870 | t.Error("Unexpected error on global update") | |
871 | } | |
872 | ||
873 | err = db.Delete(&Toy{}).Error | |
874 | if err != nil { | |
875 | t.Error("Unexpected error on global delete") | |
876 | } | |
877 | ||
878 | db.BlockGlobalUpdate(true) | |
879 | ||
880 | db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"}) | |
881 | ||
882 | err = db.Model(&Toy{}).Update("OwnerType", "Human").Error | |
883 | if err == nil { | |
884 | t.Error("Expected error on global update") | |
885 | } | |
886 | ||
887 | err = db.Model(&Toy{}).Where(&Toy{OwnerType: "Martian"}).Update("OwnerType", "Astronaut").Error | |
888 | if err != nil { | |
889 | t.Error("Unxpected error on conditional update") | |
890 | } | |
891 | ||
892 | err = db.Delete(&Toy{}).Error | |
893 | if err == nil { | |
894 | t.Error("Expected error on global delete") | |
895 | } | |
896 | err = db.Where(&Toy{OwnerType: "Martian"}).Delete(&Toy{}).Error | |
897 | if err != nil { | |
898 | t.Error("Unexpected error on conditional delete") | |
620 | 899 | } |
621 | 900 | } |
622 | 901 | |
624 | 903 | b.N = 2000 |
625 | 904 | for x := 0; x < b.N; x++ { |
626 | 905 | e := strconv.Itoa(x) + "benchmark@example.org" |
627 | email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} | |
906 | now := time.Now() | |
907 | email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now} | |
628 | 908 | // Insert |
629 | 909 | DB.Save(&email) |
630 | 910 | // Query |
631 | DB.First(&BigEmail{}, "email = ?", e) | |
911 | DB.First(&EmailWithIdx{}, "email = ?", e) | |
632 | 912 | // Update |
633 | 913 | DB.Model(&email).UpdateColumn("email", "new-"+e) |
634 | 914 | // Delete |
648 | 928 | for x := 0; x < b.N; x++ { |
649 | 929 | var id int64 |
650 | 930 | e := strconv.Itoa(x) + "benchmark@example.org" |
651 | email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} | |
931 | now := time.Now() | |
932 | email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now} | |
652 | 933 | // Insert |
653 | 934 | DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) |
654 | 935 | // Query |
660 | 941 | DB.Exec(deleteSql, id) |
661 | 942 | } |
662 | 943 | } |
944 | ||
945 | func parseTime(str string) *time.Time { | |
946 | t := now.New(time.Now().UTC()).MustParse(str) | |
947 | return &t | |
948 | } |
0 | 0 | package gorm_test |
1 | 1 | |
2 | 2 | import ( |
3 | "database/sql" | |
4 | "database/sql/driver" | |
5 | "errors" | |
3 | 6 | "fmt" |
7 | "os" | |
8 | "reflect" | |
9 | "strconv" | |
4 | 10 | "testing" |
5 | 11 | "time" |
12 | ||
13 | "github.com/jinzhu/gorm" | |
6 | 14 | ) |
15 | ||
16 | type User struct { | |
17 | Id int64 | |
18 | Age int64 | |
19 | UserNum Num | |
20 | Name string `sql:"size:255"` | |
21 | Email string | |
22 | Birthday *time.Time // Time | |
23 | CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically | |
24 | UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically | |
25 | Emails []Email // Embedded structs | |
26 | BillingAddress Address // Embedded struct | |
27 | BillingAddressID sql.NullInt64 // Embedded struct's foreign key | |
28 | ShippingAddress Address // Embedded struct | |
29 | ShippingAddressId int64 // Embedded struct's foreign key | |
30 | CreditCard CreditCard | |
31 | Latitude float64 | |
32 | Languages []Language `gorm:"many2many:user_languages;"` | |
33 | CompanyID *int | |
34 | Company Company | |
35 | Role Role | |
36 | Password EncryptedData | |
37 | PasswordHash []byte | |
38 | IgnoreMe int64 `sql:"-"` | |
39 | IgnoreStringSlice []string `sql:"-"` | |
40 | Ignored struct{ Name string } `sql:"-"` | |
41 | IgnoredPointer *User `sql:"-"` | |
42 | } | |
43 | ||
44 | type NotSoLongTableName struct { | |
45 | Id int64 | |
46 | ReallyLongThingID int64 | |
47 | ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit | |
48 | } | |
49 | ||
50 | type ReallyLongTableNameToTestMySQLNameLengthLimit struct { | |
51 | Id int64 | |
52 | } | |
53 | ||
54 | type ReallyLongThingThatReferencesShort struct { | |
55 | Id int64 | |
56 | ShortID int64 | |
57 | Short Short | |
58 | } | |
59 | ||
60 | type Short struct { | |
61 | Id int64 | |
62 | } | |
63 | ||
64 | type CreditCard struct { | |
65 | ID int8 | |
66 | Number string | |
67 | UserId sql.NullInt64 | |
68 | CreatedAt time.Time `sql:"not null"` | |
69 | UpdatedAt time.Time | |
70 | DeletedAt *time.Time `sql:"column:deleted_time"` | |
71 | } | |
72 | ||
73 | type Email struct { | |
74 | Id int16 | |
75 | UserId int | |
76 | Email string `sql:"type:varchar(100);"` | |
77 | CreatedAt time.Time | |
78 | UpdatedAt time.Time | |
79 | } | |
80 | ||
81 | type Address struct { | |
82 | ID int | |
83 | Address1 string | |
84 | Address2 string | |
85 | Post string | |
86 | CreatedAt time.Time | |
87 | UpdatedAt time.Time | |
88 | DeletedAt *time.Time | |
89 | } | |
90 | ||
91 | type Language struct { | |
92 | gorm.Model | |
93 | Name string | |
94 | Users []User `gorm:"many2many:user_languages;"` | |
95 | } | |
96 | ||
97 | type Product struct { | |
98 | Id int64 | |
99 | Code string | |
100 | Price int64 | |
101 | CreatedAt time.Time | |
102 | UpdatedAt time.Time | |
103 | AfterFindCallTimes int64 | |
104 | BeforeCreateCallTimes int64 | |
105 | AfterCreateCallTimes int64 | |
106 | BeforeUpdateCallTimes int64 | |
107 | AfterUpdateCallTimes int64 | |
108 | BeforeSaveCallTimes int64 | |
109 | AfterSaveCallTimes int64 | |
110 | BeforeDeleteCallTimes int64 | |
111 | AfterDeleteCallTimes int64 | |
112 | } | |
113 | ||
114 | type Company struct { | |
115 | Id int64 | |
116 | Name string | |
117 | Owner *User `sql:"-"` | |
118 | } | |
119 | ||
120 | type EncryptedData []byte | |
121 | ||
122 | func (data *EncryptedData) Scan(value interface{}) error { | |
123 | if b, ok := value.([]byte); ok { | |
124 | if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' { | |
125 | return errors.New("Too short") | |
126 | } | |
127 | ||
128 | *data = b[3:] | |
129 | return nil | |
130 | } | |
131 | ||
132 | return errors.New("Bytes expected") | |
133 | } | |
134 | ||
135 | func (data EncryptedData) Value() (driver.Value, error) { | |
136 | if len(data) > 0 && data[0] == 'x' { | |
137 | //needed to test failures | |
138 | return nil, errors.New("Should not start with 'x'") | |
139 | } | |
140 | ||
141 | //prepend asterisks | |
142 | return append([]byte("***"), data...), nil | |
143 | } | |
144 | ||
145 | type Role struct { | |
146 | Name string `gorm:"size:256"` | |
147 | } | |
148 | ||
149 | func (role *Role) Scan(value interface{}) error { | |
150 | if b, ok := value.([]uint8); ok { | |
151 | role.Name = string(b) | |
152 | } else { | |
153 | role.Name = value.(string) | |
154 | } | |
155 | return nil | |
156 | } | |
157 | ||
158 | func (role Role) Value() (driver.Value, error) { | |
159 | return role.Name, nil | |
160 | } | |
161 | ||
162 | func (role Role) IsAdmin() bool { | |
163 | return role.Name == "admin" | |
164 | } | |
165 | ||
166 | type Num int64 | |
167 | ||
168 | func (i *Num) Scan(src interface{}) error { | |
169 | switch s := src.(type) { | |
170 | case []byte: | |
171 | n, _ := strconv.Atoi(string(s)) | |
172 | *i = Num(n) | |
173 | case int64: | |
174 | *i = Num(s) | |
175 | default: | |
176 | return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) | |
177 | } | |
178 | return nil | |
179 | } | |
180 | ||
181 | type Animal struct { | |
182 | Counter uint64 `gorm:"primary_key:yes"` | |
183 | Name string `sql:"DEFAULT:'galeone'"` | |
184 | From string //test reserved sql keyword as field name | |
185 | Age time.Time `sql:"DEFAULT:current_timestamp"` | |
186 | unexported string // unexported value | |
187 | CreatedAt time.Time | |
188 | UpdatedAt time.Time | |
189 | } | |
190 | ||
191 | type JoinTable struct { | |
192 | From uint64 | |
193 | To uint64 | |
194 | Time time.Time `sql:"default: null"` | |
195 | } | |
196 | ||
197 | type Post struct { | |
198 | Id int64 | |
199 | CategoryId sql.NullInt64 | |
200 | MainCategoryId int64 | |
201 | Title string | |
202 | Body string | |
203 | Comments []*Comment | |
204 | Category Category | |
205 | MainCategory Category | |
206 | } | |
207 | ||
208 | type Category struct { | |
209 | gorm.Model | |
210 | Name string | |
211 | ||
212 | Categories []Category | |
213 | CategoryID *uint | |
214 | } | |
215 | ||
216 | type Comment struct { | |
217 | gorm.Model | |
218 | PostId int64 | |
219 | Content string | |
220 | Post Post | |
221 | } | |
222 | ||
223 | // Scanner | |
224 | type NullValue struct { | |
225 | Id int64 | |
226 | Name sql.NullString `sql:"not null"` | |
227 | Gender *sql.NullString `sql:"not null"` | |
228 | Age sql.NullInt64 | |
229 | Male sql.NullBool | |
230 | Height sql.NullFloat64 | |
231 | AddedAt NullTime | |
232 | } | |
233 | ||
234 | type NullTime struct { | |
235 | Time time.Time | |
236 | Valid bool | |
237 | } | |
238 | ||
239 | func (nt *NullTime) Scan(value interface{}) error { | |
240 | if value == nil { | |
241 | nt.Valid = false | |
242 | return nil | |
243 | } | |
244 | nt.Time, nt.Valid = value.(time.Time), true | |
245 | return nil | |
246 | } | |
247 | ||
248 | func (nt NullTime) Value() (driver.Value, error) { | |
249 | if !nt.Valid { | |
250 | return nil, nil | |
251 | } | |
252 | return nt.Time, nil | |
253 | } | |
254 | ||
255 | func getPreparedUser(name string, role string) *User { | |
256 | var company Company | |
257 | DB.Where(Company{Name: role}).FirstOrCreate(&company) | |
258 | ||
259 | return &User{ | |
260 | Name: name, | |
261 | Age: 20, | |
262 | Role: Role{role}, | |
263 | BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)}, | |
264 | ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, | |
265 | CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)}, | |
266 | Emails: []Email{ | |
267 | {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, | |
268 | }, | |
269 | Company: company, | |
270 | Languages: []Language{ | |
271 | {Name: fmt.Sprintf("lang_1_%v", name)}, | |
272 | {Name: fmt.Sprintf("lang_2_%v", name)}, | |
273 | }, | |
274 | } | |
275 | } | |
7 | 276 | |
8 | 277 | func runMigration() { |
9 | 278 | if err := DB.DropTableIfExists(&User{}).Error; err != nil { |
14 | 283 | DB.Exec(fmt.Sprintf("drop table %v;", table)) |
15 | 284 | } |
16 | 285 | |
17 | values := []interface{}{&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}} | |
286 | values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}} | |
18 | 287 | for _, value := range values { |
19 | 288 | DB.DropTable(value) |
20 | 289 | } |
21 | ||
22 | 290 | if err := DB.AutoMigrate(values...).Error; err != nil { |
23 | 291 | panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) |
24 | 292 | } |
30 | 298 | } |
31 | 299 | |
32 | 300 | scope := DB.NewScope(&Email{}) |
33 | if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") { | |
301 | if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { | |
34 | 302 | t.Errorf("Email should have index idx_email_email") |
35 | 303 | } |
36 | 304 | |
38 | 306 | t.Errorf("Got error when tried to remove index: %+v", err) |
39 | 307 | } |
40 | 308 | |
41 | if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") { | |
309 | if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { | |
42 | 310 | t.Errorf("Email's index idx_email_email should be deleted") |
43 | 311 | } |
44 | 312 | |
46 | 314 | t.Errorf("Got error when tried to create index: %+v", err) |
47 | 315 | } |
48 | 316 | |
49 | if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { | |
317 | if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { | |
50 | 318 | t.Errorf("Email should have index idx_email_email_and_user_id") |
51 | 319 | } |
52 | 320 | |
54 | 322 | t.Errorf("Got error when tried to remove index: %+v", err) |
55 | 323 | } |
56 | 324 | |
57 | if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { | |
325 | if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { | |
58 | 326 | t.Errorf("Email's index idx_email_email_and_user_id should be deleted") |
59 | 327 | } |
60 | 328 | |
62 | 330 | t.Errorf("Got error when tried to create index: %+v", err) |
63 | 331 | } |
64 | 332 | |
65 | if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { | |
333 | if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { | |
66 | 334 | t.Errorf("Email should have index idx_email_email_and_user_id") |
67 | 335 | } |
68 | 336 | |
70 | 338 | t.Errorf("Should get to create duplicate record when having unique index") |
71 | 339 | } |
72 | 340 | |
341 | var user = User{Name: "sample_user"} | |
342 | DB.Save(&user) | |
343 | if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil { | |
344 | t.Errorf("Should get no error when append two emails for user") | |
345 | } | |
346 | ||
347 | if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil { | |
348 | t.Errorf("Should get no duplicated email error when insert duplicated emails for a user") | |
349 | } | |
350 | ||
73 | 351 | if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { |
74 | 352 | t.Errorf("Got error when tried to remove index: %+v", err) |
75 | 353 | } |
76 | 354 | |
77 | if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { | |
355 | if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { | |
78 | 356 | t.Errorf("Email's index idx_email_email_and_user_id should be deleted") |
79 | 357 | } |
80 | 358 | |
83 | 361 | } |
84 | 362 | } |
85 | 363 | |
86 | type BigEmail struct { | |
364 | type EmailWithIdx struct { | |
87 | 365 | Id int64 |
88 | 366 | UserId int64 |
89 | Email string `sql:"index:idx_email_agent"` | |
90 | UserAgent string `sql:"index:idx_email_agent"` | |
91 | RegisteredAt time.Time `sql:"unique_index"` | |
367 | Email string `sql:"index:idx_email_agent"` | |
368 | UserAgent string `sql:"index:idx_email_agent"` | |
369 | RegisteredAt *time.Time `sql:"unique_index"` | |
92 | 370 | CreatedAt time.Time |
93 | 371 | UpdatedAt time.Time |
94 | 372 | } |
95 | 373 | |
96 | func (b BigEmail) TableName() string { | |
97 | return "emails" | |
98 | } | |
99 | ||
100 | 374 | func TestAutoMigration(t *testing.T) { |
101 | 375 | DB.AutoMigrate(&Address{}) |
102 | if err := DB.Table("emails").AutoMigrate(&BigEmail{}).Error; err != nil { | |
376 | DB.DropTable(&EmailWithIdx{}) | |
377 | if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { | |
103 | 378 | t.Errorf("Auto Migrate should not raise any error") |
104 | 379 | } |
105 | 380 | |
106 | DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()}) | |
107 | ||
108 | scope := DB.NewScope(&BigEmail{}) | |
109 | if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_agent") { | |
110 | t.Errorf("Failed to create index") | |
111 | } | |
112 | ||
113 | if !scope.Dialect().HasIndex(scope, scope.TableName(), "uix_emails_registered_at") { | |
114 | t.Errorf("Failed to create index") | |
115 | } | |
116 | ||
117 | var bigemail BigEmail | |
381 | now := time.Now() | |
382 | DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now}) | |
383 | ||
384 | scope := DB.NewScope(&EmailWithIdx{}) | |
385 | if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { | |
386 | t.Errorf("Failed to create index") | |
387 | } | |
388 | ||
389 | if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") { | |
390 | t.Errorf("Failed to create index") | |
391 | } | |
392 | ||
393 | var bigemail EmailWithIdx | |
118 | 394 | DB.First(&bigemail, "user_agent = ?", "pc") |
119 | 395 | if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() { |
120 | 396 | t.Error("Big Emails should be saved and fetched correctly") |
121 | 397 | } |
122 | 398 | } |
399 | ||
400 | type MultipleIndexes struct { | |
401 | ID int64 | |
402 | UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` | |
403 | Name string `sql:"unique_index:uix_multipleindexes_user_name"` | |
404 | Email string `sql:"unique_index:,uix_multipleindexes_user_email"` | |
405 | Other string `sql:"index:,idx_multipleindexes_user_other"` | |
406 | } | |
407 | ||
408 | func TestMultipleIndexes(t *testing.T) { | |
409 | if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil { | |
410 | fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err) | |
411 | } | |
412 | ||
413 | DB.AutoMigrate(&MultipleIndexes{}) | |
414 | if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { | |
415 | t.Errorf("Auto Migrate should not raise any error") | |
416 | } | |
417 | ||
418 | DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"}) | |
419 | ||
420 | scope := DB.NewScope(&MultipleIndexes{}) | |
421 | if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") { | |
422 | t.Errorf("Failed to create index") | |
423 | } | |
424 | ||
425 | if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") { | |
426 | t.Errorf("Failed to create index") | |
427 | } | |
428 | ||
429 | if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") { | |
430 | t.Errorf("Failed to create index") | |
431 | } | |
432 | ||
433 | if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") { | |
434 | t.Errorf("Failed to create index") | |
435 | } | |
436 | ||
437 | if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") { | |
438 | t.Errorf("Failed to create index") | |
439 | } | |
440 | ||
441 | var mutipleIndexes MultipleIndexes | |
442 | DB.First(&mutipleIndexes, "name = ?", "jinzhu") | |
443 | if mutipleIndexes.Email != "jinzhu@example.org" || mutipleIndexes.Name != "jinzhu" { | |
444 | t.Error("MutipleIndexes should be saved and fetched correctly") | |
445 | } | |
446 | ||
447 | // Check unique constraints | |
448 | if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { | |
449 | t.Error("MultipleIndexes unique index failed") | |
450 | } | |
451 | ||
452 | if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "foo@example.org", Other: "foo"}).Error; err != nil { | |
453 | t.Error("MultipleIndexes unique index failed") | |
454 | } | |
455 | ||
456 | if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { | |
457 | t.Error("MultipleIndexes unique index failed") | |
458 | } | |
459 | ||
460 | if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "foo2@example.org", Other: "foo"}).Error; err != nil { | |
461 | t.Error("MultipleIndexes unique index failed") | |
462 | } | |
463 | } | |
464 | ||
465 | func TestModifyColumnType(t *testing.T) { | |
466 | if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" { | |
467 | t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type") | |
468 | } | |
469 | ||
470 | type ModifyColumnType struct { | |
471 | gorm.Model | |
472 | Name1 string `gorm:"length:100"` | |
473 | Name2 string `gorm:"length:200"` | |
474 | } | |
475 | DB.DropTable(&ModifyColumnType{}) | |
476 | DB.CreateTable(&ModifyColumnType{}) | |
477 | ||
478 | name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2") | |
479 | name2Type := DB.Dialect().DataTypeOf(name2Field.StructField) | |
480 | ||
481 | if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil { | |
482 | t.Errorf("No error should happen when ModifyColumn, but got %v", err) | |
483 | } | |
484 | } |
1 | 1 | |
2 | 2 | import "time" |
3 | 3 | |
4 | // Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models | |
5 | // type User struct { | |
6 | // gorm.Model | |
7 | // } | |
4 | 8 | type Model struct { |
5 | 9 | ID uint `gorm:"primary_key"` |
6 | 10 | CreatedAt time.Time |
1 | 1 | |
2 | 2 | import ( |
3 | 3 | "database/sql" |
4 | "fmt" | |
4 | "errors" | |
5 | 5 | "go/ast" |
6 | 6 | "reflect" |
7 | "strconv" | |
8 | 7 | "strings" |
9 | 8 | "sync" |
10 | 9 | "time" |
11 | 10 | |
12 | "github.com/qor/inflection" | |
11 | "github.com/jinzhu/inflection" | |
13 | 12 | ) |
14 | 13 | |
14 | // DefaultTableNameHandler default table name handler | |
15 | 15 | var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { |
16 | 16 | return defaultTableName |
17 | 17 | } |
39 | 39 | |
40 | 40 | var modelStructsMap = newModelStructsMap() |
41 | 41 | |
42 | // ModelStruct model definition | |
42 | 43 | type ModelStruct struct { |
43 | 44 | PrimaryFields []*StructField |
44 | 45 | StructFields []*StructField |
45 | 46 | ModelType reflect.Type |
46 | 47 | defaultTableName string |
47 | cached bool | |
48 | } | |
49 | ||
50 | func (s ModelStruct) TableName(db *DB) string { | |
48 | } | |
49 | ||
50 | // TableName get model's table name | |
51 | func (s *ModelStruct) TableName(db *DB) string { | |
52 | if s.defaultTableName == "" && db != nil && s.ModelType != nil { | |
53 | // Set default table name | |
54 | if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { | |
55 | s.defaultTableName = tabler.TableName() | |
56 | } else { | |
57 | tableName := ToDBName(s.ModelType.Name()) | |
58 | if db == nil || !db.parent.singularTable { | |
59 | tableName = inflection.Plural(tableName) | |
60 | } | |
61 | s.defaultTableName = tableName | |
62 | } | |
63 | } | |
64 | ||
51 | 65 | return DefaultTableNameHandler(db, s.defaultTableName) |
52 | 66 | } |
53 | 67 | |
68 | // StructField model field's struct definition | |
54 | 69 | type StructField struct { |
55 | 70 | DBName string |
56 | 71 | Name string |
61 | 76 | IsScanner bool |
62 | 77 | HasDefaultValue bool |
63 | 78 | Tag reflect.StructTag |
79 | TagSettings map[string]string | |
64 | 80 | Struct reflect.StructField |
65 | 81 | IsForeignKey bool |
66 | 82 | Relationship *Relationship |
67 | 83 | } |
68 | 84 | |
69 | 85 | func (structField *StructField) clone() *StructField { |
70 | return &StructField{ | |
86 | clone := &StructField{ | |
71 | 87 | DBName: structField.DBName, |
72 | 88 | Name: structField.Name, |
73 | 89 | Names: structField.Names, |
77 | 93 | IsScanner: structField.IsScanner, |
78 | 94 | HasDefaultValue: structField.HasDefaultValue, |
79 | 95 | Tag: structField.Tag, |
96 | TagSettings: map[string]string{}, | |
80 | 97 | Struct: structField.Struct, |
81 | 98 | IsForeignKey: structField.IsForeignKey, |
82 | Relationship: structField.Relationship, | |
83 | } | |
84 | } | |
85 | ||
99 | } | |
100 | ||
101 | if structField.Relationship != nil { | |
102 | relationship := *structField.Relationship | |
103 | clone.Relationship = &relationship | |
104 | } | |
105 | ||
106 | for key, value := range structField.TagSettings { | |
107 | clone.TagSettings[key] = value | |
108 | } | |
109 | ||
110 | return clone | |
111 | } | |
112 | ||
113 | // Relationship described the relationship between models | |
86 | 114 | type Relationship struct { |
87 | Kind string | |
88 | PolymorphicType string | |
89 | PolymorphicDBName string | |
90 | ForeignFieldNames []string | |
91 | ForeignDBNames []string | |
92 | AssociationForeignFieldNames []string | |
93 | AssociationForeignStructFieldNames []string | |
94 | AssociationForeignDBNames []string | |
95 | JoinTableHandler JoinTableHandlerInterface | |
96 | } | |
97 | ||
115 | Kind string | |
116 | PolymorphicType string | |
117 | PolymorphicDBName string | |
118 | PolymorphicValue string | |
119 | ForeignFieldNames []string | |
120 | ForeignDBNames []string | |
121 | AssociationForeignFieldNames []string | |
122 | AssociationForeignDBNames []string | |
123 | JoinTableHandler JoinTableHandlerInterface | |
124 | } | |
125 | ||
126 | func getForeignField(column string, fields []*StructField) *StructField { | |
127 | for _, field := range fields { | |
128 | if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) { | |
129 | return field | |
130 | } | |
131 | } | |
132 | return nil | |
133 | } | |
134 | ||
135 | // GetModelStruct get value's model struct, relationships based on struct and tag definition | |
98 | 136 | func (scope *Scope) GetModelStruct() *ModelStruct { |
99 | 137 | var modelStruct ModelStruct |
100 | ||
101 | reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) | |
102 | if !reflectValue.IsValid() { | |
138 | // Scope value can't be nil | |
139 | if scope.Value == nil { | |
103 | 140 | return &modelStruct |
104 | 141 | } |
105 | 142 | |
106 | if reflectValue.Kind() == reflect.Slice { | |
107 | reflectValue = reflect.Indirect(reflect.New(reflectValue.Type().Elem())) | |
108 | } | |
109 | ||
110 | scopeType := reflectValue.Type() | |
111 | ||
112 | if scopeType.Kind() == reflect.Ptr { | |
113 | scopeType = scopeType.Elem() | |
114 | } | |
115 | ||
116 | if value := modelStructsMap.Get(scopeType); value != nil { | |
143 | reflectType := reflect.ValueOf(scope.Value).Type() | |
144 | for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr { | |
145 | reflectType = reflectType.Elem() | |
146 | } | |
147 | ||
148 | // Scope value need to be a struct | |
149 | if reflectType.Kind() != reflect.Struct { | |
150 | return &modelStruct | |
151 | } | |
152 | ||
153 | // Get Cached model struct | |
154 | if value := modelStructsMap.Get(reflectType); value != nil { | |
117 | 155 | return value |
118 | 156 | } |
119 | 157 | |
120 | modelStruct.ModelType = scopeType | |
121 | if scopeType.Kind() != reflect.Struct { | |
122 | return &modelStruct | |
123 | } | |
124 | ||
125 | if tabler, ok := reflect.New(scopeType).Interface().(interface { | |
126 | TableName() string | |
127 | }); ok { | |
128 | modelStruct.defaultTableName = tabler.TableName() | |
129 | } else { | |
130 | name := ToDBName(scopeType.Name()) | |
131 | if scope.db == nil || !scope.db.parent.singularTable { | |
132 | name = inflection.Plural(name) | |
133 | } | |
134 | ||
135 | modelStruct.defaultTableName = name | |
136 | } | |
158 | modelStruct.ModelType = reflectType | |
137 | 159 | |
138 | 160 | // Get all fields |
139 | fields := []*StructField{} | |
140 | for i := 0; i < scopeType.NumField(); i++ { | |
141 | if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) { | |
161 | for i := 0; i < reflectType.NumField(); i++ { | |
162 | if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { | |
142 | 163 | field := &StructField{ |
143 | Struct: fieldStruct, | |
144 | Name: fieldStruct.Name, | |
145 | Names: []string{fieldStruct.Name}, | |
146 | Tag: fieldStruct.Tag, | |
164 | Struct: fieldStruct, | |
165 | Name: fieldStruct.Name, | |
166 | Names: []string{fieldStruct.Name}, | |
167 | Tag: fieldStruct.Tag, | |
168 | TagSettings: parseTagSetting(fieldStruct.Tag), | |
147 | 169 | } |
148 | 170 | |
149 | if fieldStruct.Tag.Get("sql") == "-" { | |
171 | // is ignored field | |
172 | if _, ok := field.TagSettings["-"]; ok { | |
150 | 173 | field.IsIgnored = true |
151 | 174 | } else { |
152 | sqlSettings := parseTagSetting(field.Tag.Get("sql")) | |
153 | gormSettings := parseTagSetting(field.Tag.Get("gorm")) | |
154 | if _, ok := gormSettings["PRIMARY_KEY"]; ok { | |
175 | if _, ok := field.TagSettings["PRIMARY_KEY"]; ok { | |
155 | 176 | field.IsPrimaryKey = true |
156 | 177 | modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) |
157 | 178 | } |
158 | 179 | |
159 | if _, ok := sqlSettings["DEFAULT"]; ok { | |
180 | if _, ok := field.TagSettings["DEFAULT"]; ok { | |
160 | 181 | field.HasDefaultValue = true |
161 | 182 | } |
162 | 183 | |
163 | if value, ok := gormSettings["COLUMN"]; ok { | |
164 | field.DBName = value | |
165 | } else { | |
166 | field.DBName = ToDBName(fieldStruct.Name) | |
184 | if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsPrimaryKey { | |
185 | field.HasDefaultValue = true | |
167 | 186 | } |
168 | } | |
169 | fields = append(fields, field) | |
170 | } | |
171 | } | |
172 | ||
173 | var finished = make(chan bool) | |
174 | go func(finished chan bool) { | |
175 | for _, field := range fields { | |
176 | if !field.IsIgnored { | |
177 | fieldStruct := field.Struct | |
187 | ||
178 | 188 | indirectType := fieldStruct.Type |
179 | if indirectType.Kind() == reflect.Ptr { | |
189 | for indirectType.Kind() == reflect.Ptr { | |
180 | 190 | indirectType = indirectType.Elem() |
181 | 191 | } |
182 | 192 | |
183 | if _, isScanner := reflect.New(indirectType).Interface().(sql.Scanner); isScanner { | |
193 | fieldValue := reflect.New(indirectType).Interface() | |
194 | if _, isScanner := fieldValue.(sql.Scanner); isScanner { | |
195 | // is scanner | |
184 | 196 | field.IsScanner, field.IsNormal = true, true |
185 | } | |
186 | ||
187 | if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime { | |
188 | field.IsNormal = true | |
189 | } | |
190 | ||
191 | if !field.IsNormal { | |
192 | gormSettings := parseTagSetting(field.Tag.Get("gorm")) | |
193 | toScope := scope.New(reflect.New(fieldStruct.Type).Interface()) | |
194 | ||
195 | getForeignField := func(column string, fields []*StructField) *StructField { | |
196 | for _, field := range fields { | |
197 | if field.Name == column || field.DBName == ToDBName(column) { | |
198 | return field | |
199 | } | |
200 | } | |
201 | return nil | |
202 | } | |
203 | ||
204 | var relationship = &Relationship{} | |
205 | ||
206 | if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { | |
207 | if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { | |
208 | if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { | |
209 | relationship.ForeignFieldNames = []string{polymorphicField.Name} | |
210 | relationship.ForeignDBNames = []string{polymorphicField.DBName} | |
211 | relationship.AssociationForeignFieldNames = []string{scope.PrimaryField().Name} | |
212 | relationship.AssociationForeignDBNames = []string{scope.PrimaryField().DBName} | |
213 | relationship.PolymorphicType = polymorphicType.Name | |
214 | relationship.PolymorphicDBName = polymorphicType.DBName | |
215 | polymorphicType.IsForeignKey = true | |
216 | polymorphicField.IsForeignKey = true | |
197 | if indirectType.Kind() == reflect.Struct { | |
198 | for i := 0; i < indirectType.NumField(); i++ { | |
199 | for key, value := range parseTagSetting(indirectType.Field(i).Tag) { | |
200 | if _, ok := field.TagSettings[key]; !ok { | |
201 | field.TagSettings[key] = value | |
202 | } | |
217 | 203 | } |
218 | 204 | } |
219 | 205 | } |
220 | ||
221 | var foreignKeys []string | |
222 | if foreignKey, ok := gormSettings["FOREIGNKEY"]; ok { | |
223 | foreignKeys = append(foreignKeys, foreignKey) | |
206 | } else if _, isTime := fieldValue.(*time.Time); isTime { | |
207 | // is time | |
208 | field.IsNormal = true | |
209 | } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { | |
210 | // is embedded struct | |
211 | for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields { | |
212 | subField = subField.clone() | |
213 | subField.Names = append([]string{fieldStruct.Name}, subField.Names...) | |
214 | if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok { | |
215 | subField.DBName = prefix + subField.DBName | |
216 | } | |
217 | ||
218 | if subField.IsPrimaryKey { | |
219 | if _, ok := subField.TagSettings["PRIMARY_KEY"]; ok { | |
220 | modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) | |
221 | } else { | |
222 | subField.IsPrimaryKey = false | |
223 | } | |
224 | } | |
225 | ||
226 | if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil { | |
227 | if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok { | |
228 | newJoinTableHandler := &JoinTableHandler{} | |
229 | newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType) | |
230 | subField.Relationship.JoinTableHandler = newJoinTableHandler | |
231 | } | |
232 | } | |
233 | ||
234 | modelStruct.StructFields = append(modelStruct.StructFields, subField) | |
224 | 235 | } |
236 | continue | |
237 | } else { | |
238 | // build relationships | |
225 | 239 | switch indirectType.Kind() { |
226 | 240 | case reflect.Slice: |
227 | elemType := indirectType.Elem() | |
228 | if elemType.Kind() == reflect.Ptr { | |
229 | elemType = elemType.Elem() | |
230 | } | |
231 | ||
232 | if elemType.Kind() == reflect.Struct { | |
233 | if many2many := gormSettings["MANY2MANY"]; many2many != "" { | |
234 | relationship.Kind = "many_to_many" | |
235 | ||
236 | // foreign keys | |
241 | defer func(field *StructField) { | |
242 | var ( | |
243 | relationship = &Relationship{} | |
244 | toScope = scope.New(reflect.New(field.Struct.Type).Interface()) | |
245 | foreignKeys []string | |
246 | associationForeignKeys []string | |
247 | elemType = field.Struct.Type | |
248 | ) | |
249 | ||
250 | if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { | |
251 | foreignKeys = strings.Split(foreignKey, ",") | |
252 | } | |
253 | ||
254 | if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { | |
255 | associationForeignKeys = strings.Split(foreignKey, ",") | |
256 | } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { | |
257 | associationForeignKeys = strings.Split(foreignKey, ",") | |
258 | } | |
259 | ||
260 | for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { | |
261 | elemType = elemType.Elem() | |
262 | } | |
263 | ||
264 | if elemType.Kind() == reflect.Struct { | |
265 | if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { | |
266 | relationship.Kind = "many_to_many" | |
267 | ||
268 | { // Foreign Keys for Source | |
269 | joinTableDBNames := []string{} | |
270 | ||
271 | if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" { | |
272 | joinTableDBNames = strings.Split(foreignKey, ",") | |
273 | } | |
274 | ||
275 | // if no foreign keys defined with tag | |
276 | if len(foreignKeys) == 0 { | |
277 | for _, field := range modelStruct.PrimaryFields { | |
278 | foreignKeys = append(foreignKeys, field.DBName) | |
279 | } | |
280 | } | |
281 | ||
282 | for idx, foreignKey := range foreignKeys { | |
283 | if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { | |
284 | // source foreign keys (db names) | |
285 | relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) | |
286 | ||
287 | // setup join table foreign keys for source | |
288 | if len(joinTableDBNames) > idx { | |
289 | // if defined join table's foreign key | |
290 | relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) | |
291 | } else { | |
292 | defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName | |
293 | relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) | |
294 | } | |
295 | } | |
296 | } | |
297 | } | |
298 | ||
299 | { // Foreign Keys for Association (Destination) | |
300 | associationJoinTableDBNames := []string{} | |
301 | ||
302 | if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" { | |
303 | associationJoinTableDBNames = strings.Split(foreignKey, ",") | |
304 | } | |
305 | ||
306 | // if no association foreign keys defined with tag | |
307 | if len(associationForeignKeys) == 0 { | |
308 | for _, field := range toScope.PrimaryFields() { | |
309 | associationForeignKeys = append(associationForeignKeys, field.DBName) | |
310 | } | |
311 | } | |
312 | ||
313 | for idx, name := range associationForeignKeys { | |
314 | if field, ok := toScope.FieldByName(name); ok { | |
315 | // association foreign keys (db names) | |
316 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) | |
317 | ||
318 | // setup join table foreign keys for association | |
319 | if len(associationJoinTableDBNames) > idx { | |
320 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) | |
321 | } else { | |
322 | // join table foreign keys for association | |
323 | joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName | |
324 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) | |
325 | } | |
326 | } | |
327 | } | |
328 | } | |
329 | ||
330 | joinTableHandler := JoinTableHandler{} | |
331 | joinTableHandler.Setup(relationship, many2many, reflectType, elemType) | |
332 | relationship.JoinTableHandler = &joinTableHandler | |
333 | field.Relationship = relationship | |
334 | } else { | |
335 | // User has many comments, associationType is User, comment use UserID as foreign key | |
336 | var associationType = reflectType.Name() | |
337 | var toFields = toScope.GetStructFields() | |
338 | relationship.Kind = "has_many" | |
339 | ||
340 | if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { | |
341 | // Dog has many toys, tag polymorphic is Owner, then associationType is Owner | |
342 | // Toy use OwnerID, OwnerType ('dogs') as foreign key | |
343 | if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { | |
344 | associationType = polymorphic | |
345 | relationship.PolymorphicType = polymorphicType.Name | |
346 | relationship.PolymorphicDBName = polymorphicType.DBName | |
347 | // if Dog has multiple set of toys set name of the set (instead of default 'dogs') | |
348 | if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { | |
349 | relationship.PolymorphicValue = value | |
350 | } else { | |
351 | relationship.PolymorphicValue = scope.TableName() | |
352 | } | |
353 | polymorphicType.IsForeignKey = true | |
354 | } | |
355 | } | |
356 | ||
357 | // if no foreign keys defined with tag | |
358 | if len(foreignKeys) == 0 { | |
359 | // if no association foreign keys defined with tag | |
360 | if len(associationForeignKeys) == 0 { | |
361 | for _, field := range modelStruct.PrimaryFields { | |
362 | foreignKeys = append(foreignKeys, associationType+field.Name) | |
363 | associationForeignKeys = append(associationForeignKeys, field.Name) | |
364 | } | |
365 | } else { | |
366 | // generate foreign keys from defined association foreign keys | |
367 | for _, scopeFieldName := range associationForeignKeys { | |
368 | if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil { | |
369 | foreignKeys = append(foreignKeys, associationType+foreignField.Name) | |
370 | associationForeignKeys = append(associationForeignKeys, foreignField.Name) | |
371 | } | |
372 | } | |
373 | } | |
374 | } else { | |
375 | // generate association foreign keys from foreign keys | |
376 | if len(associationForeignKeys) == 0 { | |
377 | for _, foreignKey := range foreignKeys { | |
378 | if strings.HasPrefix(foreignKey, associationType) { | |
379 | associationForeignKey := strings.TrimPrefix(foreignKey, associationType) | |
380 | if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { | |
381 | associationForeignKeys = append(associationForeignKeys, associationForeignKey) | |
382 | } | |
383 | } | |
384 | } | |
385 | if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { | |
386 | associationForeignKeys = []string{scope.PrimaryKey()} | |
387 | } | |
388 | } else if len(foreignKeys) != len(associationForeignKeys) { | |
389 | scope.Err(errors.New("invalid foreign keys, should have same length")) | |
390 | return | |
391 | } | |
392 | } | |
393 | ||
394 | for idx, foreignKey := range foreignKeys { | |
395 | if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { | |
396 | if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { | |
397 | // source foreign keys | |
398 | foreignField.IsForeignKey = true | |
399 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) | |
400 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) | |
401 | ||
402 | // association foreign keys | |
403 | relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) | |
404 | relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) | |
405 | } | |
406 | } | |
407 | } | |
408 | ||
409 | if len(relationship.ForeignFieldNames) != 0 { | |
410 | field.Relationship = relationship | |
411 | } | |
412 | } | |
413 | } else { | |
414 | field.IsNormal = true | |
415 | } | |
416 | }(field) | |
417 | case reflect.Struct: | |
418 | defer func(field *StructField) { | |
419 | var ( | |
420 | // user has one profile, associationType is User, profile use UserID as foreign key | |
421 | // user belongs to profile, associationType is Profile, user use ProfileID as foreign key | |
422 | associationType = reflectType.Name() | |
423 | relationship = &Relationship{} | |
424 | toScope = scope.New(reflect.New(field.Struct.Type).Interface()) | |
425 | toFields = toScope.GetStructFields() | |
426 | tagForeignKeys []string | |
427 | tagAssociationForeignKeys []string | |
428 | ) | |
429 | ||
430 | if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { | |
431 | tagForeignKeys = strings.Split(foreignKey, ",") | |
432 | } | |
433 | ||
434 | if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { | |
435 | tagAssociationForeignKeys = strings.Split(foreignKey, ",") | |
436 | } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { | |
437 | tagAssociationForeignKeys = strings.Split(foreignKey, ",") | |
438 | } | |
439 | ||
440 | if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { | |
441 | // Cat has one toy, tag polymorphic is Owner, then associationType is Owner | |
442 | // Toy use OwnerID, OwnerType ('cats') as foreign key | |
443 | if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { | |
444 | associationType = polymorphic | |
445 | relationship.PolymorphicType = polymorphicType.Name | |
446 | relationship.PolymorphicDBName = polymorphicType.DBName | |
447 | // if Cat has several different types of toys set name for each (instead of default 'cats') | |
448 | if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { | |
449 | relationship.PolymorphicValue = value | |
450 | } else { | |
451 | relationship.PolymorphicValue = scope.TableName() | |
452 | } | |
453 | polymorphicType.IsForeignKey = true | |
454 | } | |
455 | } | |
456 | ||
457 | // Has One | |
458 | { | |
459 | var foreignKeys = tagForeignKeys | |
460 | var associationForeignKeys = tagAssociationForeignKeys | |
461 | // if no foreign keys defined with tag | |
237 | 462 | if len(foreignKeys) == 0 { |
238 | for _, field := range scope.PrimaryFields() { | |
239 | foreignKeys = append(foreignKeys, field.DBName) | |
240 | } | |
241 | } | |
242 | ||
243 | for _, foreignKey := range foreignKeys { | |
244 | if field, ok := scope.FieldByName(foreignKey); ok { | |
245 | relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName) | |
246 | joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName | |
247 | relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) | |
248 | } | |
249 | } | |
250 | ||
251 | // association foreign keys | |
252 | var associationForeignKeys []string | |
253 | if foreignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { | |
254 | associationForeignKeys = []string{gormSettings["ASSOCIATIONFOREIGNKEY"]} | |
463 | // if no association foreign keys defined with tag | |
464 | if len(associationForeignKeys) == 0 { | |
465 | for _, primaryField := range modelStruct.PrimaryFields { | |
466 | foreignKeys = append(foreignKeys, associationType+primaryField.Name) | |
467 | associationForeignKeys = append(associationForeignKeys, primaryField.Name) | |
468 | } | |
469 | } else { | |
470 | // generate foreign keys form association foreign keys | |
471 | for _, associationForeignKey := range tagAssociationForeignKeys { | |
472 | if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { | |
473 | foreignKeys = append(foreignKeys, associationType+foreignField.Name) | |
474 | associationForeignKeys = append(associationForeignKeys, foreignField.Name) | |
475 | } | |
476 | } | |
477 | } | |
255 | 478 | } else { |
256 | for _, field := range toScope.PrimaryFields() { | |
257 | associationForeignKeys = append(associationForeignKeys, field.DBName) | |
258 | } | |
259 | } | |
260 | ||
261 | for _, name := range associationForeignKeys { | |
262 | if field, ok := toScope.FieldByName(name); ok { | |
263 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) | |
264 | relationship.AssociationForeignStructFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) | |
265 | joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName | |
266 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) | |
267 | } | |
268 | } | |
269 | ||
270 | joinTableHandler := JoinTableHandler{} | |
271 | joinTableHandler.Setup(relationship, many2many, scopeType, elemType) | |
272 | relationship.JoinTableHandler = &joinTableHandler | |
273 | field.Relationship = relationship | |
274 | } else { | |
275 | relationship.Kind = "has_many" | |
276 | ||
277 | if len(foreignKeys) == 0 { | |
278 | for _, field := range scope.PrimaryFields() { | |
279 | if foreignField := getForeignField(scopeType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil { | |
280 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) | |
281 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) | |
479 | // generate association foreign keys from foreign keys | |
480 | if len(associationForeignKeys) == 0 { | |
481 | for _, foreignKey := range foreignKeys { | |
482 | if strings.HasPrefix(foreignKey, associationType) { | |
483 | associationForeignKey := strings.TrimPrefix(foreignKey, associationType) | |
484 | if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { | |
485 | associationForeignKeys = append(associationForeignKeys, associationForeignKey) | |
486 | } | |
487 | } | |
488 | } | |
489 | if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { | |
490 | associationForeignKeys = []string{scope.PrimaryKey()} | |
491 | } | |
492 | } else if len(foreignKeys) != len(associationForeignKeys) { | |
493 | scope.Err(errors.New("invalid foreign keys, should have same length")) | |
494 | return | |
495 | } | |
496 | } | |
497 | ||
498 | for idx, foreignKey := range foreignKeys { | |
499 | if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { | |
500 | if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { | |
501 | foreignField.IsForeignKey = true | |
502 | // source foreign keys | |
503 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) | |
504 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) | |
505 | ||
506 | // association foreign keys | |
282 | 507 | relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) |
283 | 508 | relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) |
284 | foreignField.IsForeignKey = true | |
285 | } | |
286 | } | |
287 | } else { | |
288 | for _, foreignKey := range foreignKeys { | |
289 | if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { | |
290 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) | |
291 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) | |
292 | relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) | |
293 | relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) | |
294 | foreignField.IsForeignKey = true | |
295 | } | |
296 | } | |
297 | } | |
298 | ||
299 | if len(relationship.ForeignFieldNames) != 0 { | |
300 | field.Relationship = relationship | |
301 | } | |
302 | } | |
303 | } else { | |
304 | field.IsNormal = true | |
305 | } | |
306 | case reflect.Struct: | |
307 | if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { | |
308 | for _, toField := range toScope.GetStructFields() { | |
309 | toField = toField.clone() | |
310 | toField.Names = append([]string{fieldStruct.Name}, toField.Names...) | |
311 | modelStruct.StructFields = append(modelStruct.StructFields, toField) | |
312 | if toField.IsPrimaryKey { | |
313 | modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField) | |
314 | } | |
315 | } | |
316 | continue | |
317 | } else { | |
318 | if len(foreignKeys) == 0 { | |
319 | for _, f := range scope.PrimaryFields() { | |
320 | if foreignField := getForeignField(modelStruct.ModelType.Name()+f.Name, toScope.GetStructFields()); foreignField != nil { | |
321 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name) | |
322 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName) | |
323 | relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) | |
324 | relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) | |
325 | foreignField.IsForeignKey = true | |
326 | } | |
327 | } | |
328 | } else { | |
329 | for _, foreignKey := range foreignKeys { | |
330 | if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { | |
331 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) | |
332 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) | |
333 | relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) | |
334 | relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) | |
335 | foreignField.IsForeignKey = true | |
509 | } | |
336 | 510 | } |
337 | 511 | } |
338 | 512 | } |
341 | 515 | relationship.Kind = "has_one" |
342 | 516 | field.Relationship = relationship |
343 | 517 | } else { |
518 | var foreignKeys = tagForeignKeys | |
519 | var associationForeignKeys = tagAssociationForeignKeys | |
520 | ||
344 | 521 | if len(foreignKeys) == 0 { |
345 | for _, f := range toScope.PrimaryFields() { | |
346 | if foreignField := getForeignField(field.Name+f.Name, fields); foreignField != nil { | |
347 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name) | |
348 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName) | |
522 | // generate foreign keys & association foreign keys | |
523 | if len(associationForeignKeys) == 0 { | |
524 | for _, primaryField := range toScope.PrimaryFields() { | |
525 | foreignKeys = append(foreignKeys, field.Name+primaryField.Name) | |
526 | associationForeignKeys = append(associationForeignKeys, primaryField.Name) | |
527 | } | |
528 | } else { | |
529 | // generate foreign keys with association foreign keys | |
530 | for _, associationForeignKey := range associationForeignKeys { | |
531 | if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { | |
532 | foreignKeys = append(foreignKeys, field.Name+foreignField.Name) | |
533 | associationForeignKeys = append(associationForeignKeys, foreignField.Name) | |
534 | } | |
535 | } | |
536 | } | |
537 | } else { | |
538 | // generate foreign keys & association foreign keys | |
539 | if len(associationForeignKeys) == 0 { | |
540 | for _, foreignKey := range foreignKeys { | |
541 | if strings.HasPrefix(foreignKey, field.Name) { | |
542 | associationForeignKey := strings.TrimPrefix(foreignKey, field.Name) | |
543 | if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { | |
544 | associationForeignKeys = append(associationForeignKeys, associationForeignKey) | |
545 | } | |
546 | } | |
547 | } | |
548 | if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { | |
549 | associationForeignKeys = []string{toScope.PrimaryKey()} | |
550 | } | |
551 | } else if len(foreignKeys) != len(associationForeignKeys) { | |
552 | scope.Err(errors.New("invalid foreign keys, should have same length")) | |
553 | return | |
554 | } | |
555 | } | |
556 | ||
557 | for idx, foreignKey := range foreignKeys { | |
558 | if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { | |
559 | if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { | |
560 | foreignField.IsForeignKey = true | |
561 | ||
562 | // association foreign keys | |
563 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) | |
564 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) | |
565 | ||
566 | // source foreign keys | |
349 | 567 | relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) |
350 | 568 | relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) |
351 | foreignField.IsForeignKey = true | |
352 | } | |
353 | } | |
354 | } else { | |
355 | for _, foreignKey := range foreignKeys { | |
356 | if foreignField := getForeignField(foreignKey, fields); foreignField != nil { | |
357 | relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, toScope.PrimaryField().Name) | |
358 | relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, toScope.PrimaryField().DBName) | |
359 | relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) | |
360 | relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) | |
361 | foreignField.IsForeignKey = true | |
362 | 569 | } |
363 | 570 | } |
364 | 571 | } |
368 | 575 | field.Relationship = relationship |
369 | 576 | } |
370 | 577 | } |
371 | } | |
578 | }(field) | |
372 | 579 | default: |
373 | 580 | field.IsNormal = true |
374 | 581 | } |
375 | 582 | } |
376 | ||
377 | if field.IsNormal { | |
378 | if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" { | |
379 | field.IsPrimaryKey = true | |
380 | modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) | |
381 | } | |
382 | } | |
383 | 583 | } |
584 | ||
585 | // Even it is ignored, also possible to decode db value into the field | |
586 | if value, ok := field.TagSettings["COLUMN"]; ok { | |
587 | field.DBName = value | |
588 | } else { | |
589 | field.DBName = ToDBName(fieldStruct.Name) | |
590 | } | |
591 | ||
384 | 592 | modelStruct.StructFields = append(modelStruct.StructFields, field) |
385 | 593 | } |
386 | finished <- true | |
387 | }(finished) | |
388 | ||
389 | modelStructsMap.Set(scopeType, &modelStruct) | |
390 | ||
391 | <-finished | |
392 | modelStruct.cached = true | |
594 | } | |
595 | ||
596 | if len(modelStruct.PrimaryFields) == 0 { | |
597 | if field := getForeignField("id", modelStruct.StructFields); field != nil { | |
598 | field.IsPrimaryKey = true | |
599 | modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) | |
600 | } | |
601 | } | |
602 | ||
603 | modelStructsMap.Set(reflectType, &modelStruct) | |
393 | 604 | |
394 | 605 | return &modelStruct |
395 | 606 | } |
396 | 607 | |
608 | // GetStructFields get model's field structs | |
397 | 609 | func (scope *Scope) GetStructFields() (fields []*StructField) { |
398 | 610 | return scope.GetModelStruct().StructFields |
399 | 611 | } |
400 | 612 | |
401 | func (scope *Scope) generateSqlTag(field *StructField) string { | |
402 | var sqlType string | |
403 | structType := field.Struct.Type | |
404 | if structType.Kind() == reflect.Ptr { | |
405 | structType = structType.Elem() | |
406 | } | |
407 | reflectValue := reflect.Indirect(reflect.New(structType)) | |
408 | sqlSettings := parseTagSetting(field.Tag.Get("sql")) | |
409 | ||
410 | if value, ok := sqlSettings["TYPE"]; ok { | |
411 | sqlType = value | |
412 | } | |
413 | ||
414 | additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"] | |
415 | if value, ok := sqlSettings["DEFAULT"]; ok { | |
416 | additionalType = additionalType + " DEFAULT " + value | |
417 | } | |
418 | ||
419 | if field.IsScanner { | |
420 | var getScannerValue func(reflect.Value) | |
421 | getScannerValue = func(value reflect.Value) { | |
422 | reflectValue = value | |
423 | if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct { | |
424 | getScannerValue(reflectValue.Field(0)) | |
613 | func parseTagSetting(tags reflect.StructTag) map[string]string { | |
614 | setting := map[string]string{} | |
615 | for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { | |
616 | tags := strings.Split(str, ";") | |
617 | for _, value := range tags { | |
618 | v := strings.Split(value, ":") | |
619 | k := strings.TrimSpace(strings.ToUpper(v[0])) | |
620 | if len(v) >= 2 { | |
621 | setting[k] = strings.Join(v[1:], ":") | |
622 | } else { | |
623 | setting[k] = k | |
425 | 624 | } |
426 | 625 | } |
427 | getScannerValue(reflectValue) | |
428 | } | |
429 | ||
430 | if sqlType == "" { | |
431 | var size = 255 | |
432 | ||
433 | if value, ok := sqlSettings["SIZE"]; ok { | |
434 | size, _ = strconv.Atoi(value) | |
435 | } | |
436 | ||
437 | _, autoIncrease := sqlSettings["AUTO_INCREMENT"] | |
438 | if field.IsPrimaryKey { | |
439 | autoIncrease = true | |
440 | } | |
441 | ||
442 | sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease) | |
443 | } | |
444 | ||
445 | if strings.TrimSpace(additionalType) == "" { | |
446 | return sqlType | |
447 | } else { | |
448 | return fmt.Sprintf("%v %v", sqlType, additionalType) | |
449 | } | |
450 | } | |
451 | ||
452 | func parseTagSetting(str string) map[string]string { | |
453 | tags := strings.Split(str, ";") | |
454 | setting := map[string]string{} | |
455 | for _, value := range tags { | |
456 | v := strings.Split(value, ":") | |
457 | k := strings.TrimSpace(strings.ToUpper(v[0])) | |
458 | if len(v) >= 2 { | |
459 | setting[k] = strings.Join(v[1:], ":") | |
460 | } else { | |
461 | setting[k] = k | |
462 | } | |
463 | 626 | } |
464 | 627 | return setting |
465 | 628 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "reflect" | |
5 | "time" | |
6 | ) | |
7 | ||
8 | type mssql struct { | |
9 | commonDialect | |
10 | } | |
11 | ||
12 | func (mssql) HasTop() bool { | |
13 | return true | |
14 | } | |
15 | ||
16 | func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { | |
17 | switch value.Kind() { | |
18 | case reflect.Bool: | |
19 | return "bit" | |
20 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | |
21 | if autoIncrease { | |
22 | return "int IDENTITY(1,1)" | |
23 | } | |
24 | return "int" | |
25 | case reflect.Int64, reflect.Uint64: | |
26 | if autoIncrease { | |
27 | return "bigint IDENTITY(1,1)" | |
28 | } | |
29 | return "bigint" | |
30 | case reflect.Float32, reflect.Float64: | |
31 | return "float" | |
32 | case reflect.String: | |
33 | if size > 0 && size < 65532 { | |
34 | return fmt.Sprintf("nvarchar(%d)", size) | |
35 | } | |
36 | return "text" | |
37 | case reflect.Struct: | |
38 | if _, ok := value.Interface().(time.Time); ok { | |
39 | return "datetime2" | |
40 | } | |
41 | default: | |
42 | if _, ok := value.Interface().([]byte); ok { | |
43 | if size > 0 && size < 65532 { | |
44 | return fmt.Sprintf("varchar(%d)", size) | |
45 | } | |
46 | return "text" | |
47 | } | |
48 | } | |
49 | panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String())) | |
50 | } | |
51 | ||
52 | func (s mssql) HasTable(scope *Scope, tableName string) bool { | |
53 | var ( | |
54 | count int | |
55 | databaseName = s.CurrentDatabase(scope) | |
56 | ) | |
57 | s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName) | |
58 | return count > 0 | |
59 | } | |
60 | ||
61 | func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { | |
62 | var ( | |
63 | count int | |
64 | databaseName = s.CurrentDatabase(scope) | |
65 | ) | |
66 | s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName) | |
67 | return count > 0 | |
68 | } | |
69 | ||
70 | func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { | |
71 | var count int | |
72 | s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName) | |
73 | return count > 0 | |
74 | } | |
75 | ||
76 | func (s mssql) CurrentDatabase(scope *Scope) (name string) { | |
77 | s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]") | |
78 | return | |
79 | } |
0 | 0 | package gorm_test |
1 | 1 | |
2 | 2 | import ( |
3 | "fmt" | |
4 | 3 | "os" |
4 | "reflect" | |
5 | "sort" | |
5 | 6 | "testing" |
6 | 7 | ) |
7 | 8 | |
8 | 9 | type Blog struct { |
9 | ID uint `gorm:"primary_key"` | |
10 | Locale string `gorm:"primary_key"` | |
11 | Subject string | |
12 | Body string | |
13 | Tags []Tag `gorm:"many2many:blog_tags;"` | |
10 | ID uint `gorm:"primary_key"` | |
11 | Locale string `gorm:"primary_key"` | |
12 | Subject string | |
13 | Body string | |
14 | Tags []Tag `gorm:"many2many:blog_tags;"` | |
15 | SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;AssociationForeignKey:id"` | |
16 | LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;AssociationForeignKey:id"` | |
14 | 17 | } |
15 | 18 | |
16 | 19 | type Tag struct { |
17 | 20 | ID uint `gorm:"primary_key"` |
18 | 21 | Locale string `gorm:"primary_key"` |
19 | 22 | Value string |
23 | Blogs []*Blog `gorm:"many2many:blogs_tags"` | |
24 | } | |
25 | ||
26 | func compareTags(tags []Tag, contents []string) bool { | |
27 | var tagContents []string | |
28 | for _, tag := range tags { | |
29 | tagContents = append(tagContents, tag.Value) | |
30 | } | |
31 | sort.Strings(tagContents) | |
32 | sort.Strings(contents) | |
33 | return reflect.DeepEqual(tagContents, contents) | |
20 | 34 | } |
21 | 35 | |
22 | 36 | func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { |
23 | if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { | |
24 | DB.Exec(fmt.Sprintf("drop table blog_tags;")) | |
25 | DB.AutoMigrate(&Blog{}, &Tag{}) | |
37 | if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { | |
38 | DB.DropTable(&Blog{}, &Tag{}) | |
39 | DB.DropTable("blog_tags") | |
40 | DB.CreateTable(&Blog{}, &Tag{}) | |
26 | 41 | blog := Blog{ |
27 | 42 | Locale: "ZH", |
28 | 43 | Subject: "subject", |
34 | 49 | } |
35 | 50 | |
36 | 51 | DB.Save(&blog) |
37 | DB.Model(&blog).Association("Tags").Append([]Tag{{Locale: "ZH", Value: "tag3"}}) | |
52 | if !compareTags(blog.Tags, []string{"tag1", "tag2"}) { | |
53 | t.Errorf("Blog should has two tags") | |
54 | } | |
55 | ||
56 | // Append | |
57 | var tag3 = &Tag{Locale: "ZH", Value: "tag3"} | |
58 | DB.Model(&blog).Association("Tags").Append([]*Tag{tag3}) | |
59 | if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) { | |
60 | t.Errorf("Blog should has three tags after Append") | |
61 | } | |
62 | ||
63 | if DB.Model(&blog).Association("Tags").Count() != 3 { | |
64 | t.Errorf("Blog should has three tags after Append") | |
65 | } | |
38 | 66 | |
39 | 67 | var tags []Tag |
40 | 68 | DB.Model(&blog).Related(&tags, "Tags") |
41 | if len(tags) != 3 { | |
42 | t.Errorf("should found 3 tags with blog") | |
69 | if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { | |
70 | t.Errorf("Should find 3 tags with Related") | |
71 | } | |
72 | ||
73 | var blog1 Blog | |
74 | DB.Preload("Tags").Find(&blog1) | |
75 | if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) { | |
76 | t.Errorf("Preload many2many relations") | |
77 | } | |
78 | ||
79 | // Replace | |
80 | var tag5 = &Tag{Locale: "ZH", Value: "tag5"} | |
81 | var tag6 = &Tag{Locale: "ZH", Value: "tag6"} | |
82 | DB.Model(&blog).Association("Tags").Replace(tag5, tag6) | |
83 | var tags2 []Tag | |
84 | DB.Model(&blog).Related(&tags2, "Tags") | |
85 | if !compareTags(tags2, []string{"tag5", "tag6"}) { | |
86 | t.Errorf("Should find 2 tags after Replace") | |
87 | } | |
88 | ||
89 | if DB.Model(&blog).Association("Tags").Count() != 2 { | |
90 | t.Errorf("Blog should has three tags after Replace") | |
91 | } | |
92 | ||
93 | // Delete | |
94 | DB.Model(&blog).Association("Tags").Delete(tag5) | |
95 | var tags3 []Tag | |
96 | DB.Model(&blog).Related(&tags3, "Tags") | |
97 | if !compareTags(tags3, []string{"tag6"}) { | |
98 | t.Errorf("Should find 1 tags after Delete") | |
99 | } | |
100 | ||
101 | if DB.Model(&blog).Association("Tags").Count() != 1 { | |
102 | t.Errorf("Blog should has three tags after Delete") | |
103 | } | |
104 | ||
105 | DB.Model(&blog).Association("Tags").Delete(tag3) | |
106 | var tags4 []Tag | |
107 | DB.Model(&blog).Related(&tags4, "Tags") | |
108 | if !compareTags(tags4, []string{"tag6"}) { | |
109 | t.Errorf("Tag should not be deleted when Delete with a unrelated tag") | |
110 | } | |
111 | ||
112 | // Clear | |
113 | DB.Model(&blog).Association("Tags").Clear() | |
114 | if DB.Model(&blog).Association("Tags").Count() != 0 { | |
115 | t.Errorf("All tags should be cleared") | |
43 | 116 | } |
44 | 117 | } |
45 | 118 | } |
119 | ||
120 | func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { | |
121 | if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { | |
122 | DB.DropTable(&Blog{}, &Tag{}) | |
123 | DB.DropTable("shared_blog_tags") | |
124 | DB.CreateTable(&Blog{}, &Tag{}) | |
125 | blog := Blog{ | |
126 | Locale: "ZH", | |
127 | Subject: "subject", | |
128 | Body: "body", | |
129 | SharedTags: []Tag{ | |
130 | {Locale: "ZH", Value: "tag1"}, | |
131 | {Locale: "ZH", Value: "tag2"}, | |
132 | }, | |
133 | } | |
134 | DB.Save(&blog) | |
135 | ||
136 | blog2 := Blog{ | |
137 | ID: blog.ID, | |
138 | Locale: "EN", | |
139 | } | |
140 | DB.Create(&blog2) | |
141 | ||
142 | if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) { | |
143 | t.Errorf("Blog should has two tags") | |
144 | } | |
145 | ||
146 | // Append | |
147 | var tag3 = &Tag{Locale: "ZH", Value: "tag3"} | |
148 | DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3}) | |
149 | if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) { | |
150 | t.Errorf("Blog should has three tags after Append") | |
151 | } | |
152 | ||
153 | if DB.Model(&blog).Association("SharedTags").Count() != 3 { | |
154 | t.Errorf("Blog should has three tags after Append") | |
155 | } | |
156 | ||
157 | if DB.Model(&blog2).Association("SharedTags").Count() != 3 { | |
158 | t.Errorf("Blog should has three tags after Append") | |
159 | } | |
160 | ||
161 | var tags []Tag | |
162 | DB.Model(&blog).Related(&tags, "SharedTags") | |
163 | if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { | |
164 | t.Errorf("Should find 3 tags with Related") | |
165 | } | |
166 | ||
167 | DB.Model(&blog2).Related(&tags, "SharedTags") | |
168 | if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { | |
169 | t.Errorf("Should find 3 tags with Related") | |
170 | } | |
171 | ||
172 | var blog1 Blog | |
173 | DB.Preload("SharedTags").Find(&blog1) | |
174 | if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) { | |
175 | t.Errorf("Preload many2many relations") | |
176 | } | |
177 | ||
178 | var tag4 = &Tag{Locale: "ZH", Value: "tag4"} | |
179 | DB.Model(&blog2).Association("SharedTags").Append(tag4) | |
180 | ||
181 | DB.Model(&blog).Related(&tags, "SharedTags") | |
182 | if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { | |
183 | t.Errorf("Should find 3 tags with Related") | |
184 | } | |
185 | ||
186 | DB.Model(&blog2).Related(&tags, "SharedTags") | |
187 | if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { | |
188 | t.Errorf("Should find 3 tags with Related") | |
189 | } | |
190 | ||
191 | // Replace | |
192 | var tag5 = &Tag{Locale: "ZH", Value: "tag5"} | |
193 | var tag6 = &Tag{Locale: "ZH", Value: "tag6"} | |
194 | DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6) | |
195 | var tags2 []Tag | |
196 | DB.Model(&blog).Related(&tags2, "SharedTags") | |
197 | if !compareTags(tags2, []string{"tag5", "tag6"}) { | |
198 | t.Errorf("Should find 2 tags after Replace") | |
199 | } | |
200 | ||
201 | DB.Model(&blog2).Related(&tags2, "SharedTags") | |
202 | if !compareTags(tags2, []string{"tag5", "tag6"}) { | |
203 | t.Errorf("Should find 2 tags after Replace") | |
204 | } | |
205 | ||
206 | if DB.Model(&blog).Association("SharedTags").Count() != 2 { | |
207 | t.Errorf("Blog should has three tags after Replace") | |
208 | } | |
209 | ||
210 | // Delete | |
211 | DB.Model(&blog).Association("SharedTags").Delete(tag5) | |
212 | var tags3 []Tag | |
213 | DB.Model(&blog).Related(&tags3, "SharedTags") | |
214 | if !compareTags(tags3, []string{"tag6"}) { | |
215 | t.Errorf("Should find 1 tags after Delete") | |
216 | } | |
217 | ||
218 | if DB.Model(&blog).Association("SharedTags").Count() != 1 { | |
219 | t.Errorf("Blog should has three tags after Delete") | |
220 | } | |
221 | ||
222 | DB.Model(&blog2).Association("SharedTags").Delete(tag3) | |
223 | var tags4 []Tag | |
224 | DB.Model(&blog).Related(&tags4, "SharedTags") | |
225 | if !compareTags(tags4, []string{"tag6"}) { | |
226 | t.Errorf("Tag should not be deleted when Delete with a unrelated tag") | |
227 | } | |
228 | ||
229 | // Clear | |
230 | DB.Model(&blog2).Association("SharedTags").Clear() | |
231 | if DB.Model(&blog).Association("SharedTags").Count() != 0 { | |
232 | t.Errorf("All tags should be cleared") | |
233 | } | |
234 | } | |
235 | } | |
236 | ||
237 | func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { | |
238 | if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { | |
239 | DB.DropTable(&Blog{}, &Tag{}) | |
240 | DB.DropTable("locale_blog_tags") | |
241 | DB.CreateTable(&Blog{}, &Tag{}) | |
242 | blog := Blog{ | |
243 | Locale: "ZH", | |
244 | Subject: "subject", | |
245 | Body: "body", | |
246 | LocaleTags: []Tag{ | |
247 | {Locale: "ZH", Value: "tag1"}, | |
248 | {Locale: "ZH", Value: "tag2"}, | |
249 | }, | |
250 | } | |
251 | DB.Save(&blog) | |
252 | ||
253 | blog2 := Blog{ | |
254 | ID: blog.ID, | |
255 | Locale: "EN", | |
256 | } | |
257 | DB.Create(&blog2) | |
258 | ||
259 | // Append | |
260 | var tag3 = &Tag{Locale: "ZH", Value: "tag3"} | |
261 | DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) | |
262 | if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { | |
263 | t.Errorf("Blog should has three tags after Append") | |
264 | } | |
265 | ||
266 | if DB.Model(&blog).Association("LocaleTags").Count() != 3 { | |
267 | t.Errorf("Blog should has three tags after Append") | |
268 | } | |
269 | ||
270 | if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { | |
271 | t.Errorf("EN Blog should has 0 tags after ZH Blog Append") | |
272 | } | |
273 | ||
274 | var tags []Tag | |
275 | DB.Model(&blog).Related(&tags, "LocaleTags") | |
276 | if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { | |
277 | t.Errorf("Should find 3 tags with Related") | |
278 | } | |
279 | ||
280 | DB.Model(&blog2).Related(&tags, "LocaleTags") | |
281 | if len(tags) != 0 { | |
282 | t.Errorf("Should find 0 tags with Related for EN Blog") | |
283 | } | |
284 | ||
285 | var blog1 Blog | |
286 | DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID) | |
287 | if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) { | |
288 | t.Errorf("Preload many2many relations") | |
289 | } | |
290 | ||
291 | var tag4 = &Tag{Locale: "ZH", Value: "tag4"} | |
292 | DB.Model(&blog2).Association("LocaleTags").Append(tag4) | |
293 | ||
294 | DB.Model(&blog).Related(&tags, "LocaleTags") | |
295 | if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { | |
296 | t.Errorf("Should find 3 tags with Related for EN Blog") | |
297 | } | |
298 | ||
299 | DB.Model(&blog2).Related(&tags, "LocaleTags") | |
300 | if !compareTags(tags, []string{"tag4"}) { | |
301 | t.Errorf("Should find 1 tags with Related for EN Blog") | |
302 | } | |
303 | ||
304 | // Replace | |
305 | var tag5 = &Tag{Locale: "ZH", Value: "tag5"} | |
306 | var tag6 = &Tag{Locale: "ZH", Value: "tag6"} | |
307 | DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) | |
308 | ||
309 | var tags2 []Tag | |
310 | DB.Model(&blog).Related(&tags2, "LocaleTags") | |
311 | if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) { | |
312 | t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") | |
313 | } | |
314 | ||
315 | var blog11 Blog | |
316 | DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale) | |
317 | if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) { | |
318 | t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") | |
319 | } | |
320 | ||
321 | DB.Model(&blog2).Related(&tags2, "LocaleTags") | |
322 | if !compareTags(tags2, []string{"tag5", "tag6"}) { | |
323 | t.Errorf("Should find 2 tags after Replace") | |
324 | } | |
325 | ||
326 | var blog21 Blog | |
327 | DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale) | |
328 | if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) { | |
329 | t.Errorf("EN Blog's tags should be changed after Replace") | |
330 | } | |
331 | ||
332 | if DB.Model(&blog).Association("LocaleTags").Count() != 3 { | |
333 | t.Errorf("ZH Blog should has three tags after Replace") | |
334 | } | |
335 | ||
336 | if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { | |
337 | t.Errorf("EN Blog should has two tags after Replace") | |
338 | } | |
339 | ||
340 | // Delete | |
341 | DB.Model(&blog).Association("LocaleTags").Delete(tag5) | |
342 | ||
343 | if DB.Model(&blog).Association("LocaleTags").Count() != 3 { | |
344 | t.Errorf("ZH Blog should has three tags after Delete with EN's tag") | |
345 | } | |
346 | ||
347 | if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { | |
348 | t.Errorf("EN Blog should has two tags after ZH Blog Delete with EN's tag") | |
349 | } | |
350 | ||
351 | DB.Model(&blog2).Association("LocaleTags").Delete(tag5) | |
352 | ||
353 | if DB.Model(&blog).Association("LocaleTags").Count() != 3 { | |
354 | t.Errorf("ZH Blog should has three tags after EN Blog Delete with EN's tag") | |
355 | } | |
356 | ||
357 | if DB.Model(&blog2).Association("LocaleTags").Count() != 1 { | |
358 | t.Errorf("EN Blog should has 1 tags after EN Blog Delete with EN's tag") | |
359 | } | |
360 | ||
361 | // Clear | |
362 | DB.Model(&blog2).Association("LocaleTags").Clear() | |
363 | if DB.Model(&blog).Association("LocaleTags").Count() != 3 { | |
364 | t.Errorf("ZH Blog's tags should not be cleared when clear EN Blog's tags") | |
365 | } | |
366 | ||
367 | if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { | |
368 | t.Errorf("EN Blog's tags should be cleared when clear EN Blog's tags") | |
369 | } | |
370 | ||
371 | DB.Model(&blog).Association("LocaleTags").Clear() | |
372 | if DB.Model(&blog).Association("LocaleTags").Count() != 0 { | |
373 | t.Errorf("ZH Blog's tags should be cleared when clear ZH Blog's tags") | |
374 | } | |
375 | ||
376 | if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { | |
377 | t.Errorf("EN Blog's tags should be cleared") | |
378 | } | |
379 | } | |
380 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "reflect" | |
5 | "time" | |
6 | ) | |
7 | ||
8 | type mysql struct { | |
9 | commonDialect | |
10 | } | |
11 | ||
12 | func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { | |
13 | switch value.Kind() { | |
14 | case reflect.Bool: | |
15 | return "boolean" | |
16 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: | |
17 | if autoIncrease { | |
18 | return "int AUTO_INCREMENT" | |
19 | } | |
20 | return "int" | |
21 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | |
22 | if autoIncrease { | |
23 | return "int unsigned AUTO_INCREMENT" | |
24 | } | |
25 | return "int unsigned" | |
26 | case reflect.Int64: | |
27 | if autoIncrease { | |
28 | return "bigint AUTO_INCREMENT" | |
29 | } | |
30 | return "bigint" | |
31 | case reflect.Uint64: | |
32 | if autoIncrease { | |
33 | return "bigint unsigned AUTO_INCREMENT" | |
34 | } | |
35 | return "bigint unsigned" | |
36 | case reflect.Float32, reflect.Float64: | |
37 | return "double" | |
38 | case reflect.String: | |
39 | if size > 0 && size < 65532 { | |
40 | return fmt.Sprintf("varchar(%d)", size) | |
41 | } | |
42 | return "longtext" | |
43 | case reflect.Struct: | |
44 | if _, ok := value.Interface().(time.Time); ok { | |
45 | return "timestamp NULL" | |
46 | } | |
47 | default: | |
48 | if _, ok := value.Interface().([]byte); ok { | |
49 | if size > 0 && size < 65532 { | |
50 | return fmt.Sprintf("varbinary(%d)", size) | |
51 | } | |
52 | return "longblob" | |
53 | } | |
54 | } | |
55 | panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) | |
56 | } | |
57 | ||
58 | func (mysql) Quote(key string) string { | |
59 | return fmt.Sprintf("`%s`", key) | |
60 | } | |
61 | ||
62 | func (mysql) SelectFromDummyTable() string { | |
63 | return "FROM DUAL" | |
64 | } | |
65 | ||
66 | func (s mysql) CurrentDatabase(scope *Scope) (name string) { | |
67 | s.RawScanString(scope, &name, "SELECT DATABASE()") | |
68 | return | |
69 | } |
38 | 38 | |
39 | 39 | var nilPointerStruct = PointerStruct{} |
40 | 40 | if err := DB.Create(&nilPointerStruct).Error; err != nil { |
41 | t.Errorf("Failed to save nil pointer struct", err) | |
41 | t.Error("Failed to save nil pointer struct", err) | |
42 | 42 | } |
43 | 43 | |
44 | 44 | var pointerStruct2 PointerStruct |
45 | 45 | if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { |
46 | t.Errorf("Failed to query saved nil pointer struct", err) | |
46 | t.Error("Failed to query saved nil pointer struct", err) | |
47 | 47 | } |
48 | 48 | |
49 | 49 | var normalStruct2 NormalStruct |
50 | 50 | if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { |
51 | t.Errorf("Failed to query saved nil pointer struct", err) | |
51 | t.Error("Failed to query saved nil pointer struct", err) | |
52 | 52 | } |
53 | 53 | |
54 | 54 | var partialNilPointerStruct1 = PointerStruct{Num: &num} |
55 | 55 | if err := DB.Create(&partialNilPointerStruct1).Error; err != nil { |
56 | t.Errorf("Failed to save partial nil pointer struct", err) | |
56 | t.Error("Failed to save partial nil pointer struct", err) | |
57 | 57 | } |
58 | 58 | |
59 | 59 | var pointerStruct3 PointerStruct |
60 | 60 | if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num { |
61 | t.Errorf("Failed to query saved partial nil pointer struct", err) | |
61 | t.Error("Failed to query saved partial nil pointer struct", err) | |
62 | 62 | } |
63 | 63 | |
64 | 64 | var normalStruct3 NormalStruct |
65 | 65 | if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num { |
66 | t.Errorf("Failed to query saved partial pointer struct", err) | |
66 | t.Error("Failed to query saved partial pointer struct", err) | |
67 | 67 | } |
68 | 68 | |
69 | 69 | var partialNilPointerStruct2 = PointerStruct{Name: &name} |
70 | 70 | if err := DB.Create(&partialNilPointerStruct2).Error; err != nil { |
71 | t.Errorf("Failed to save partial nil pointer struct", err) | |
71 | t.Error("Failed to save partial nil pointer struct", err) | |
72 | 72 | } |
73 | 73 | |
74 | 74 | var pointerStruct4 PointerStruct |
75 | 75 | if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name { |
76 | t.Errorf("Failed to query saved partial nil pointer struct", err) | |
76 | t.Error("Failed to query saved partial nil pointer struct", err) | |
77 | 77 | } |
78 | 78 | |
79 | 79 | var normalStruct4 NormalStruct |
80 | 80 | if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name { |
81 | t.Errorf("Failed to query saved partial pointer struct", err) | |
81 | t.Error("Failed to query saved partial pointer struct", err) | |
82 | 82 | } |
83 | 83 | } |
0 | 0 | package gorm_test |
1 | 1 | |
2 | import "testing" | |
2 | import ( | |
3 | "reflect" | |
4 | "sort" | |
5 | "testing" | |
6 | ) | |
3 | 7 | |
4 | 8 | type Cat struct { |
5 | 9 | Id int |
11 | 15 | Id int |
12 | 16 | Name string |
13 | 17 | Toys []Toy `gorm:"polymorphic:Owner;"` |
18 | } | |
19 | ||
20 | type Hamster struct { | |
21 | Id int | |
22 | Name string | |
23 | PreferredToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_preferred"` | |
24 | OtherToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_other"` | |
14 | 25 | } |
15 | 26 | |
16 | 27 | type Toy struct { |
20 | 31 | OwnerType string |
21 | 32 | } |
22 | 33 | |
34 | var compareToys = func(toys []Toy, contents []string) bool { | |
35 | var toyContents []string | |
36 | for _, toy := range toys { | |
37 | toyContents = append(toyContents, toy.Name) | |
38 | } | |
39 | sort.Strings(toyContents) | |
40 | sort.Strings(contents) | |
41 | return reflect.DeepEqual(toyContents, contents) | |
42 | } | |
43 | ||
23 | 44 | func TestPolymorphic(t *testing.T) { |
24 | DB.AutoMigrate(&Cat{}) | |
25 | DB.AutoMigrate(&Dog{}) | |
26 | DB.AutoMigrate(&Toy{}) | |
27 | ||
28 | cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat nip"}} | |
29 | dog := Dog{Name: "Pluto", Toys: []Toy{Toy{Name: "orange ball"}, Toy{Name: "yellow ball"}}} | |
45 | cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat toy"}} | |
46 | dog := Dog{Name: "Pluto", Toys: []Toy{{Name: "dog toy 1"}, {Name: "dog toy 2"}}} | |
30 | 47 | DB.Save(&cat).Save(&dog) |
31 | 48 | |
49 | if DB.Model(&cat).Association("Toy").Count() != 1 { | |
50 | t.Errorf("Cat's toys count should be 1") | |
51 | } | |
52 | ||
53 | if DB.Model(&dog).Association("Toys").Count() != 2 { | |
54 | t.Errorf("Dog's toys count should be 2") | |
55 | } | |
56 | ||
57 | // Query | |
32 | 58 | var catToys []Toy |
33 | 59 | if DB.Model(&cat).Related(&catToys, "Toy").RecordNotFound() { |
34 | 60 | t.Errorf("Did not find any has one polymorphic association") |
45 | 71 | t.Errorf("Should have found all polymorphic has many associations") |
46 | 72 | } |
47 | 73 | |
48 | if DB.Model(&cat).Association("Toy").Count() != 1 { | |
49 | t.Errorf("Should return one polymorphic has one association") | |
74 | var catToy Toy | |
75 | DB.Model(&cat).Association("Toy").Find(&catToy) | |
76 | if catToy.Name != cat.Toy.Name { | |
77 | t.Errorf("Should find has one polymorphic association") | |
78 | } | |
79 | ||
80 | var dogToys1 []Toy | |
81 | DB.Model(&dog).Association("Toys").Find(&dogToys1) | |
82 | if !compareToys(dogToys1, []string{"dog toy 1", "dog toy 2"}) { | |
83 | t.Errorf("Should find has many polymorphic association") | |
84 | } | |
85 | ||
86 | // Append | |
87 | DB.Model(&cat).Association("Toy").Append(&Toy{ | |
88 | Name: "cat toy 2", | |
89 | }) | |
90 | ||
91 | var catToy2 Toy | |
92 | DB.Model(&cat).Association("Toy").Find(&catToy2) | |
93 | if catToy2.Name != "cat toy 2" { | |
94 | t.Errorf("Should update has one polymorphic association with Append") | |
95 | } | |
96 | ||
97 | if DB.Model(&cat).Association("Toy").Count() != 1 { | |
98 | t.Errorf("Cat's toys count should be 1 after Append") | |
50 | 99 | } |
51 | 100 | |
52 | 101 | if DB.Model(&dog).Association("Toys").Count() != 2 { |
53 | 102 | t.Errorf("Should return two polymorphic has many associations") |
54 | 103 | } |
55 | } | |
104 | ||
105 | DB.Model(&dog).Association("Toys").Append(&Toy{ | |
106 | Name: "dog toy 3", | |
107 | }) | |
108 | ||
109 | var dogToys2 []Toy | |
110 | DB.Model(&dog).Association("Toys").Find(&dogToys2) | |
111 | if !compareToys(dogToys2, []string{"dog toy 1", "dog toy 2", "dog toy 3"}) { | |
112 | t.Errorf("Dog's toys should be updated with Append") | |
113 | } | |
114 | ||
115 | if DB.Model(&dog).Association("Toys").Count() != 3 { | |
116 | t.Errorf("Should return three polymorphic has many associations") | |
117 | } | |
118 | ||
119 | // Replace | |
120 | DB.Model(&cat).Association("Toy").Replace(&Toy{ | |
121 | Name: "cat toy 3", | |
122 | }) | |
123 | ||
124 | var catToy3 Toy | |
125 | DB.Model(&cat).Association("Toy").Find(&catToy3) | |
126 | if catToy3.Name != "cat toy 3" { | |
127 | t.Errorf("Should update has one polymorphic association with Replace") | |
128 | } | |
129 | ||
130 | if DB.Model(&cat).Association("Toy").Count() != 1 { | |
131 | t.Errorf("Cat's toys count should be 1 after Replace") | |
132 | } | |
133 | ||
134 | if DB.Model(&dog).Association("Toys").Count() != 3 { | |
135 | t.Errorf("Should return three polymorphic has many associations") | |
136 | } | |
137 | ||
138 | DB.Model(&dog).Association("Toys").Replace(&Toy{ | |
139 | Name: "dog toy 4", | |
140 | }, []Toy{ | |
141 | {Name: "dog toy 5"}, {Name: "dog toy 6"}, {Name: "dog toy 7"}, | |
142 | }) | |
143 | ||
144 | var dogToys3 []Toy | |
145 | DB.Model(&dog).Association("Toys").Find(&dogToys3) | |
146 | if !compareToys(dogToys3, []string{"dog toy 4", "dog toy 5", "dog toy 6", "dog toy 7"}) { | |
147 | t.Errorf("Dog's toys should be updated with Replace") | |
148 | } | |
149 | ||
150 | if DB.Model(&dog).Association("Toys").Count() != 4 { | |
151 | t.Errorf("Should return three polymorphic has many associations") | |
152 | } | |
153 | ||
154 | // Delete | |
155 | DB.Model(&cat).Association("Toy").Delete(&catToy2) | |
156 | ||
157 | var catToy4 Toy | |
158 | DB.Model(&cat).Association("Toy").Find(&catToy4) | |
159 | if catToy4.Name != "cat toy 3" { | |
160 | t.Errorf("Should not update has one polymorphic association when Delete a unrelated Toy") | |
161 | } | |
162 | ||
163 | if DB.Model(&cat).Association("Toy").Count() != 1 { | |
164 | t.Errorf("Cat's toys count should be 1") | |
165 | } | |
166 | ||
167 | if DB.Model(&dog).Association("Toys").Count() != 4 { | |
168 | t.Errorf("Dog's toys count should be 4") | |
169 | } | |
170 | ||
171 | DB.Model(&cat).Association("Toy").Delete(&catToy3) | |
172 | ||
173 | if !DB.Model(&cat).Related(&Toy{}, "Toy").RecordNotFound() { | |
174 | t.Errorf("Toy should be deleted with Delete") | |
175 | } | |
176 | ||
177 | if DB.Model(&cat).Association("Toy").Count() != 0 { | |
178 | t.Errorf("Cat's toys count should be 0 after Delete") | |
179 | } | |
180 | ||
181 | if DB.Model(&dog).Association("Toys").Count() != 4 { | |
182 | t.Errorf("Dog's toys count should not be changed when delete cat's toy") | |
183 | } | |
184 | ||
185 | DB.Model(&dog).Association("Toys").Delete(&dogToys2) | |
186 | ||
187 | if DB.Model(&dog).Association("Toys").Count() != 4 { | |
188 | t.Errorf("Dog's toys count should not be changed when delete unrelated toys") | |
189 | } | |
190 | ||
191 | DB.Model(&dog).Association("Toys").Delete(&dogToys3) | |
192 | ||
193 | if DB.Model(&dog).Association("Toys").Count() != 0 { | |
194 | t.Errorf("Dog's toys count should be deleted with Delete") | |
195 | } | |
196 | ||
197 | // Clear | |
198 | DB.Model(&cat).Association("Toy").Append(&Toy{ | |
199 | Name: "cat toy 2", | |
200 | }) | |
201 | ||
202 | if DB.Model(&cat).Association("Toy").Count() != 1 { | |
203 | t.Errorf("Cat's toys should be added with Append") | |
204 | } | |
205 | ||
206 | DB.Model(&cat).Association("Toy").Clear() | |
207 | ||
208 | if DB.Model(&cat).Association("Toy").Count() != 0 { | |
209 | t.Errorf("Cat's toys should be cleared with Clear") | |
210 | } | |
211 | ||
212 | DB.Model(&dog).Association("Toys").Append(&Toy{ | |
213 | Name: "dog toy 8", | |
214 | }) | |
215 | ||
216 | if DB.Model(&dog).Association("Toys").Count() != 1 { | |
217 | t.Errorf("Dog's toys should be added with Append") | |
218 | } | |
219 | ||
220 | DB.Model(&dog).Association("Toys").Clear() | |
221 | ||
222 | if DB.Model(&dog).Association("Toys").Count() != 0 { | |
223 | t.Errorf("Dog's toys should be cleared with Clear") | |
224 | } | |
225 | } | |
226 | ||
227 | func TestNamedPolymorphic(t *testing.T) { | |
228 | hamster := Hamster{Name: "Mr. Hammond", PreferredToy: Toy{Name: "bike"}, OtherToy: Toy{Name: "treadmill"}} | |
229 | DB.Save(&hamster) | |
230 | ||
231 | hamster2 := Hamster{} | |
232 | DB.Preload("PreferredToy").Preload("OtherToy").Find(&hamster2, hamster.Id) | |
233 | if hamster2.PreferredToy.Id != hamster.PreferredToy.Id || hamster2.PreferredToy.Name != hamster.PreferredToy.Name { | |
234 | t.Errorf("Hamster's preferred toy couldn't be preloaded") | |
235 | } | |
236 | if hamster2.OtherToy.Id != hamster.OtherToy.Id || hamster2.OtherToy.Name != hamster.OtherToy.Name { | |
237 | t.Errorf("Hamster's other toy couldn't be preloaded") | |
238 | } | |
239 | ||
240 | // clear to omit Toy.Id in count | |
241 | hamster2.PreferredToy = Toy{} | |
242 | hamster2.OtherToy = Toy{} | |
243 | ||
244 | if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { | |
245 | t.Errorf("Hamster's preferred toy count should be 1") | |
246 | } | |
247 | ||
248 | if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { | |
249 | t.Errorf("Hamster's other toy count should be 1") | |
250 | } | |
251 | ||
252 | // Query | |
253 | var hamsterToys []Toy | |
254 | if DB.Model(&hamster).Related(&hamsterToys, "PreferredToy").RecordNotFound() { | |
255 | t.Errorf("Did not find any has one polymorphic association") | |
256 | } else if len(hamsterToys) != 1 { | |
257 | t.Errorf("Should have found only one polymorphic has one association") | |
258 | } else if hamsterToys[0].Name != hamster.PreferredToy.Name { | |
259 | t.Errorf("Should have found the proper has one polymorphic association") | |
260 | } | |
261 | ||
262 | if DB.Model(&hamster).Related(&hamsterToys, "OtherToy").RecordNotFound() { | |
263 | t.Errorf("Did not find any has one polymorphic association") | |
264 | } else if len(hamsterToys) != 1 { | |
265 | t.Errorf("Should have found only one polymorphic has one association") | |
266 | } else if hamsterToys[0].Name != hamster.OtherToy.Name { | |
267 | t.Errorf("Should have found the proper has one polymorphic association") | |
268 | } | |
269 | ||
270 | hamsterToy := Toy{} | |
271 | DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) | |
272 | if hamsterToy.Name != hamster.PreferredToy.Name { | |
273 | t.Errorf("Should find has one polymorphic association") | |
274 | } | |
275 | hamsterToy = Toy{} | |
276 | DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) | |
277 | if hamsterToy.Name != hamster.OtherToy.Name { | |
278 | t.Errorf("Should find has one polymorphic association") | |
279 | } | |
280 | ||
281 | // Append | |
282 | DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ | |
283 | Name: "bike 2", | |
284 | }) | |
285 | DB.Model(&hamster).Association("OtherToy").Append(&Toy{ | |
286 | Name: "treadmill 2", | |
287 | }) | |
288 | ||
289 | hamsterToy = Toy{} | |
290 | DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) | |
291 | if hamsterToy.Name != "bike 2" { | |
292 | t.Errorf("Should update has one polymorphic association with Append") | |
293 | } | |
294 | ||
295 | hamsterToy = Toy{} | |
296 | DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) | |
297 | if hamsterToy.Name != "treadmill 2" { | |
298 | t.Errorf("Should update has one polymorphic association with Append") | |
299 | } | |
300 | ||
301 | if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { | |
302 | t.Errorf("Hamster's toys count should be 1 after Append") | |
303 | } | |
304 | ||
305 | if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { | |
306 | t.Errorf("Hamster's toys count should be 1 after Append") | |
307 | } | |
308 | ||
309 | // Replace | |
310 | DB.Model(&hamster).Association("PreferredToy").Replace(&Toy{ | |
311 | Name: "bike 3", | |
312 | }) | |
313 | DB.Model(&hamster).Association("OtherToy").Replace(&Toy{ | |
314 | Name: "treadmill 3", | |
315 | }) | |
316 | ||
317 | hamsterToy = Toy{} | |
318 | DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) | |
319 | if hamsterToy.Name != "bike 3" { | |
320 | t.Errorf("Should update has one polymorphic association with Replace") | |
321 | } | |
322 | ||
323 | hamsterToy = Toy{} | |
324 | DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) | |
325 | if hamsterToy.Name != "treadmill 3" { | |
326 | t.Errorf("Should update has one polymorphic association with Replace") | |
327 | } | |
328 | ||
329 | if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { | |
330 | t.Errorf("hamster's toys count should be 1 after Replace") | |
331 | } | |
332 | ||
333 | if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { | |
334 | t.Errorf("hamster's toys count should be 1 after Replace") | |
335 | } | |
336 | ||
337 | // Clear | |
338 | DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ | |
339 | Name: "bike 2", | |
340 | }) | |
341 | DB.Model(&hamster).Association("OtherToy").Append(&Toy{ | |
342 | Name: "treadmill 2", | |
343 | }) | |
344 | ||
345 | if DB.Model(&hamster).Association("PreferredToy").Count() != 1 { | |
346 | t.Errorf("Hamster's toys should be added with Append") | |
347 | } | |
348 | if DB.Model(&hamster).Association("OtherToy").Count() != 1 { | |
349 | t.Errorf("Hamster's toys should be added with Append") | |
350 | } | |
351 | ||
352 | DB.Model(&hamster).Association("PreferredToy").Clear() | |
353 | ||
354 | if DB.Model(&hamster2).Association("PreferredToy").Count() != 0 { | |
355 | t.Errorf("Hamster's preferred toy should be cleared with Clear") | |
356 | } | |
357 | if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { | |
358 | t.Errorf("Hamster's other toy should be still available") | |
359 | } | |
360 | ||
361 | DB.Model(&hamster).Association("OtherToy").Clear() | |
362 | if DB.Model(&hamster).Association("OtherToy").Count() != 0 { | |
363 | t.Errorf("Hamster's other toy should be cleared with Clear") | |
364 | } | |
365 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "database/sql" | |
4 | "database/sql/driver" | |
5 | "fmt" | |
6 | "reflect" | |
7 | "time" | |
8 | ||
9 | "github.com/lib/pq/hstore" | |
10 | ) | |
11 | ||
12 | type postgres struct { | |
13 | commonDialect | |
14 | } | |
15 | ||
16 | func (postgres) BinVar(i int) string { | |
17 | return fmt.Sprintf("$%v", i) | |
18 | } | |
19 | ||
20 | func (postgres) SupportLastInsertId() bool { | |
21 | return false | |
22 | } | |
23 | ||
24 | func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string { | |
25 | switch value.Kind() { | |
26 | case reflect.Bool: | |
27 | return "boolean" | |
28 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | |
29 | if autoIncrease { | |
30 | return "serial" | |
31 | } | |
32 | return "integer" | |
33 | case reflect.Int64, reflect.Uint64: | |
34 | if autoIncrease { | |
35 | return "bigserial" | |
36 | } | |
37 | return "bigint" | |
38 | case reflect.Float32, reflect.Float64: | |
39 | return "numeric" | |
40 | case reflect.String: | |
41 | if size > 0 && size < 65532 { | |
42 | return fmt.Sprintf("varchar(%d)", size) | |
43 | } | |
44 | return "text" | |
45 | case reflect.Struct: | |
46 | if _, ok := value.Interface().(time.Time); ok { | |
47 | return "timestamp with time zone" | |
48 | } | |
49 | case reflect.Map: | |
50 | if value.Type() == hstoreType { | |
51 | return "hstore" | |
52 | } | |
53 | default: | |
54 | if _, ok := value.Interface().([]byte); ok { | |
55 | return "bytea" | |
56 | } | |
57 | } | |
58 | panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String())) | |
59 | } | |
60 | ||
61 | func (s postgres) ReturningStr(tableName, key string) string { | |
62 | return fmt.Sprintf("RETURNING %v.%v", tableName, key) | |
63 | } | |
64 | ||
65 | func (s postgres) HasTable(scope *Scope, tableName string) bool { | |
66 | var count int | |
67 | s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName) | |
68 | return count > 0 | |
69 | } | |
70 | ||
71 | func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) bool { | |
72 | var count int | |
73 | s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName) | |
74 | return count > 0 | |
75 | } | |
76 | ||
77 | func (postgres) RemoveIndex(scope *Scope, indexName string) { | |
78 | scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) | |
79 | } | |
80 | ||
81 | func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { | |
82 | var count int | |
83 | s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName) | |
84 | return count > 0 | |
85 | } | |
86 | ||
87 | func (s postgres) CurrentDatabase(scope *Scope) (name string) { | |
88 | s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()") | |
89 | return | |
90 | } | |
91 | ||
92 | var hstoreType = reflect.TypeOf(Hstore{}) | |
93 | ||
94 | type Hstore map[string]*string | |
95 | ||
96 | func (h Hstore) Value() (driver.Value, error) { | |
97 | hstore := hstore.Hstore{Map: map[string]sql.NullString{}} | |
98 | if len(h) == 0 { | |
99 | return nil, nil | |
100 | } | |
101 | ||
102 | for key, value := range h { | |
103 | var s sql.NullString | |
104 | if value != nil { | |
105 | s.String = *value | |
106 | s.Valid = true | |
107 | } | |
108 | hstore.Map[key] = s | |
109 | } | |
110 | return hstore.Value() | |
111 | } | |
112 | ||
113 | func (h *Hstore) Scan(value interface{}) error { | |
114 | hstore := hstore.Hstore{} | |
115 | ||
116 | if err := hstore.Scan(value); err != nil { | |
117 | return err | |
118 | } | |
119 | ||
120 | if len(hstore.Map) == 0 { | |
121 | return nil | |
122 | } | |
123 | ||
124 | *h = Hstore{} | |
125 | for k := range hstore.Map { | |
126 | if hstore.Map[k].Valid { | |
127 | s := hstore.Map[k].String | |
128 | (*h)[k] = &s | |
129 | } else { | |
130 | (*h)[k] = nil | |
131 | } | |
132 | } | |
133 | ||
134 | return nil | |
135 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "database/sql/driver" | |
4 | "errors" | |
5 | "fmt" | |
6 | "reflect" | |
7 | "strings" | |
8 | ) | |
9 | ||
10 | func getRealValue(value reflect.Value, columns []string) (results []interface{}) { | |
11 | for _, column := range columns { | |
12 | if reflect.Indirect(value).FieldByName(column).IsValid() { | |
13 | result := reflect.Indirect(value).FieldByName(column).Interface() | |
14 | if r, ok := result.(driver.Valuer); ok { | |
15 | result, _ = r.Value() | |
16 | } | |
17 | results = append(results, result) | |
18 | } | |
19 | } | |
20 | return | |
21 | } | |
22 | ||
23 | func equalAsString(a interface{}, b interface{}) bool { | |
24 | return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b) | |
25 | } | |
26 | ||
27 | func Preload(scope *Scope) { | |
28 | if scope.Search.preload == nil { | |
29 | return | |
30 | } | |
31 | ||
32 | preloadMap := map[string]bool{} | |
33 | fields := scope.Fields() | |
34 | for _, preload := range scope.Search.preload { | |
35 | schema, conditions := preload.schema, preload.conditions | |
36 | keys := strings.Split(schema, ".") | |
37 | currentScope := scope | |
38 | currentFields := fields | |
39 | originalConditions := conditions | |
40 | conditions = []interface{}{} | |
41 | for i, key := range keys { | |
42 | var found bool | |
43 | if preloadMap[strings.Join(keys[:i+1], ".")] { | |
44 | goto nextLoop | |
45 | } | |
46 | ||
47 | if i == len(keys)-1 { | |
48 | conditions = originalConditions | |
49 | } | |
50 | ||
51 | for _, field := range currentFields { | |
52 | if field.Name != key || field.Relationship == nil { | |
53 | continue | |
54 | } | |
55 | ||
56 | found = true | |
57 | switch field.Relationship.Kind { | |
58 | case "has_one": | |
59 | currentScope.handleHasOnePreload(field, conditions) | |
60 | case "has_many": | |
61 | currentScope.handleHasManyPreload(field, conditions) | |
62 | case "belongs_to": | |
63 | currentScope.handleBelongsToPreload(field, conditions) | |
64 | case "many_to_many": | |
65 | currentScope.handleHasManyToManyPreload(field, conditions) | |
66 | default: | |
67 | currentScope.Err(errors.New("not supported relation")) | |
68 | } | |
69 | break | |
70 | } | |
71 | ||
72 | if !found { | |
73 | value := reflect.ValueOf(currentScope.Value) | |
74 | if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface { | |
75 | value = value.Index(0).Elem() | |
76 | } | |
77 | scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type())) | |
78 | return | |
79 | } | |
80 | ||
81 | preloadMap[strings.Join(keys[:i+1], ".")] = true | |
82 | ||
83 | nextLoop: | |
84 | if i < len(keys)-1 { | |
85 | currentScope = currentScope.getColumnsAsScope(key) | |
86 | currentFields = currentScope.Fields() | |
87 | } | |
88 | } | |
89 | } | |
90 | ||
91 | } | |
92 | ||
93 | func makeSlice(typ reflect.Type) interface{} { | |
94 | if typ.Kind() == reflect.Slice { | |
95 | typ = typ.Elem() | |
96 | } | |
97 | sliceType := reflect.SliceOf(typ) | |
98 | slice := reflect.New(sliceType) | |
99 | slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) | |
100 | return slice.Interface() | |
101 | } | |
102 | ||
103 | func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { | |
104 | relation := field.Relationship | |
105 | ||
106 | primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) | |
107 | if len(primaryKeys) == 0 { | |
108 | return | |
109 | } | |
110 | ||
111 | results := makeSlice(field.Struct.Type) | |
112 | scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) | |
113 | resultValues := reflect.Indirect(reflect.ValueOf(results)) | |
114 | ||
115 | for i := 0; i < resultValues.Len(); i++ { | |
116 | result := resultValues.Index(i) | |
117 | if scope.IndirectValue().Kind() == reflect.Slice { | |
118 | value := getRealValue(result, relation.ForeignFieldNames) | |
119 | objects := scope.IndirectValue() | |
120 | for j := 0; j < objects.Len(); j++ { | |
121 | if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) { | |
122 | reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) | |
123 | break | |
124 | } | |
125 | } | |
126 | } else { | |
127 | if err := scope.SetColumn(field, result); err != nil { | |
128 | scope.Err(err) | |
129 | return | |
130 | } | |
131 | } | |
132 | } | |
133 | } | |
134 | ||
135 | func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { | |
136 | relation := field.Relationship | |
137 | primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) | |
138 | if len(primaryKeys) == 0 { | |
139 | return | |
140 | } | |
141 | ||
142 | results := makeSlice(field.Struct.Type) | |
143 | scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) | |
144 | resultValues := reflect.Indirect(reflect.ValueOf(results)) | |
145 | ||
146 | if scope.IndirectValue().Kind() == reflect.Slice { | |
147 | for i := 0; i < resultValues.Len(); i++ { | |
148 | result := resultValues.Index(i) | |
149 | value := getRealValue(result, relation.ForeignFieldNames) | |
150 | objects := scope.IndirectValue() | |
151 | for j := 0; j < objects.Len(); j++ { | |
152 | object := reflect.Indirect(objects.Index(j)) | |
153 | if equalAsString(getRealValue(object, relation.AssociationForeignFieldNames), value) { | |
154 | f := object.FieldByName(field.Name) | |
155 | f.Set(reflect.Append(f, result)) | |
156 | break | |
157 | } | |
158 | } | |
159 | } | |
160 | } else { | |
161 | scope.SetColumn(field, resultValues) | |
162 | } | |
163 | } | |
164 | ||
165 | func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { | |
166 | relation := field.Relationship | |
167 | primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames) | |
168 | if len(primaryKeys) == 0 { | |
169 | return | |
170 | } | |
171 | ||
172 | results := makeSlice(field.Struct.Type) | |
173 | scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) | |
174 | resultValues := reflect.Indirect(reflect.ValueOf(results)) | |
175 | ||
176 | for i := 0; i < resultValues.Len(); i++ { | |
177 | result := resultValues.Index(i) | |
178 | if scope.IndirectValue().Kind() == reflect.Slice { | |
179 | value := getRealValue(result, relation.AssociationForeignFieldNames) | |
180 | objects := scope.IndirectValue() | |
181 | for j := 0; j < objects.Len(); j++ { | |
182 | object := reflect.Indirect(objects.Index(j)) | |
183 | if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) { | |
184 | object.FieldByName(field.Name).Set(result) | |
185 | } | |
186 | } | |
187 | } else { | |
188 | scope.SetColumn(field, result) | |
189 | } | |
190 | } | |
191 | } | |
192 | ||
193 | func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interface{}) { | |
194 | relation := field.Relationship | |
195 | ||
196 | joinTableHandler := relation.JoinTableHandler | |
197 | destType := field.StructField.Struct.Type.Elem() | |
198 | var isPtr bool | |
199 | if destType.Kind() == reflect.Ptr { | |
200 | isPtr = true | |
201 | destType = destType.Elem() | |
202 | } | |
203 | ||
204 | var sourceKeys []string | |
205 | var linkHash = make(map[string][]reflect.Value) | |
206 | ||
207 | for _, key := range joinTableHandler.SourceForeignKeys() { | |
208 | sourceKeys = append(sourceKeys, key.DBName) | |
209 | } | |
210 | ||
211 | db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*") | |
212 | preloadJoinDB := joinTableHandler.JoinWith(joinTableHandler, db, scope.Value) | |
213 | if len(conditions) > 0 { | |
214 | preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...) | |
215 | } | |
216 | rows, err := preloadJoinDB.Rows() | |
217 | ||
218 | if scope.Err(err) != nil { | |
219 | return | |
220 | } | |
221 | defer rows.Close() | |
222 | ||
223 | columns, _ := rows.Columns() | |
224 | for rows.Next() { | |
225 | elem := reflect.New(destType).Elem() | |
226 | var values = make([]interface{}, len(columns)) | |
227 | ||
228 | fields := scope.New(elem.Addr().Interface()).Fields() | |
229 | ||
230 | for index, column := range columns { | |
231 | if field, ok := fields[column]; ok { | |
232 | if field.Field.Kind() == reflect.Ptr { | |
233 | values[index] = field.Field.Addr().Interface() | |
234 | } else { | |
235 | values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface() | |
236 | } | |
237 | } else { | |
238 | var i interface{} | |
239 | values[index] = &i | |
240 | } | |
241 | } | |
242 | ||
243 | scope.Err(rows.Scan(values...)) | |
244 | ||
245 | var sourceKey []interface{} | |
246 | ||
247 | for index, column := range columns { | |
248 | value := values[index] | |
249 | if field, ok := fields[column]; ok { | |
250 | if field.Field.Kind() == reflect.Ptr { | |
251 | field.Field.Set(reflect.ValueOf(value).Elem()) | |
252 | } else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() { | |
253 | field.Field.Set(v) | |
254 | } | |
255 | } else if strInSlice(column, sourceKeys) { | |
256 | sourceKey = append(sourceKey, *(value.(*interface{}))) | |
257 | } | |
258 | } | |
259 | ||
260 | if len(sourceKey) != 0 { | |
261 | if isPtr { | |
262 | linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem.Addr()) | |
263 | } else { | |
264 | linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem) | |
265 | } | |
266 | } | |
267 | } | |
268 | ||
269 | var associationForeignStructFieldNames []string | |
270 | for _, dbName := range relation.AssociationForeignFieldNames { | |
271 | if field, ok := scope.FieldByName(dbName); ok { | |
272 | associationForeignStructFieldNames = append(associationForeignStructFieldNames, field.Name) | |
273 | } | |
274 | } | |
275 | ||
276 | if scope.IndirectValue().Kind() == reflect.Slice { | |
277 | objects := scope.IndirectValue() | |
278 | for j := 0; j < objects.Len(); j++ { | |
279 | object := reflect.Indirect(objects.Index(j)) | |
280 | source := getRealValue(object, associationForeignStructFieldNames) | |
281 | field := object.FieldByName(field.Name) | |
282 | for _, link := range linkHash[toString(source)] { | |
283 | field.Set(reflect.Append(field, link)) | |
284 | } | |
285 | } | |
286 | } else { | |
287 | object := scope.IndirectValue() | |
288 | source := getRealValue(object, associationForeignStructFieldNames) | |
289 | field := object.FieldByName(field.Name) | |
290 | for _, link := range linkHash[toString(source)] { | |
291 | field.Set(reflect.Append(field, link)) | |
292 | } | |
293 | } | |
294 | } | |
295 | ||
296 | func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) { | |
297 | values := scope.IndirectValue() | |
298 | switch values.Kind() { | |
299 | case reflect.Slice: | |
300 | for i := 0; i < values.Len(); i++ { | |
301 | var result []interface{} | |
302 | for _, column := range columns { | |
303 | result = append(result, reflect.Indirect(values.Index(i)).FieldByName(column).Interface()) | |
304 | } | |
305 | results = append(results, result) | |
306 | } | |
307 | case reflect.Struct: | |
308 | var result []interface{} | |
309 | for _, column := range columns { | |
310 | result = append(result, values.FieldByName(column).Interface()) | |
311 | } | |
312 | return [][]interface{}{result} | |
313 | } | |
314 | return | |
315 | } | |
316 | ||
317 | func (scope *Scope) getColumnsAsScope(column string) *Scope { | |
318 | values := scope.IndirectValue() | |
319 | switch values.Kind() { | |
320 | case reflect.Slice: | |
321 | modelType := values.Type().Elem() | |
322 | if modelType.Kind() == reflect.Ptr { | |
323 | modelType = modelType.Elem() | |
324 | } | |
325 | fieldStruct, _ := modelType.FieldByName(column) | |
326 | var columns reflect.Value | |
327 | if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr { | |
328 | columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem() | |
329 | } else { | |
330 | columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem() | |
331 | } | |
332 | for i := 0; i < values.Len(); i++ { | |
333 | column := reflect.Indirect(values.Index(i)).FieldByName(column) | |
334 | if column.Kind() == reflect.Ptr { | |
335 | column = column.Elem() | |
336 | } | |
337 | if column.Kind() == reflect.Slice { | |
338 | for i := 0; i < column.Len(); i++ { | |
339 | elem := column.Index(i) | |
340 | if elem.CanAddr() { | |
341 | columns = reflect.Append(columns, elem.Addr()) | |
342 | } | |
343 | } | |
344 | } else { | |
345 | if column.CanAddr() { | |
346 | columns = reflect.Append(columns, column.Addr()) | |
347 | } | |
348 | } | |
349 | } | |
350 | return scope.New(columns.Interface()) | |
351 | case reflect.Struct: | |
352 | field := values.FieldByName(column) | |
353 | if !field.CanAddr() { | |
354 | return nil | |
355 | } | |
356 | return scope.New(field.Addr().Interface()) | |
357 | } | |
358 | return nil | |
359 | } |
0 | 0 | package gorm_test |
1 | 1 | |
2 | 2 | import ( |
3 | "database/sql" | |
3 | 4 | "encoding/json" |
4 | 5 | "os" |
5 | 6 | "reflect" |
6 | 7 | "testing" |
8 | ||
9 | "github.com/jinzhu/gorm" | |
7 | 10 | ) |
8 | 11 | |
9 | 12 | func getPreloadUser(name string) *User { |
86 | 89 | } |
87 | 90 | } else if len(user.Emails) != 0 { |
88 | 91 | t.Errorf("should not preload any emails for other users when with condition") |
89 | } | |
92 | } else if user.Emails == nil { | |
93 | t.Errorf("should return an empty slice to indicate zero results") | |
94 | } | |
95 | } | |
96 | } | |
97 | ||
98 | func TestAutoPreload(t *testing.T) { | |
99 | user1 := getPreloadUser("auto_user1") | |
100 | DB.Save(user1) | |
101 | ||
102 | preloadDB := DB.Set("gorm:auto_preload", true).Where("role = ?", "Preload") | |
103 | var user User | |
104 | preloadDB.Find(&user) | |
105 | checkUserHasPreloadData(user, t) | |
106 | ||
107 | user2 := getPreloadUser("auto_user2") | |
108 | DB.Save(user2) | |
109 | ||
110 | var users []User | |
111 | preloadDB.Find(&users) | |
112 | ||
113 | for _, user := range users { | |
114 | checkUserHasPreloadData(user, t) | |
115 | } | |
116 | ||
117 | var users2 []*User | |
118 | preloadDB.Find(&users2) | |
119 | ||
120 | for _, user := range users2 { | |
121 | checkUserHasPreloadData(*user, t) | |
90 | 122 | } |
91 | 123 | } |
92 | 124 | |
112 | 144 | DB.DropTableIfExists(&Level2{}) |
113 | 145 | DB.DropTableIfExists(&Level1{}) |
114 | 146 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { |
115 | panic(err) | |
147 | t.Error(err) | |
116 | 148 | } |
117 | 149 | |
118 | 150 | want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} |
119 | 151 | if err := DB.Create(&want).Error; err != nil { |
120 | panic(err) | |
152 | t.Error(err) | |
121 | 153 | } |
122 | 154 | |
123 | 155 | var got Level3 |
124 | 156 | if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { |
125 | panic(err) | |
126 | } | |
127 | ||
128 | if !reflect.DeepEqual(got, want) { | |
129 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
157 | t.Error(err) | |
158 | } | |
159 | ||
160 | if !reflect.DeepEqual(got, want) { | |
161 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
162 | } | |
163 | ||
164 | if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound { | |
165 | t.Error(err) | |
130 | 166 | } |
131 | 167 | } |
132 | 168 | |
152 | 188 | DB.DropTableIfExists(&Level2{}) |
153 | 189 | DB.DropTableIfExists(&Level1{}) |
154 | 190 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { |
155 | panic(err) | |
191 | t.Error(err) | |
156 | 192 | } |
157 | 193 | |
158 | 194 | want := Level3{ |
159 | 195 | Level2s: []Level2{ |
160 | 196 | { |
161 | 197 | Level1s: []*Level1{ |
162 | &Level1{Value: "value1"}, | |
163 | &Level1{Value: "value2"}, | |
198 | {Value: "value1"}, | |
199 | {Value: "value2"}, | |
164 | 200 | }, |
165 | 201 | }, |
166 | 202 | { |
167 | 203 | Level1s: []*Level1{ |
168 | &Level1{Value: "value3"}, | |
204 | {Value: "value3"}, | |
169 | 205 | }, |
170 | 206 | }, |
171 | 207 | }, |
172 | 208 | } |
173 | 209 | if err := DB.Create(&want).Error; err != nil { |
174 | panic(err) | |
210 | t.Error(err) | |
175 | 211 | } |
176 | 212 | |
177 | 213 | var got Level3 |
178 | 214 | if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { |
179 | panic(err) | |
215 | t.Error(err) | |
180 | 216 | } |
181 | 217 | |
182 | 218 | if !reflect.DeepEqual(got, want) { |
206 | 242 | DB.DropTableIfExists(&Level2{}) |
207 | 243 | DB.DropTableIfExists(&Level1{}) |
208 | 244 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { |
209 | panic(err) | |
245 | t.Error(err) | |
210 | 246 | } |
211 | 247 | |
212 | 248 | want := Level3{ |
216 | 252 | }, |
217 | 253 | } |
218 | 254 | if err := DB.Create(&want).Error; err != nil { |
219 | panic(err) | |
255 | t.Error(err) | |
220 | 256 | } |
221 | 257 | |
222 | 258 | var got Level3 |
223 | 259 | if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { |
224 | panic(err) | |
260 | t.Error(err) | |
225 | 261 | } |
226 | 262 | |
227 | 263 | if !reflect.DeepEqual(got, want) { |
251 | 287 | DB.DropTableIfExists(&Level2{}) |
252 | 288 | DB.DropTableIfExists(&Level1{}) |
253 | 289 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { |
254 | panic(err) | |
290 | t.Error(err) | |
255 | 291 | } |
256 | 292 | |
257 | 293 | want := Level3{ |
258 | 294 | Level2: Level2{ |
259 | 295 | Level1s: []Level1{ |
260 | Level1{Value: "value1"}, | |
261 | Level1{Value: "value2"}, | |
296 | {Value: "value1"}, | |
297 | {Value: "value2"}, | |
262 | 298 | }, |
263 | 299 | }, |
264 | 300 | } |
265 | 301 | if err := DB.Create(&want).Error; err != nil { |
266 | panic(err) | |
302 | t.Error(err) | |
267 | 303 | } |
268 | 304 | |
269 | 305 | var got Level3 |
270 | 306 | if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { |
271 | panic(err) | |
307 | t.Error(err) | |
272 | 308 | } |
273 | 309 | |
274 | 310 | if !reflect.DeepEqual(got, want) { |
299 | 335 | DB.DropTableIfExists(&Level2{}) |
300 | 336 | DB.DropTableIfExists(&Level1{}) |
301 | 337 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { |
302 | panic(err) | |
338 | t.Error(err) | |
303 | 339 | } |
304 | 340 | |
305 | 341 | want := make([]Level3, 2) |
306 | 342 | want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} |
307 | 343 | if err := DB.Create(&want[0]).Error; err != nil { |
308 | panic(err) | |
344 | t.Error(err) | |
309 | 345 | } |
310 | 346 | want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} |
311 | 347 | if err := DB.Create(&want[1]).Error; err != nil { |
312 | panic(err) | |
348 | t.Error(err) | |
313 | 349 | } |
314 | 350 | |
315 | 351 | var got []Level3 |
316 | 352 | if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { |
317 | panic(err) | |
353 | t.Error(err) | |
318 | 354 | } |
319 | 355 | |
320 | 356 | if !reflect.DeepEqual(got, want) { |
344 | 380 | DB.DropTableIfExists(&Level2{}) |
345 | 381 | DB.DropTableIfExists(&Level1{}) |
346 | 382 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { |
347 | panic(err) | |
383 | t.Error(err) | |
348 | 384 | } |
349 | 385 | |
350 | 386 | want := make([]Level3, 2) |
364 | 400 | }, |
365 | 401 | } |
366 | 402 | if err := DB.Create(&want[0]).Error; err != nil { |
367 | panic(err) | |
403 | t.Error(err) | |
368 | 404 | } |
369 | 405 | |
370 | 406 | want[1] = Level3{ |
383 | 419 | }, |
384 | 420 | } |
385 | 421 | if err := DB.Create(&want[1]).Error; err != nil { |
386 | panic(err) | |
422 | t.Error(err) | |
387 | 423 | } |
388 | 424 | |
389 | 425 | var got []Level3 |
390 | 426 | if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { |
391 | panic(err) | |
427 | t.Error(err) | |
392 | 428 | } |
393 | 429 | |
394 | 430 | if !reflect.DeepEqual(got, want) { |
418 | 454 | DB.DropTableIfExists(&Level2{}) |
419 | 455 | DB.DropTableIfExists(&Level1{}) |
420 | 456 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { |
421 | panic(err) | |
457 | t.Error(err) | |
422 | 458 | } |
423 | 459 | |
424 | 460 | want := make([]Level3, 2) |
429 | 465 | }, |
430 | 466 | } |
431 | 467 | if err := DB.Create(&want[0]).Error; err != nil { |
432 | panic(err) | |
468 | t.Error(err) | |
433 | 469 | } |
434 | 470 | |
435 | 471 | want[1] = Level3{ |
439 | 475 | }, |
440 | 476 | } |
441 | 477 | if err := DB.Create(&want[1]).Error; err != nil { |
442 | panic(err) | |
478 | t.Error(err) | |
443 | 479 | } |
444 | 480 | |
445 | 481 | var got []Level3 |
446 | 482 | if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { |
447 | panic(err) | |
483 | t.Error(err) | |
448 | 484 | } |
449 | 485 | |
450 | 486 | if !reflect.DeepEqual(got, want) { |
474 | 510 | DB.DropTableIfExists(&Level2{}) |
475 | 511 | DB.DropTableIfExists(&Level1{}) |
476 | 512 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { |
477 | panic(err) | |
513 | t.Error(err) | |
478 | 514 | } |
479 | 515 | |
480 | 516 | want := make([]Level3, 2) |
481 | 517 | want[0] = Level3{ |
482 | 518 | Level2: Level2{ |
483 | 519 | Level1s: []Level1{ |
484 | Level1{Value: "value1"}, | |
485 | Level1{Value: "value2"}, | |
520 | {Value: "value1"}, | |
521 | {Value: "value2"}, | |
486 | 522 | }, |
487 | 523 | }, |
488 | 524 | } |
489 | 525 | if err := DB.Create(&want[0]).Error; err != nil { |
490 | panic(err) | |
526 | t.Error(err) | |
491 | 527 | } |
492 | 528 | want[1] = Level3{ |
493 | 529 | Level2: Level2{ |
494 | 530 | Level1s: []Level1{ |
495 | Level1{Value: "value3"}, | |
496 | Level1{Value: "value4"}, | |
531 | {Value: "value3"}, | |
532 | {Value: "value4"}, | |
497 | 533 | }, |
498 | 534 | }, |
499 | 535 | } |
500 | 536 | if err := DB.Create(&want[1]).Error; err != nil { |
501 | panic(err) | |
537 | t.Error(err) | |
502 | 538 | } |
503 | 539 | |
504 | 540 | var got []Level3 |
505 | 541 | if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { |
506 | panic(err) | |
542 | t.Error(err) | |
507 | 543 | } |
508 | 544 | |
509 | 545 | if !reflect.DeepEqual(got, want) { |
548 | 584 | DB.DropTableIfExists(&Level1{}) |
549 | 585 | DB.DropTableIfExists(&Level0{}) |
550 | 586 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil { |
551 | panic(err) | |
587 | t.Error(err) | |
552 | 588 | } |
553 | 589 | |
554 | 590 | want := make([]Level3, 2) |
555 | 591 | want[0] = Level3{ |
556 | 592 | Level2: Level2{ |
557 | 593 | Level1s: []Level1{ |
558 | Level1{Value: "value1"}, | |
559 | Level1{Value: "value2"}, | |
594 | {Value: "value1"}, | |
595 | {Value: "value2"}, | |
560 | 596 | }, |
561 | 597 | }, |
562 | 598 | Level2_1: Level2_1{ |
563 | 599 | Level1s: []Level1{ |
564 | Level1{ | |
600 | { | |
565 | 601 | Value: "value1-1", |
566 | 602 | Level0s: []Level0{{Value: "Level0-1"}}, |
567 | 603 | }, |
568 | Level1{ | |
604 | { | |
569 | 605 | Value: "value2-2", |
570 | 606 | Level0s: []Level0{{Value: "Level0-2"}}, |
571 | 607 | }, |
573 | 609 | }, |
574 | 610 | } |
575 | 611 | if err := DB.Create(&want[0]).Error; err != nil { |
576 | panic(err) | |
612 | t.Error(err) | |
577 | 613 | } |
578 | 614 | want[1] = Level3{ |
579 | 615 | Level2: Level2{ |
580 | 616 | Level1s: []Level1{ |
581 | Level1{Value: "value3"}, | |
582 | Level1{Value: "value4"}, | |
617 | {Value: "value3"}, | |
618 | {Value: "value4"}, | |
583 | 619 | }, |
584 | 620 | }, |
585 | 621 | Level2_1: Level2_1{ |
586 | 622 | Level1s: []Level1{ |
587 | Level1{Value: "value3-3"}, | |
588 | Level1{Value: "value4-4"}, | |
623 | { | |
624 | Value: "value3-3", | |
625 | Level0s: []Level0{}, | |
626 | }, | |
627 | { | |
628 | Value: "value4-4", | |
629 | Level0s: []Level0{}, | |
630 | }, | |
589 | 631 | }, |
590 | 632 | }, |
591 | 633 | } |
592 | 634 | if err := DB.Create(&want[1]).Error; err != nil { |
593 | panic(err) | |
635 | t.Error(err) | |
594 | 636 | } |
595 | 637 | |
596 | 638 | var got []Level3 |
597 | 639 | if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil { |
598 | panic(err) | |
640 | t.Error(err) | |
641 | } | |
642 | ||
643 | if !reflect.DeepEqual(got, want) { | |
644 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
645 | } | |
646 | } | |
647 | ||
648 | type LevelA1 struct { | |
649 | ID uint | |
650 | Value string | |
651 | } | |
652 | ||
653 | type LevelA2 struct { | |
654 | ID uint | |
655 | Value string | |
656 | LevelA3s []*LevelA3 | |
657 | } | |
658 | ||
659 | type LevelA3 struct { | |
660 | ID uint | |
661 | Value string | |
662 | LevelA1ID sql.NullInt64 | |
663 | LevelA1 *LevelA1 | |
664 | LevelA2ID sql.NullInt64 | |
665 | LevelA2 *LevelA2 | |
666 | } | |
667 | ||
668 | func TestNestedPreload10(t *testing.T) { | |
669 | DB.DropTableIfExists(&LevelA3{}) | |
670 | DB.DropTableIfExists(&LevelA2{}) | |
671 | DB.DropTableIfExists(&LevelA1{}) | |
672 | ||
673 | if err := DB.AutoMigrate(&LevelA1{}, &LevelA2{}, &LevelA3{}).Error; err != nil { | |
674 | t.Error(err) | |
675 | } | |
676 | ||
677 | levelA1 := &LevelA1{Value: "foo"} | |
678 | if err := DB.Save(levelA1).Error; err != nil { | |
679 | t.Error(err) | |
680 | } | |
681 | ||
682 | want := []*LevelA2{ | |
683 | { | |
684 | Value: "bar", | |
685 | LevelA3s: []*LevelA3{ | |
686 | { | |
687 | Value: "qux", | |
688 | LevelA1: levelA1, | |
689 | }, | |
690 | }, | |
691 | }, | |
692 | { | |
693 | Value: "bar 2", | |
694 | LevelA3s: []*LevelA3{}, | |
695 | }, | |
696 | } | |
697 | for _, levelA2 := range want { | |
698 | if err := DB.Save(levelA2).Error; err != nil { | |
699 | t.Error(err) | |
700 | } | |
701 | } | |
702 | ||
703 | var got []*LevelA2 | |
704 | if err := DB.Preload("LevelA3s.LevelA1").Find(&got).Error; err != nil { | |
705 | t.Error(err) | |
706 | } | |
707 | ||
708 | if !reflect.DeepEqual(got, want) { | |
709 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
710 | } | |
711 | } | |
712 | ||
713 | type LevelB1 struct { | |
714 | ID uint | |
715 | Value string | |
716 | LevelB3s []*LevelB3 | |
717 | } | |
718 | ||
719 | type LevelB2 struct { | |
720 | ID uint | |
721 | Value string | |
722 | } | |
723 | ||
724 | type LevelB3 struct { | |
725 | ID uint | |
726 | Value string | |
727 | LevelB1ID sql.NullInt64 | |
728 | LevelB1 *LevelB1 | |
729 | LevelB2s []*LevelB2 `gorm:"many2many:levelb1_levelb3_levelb2s"` | |
730 | } | |
731 | ||
732 | func TestNestedPreload11(t *testing.T) { | |
733 | DB.DropTableIfExists(&LevelB2{}) | |
734 | DB.DropTableIfExists(&LevelB3{}) | |
735 | DB.DropTableIfExists(&LevelB1{}) | |
736 | if err := DB.AutoMigrate(&LevelB1{}, &LevelB2{}, &LevelB3{}).Error; err != nil { | |
737 | t.Error(err) | |
738 | } | |
739 | ||
740 | levelB1 := &LevelB1{Value: "foo"} | |
741 | if err := DB.Create(levelB1).Error; err != nil { | |
742 | t.Error(err) | |
743 | } | |
744 | ||
745 | levelB3 := &LevelB3{ | |
746 | Value: "bar", | |
747 | LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)}, | |
748 | } | |
749 | if err := DB.Create(levelB3).Error; err != nil { | |
750 | t.Error(err) | |
751 | } | |
752 | levelB1.LevelB3s = []*LevelB3{levelB3} | |
753 | ||
754 | want := []*LevelB1{levelB1} | |
755 | var got []*LevelB1 | |
756 | if err := DB.Preload("LevelB3s.LevelB2s").Find(&got).Error; err != nil { | |
757 | t.Error(err) | |
758 | } | |
759 | ||
760 | if !reflect.DeepEqual(got, want) { | |
761 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
762 | } | |
763 | } | |
764 | ||
765 | type LevelC1 struct { | |
766 | ID uint | |
767 | Value string | |
768 | LevelC2ID uint | |
769 | } | |
770 | ||
771 | type LevelC2 struct { | |
772 | ID uint | |
773 | Value string | |
774 | LevelC1 LevelC1 | |
775 | } | |
776 | ||
777 | type LevelC3 struct { | |
778 | ID uint | |
779 | Value string | |
780 | LevelC2ID uint | |
781 | LevelC2 LevelC2 | |
782 | } | |
783 | ||
784 | func TestNestedPreload12(t *testing.T) { | |
785 | DB.DropTableIfExists(&LevelC2{}) | |
786 | DB.DropTableIfExists(&LevelC3{}) | |
787 | DB.DropTableIfExists(&LevelC1{}) | |
788 | if err := DB.AutoMigrate(&LevelC1{}, &LevelC2{}, &LevelC3{}).Error; err != nil { | |
789 | t.Error(err) | |
790 | } | |
791 | ||
792 | level2 := LevelC2{ | |
793 | Value: "c2", | |
794 | LevelC1: LevelC1{ | |
795 | Value: "c1", | |
796 | }, | |
797 | } | |
798 | DB.Create(&level2) | |
799 | ||
800 | want := []LevelC3{ | |
801 | { | |
802 | Value: "c3-1", | |
803 | LevelC2: level2, | |
804 | }, { | |
805 | Value: "c3-2", | |
806 | LevelC2: level2, | |
807 | }, | |
808 | } | |
809 | ||
810 | for i := range want { | |
811 | if err := DB.Create(&want[i]).Error; err != nil { | |
812 | t.Error(err) | |
813 | } | |
814 | } | |
815 | ||
816 | var got []LevelC3 | |
817 | if err := DB.Preload("LevelC2").Preload("LevelC2.LevelC1").Find(&got).Error; err != nil { | |
818 | t.Error(err) | |
599 | 819 | } |
600 | 820 | |
601 | 821 | if !reflect.DeepEqual(got, want) { |
604 | 824 | } |
605 | 825 | |
606 | 826 | func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { |
607 | if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { | |
827 | if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" || dialect == "mssql" { | |
608 | 828 | return |
609 | 829 | } |
610 | 830 | |
624 | 844 | |
625 | 845 | DB.DropTableIfExists(&Level2{}) |
626 | 846 | DB.DropTableIfExists(&Level1{}) |
627 | DB.Table("levels").DropTableIfExists("levels") | |
847 | DB.DropTableIfExists("levels") | |
628 | 848 | |
629 | 849 | if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { |
630 | panic(err) | |
850 | t.Error(err) | |
631 | 851 | } |
632 | 852 | |
633 | 853 | want := Level2{Value: "Bob", LanguageCode: "ru", Level1s: []Level1{ |
635 | 855 | {Value: "en", LanguageCode: "en"}, |
636 | 856 | }} |
637 | 857 | if err := DB.Save(&want).Error; err != nil { |
638 | panic(err) | |
858 | t.Error(err) | |
639 | 859 | } |
640 | 860 | |
641 | 861 | want2 := Level2{Value: "Tom", LanguageCode: "zh", Level1s: []Level1{ |
643 | 863 | {Value: "de", LanguageCode: "de"}, |
644 | 864 | }} |
645 | 865 | if err := DB.Save(&want2).Error; err != nil { |
646 | panic(err) | |
866 | t.Error(err) | |
647 | 867 | } |
648 | 868 | |
649 | 869 | var got Level2 |
650 | 870 | if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { |
651 | panic(err) | |
871 | t.Error(err) | |
652 | 872 | } |
653 | 873 | |
654 | 874 | if !reflect.DeepEqual(got, want) { |
657 | 877 | |
658 | 878 | var got2 Level2 |
659 | 879 | if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { |
660 | panic(err) | |
880 | t.Error(err) | |
661 | 881 | } |
662 | 882 | |
663 | 883 | if !reflect.DeepEqual(got2, want2) { |
666 | 886 | |
667 | 887 | var got3 []Level2 |
668 | 888 | if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { |
669 | panic(err) | |
889 | t.Error(err) | |
670 | 890 | } |
671 | 891 | |
672 | 892 | if !reflect.DeepEqual(got3, []Level2{got, got2}) { |
675 | 895 | |
676 | 896 | var got4 []Level2 |
677 | 897 | if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { |
678 | panic(err) | |
898 | t.Error(err) | |
679 | 899 | } |
680 | 900 | |
681 | 901 | var ruLevel1 Level1 |
688 | 908 | if !reflect.DeepEqual(got4, []Level2{got, got2}) { |
689 | 909 | t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) |
690 | 910 | } |
691 | } | |
692 | ||
693 | func TestManyToManyPreloadForPointer(t *testing.T) { | |
694 | type ( | |
695 | Level1 struct { | |
696 | ID uint `gorm:"primary_key;"` | |
911 | ||
912 | if err := DB.Preload("Level1s").Find(&got4, "value IN (?)", []string{"non-existing"}).Error; err != nil { | |
913 | t.Error(err) | |
914 | } | |
915 | } | |
916 | ||
917 | func TestManyToManyPreloadForNestedPointer(t *testing.T) { | |
918 | type ( | |
919 | Level1 struct { | |
920 | ID uint | |
697 | 921 | Value string |
698 | 922 | } |
699 | 923 | Level2 struct { |
700 | ID uint `gorm:"primary_key;"` | |
924 | ID uint | |
701 | 925 | Value string |
702 | 926 | Level1s []*Level1 `gorm:"many2many:levels;"` |
703 | 927 | } |
704 | ) | |
705 | ||
706 | DB.DropTableIfExists(&Level2{}) | |
707 | DB.DropTableIfExists(&Level1{}) | |
708 | DB.Table("levels").DropTableIfExists("levels") | |
928 | Level3 struct { | |
929 | ID uint | |
930 | Value string | |
931 | Level2ID sql.NullInt64 | |
932 | Level2 *Level2 | |
933 | } | |
934 | ) | |
935 | ||
936 | DB.DropTableIfExists(&Level3{}) | |
937 | DB.DropTableIfExists(&Level2{}) | |
938 | DB.DropTableIfExists(&Level1{}) | |
939 | DB.DropTableIfExists("levels") | |
940 | ||
941 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { | |
942 | t.Error(err) | |
943 | } | |
944 | ||
945 | want := Level3{ | |
946 | Value: "Bob", | |
947 | Level2: &Level2{ | |
948 | Value: "Foo", | |
949 | Level1s: []*Level1{ | |
950 | {Value: "ru"}, | |
951 | {Value: "en"}, | |
952 | }, | |
953 | }, | |
954 | } | |
955 | if err := DB.Save(&want).Error; err != nil { | |
956 | t.Error(err) | |
957 | } | |
958 | ||
959 | want2 := Level3{ | |
960 | Value: "Tom", | |
961 | Level2: &Level2{ | |
962 | Value: "Bar", | |
963 | Level1s: []*Level1{ | |
964 | {Value: "zh"}, | |
965 | {Value: "de"}, | |
966 | }, | |
967 | }, | |
968 | } | |
969 | if err := DB.Save(&want2).Error; err != nil { | |
970 | t.Error(err) | |
971 | } | |
972 | ||
973 | var got Level3 | |
974 | if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { | |
975 | t.Error(err) | |
976 | } | |
977 | ||
978 | if !reflect.DeepEqual(got, want) { | |
979 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
980 | } | |
981 | ||
982 | var got2 Level3 | |
983 | if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { | |
984 | t.Error(err) | |
985 | } | |
986 | ||
987 | if !reflect.DeepEqual(got2, want2) { | |
988 | t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) | |
989 | } | |
990 | ||
991 | var got3 []Level3 | |
992 | if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { | |
993 | t.Error(err) | |
994 | } | |
995 | ||
996 | if !reflect.DeepEqual(got3, []Level3{got, got2}) { | |
997 | t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level3{got, got2})) | |
998 | } | |
999 | ||
1000 | var got4 []Level3 | |
1001 | if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { | |
1002 | t.Error(err) | |
1003 | } | |
1004 | ||
1005 | var got5 Level3 | |
1006 | DB.Preload("Level2.Level1s").Find(&got5, "value = ?", "bogus") | |
1007 | ||
1008 | var ruLevel1 Level1 | |
1009 | var zhLevel1 Level1 | |
1010 | DB.First(&ruLevel1, "value = ?", "ru") | |
1011 | DB.First(&zhLevel1, "value = ?", "zh") | |
1012 | ||
1013 | got.Level2.Level1s = []*Level1{&ruLevel1} | |
1014 | got2.Level2.Level1s = []*Level1{&zhLevel1} | |
1015 | if !reflect.DeepEqual(got4, []Level3{got, got2}) { | |
1016 | t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level3{got, got2})) | |
1017 | } | |
1018 | } | |
1019 | ||
1020 | func TestNestedManyToManyPreload(t *testing.T) { | |
1021 | type ( | |
1022 | Level1 struct { | |
1023 | ID uint | |
1024 | Value string | |
1025 | } | |
1026 | Level2 struct { | |
1027 | ID uint | |
1028 | Value string | |
1029 | Level1s []*Level1 `gorm:"many2many:level1_level2;"` | |
1030 | } | |
1031 | Level3 struct { | |
1032 | ID uint | |
1033 | Value string | |
1034 | Level2s []Level2 `gorm:"many2many:level2_level3;"` | |
1035 | } | |
1036 | ) | |
1037 | ||
1038 | DB.DropTableIfExists(&Level1{}) | |
1039 | DB.DropTableIfExists(&Level2{}) | |
1040 | DB.DropTableIfExists(&Level3{}) | |
1041 | DB.DropTableIfExists("level1_level2") | |
1042 | DB.DropTableIfExists("level2_level3") | |
1043 | ||
1044 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { | |
1045 | t.Error(err) | |
1046 | } | |
1047 | ||
1048 | want := Level3{ | |
1049 | Value: "Level3", | |
1050 | Level2s: []Level2{ | |
1051 | { | |
1052 | Value: "Bob", | |
1053 | Level1s: []*Level1{ | |
1054 | {Value: "ru"}, | |
1055 | {Value: "en"}, | |
1056 | }, | |
1057 | }, { | |
1058 | Value: "Tom", | |
1059 | Level1s: []*Level1{ | |
1060 | {Value: "zh"}, | |
1061 | {Value: "de"}, | |
1062 | }, | |
1063 | }, | |
1064 | }, | |
1065 | } | |
1066 | ||
1067 | if err := DB.Save(&want).Error; err != nil { | |
1068 | t.Error(err) | |
1069 | } | |
1070 | ||
1071 | var got Level3 | |
1072 | if err := DB.Preload("Level2s").Preload("Level2s.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { | |
1073 | t.Error(err) | |
1074 | } | |
1075 | ||
1076 | if !reflect.DeepEqual(got, want) { | |
1077 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
1078 | } | |
1079 | ||
1080 | if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { | |
1081 | t.Error(err) | |
1082 | } | |
1083 | } | |
1084 | ||
1085 | func TestNestedManyToManyPreload2(t *testing.T) { | |
1086 | type ( | |
1087 | Level1 struct { | |
1088 | ID uint | |
1089 | Value string | |
1090 | } | |
1091 | Level2 struct { | |
1092 | ID uint | |
1093 | Value string | |
1094 | Level1s []*Level1 `gorm:"many2many:level1_level2;"` | |
1095 | } | |
1096 | Level3 struct { | |
1097 | ID uint | |
1098 | Value string | |
1099 | Level2ID sql.NullInt64 | |
1100 | Level2 *Level2 | |
1101 | } | |
1102 | ) | |
1103 | ||
1104 | DB.DropTableIfExists(&Level1{}) | |
1105 | DB.DropTableIfExists(&Level2{}) | |
1106 | DB.DropTableIfExists(&Level3{}) | |
1107 | DB.DropTableIfExists("level1_level2") | |
1108 | ||
1109 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { | |
1110 | t.Error(err) | |
1111 | } | |
1112 | ||
1113 | want := Level3{ | |
1114 | Value: "Level3", | |
1115 | Level2: &Level2{ | |
1116 | Value: "Bob", | |
1117 | Level1s: []*Level1{ | |
1118 | {Value: "ru"}, | |
1119 | {Value: "en"}, | |
1120 | }, | |
1121 | }, | |
1122 | } | |
1123 | ||
1124 | if err := DB.Save(&want).Error; err != nil { | |
1125 | t.Error(err) | |
1126 | } | |
1127 | ||
1128 | var got Level3 | |
1129 | if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { | |
1130 | t.Error(err) | |
1131 | } | |
1132 | ||
1133 | if !reflect.DeepEqual(got, want) { | |
1134 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
1135 | } | |
1136 | ||
1137 | if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { | |
1138 | t.Error(err) | |
1139 | } | |
1140 | } | |
1141 | ||
1142 | func TestNestedManyToManyPreload3(t *testing.T) { | |
1143 | type ( | |
1144 | Level1 struct { | |
1145 | ID uint | |
1146 | Value string | |
1147 | } | |
1148 | Level2 struct { | |
1149 | ID uint | |
1150 | Value string | |
1151 | Level1s []*Level1 `gorm:"many2many:level1_level2;"` | |
1152 | } | |
1153 | Level3 struct { | |
1154 | ID uint | |
1155 | Value string | |
1156 | Level2ID sql.NullInt64 | |
1157 | Level2 *Level2 | |
1158 | } | |
1159 | ) | |
1160 | ||
1161 | DB.DropTableIfExists(&Level1{}) | |
1162 | DB.DropTableIfExists(&Level2{}) | |
1163 | DB.DropTableIfExists(&Level3{}) | |
1164 | DB.DropTableIfExists("level1_level2") | |
1165 | ||
1166 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { | |
1167 | t.Error(err) | |
1168 | } | |
1169 | ||
1170 | level1Zh := &Level1{Value: "zh"} | |
1171 | level1Ru := &Level1{Value: "ru"} | |
1172 | level1En := &Level1{Value: "en"} | |
1173 | ||
1174 | level21 := &Level2{ | |
1175 | Value: "Level2-1", | |
1176 | Level1s: []*Level1{level1Zh, level1Ru}, | |
1177 | } | |
1178 | ||
1179 | level22 := &Level2{ | |
1180 | Value: "Level2-2", | |
1181 | Level1s: []*Level1{level1Zh, level1En}, | |
1182 | } | |
1183 | ||
1184 | wants := []*Level3{ | |
1185 | { | |
1186 | Value: "Level3-1", | |
1187 | Level2: level21, | |
1188 | }, | |
1189 | { | |
1190 | Value: "Level3-2", | |
1191 | Level2: level22, | |
1192 | }, | |
1193 | { | |
1194 | Value: "Level3-3", | |
1195 | Level2: level21, | |
1196 | }, | |
1197 | } | |
1198 | ||
1199 | for _, want := range wants { | |
1200 | if err := DB.Save(&want).Error; err != nil { | |
1201 | t.Error(err) | |
1202 | } | |
1203 | } | |
1204 | ||
1205 | var gots []*Level3 | |
1206 | if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { | |
1207 | return db.Order("level1.id ASC") | |
1208 | }).Find(&gots).Error; err != nil { | |
1209 | t.Error(err) | |
1210 | } | |
1211 | ||
1212 | if !reflect.DeepEqual(gots, wants) { | |
1213 | t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) | |
1214 | } | |
1215 | } | |
1216 | ||
1217 | func TestNestedManyToManyPreload3ForStruct(t *testing.T) { | |
1218 | type ( | |
1219 | Level1 struct { | |
1220 | ID uint | |
1221 | Value string | |
1222 | } | |
1223 | Level2 struct { | |
1224 | ID uint | |
1225 | Value string | |
1226 | Level1s []Level1 `gorm:"many2many:level1_level2;"` | |
1227 | } | |
1228 | Level3 struct { | |
1229 | ID uint | |
1230 | Value string | |
1231 | Level2ID sql.NullInt64 | |
1232 | Level2 Level2 | |
1233 | } | |
1234 | ) | |
1235 | ||
1236 | DB.DropTableIfExists(&Level1{}) | |
1237 | DB.DropTableIfExists(&Level2{}) | |
1238 | DB.DropTableIfExists(&Level3{}) | |
1239 | DB.DropTableIfExists("level1_level2") | |
1240 | ||
1241 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { | |
1242 | t.Error(err) | |
1243 | } | |
1244 | ||
1245 | level1Zh := Level1{Value: "zh"} | |
1246 | level1Ru := Level1{Value: "ru"} | |
1247 | level1En := Level1{Value: "en"} | |
1248 | ||
1249 | level21 := Level2{ | |
1250 | Value: "Level2-1", | |
1251 | Level1s: []Level1{level1Zh, level1Ru}, | |
1252 | } | |
1253 | ||
1254 | level22 := Level2{ | |
1255 | Value: "Level2-2", | |
1256 | Level1s: []Level1{level1Zh, level1En}, | |
1257 | } | |
1258 | ||
1259 | wants := []*Level3{ | |
1260 | { | |
1261 | Value: "Level3-1", | |
1262 | Level2: level21, | |
1263 | }, | |
1264 | { | |
1265 | Value: "Level3-2", | |
1266 | Level2: level22, | |
1267 | }, | |
1268 | { | |
1269 | Value: "Level3-3", | |
1270 | Level2: level21, | |
1271 | }, | |
1272 | } | |
1273 | ||
1274 | for _, want := range wants { | |
1275 | if err := DB.Save(&want).Error; err != nil { | |
1276 | t.Error(err) | |
1277 | } | |
1278 | } | |
1279 | ||
1280 | var gots []*Level3 | |
1281 | if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { | |
1282 | return db.Order("level1.id ASC") | |
1283 | }).Find(&gots).Error; err != nil { | |
1284 | t.Error(err) | |
1285 | } | |
1286 | ||
1287 | if !reflect.DeepEqual(gots, wants) { | |
1288 | t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) | |
1289 | } | |
1290 | } | |
1291 | ||
1292 | func TestNestedManyToManyPreload4(t *testing.T) { | |
1293 | type ( | |
1294 | Level4 struct { | |
1295 | ID uint | |
1296 | Value string | |
1297 | Level3ID uint | |
1298 | } | |
1299 | Level3 struct { | |
1300 | ID uint | |
1301 | Value string | |
1302 | Level4s []*Level4 | |
1303 | } | |
1304 | Level2 struct { | |
1305 | ID uint | |
1306 | Value string | |
1307 | Level3s []*Level3 `gorm:"many2many:level2_level3;"` | |
1308 | } | |
1309 | Level1 struct { | |
1310 | ID uint | |
1311 | Value string | |
1312 | Level2s []*Level2 `gorm:"many2many:level1_level2;"` | |
1313 | } | |
1314 | ) | |
1315 | ||
1316 | DB.DropTableIfExists(&Level1{}) | |
1317 | DB.DropTableIfExists(&Level2{}) | |
1318 | DB.DropTableIfExists(&Level3{}) | |
1319 | DB.DropTableIfExists(&Level4{}) | |
1320 | DB.DropTableIfExists("level1_level2") | |
1321 | DB.DropTableIfExists("level2_level3") | |
1322 | ||
1323 | dummy := Level1{ | |
1324 | Value: "Level1", | |
1325 | Level2s: []*Level2{{ | |
1326 | Value: "Level2", | |
1327 | Level3s: []*Level3{{ | |
1328 | Value: "Level3", | |
1329 | Level4s: []*Level4{{ | |
1330 | Value: "Level4", | |
1331 | }}, | |
1332 | }}, | |
1333 | }}, | |
1334 | } | |
1335 | ||
1336 | if err := DB.AutoMigrate(&Level4{}, &Level3{}, &Level2{}, &Level1{}).Error; err != nil { | |
1337 | t.Error(err) | |
1338 | } | |
1339 | ||
1340 | if err := DB.Save(&dummy).Error; err != nil { | |
1341 | t.Error(err) | |
1342 | } | |
1343 | ||
1344 | var level1 Level1 | |
1345 | if err := DB.Preload("Level2s").Preload("Level2s.Level3s").Preload("Level2s.Level3s.Level4s").First(&level1).Error; err != nil { | |
1346 | t.Error(err) | |
1347 | } | |
1348 | } | |
1349 | ||
1350 | func TestManyToManyPreloadForPointer(t *testing.T) { | |
1351 | type ( | |
1352 | Level1 struct { | |
1353 | ID uint | |
1354 | Value string | |
1355 | } | |
1356 | Level2 struct { | |
1357 | ID uint | |
1358 | Value string | |
1359 | Level1s []*Level1 `gorm:"many2many:levels;"` | |
1360 | } | |
1361 | ) | |
1362 | ||
1363 | DB.DropTableIfExists(&Level2{}) | |
1364 | DB.DropTableIfExists(&Level1{}) | |
1365 | DB.DropTableIfExists("levels") | |
709 | 1366 | |
710 | 1367 | if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { |
711 | panic(err) | |
1368 | t.Error(err) | |
712 | 1369 | } |
713 | 1370 | |
714 | 1371 | want := Level2{Value: "Bob", Level1s: []*Level1{ |
716 | 1373 | {Value: "en"}, |
717 | 1374 | }} |
718 | 1375 | if err := DB.Save(&want).Error; err != nil { |
719 | panic(err) | |
1376 | t.Error(err) | |
720 | 1377 | } |
721 | 1378 | |
722 | 1379 | want2 := Level2{Value: "Tom", Level1s: []*Level1{ |
724 | 1381 | {Value: "de"}, |
725 | 1382 | }} |
726 | 1383 | if err := DB.Save(&want2).Error; err != nil { |
727 | panic(err) | |
1384 | t.Error(err) | |
728 | 1385 | } |
729 | 1386 | |
730 | 1387 | var got Level2 |
731 | 1388 | if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { |
732 | panic(err) | |
1389 | t.Error(err) | |
733 | 1390 | } |
734 | 1391 | |
735 | 1392 | if !reflect.DeepEqual(got, want) { |
738 | 1395 | |
739 | 1396 | var got2 Level2 |
740 | 1397 | if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { |
741 | panic(err) | |
1398 | t.Error(err) | |
742 | 1399 | } |
743 | 1400 | |
744 | 1401 | if !reflect.DeepEqual(got2, want2) { |
747 | 1404 | |
748 | 1405 | var got3 []Level2 |
749 | 1406 | if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { |
750 | panic(err) | |
1407 | t.Error(err) | |
751 | 1408 | } |
752 | 1409 | |
753 | 1410 | if !reflect.DeepEqual(got3, []Level2{got, got2}) { |
756 | 1413 | |
757 | 1414 | var got4 []Level2 |
758 | 1415 | if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { |
759 | panic(err) | |
760 | } | |
1416 | t.Error(err) | |
1417 | } | |
1418 | ||
1419 | var got5 Level2 | |
1420 | DB.Preload("Level1s").First(&got5, "value = ?", "bogus") | |
761 | 1421 | |
762 | 1422 | var ruLevel1 Level1 |
763 | 1423 | var zhLevel1 Level1 |
774 | 1434 | func TestNilPointerSlice(t *testing.T) { |
775 | 1435 | type ( |
776 | 1436 | Level3 struct { |
777 | ID uint `gorm:"primary_key;"` | |
1437 | ID uint | |
778 | 1438 | Value string |
779 | 1439 | } |
780 | 1440 | Level2 struct { |
781 | ID uint `gorm:"primary_key;"` | |
1441 | ID uint | |
782 | 1442 | Value string |
783 | 1443 | Level3ID uint |
784 | 1444 | Level3 *Level3 |
785 | 1445 | } |
786 | 1446 | Level1 struct { |
787 | ID uint `gorm:"primary_key;"` | |
1447 | ID uint | |
788 | 1448 | Value string |
789 | 1449 | Level2ID uint |
790 | 1450 | Level2 *Level2 |
796 | 1456 | DB.DropTableIfExists(&Level1{}) |
797 | 1457 | |
798 | 1458 | if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { |
799 | panic(err) | |
800 | } | |
801 | ||
802 | want := Level1{Value: "Bob", Level2: &Level2{ | |
803 | Value: "en", | |
804 | Level3: &Level3{ | |
805 | Value: "native", | |
806 | }, | |
807 | }} | |
1459 | t.Error(err) | |
1460 | } | |
1461 | ||
1462 | want := Level1{ | |
1463 | Value: "Bob", | |
1464 | Level2: &Level2{ | |
1465 | Value: "en", | |
1466 | Level3: &Level3{ | |
1467 | Value: "native", | |
1468 | }, | |
1469 | }, | |
1470 | } | |
808 | 1471 | if err := DB.Save(&want).Error; err != nil { |
809 | panic(err) | |
810 | } | |
811 | ||
812 | want2 := Level1{Value: "Tom", Level2: nil} | |
1472 | t.Error(err) | |
1473 | } | |
1474 | ||
1475 | want2 := Level1{ | |
1476 | Value: "Tom", | |
1477 | Level2: nil, | |
1478 | } | |
813 | 1479 | if err := DB.Save(&want2).Error; err != nil { |
814 | panic(err) | |
1480 | t.Error(err) | |
815 | 1481 | } |
816 | 1482 | |
817 | 1483 | var got []Level1 |
818 | 1484 | if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).Error; err != nil { |
819 | panic(err) | |
1485 | t.Error(err) | |
820 | 1486 | } |
821 | 1487 | |
822 | 1488 | if len(got) != 2 { |
823 | t.Fatalf("got %v items, expected 2", len(got)) | |
1489 | t.Errorf("got %v items, expected 2", len(got)) | |
824 | 1490 | } |
825 | 1491 | |
826 | 1492 | if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { |
829 | 1495 | |
830 | 1496 | if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { |
831 | 1497 | t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want2)) |
1498 | } | |
1499 | } | |
1500 | ||
1501 | func TestNilPointerSlice2(t *testing.T) { | |
1502 | type ( | |
1503 | Level4 struct { | |
1504 | ID uint | |
1505 | } | |
1506 | Level3 struct { | |
1507 | ID uint | |
1508 | Level4ID sql.NullInt64 `sql:"index"` | |
1509 | Level4 *Level4 | |
1510 | } | |
1511 | Level2 struct { | |
1512 | ID uint | |
1513 | Level3s []*Level3 `gorm:"many2many:level2_level3s"` | |
1514 | } | |
1515 | Level1 struct { | |
1516 | ID uint | |
1517 | Level2ID sql.NullInt64 `sql:"index"` | |
1518 | Level2 *Level2 | |
1519 | } | |
1520 | ) | |
1521 | ||
1522 | DB.DropTableIfExists(new(Level4)) | |
1523 | DB.DropTableIfExists(new(Level3)) | |
1524 | DB.DropTableIfExists(new(Level2)) | |
1525 | DB.DropTableIfExists(new(Level1)) | |
1526 | ||
1527 | if err := DB.AutoMigrate(new(Level4), new(Level3), new(Level2), new(Level1)).Error; err != nil { | |
1528 | t.Error(err) | |
1529 | } | |
1530 | ||
1531 | want := new(Level1) | |
1532 | if err := DB.Save(want).Error; err != nil { | |
1533 | t.Error(err) | |
1534 | } | |
1535 | ||
1536 | got := new(Level1) | |
1537 | err := DB.Preload("Level2.Level3s.Level4").Last(&got).Error | |
1538 | if err != nil { | |
1539 | t.Error(err) | |
1540 | } | |
1541 | ||
1542 | if !reflect.DeepEqual(got, want) { | |
1543 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
1544 | } | |
1545 | } | |
1546 | ||
1547 | func TestPrefixedPreloadDuplication(t *testing.T) { | |
1548 | type ( | |
1549 | Level4 struct { | |
1550 | ID uint | |
1551 | Name string | |
1552 | Level3ID uint | |
1553 | } | |
1554 | Level3 struct { | |
1555 | ID uint | |
1556 | Name string | |
1557 | Level4s []*Level4 | |
1558 | } | |
1559 | Level2 struct { | |
1560 | ID uint | |
1561 | Name string | |
1562 | Level3ID sql.NullInt64 `sql:"index"` | |
1563 | Level3 *Level3 | |
1564 | } | |
1565 | Level1 struct { | |
1566 | ID uint | |
1567 | Name string | |
1568 | Level2ID sql.NullInt64 `sql:"index"` | |
1569 | Level2 *Level2 | |
1570 | } | |
1571 | ) | |
1572 | ||
1573 | DB.DropTableIfExists(new(Level3)) | |
1574 | DB.DropTableIfExists(new(Level4)) | |
1575 | DB.DropTableIfExists(new(Level2)) | |
1576 | DB.DropTableIfExists(new(Level1)) | |
1577 | ||
1578 | if err := DB.AutoMigrate(new(Level3), new(Level4), new(Level2), new(Level1)).Error; err != nil { | |
1579 | t.Error(err) | |
1580 | } | |
1581 | ||
1582 | lvl := &Level3{} | |
1583 | if err := DB.Save(lvl).Error; err != nil { | |
1584 | t.Error(err) | |
1585 | } | |
1586 | ||
1587 | sublvl1 := &Level4{Level3ID: lvl.ID} | |
1588 | if err := DB.Save(sublvl1).Error; err != nil { | |
1589 | t.Error(err) | |
1590 | } | |
1591 | sublvl2 := &Level4{Level3ID: lvl.ID} | |
1592 | if err := DB.Save(sublvl2).Error; err != nil { | |
1593 | t.Error(err) | |
1594 | } | |
1595 | ||
1596 | lvl.Level4s = []*Level4{sublvl1, sublvl2} | |
1597 | ||
1598 | want1 := Level1{ | |
1599 | Level2: &Level2{ | |
1600 | Level3: lvl, | |
1601 | }, | |
1602 | } | |
1603 | if err := DB.Save(&want1).Error; err != nil { | |
1604 | t.Error(err) | |
1605 | } | |
1606 | ||
1607 | want2 := Level1{ | |
1608 | Level2: &Level2{ | |
1609 | Level3: lvl, | |
1610 | }, | |
1611 | } | |
1612 | if err := DB.Save(&want2).Error; err != nil { | |
1613 | t.Error(err) | |
1614 | } | |
1615 | ||
1616 | want := []Level1{want1, want2} | |
1617 | ||
1618 | var got []Level1 | |
1619 | err := DB.Preload("Level2.Level3.Level4s").Find(&got).Error | |
1620 | if err != nil { | |
1621 | t.Error(err) | |
1622 | } | |
1623 | ||
1624 | if !reflect.DeepEqual(got, want) { | |
1625 | t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) | |
1626 | } | |
1627 | } | |
1628 | ||
1629 | func TestPreloadManyToManyCallbacks(t *testing.T) { | |
1630 | type ( | |
1631 | Level2 struct { | |
1632 | ID uint | |
1633 | Name string | |
1634 | } | |
1635 | Level1 struct { | |
1636 | ID uint | |
1637 | Name string | |
1638 | Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"` | |
1639 | } | |
1640 | ) | |
1641 | ||
1642 | DB.DropTableIfExists("level1_level2s") | |
1643 | DB.DropTableIfExists(new(Level1)) | |
1644 | DB.DropTableIfExists(new(Level2)) | |
1645 | ||
1646 | if err := DB.AutoMigrate(new(Level1), new(Level2)).Error; err != nil { | |
1647 | t.Error(err) | |
1648 | } | |
1649 | ||
1650 | lvl := Level1{ | |
1651 | Name: "l1", | |
1652 | Level2s: []Level2{ | |
1653 | Level2{Name: "l2-1"}, Level2{Name: "l2-2"}, | |
1654 | }, | |
1655 | } | |
1656 | DB.Save(&lvl) | |
1657 | ||
1658 | called := 0 | |
1659 | ||
1660 | DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) { | |
1661 | called = called + 1 | |
1662 | }) | |
1663 | ||
1664 | DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID) | |
1665 | ||
1666 | if called != 3 { | |
1667 | t.Errorf("Wanted callback to be called 3 times but got %d", called) | |
832 | 1668 | } |
833 | 1669 | } |
834 | 1670 |
4 | 4 | "reflect" |
5 | 5 | |
6 | 6 | "github.com/jinzhu/gorm" |
7 | "github.com/jinzhu/now" | |
8 | 7 | |
9 | 8 | "testing" |
10 | 9 | "time" |
18 | 17 | DB.First(&user1) |
19 | 18 | DB.Order("id").Limit(1).Find(&user2) |
20 | 19 | |
21 | DB.Last(&user3) | |
20 | ptrOfUser3 := &user3 | |
21 | DB.Last(&ptrOfUser3) | |
22 | 22 | DB.Order("id desc").Limit(1).Find(&user4) |
23 | 23 | if user1.Id != user2.Id || user3.Id != user4.Id { |
24 | 24 | t.Errorf("First and Last should by order by primary key") |
30 | 30 | t.Errorf("Find first record as slice") |
31 | 31 | } |
32 | 32 | |
33 | if DB.Joins("left join emails on emails.user_id = users.id").First(&User{}).Error != nil { | |
33 | var user User | |
34 | if DB.Joins("left join emails on emails.user_id = users.id").First(&user).Error != nil { | |
34 | 35 | t.Errorf("Should not raise any error when order with Join table") |
36 | } | |
37 | ||
38 | if user.Email != "" { | |
39 | t.Errorf("User's Email should be blank as no one set it") | |
35 | 40 | } |
36 | 41 | } |
37 | 42 | |
47 | 52 | DB.Order("counter desc").Limit(1).Find(&animal4) |
48 | 53 | if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter { |
49 | 54 | t.Errorf("First and Last should work correctly") |
55 | } | |
56 | } | |
57 | ||
58 | func TestFirstAndLastWithRaw(t *testing.T) { | |
59 | user1 := User{Name: "user", Emails: []Email{{Email: "user1@example.com"}}} | |
60 | user2 := User{Name: "user", Emails: []Email{{Email: "user2@example.com"}}} | |
61 | DB.Save(&user1) | |
62 | DB.Save(&user2) | |
63 | ||
64 | var user3, user4 User | |
65 | DB.Raw("select * from users WHERE name = ?", "user").First(&user3) | |
66 | if user3.Id != user1.Id { | |
67 | t.Errorf("Find first record with raw") | |
68 | } | |
69 | ||
70 | DB.Raw("select * from users WHERE name = ?", "user").Last(&user4) | |
71 | if user4.Id != user2.Id { | |
72 | t.Errorf("Find last record with raw") | |
50 | 73 | } |
51 | 74 | } |
52 | 75 | |
63 | 86 | } |
64 | 87 | } |
65 | 88 | |
89 | func TestCustomizedTypePrimaryKey(t *testing.T) { | |
90 | type ID uint | |
91 | type CustomizedTypePrimaryKey struct { | |
92 | ID ID | |
93 | Name string | |
94 | } | |
95 | ||
96 | DB.AutoMigrate(&CustomizedTypePrimaryKey{}) | |
97 | ||
98 | p1 := CustomizedTypePrimaryKey{Name: "p1"} | |
99 | p2 := CustomizedTypePrimaryKey{Name: "p2"} | |
100 | p3 := CustomizedTypePrimaryKey{Name: "p3"} | |
101 | DB.Create(&p1) | |
102 | DB.Create(&p2) | |
103 | DB.Create(&p3) | |
104 | ||
105 | var p CustomizedTypePrimaryKey | |
106 | ||
107 | if err := DB.First(&p, p2.ID).Error; err == nil { | |
108 | t.Errorf("Should return error for invalid query condition") | |
109 | } | |
110 | ||
111 | if err := DB.First(&p, "id = ?", p2.ID).Error; err != nil { | |
112 | t.Errorf("No error should happen when querying with customized type for primary key, got err %v", err) | |
113 | } | |
114 | ||
115 | if p.Name != "p2" { | |
116 | t.Errorf("Should find correct value when querying with customized type for primary key") | |
117 | } | |
118 | } | |
119 | ||
120 | func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { | |
121 | type AddressByZipCode struct { | |
122 | ZipCode string `gorm:"primary_key"` | |
123 | Address string | |
124 | } | |
125 | ||
126 | DB.AutoMigrate(&AddressByZipCode{}) | |
127 | DB.Create(&AddressByZipCode{ZipCode: "00501", Address: "Holtsville"}) | |
128 | ||
129 | var address AddressByZipCode | |
130 | DB.First(&address, "00501") | |
131 | if address.ZipCode != "00501" { | |
132 | t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode) | |
133 | } | |
134 | } | |
135 | ||
66 | 136 | func TestFindAsSliceOfPointers(t *testing.T) { |
67 | 137 | DB.Save(&User{Name: "user"}) |
68 | 138 | |
78 | 148 | } |
79 | 149 | |
80 | 150 | func TestSearchWithPlainSQL(t *testing.T) { |
81 | user1 := User{Name: "PlainSqlUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} | |
82 | user2 := User{Name: "PlainSqlUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} | |
83 | user3 := User{Name: "PlainSqlUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} | |
151 | user1 := User{Name: "PlainSqlUser1", Age: 1, Birthday: parseTime("2000-1-1")} | |
152 | user2 := User{Name: "PlainSqlUser2", Age: 10, Birthday: parseTime("2010-1-1")} | |
153 | user3 := User{Name: "PlainSqlUser3", Age: 20, Birthday: parseTime("2020-1-1")} | |
84 | 154 | DB.Save(&user1).Save(&user2).Save(&user3) |
85 | 155 | scopedb := DB.Where("name LIKE ?", "%PlainSqlUser%") |
86 | 156 | |
108 | 178 | t.Errorf("Should found 2 users age != 20, but got %v", len(users)) |
109 | 179 | } |
110 | 180 | |
111 | scopedb.Where("birthday > ?", now.MustParse("2000-1-1")).Find(&users) | |
181 | scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users) | |
112 | 182 | if len(users) != 2 { |
113 | 183 | t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) |
114 | 184 | } |
138 | 208 | t.Errorf("Should found 1 users, but got %v", len(users)) |
139 | 209 | } |
140 | 210 | |
211 | if err := DB.Where("id IN (?)", []string{}).Find(&users).Error; err != nil { | |
212 | t.Error("no error should happen when query with empty slice, but got: ", err) | |
213 | } | |
214 | ||
215 | if err := DB.Not("id IN (?)", []string{}).Find(&users).Error; err != nil { | |
216 | t.Error("no error should happen when query with empty slice, but got: ", err) | |
217 | } | |
218 | ||
141 | 219 | if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() { |
142 | 220 | t.Errorf("Should not get RecordNotFound error when looking for none existing records") |
143 | 221 | } |
144 | 222 | } |
145 | 223 | |
224 | func TestSearchWithTwoDimensionalArray(t *testing.T) { | |
225 | var users []User | |
226 | user1 := User{Name: "2DSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} | |
227 | user2 := User{Name: "2DSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} | |
228 | user3 := User{Name: "2DSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} | |
229 | DB.Create(&user1) | |
230 | DB.Create(&user2) | |
231 | DB.Create(&user3) | |
232 | ||
233 | if dialect := DB.Dialect().GetName(); dialect == "mysql" || dialect == "postgres" { | |
234 | if err := DB.Where("(name, age) IN (?)", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil { | |
235 | t.Errorf("No error should happen when query with 2D array, but got %v", err) | |
236 | ||
237 | if len(users) != 2 { | |
238 | t.Errorf("Should find 2 users with 2D array, but got %v", len(users)) | |
239 | } | |
240 | } | |
241 | } | |
242 | ||
243 | if dialect := DB.Dialect().GetName(); dialect == "mssql" { | |
244 | if err := DB.Joins("JOIN (VALUES ?) AS x (col1, col2) ON x.col1 = name AND x.col2 = age", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil { | |
245 | t.Errorf("No error should happen when query with 2D array, but got %v", err) | |
246 | ||
247 | if len(users) != 2 { | |
248 | t.Errorf("Should find 2 users with 2D array, but got %v", len(users)) | |
249 | } | |
250 | } | |
251 | } | |
252 | } | |
253 | ||
146 | 254 | func TestSearchWithStruct(t *testing.T) { |
147 | user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} | |
148 | user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} | |
149 | user3 := User{Name: "StructSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} | |
255 | user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} | |
256 | user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} | |
257 | user3 := User{Name: "StructSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} | |
150 | 258 | DB.Save(&user1).Save(&user2).Save(&user3) |
151 | 259 | |
152 | 260 | if DB.Where(user1.Id).First(&User{}).RecordNotFound() { |
174 | 282 | } |
175 | 283 | |
176 | 284 | DB.First(&user, User{Name: user1.Name}) |
177 | if user.Id == 0 || user.Name != user.Name { | |
285 | if user.Id == 0 || user.Name != user1.Name { | |
178 | 286 | t.Errorf("Search first record with inline struct") |
179 | 287 | } |
180 | 288 | |
190 | 298 | } |
191 | 299 | |
192 | 300 | func TestSearchWithMap(t *testing.T) { |
193 | user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} | |
194 | user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} | |
195 | user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} | |
196 | DB.Save(&user1).Save(&user2).Save(&user3) | |
301 | companyID := 1 | |
302 | user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} | |
303 | user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} | |
304 | user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} | |
305 | user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: parseTime("2020-1-1"), CompanyID: &companyID} | |
306 | DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4) | |
197 | 307 | |
198 | 308 | var user User |
199 | 309 | DB.First(&user, map[string]interface{}{"name": user1.Name}) |
217 | 327 | if len(users) != 1 { |
218 | 328 | t.Errorf("Search all records with inline map") |
219 | 329 | } |
330 | ||
331 | DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": nil}) | |
332 | if len(users) != 0 { | |
333 | t.Errorf("Search all records with inline map containing null value finding 0 records") | |
334 | } | |
335 | ||
336 | DB.Find(&users, map[string]interface{}{"name": user1.Name, "company_id": nil}) | |
337 | if len(users) != 1 { | |
338 | t.Errorf("Search all records with inline map containing null value finding 1 record") | |
339 | } | |
340 | ||
341 | DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": companyID}) | |
342 | if len(users) != 1 { | |
343 | t.Errorf("Search all records with inline multiple value map") | |
344 | } | |
220 | 345 | } |
221 | 346 | |
222 | 347 | func TestSearchWithEmptyChain(t *testing.T) { |
223 | user1 := User{Name: "ChainSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} | |
224 | user2 := User{Name: "ChainearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} | |
225 | user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} | |
348 | user1 := User{Name: "ChainSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} | |
349 | user2 := User{Name: "ChainearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} | |
350 | user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} | |
226 | 351 | DB.Save(&user1).Save(&user2).Save(&user3) |
227 | 352 | |
228 | 353 | if DB.Where("").Where("").First(&User{}).Error != nil { |
259 | 384 | user3 := User{Name: "OrderPluckUser3", Age: 20} |
260 | 385 | DB.Save(&user1).Save(&user2).Save(&user3) |
261 | 386 | scopedb := DB.Model(&User{}).Where("name like ?", "%OrderPluckUser%") |
387 | ||
388 | var user User | |
389 | scopedb.Order(gorm.Expr("case when name = ? then 0 else 1 end", "OrderPluckUser2")).First(&user) | |
390 | if user.Name != "OrderPluckUser2" { | |
391 | t.Errorf("Order with sql expression") | |
392 | } | |
262 | 393 | |
263 | 394 | var ages []int64 |
264 | 395 | scopedb.Order("age desc").Pluck("age", &ages) |
289 | 420 | t.Errorf("Order with multiple orders") |
290 | 421 | } |
291 | 422 | |
423 | var ages6 []int64 | |
424 | if err := scopedb.Order("").Pluck("age", &ages6).Error; err != nil { | |
425 | t.Errorf("An empty string as order clause produces invalid queries") | |
426 | } | |
427 | ||
292 | 428 | DB.Model(User{}).Select("name, age").Find(&[]User{}) |
293 | 429 | } |
294 | 430 | |
313 | 449 | DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) |
314 | 450 | } |
315 | 451 | var users1, users2, users3, users4 []User |
316 | DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) | |
452 | DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) | |
317 | 453 | |
318 | 454 | if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { |
319 | 455 | t.Errorf("Offset should work") |
354 | 490 | if count1 != 1 || count2 != 3 { |
355 | 491 | t.Errorf("Multiple count in chain") |
356 | 492 | } |
493 | ||
494 | var count3 int | |
495 | if err := DB.Model(&User{}).Where("name in (?)", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { | |
496 | t.Errorf("Not error should happen, but got %v", err) | |
497 | } | |
498 | ||
499 | if count3 != 2 { | |
500 | t.Errorf("Should get correct count, but got %v", count3) | |
501 | } | |
357 | 502 | } |
358 | 503 | |
359 | 504 | func TestNot(t *testing.T) { |
360 | 505 | DB.Create(getPreparedUser("user1", "not")) |
361 | 506 | DB.Create(getPreparedUser("user2", "not")) |
362 | 507 | DB.Create(getPreparedUser("user3", "not")) |
363 | DB.Create(getPreparedUser("user4", "not")) | |
508 | ||
509 | user4 := getPreparedUser("user4", "not") | |
510 | user4.Company = Company{} | |
511 | DB.Create(user4) | |
512 | ||
364 | 513 | DB := DB.Where("role = ?", "not") |
365 | 514 | |
366 | var users1, users2, users3, users4, users5, users6, users7, users8 []User | |
515 | var users1, users2, users3, users4, users5, users6, users7, users8, users9 []User | |
367 | 516 | if DB.Find(&users1).RowsAffected != 4 { |
368 | 517 | t.Errorf("should find 4 not users") |
369 | 518 | } |
406 | 555 | t.Errorf("Should find all users's name not equal 3") |
407 | 556 | } |
408 | 557 | |
409 | DB.Not("name", []string{"user3"}).Find(&users7) | |
410 | if len(users1)-len(users7) != int(name3Count) { | |
558 | DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) | |
559 | if len(users1)-len(users7) != 2 { // not user3 or user4 | |
560 | t.Errorf("Should find all user's name not equal to 3 who do not have company id") | |
561 | } | |
562 | ||
563 | DB.Not("name", []string{"user3"}).Find(&users8) | |
564 | if len(users1)-len(users8) != int(name3Count) { | |
411 | 565 | t.Errorf("Should find all users's name not equal 3") |
412 | 566 | } |
413 | 567 | |
414 | 568 | var name2Count int64 |
415 | 569 | DB.Table("users").Where("name = ?", "user2").Count(&name2Count) |
416 | DB.Not("name", []string{"user3", "user2"}).Find(&users8) | |
417 | if len(users1)-len(users8) != (int(name3Count) + int(name2Count)) { | |
570 | DB.Not("name", []string{"user3", "user2"}).Find(&users9) | |
571 | if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { | |
418 | 572 | t.Errorf("Should find all users's name not equal 3") |
419 | 573 | } |
420 | 574 | } |
566 | 720 | t.Errorf("Should only contains one column") |
567 | 721 | } |
568 | 722 | } |
723 | ||
724 | rows.Close() | |
569 | 725 | } |
570 | 726 | |
571 | 727 | func TestSelectWithArrayInput(t *testing.T) { |
579 | 735 | } |
580 | 736 | } |
581 | 737 | |
582 | func TestCurrentDatabase(t *testing.T) { | |
583 | databaseName := DB.CurrentDatabase() | |
584 | if err := DB.Error; err != nil { | |
585 | t.Errorf("Problem getting current db name: %s", err) | |
586 | } | |
587 | if databaseName == "" { | |
588 | t.Errorf("Current db name returned empty; this should never happen!") | |
589 | } | |
590 | t.Logf("Got current db name: %v", databaseName) | |
591 | } | |
738 | func TestPluckWithSelect(t *testing.T) { | |
739 | var ( | |
740 | user = User{Name: "matematik7_pluck_with_select", Age: 25} | |
741 | combinedName = fmt.Sprintf("%v%v", user.Name, user.Age) | |
742 | combineUserAgeSQL = fmt.Sprintf("concat(%v, %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age")) | |
743 | ) | |
744 | ||
745 | if dialect := DB.Dialect().GetName(); dialect == "sqlite3" { | |
746 | combineUserAgeSQL = fmt.Sprintf("(%v || %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age")) | |
747 | } | |
748 | ||
749 | DB.Save(&user) | |
750 | ||
751 | selectStr := combineUserAgeSQL + " as user_age" | |
752 | var userAges []string | |
753 | err := DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error | |
754 | if err != nil { | |
755 | t.Error(err) | |
756 | } | |
757 | ||
758 | if len(userAges) != 1 || userAges[0] != combinedName { | |
759 | t.Errorf("Should correctly pluck with select, got: %s", userAges) | |
760 | } | |
761 | ||
762 | selectStr = combineUserAgeSQL + fmt.Sprintf(" as %v", DB.Dialect().Quote("user_age")) | |
763 | userAges = userAges[:0] | |
764 | err = DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error | |
765 | if err != nil { | |
766 | t.Error(err) | |
767 | } | |
768 | ||
769 | if len(userAges) != 1 || userAges[0] != combinedName { | |
770 | t.Errorf("Should correctly pluck with select, got: %s", userAges) | |
771 | } | |
772 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "database/sql/driver" | |
4 | "encoding/json" | |
5 | "errors" | |
6 | "testing" | |
7 | ||
8 | "github.com/jinzhu/gorm" | |
9 | ) | |
10 | ||
11 | func TestScannableSlices(t *testing.T) { | |
12 | if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil { | |
13 | t.Errorf("Should create table with slice values correctly: %s", err) | |
14 | } | |
15 | ||
16 | r1 := RecordWithSlice{ | |
17 | Strings: ExampleStringSlice{"a", "b", "c"}, | |
18 | Structs: ExampleStructSlice{ | |
19 | {"name1", "value1"}, | |
20 | {"name2", "value2"}, | |
21 | }, | |
22 | } | |
23 | ||
24 | if err := DB.Save(&r1).Error; err != nil { | |
25 | t.Errorf("Should save record with slice values") | |
26 | } | |
27 | ||
28 | var r2 RecordWithSlice | |
29 | ||
30 | if err := DB.Find(&r2).Error; err != nil { | |
31 | t.Errorf("Should fetch record with slice values") | |
32 | } | |
33 | ||
34 | if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" { | |
35 | t.Errorf("Should have serialised and deserialised a string array") | |
36 | } | |
37 | ||
38 | if len(r2.Structs) != 2 || r2.Structs[0].Name != "name1" || r2.Structs[0].Value != "value1" || r2.Structs[1].Name != "name2" || r2.Structs[1].Value != "value2" { | |
39 | t.Errorf("Should have serialised and deserialised a struct array") | |
40 | } | |
41 | } | |
42 | ||
43 | type RecordWithSlice struct { | |
44 | ID uint64 | |
45 | Strings ExampleStringSlice `sql:"type:text"` | |
46 | Structs ExampleStructSlice `sql:"type:text"` | |
47 | } | |
48 | ||
49 | type ExampleStringSlice []string | |
50 | ||
51 | func (l ExampleStringSlice) Value() (driver.Value, error) { | |
52 | bytes, err := json.Marshal(l) | |
53 | return string(bytes), err | |
54 | } | |
55 | ||
56 | func (l *ExampleStringSlice) Scan(input interface{}) error { | |
57 | switch value := input.(type) { | |
58 | case string: | |
59 | return json.Unmarshal([]byte(value), l) | |
60 | case []byte: | |
61 | return json.Unmarshal(value, l) | |
62 | default: | |
63 | return errors.New("not supported") | |
64 | } | |
65 | } | |
66 | ||
67 | type ExampleStruct struct { | |
68 | Name string | |
69 | Value string | |
70 | } | |
71 | ||
72 | type ExampleStructSlice []ExampleStruct | |
73 | ||
74 | func (l ExampleStructSlice) Value() (driver.Value, error) { | |
75 | bytes, err := json.Marshal(l) | |
76 | return string(bytes), err | |
77 | } | |
78 | ||
79 | func (l *ExampleStructSlice) Scan(input interface{}) error { | |
80 | switch value := input.(type) { | |
81 | case string: | |
82 | return json.Unmarshal([]byte(value), l) | |
83 | case []byte: | |
84 | return json.Unmarshal(value, l) | |
85 | default: | |
86 | return errors.New("not supported") | |
87 | } | |
88 | } | |
89 | ||
90 | type ScannerDataType struct { | |
91 | Street string `sql:"TYPE:varchar(24)"` | |
92 | } | |
93 | ||
94 | func (ScannerDataType) Value() (driver.Value, error) { | |
95 | return nil, nil | |
96 | } | |
97 | ||
98 | func (*ScannerDataType) Scan(input interface{}) error { | |
99 | return nil | |
100 | } | |
101 | ||
102 | type ScannerDataTypeTestStruct struct { | |
103 | Field1 int | |
104 | ScannerDataType *ScannerDataType `sql:"TYPE:json"` | |
105 | } | |
106 | ||
107 | type ScannerDataType2 struct { | |
108 | Street string `sql:"TYPE:varchar(24)"` | |
109 | } | |
110 | ||
111 | func (ScannerDataType2) Value() (driver.Value, error) { | |
112 | return nil, nil | |
113 | } | |
114 | ||
115 | func (*ScannerDataType2) Scan(input interface{}) error { | |
116 | return nil | |
117 | } | |
118 | ||
119 | type ScannerDataTypeTestStruct2 struct { | |
120 | Field1 int | |
121 | ScannerDataType *ScannerDataType2 | |
122 | } | |
123 | ||
124 | func TestScannerDataType(t *testing.T) { | |
125 | scope := gorm.Scope{Value: &ScannerDataTypeTestStruct{}} | |
126 | if field, ok := scope.FieldByName("ScannerDataType"); ok { | |
127 | if DB.Dialect().DataTypeOf(field.StructField) != "json" { | |
128 | t.Errorf("data type for scanner is wrong") | |
129 | } | |
130 | } | |
131 | ||
132 | scope = gorm.Scope{Value: &ScannerDataTypeTestStruct2{}} | |
133 | if field, ok := scope.FieldByName("ScannerDataType"); ok { | |
134 | if DB.Dialect().DataTypeOf(field.StructField) != "varchar(24)" { | |
135 | t.Errorf("data type for scanner is wrong") | |
136 | } | |
137 | } | |
138 | } |
0 | 0 | package gorm |
1 | 1 | |
2 | 2 | import ( |
3 | "bytes" | |
4 | "database/sql" | |
5 | "database/sql/driver" | |
3 | 6 | "errors" |
4 | 7 | "fmt" |
8 | "reflect" | |
5 | 9 | "regexp" |
6 | 10 | "strings" |
7 | 11 | "time" |
8 | ||
9 | "reflect" | |
10 | 12 | ) |
11 | 13 | |
14 | // Scope contain current operation's information when you perform any operation on the database | |
12 | 15 | type Scope struct { |
13 | 16 | Search *search |
14 | 17 | Value interface{} |
15 | Sql string | |
16 | SqlVars []interface{} | |
18 | SQL string | |
19 | SQLVars []interface{} | |
17 | 20 | db *DB |
18 | indirectValue *reflect.Value | |
19 | instanceId string | |
21 | instanceID string | |
20 | 22 | primaryKeyField *Field |
21 | 23 | skipLeft bool |
22 | fields map[string]*Field | |
24 | fields *[]*Field | |
23 | 25 | selectAttrs *[]string |
24 | 26 | } |
25 | 27 | |
28 | // IndirectValue return scope's reflect value's indirect value | |
26 | 29 | func (scope *Scope) IndirectValue() reflect.Value { |
27 | if scope.indirectValue == nil { | |
28 | value := reflect.Indirect(reflect.ValueOf(scope.Value)) | |
29 | if value.Kind() == reflect.Ptr { | |
30 | value = value.Elem() | |
31 | } | |
32 | scope.indirectValue = &value | |
33 | } | |
34 | return *scope.indirectValue | |
35 | } | |
36 | ||
37 | func (scope *Scope) NeedPtr() *Scope { | |
38 | reflectKind := reflect.ValueOf(scope.Value).Kind() | |
39 | if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) { | |
40 | err := fmt.Errorf("%v %v\n", fileWithLineNum(), "using unaddressable value") | |
41 | scope.Err(err) | |
42 | fmt.Printf(err.Error()) | |
43 | } | |
44 | return scope | |
30 | return indirect(reflect.ValueOf(scope.Value)) | |
45 | 31 | } |
46 | 32 | |
47 | 33 | // New create a new Scope without search information |
48 | 34 | func (scope *Scope) New(value interface{}) *Scope { |
49 | 35 | return &Scope{db: scope.NewDB(), Search: &search{}, Value: value} |
36 | } | |
37 | ||
38 | //////////////////////////////////////////////////////////////////////////////// | |
39 | // Scope DB | |
40 | //////////////////////////////////////////////////////////////////////////////// | |
41 | ||
42 | // DB return scope's DB connection | |
43 | func (scope *Scope) DB() *DB { | |
44 | return scope.db | |
50 | 45 | } |
51 | 46 | |
52 | 47 | // NewDB create a new DB without search information |
60 | 55 | return nil |
61 | 56 | } |
62 | 57 | |
63 | func (scope *Scope) DB() *DB { | |
64 | return scope.db | |
65 | } | |
66 | ||
67 | // SqlDB return *sql.DB | |
68 | func (scope *Scope) SqlDB() sqlCommon { | |
58 | // SQLDB return *sql.DB | |
59 | func (scope *Scope) SQLDB() SQLCommon { | |
69 | 60 | return scope.db.db |
70 | 61 | } |
71 | 62 | |
72 | // SkipLeft skip remaining callbacks | |
73 | func (scope *Scope) SkipLeft() { | |
74 | scope.skipLeft = true | |
75 | } | |
76 | ||
77 | // Quote used to quote database column name according to database dialect | |
63 | // Dialect get dialect | |
64 | func (scope *Scope) Dialect() Dialect { | |
65 | return scope.db.parent.dialect | |
66 | } | |
67 | ||
68 | // Quote used to quote string to escape them for database | |
78 | 69 | func (scope *Scope) Quote(str string) string { |
79 | 70 | if strings.Index(str, ".") != -1 { |
80 | 71 | newStrs := []string{} |
82 | 73 | newStrs = append(newStrs, scope.Dialect().Quote(str)) |
83 | 74 | } |
84 | 75 | return strings.Join(newStrs, ".") |
85 | } else { | |
86 | return scope.Dialect().Quote(str) | |
87 | } | |
88 | } | |
89 | ||
90 | func (scope *Scope) QuoteIfPossible(str string) string { | |
91 | if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) { | |
92 | return scope.Quote(str) | |
93 | } | |
94 | return str | |
95 | } | |
96 | ||
97 | // Dialect get dialect | |
98 | func (scope *Scope) Dialect() Dialect { | |
99 | return scope.db.parent.dialect | |
100 | } | |
101 | ||
102 | // Err write error | |
76 | } | |
77 | ||
78 | return scope.Dialect().Quote(str) | |
79 | } | |
80 | ||
81 | // Err add error to Scope | |
103 | 82 | func (scope *Scope) Err(err error) error { |
104 | 83 | if err != nil { |
105 | 84 | scope.db.AddError(err) |
107 | 86 | return err |
108 | 87 | } |
109 | 88 | |
89 | // HasError check if there are any error | |
90 | func (scope *Scope) HasError() bool { | |
91 | return scope.db.Error != nil | |
92 | } | |
93 | ||
110 | 94 | // Log print log message |
111 | 95 | func (scope *Scope) Log(v ...interface{}) { |
112 | 96 | scope.db.log(v...) |
113 | 97 | } |
114 | 98 | |
115 | // HasError check if there are any error | |
116 | func (scope *Scope) HasError() bool { | |
117 | return scope.db.Error != nil | |
118 | } | |
119 | ||
120 | func (scope *Scope) PrimaryFields() []*Field { | |
121 | var fields = []*Field{} | |
122 | for _, field := range scope.GetModelStruct().PrimaryFields { | |
123 | fields = append(fields, scope.Fields()[field.DBName]) | |
99 | // SkipLeft skip remaining callbacks | |
100 | func (scope *Scope) SkipLeft() { | |
101 | scope.skipLeft = true | |
102 | } | |
103 | ||
104 | // Fields get value's fields | |
105 | func (scope *Scope) Fields() []*Field { | |
106 | if scope.fields == nil { | |
107 | var ( | |
108 | fields []*Field | |
109 | indirectScopeValue = scope.IndirectValue() | |
110 | isStruct = indirectScopeValue.Kind() == reflect.Struct | |
111 | ) | |
112 | ||
113 | for _, structField := range scope.GetModelStruct().StructFields { | |
114 | if isStruct { | |
115 | fieldValue := indirectScopeValue | |
116 | for _, name := range structField.Names { | |
117 | if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { | |
118 | fieldValue.Set(reflect.New(fieldValue.Type().Elem())) | |
119 | } | |
120 | fieldValue = reflect.Indirect(fieldValue).FieldByName(name) | |
121 | } | |
122 | fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}) | |
123 | } else { | |
124 | fields = append(fields, &Field{StructField: structField, IsBlank: true}) | |
125 | } | |
126 | } | |
127 | scope.fields = &fields | |
128 | } | |
129 | ||
130 | return *scope.fields | |
131 | } | |
132 | ||
133 | // FieldByName find `gorm.Field` with field name or db name | |
134 | func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { | |
135 | var ( | |
136 | dbName = ToDBName(name) | |
137 | mostMatchedField *Field | |
138 | ) | |
139 | ||
140 | for _, field := range scope.Fields() { | |
141 | if field.Name == name || field.DBName == name { | |
142 | return field, true | |
143 | } | |
144 | if field.DBName == dbName { | |
145 | mostMatchedField = field | |
146 | } | |
147 | } | |
148 | return mostMatchedField, mostMatchedField != nil | |
149 | } | |
150 | ||
151 | // PrimaryFields return scope's primary fields | |
152 | func (scope *Scope) PrimaryFields() (fields []*Field) { | |
153 | for _, field := range scope.Fields() { | |
154 | if field.IsPrimaryKey { | |
155 | fields = append(fields, field) | |
156 | } | |
124 | 157 | } |
125 | 158 | return fields |
126 | 159 | } |
127 | 160 | |
161 | // PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one | |
128 | 162 | func (scope *Scope) PrimaryField() *Field { |
129 | 163 | if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { |
130 | 164 | if len(primaryFields) > 1 { |
131 | if field, ok := scope.Fields()["id"]; ok { | |
165 | if field, ok := scope.FieldByName("id"); ok { | |
132 | 166 | return field |
133 | 167 | } |
134 | 168 | } |
135 | return scope.Fields()[primaryFields[0].DBName] | |
169 | return scope.PrimaryFields()[0] | |
136 | 170 | } |
137 | 171 | return nil |
138 | 172 | } |
139 | 173 | |
140 | // PrimaryKey get the primary key's column name | |
174 | // PrimaryKey get main primary field's db name | |
141 | 175 | func (scope *Scope) PrimaryKey() string { |
142 | 176 | if field := scope.PrimaryField(); field != nil { |
143 | 177 | return field.DBName |
145 | 179 | return "" |
146 | 180 | } |
147 | 181 | |
148 | // PrimaryKeyZero check the primary key is blank or not | |
182 | // PrimaryKeyZero check main primary field's value is blank or not | |
149 | 183 | func (scope *Scope) PrimaryKeyZero() bool { |
150 | 184 | field := scope.PrimaryField() |
151 | 185 | return field == nil || field.IsBlank |
169 | 203 | return false |
170 | 204 | } |
171 | 205 | |
172 | // SetColumn to set the column's value | |
206 | // SetColumn to set the column's value, column could be field or field's name/dbname | |
173 | 207 | func (scope *Scope) SetColumn(column interface{}, value interface{}) error { |
208 | var updateAttrs = map[string]interface{}{} | |
209 | if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { | |
210 | updateAttrs = attrs.(map[string]interface{}) | |
211 | defer scope.InstanceSet("gorm:update_attrs", updateAttrs) | |
212 | } | |
213 | ||
174 | 214 | if field, ok := column.(*Field); ok { |
215 | updateAttrs[field.DBName] = value | |
175 | 216 | return field.Set(value) |
176 | 217 | } else if name, ok := column.(string); ok { |
177 | ||
178 | if field, ok := scope.Fields()[name]; ok { | |
179 | return field.Set(value) | |
180 | } | |
181 | ||
182 | dbName := ToDBName(name) | |
183 | if field, ok := scope.Fields()[dbName]; ok { | |
184 | return field.Set(value) | |
185 | } | |
186 | ||
187 | if field, ok := scope.FieldByName(name); ok { | |
188 | return field.Set(value) | |
218 | var ( | |
219 | dbName = ToDBName(name) | |
220 | mostMatchedField *Field | |
221 | ) | |
222 | for _, field := range scope.Fields() { | |
223 | if field.DBName == value { | |
224 | updateAttrs[field.DBName] = value | |
225 | return field.Set(value) | |
226 | } | |
227 | if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) { | |
228 | mostMatchedField = field | |
229 | } | |
230 | } | |
231 | ||
232 | if mostMatchedField != nil { | |
233 | updateAttrs[mostMatchedField.DBName] = value | |
234 | return mostMatchedField.Set(value) | |
189 | 235 | } |
190 | 236 | } |
191 | 237 | return errors.New("could not convert column to field") |
192 | 238 | } |
193 | 239 | |
194 | func (scope *Scope) CallMethod(name string, checkError bool) { | |
195 | if scope.Value == nil || (checkError && scope.HasError()) { | |
240 | // CallMethod call scope value's method, if it is a slice, will call its element's method one by one | |
241 | func (scope *Scope) CallMethod(methodName string) { | |
242 | if scope.Value == nil { | |
196 | 243 | return |
197 | 244 | } |
198 | 245 | |
199 | call := func(value interface{}) { | |
200 | if fm := reflect.ValueOf(value).MethodByName(name); fm.IsValid() { | |
201 | switch f := fm.Interface().(type) { | |
202 | case func(): | |
203 | f() | |
204 | case func(s *Scope): | |
205 | f(scope) | |
206 | case func(s *DB): | |
207 | newDB := scope.NewDB() | |
208 | f(newDB) | |
209 | scope.Err(newDB.Error) | |
210 | case func() error: | |
211 | scope.Err(f()) | |
212 | case func(s *Scope) error: | |
213 | scope.Err(f(scope)) | |
214 | case func(s *DB) error: | |
215 | newDB := scope.NewDB() | |
216 | scope.Err(f(newDB)) | |
217 | scope.Err(newDB.Error) | |
218 | default: | |
219 | scope.Err(fmt.Errorf("unsupported function %v", name)) | |
220 | } | |
221 | } | |
222 | } | |
223 | ||
224 | if values := scope.IndirectValue(); values.Kind() == reflect.Slice { | |
225 | for i := 0; i < values.Len(); i++ { | |
226 | value := values.Index(i).Addr().Interface() | |
227 | if values.Index(i).Kind() == reflect.Ptr { | |
228 | value = values.Index(i).Interface() | |
229 | } | |
230 | call(value) | |
246 | if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice { | |
247 | for i := 0; i < indirectScopeValue.Len(); i++ { | |
248 | scope.callMethod(methodName, indirectScopeValue.Index(i)) | |
231 | 249 | } |
232 | 250 | } else { |
233 | if scope.IndirectValue().CanAddr() { | |
234 | call(scope.IndirectValue().Addr().Interface()) | |
235 | } else { | |
236 | call(scope.IndirectValue().Interface()) | |
237 | } | |
238 | } | |
239 | } | |
240 | ||
241 | func (scope *Scope) CallMethodWithErrorCheck(name string) { | |
242 | scope.CallMethod(name, true) | |
243 | } | |
244 | ||
245 | // AddToVars add value as sql's vars, gorm will escape them | |
251 | scope.callMethod(methodName, indirectScopeValue) | |
252 | } | |
253 | } | |
254 | ||
255 | // AddToVars add value as sql's vars, used to prevent SQL injection | |
246 | 256 | func (scope *Scope) AddToVars(value interface{}) string { |
257 | _, skipBindVar := scope.InstanceGet("skip_bindvar") | |
258 | ||
247 | 259 | if expr, ok := value.(*expr); ok { |
248 | 260 | exp := expr.expr |
249 | 261 | for _, arg := range expr.args { |
250 | exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) | |
262 | if skipBindVar { | |
263 | scope.AddToVars(arg) | |
264 | } else { | |
265 | exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) | |
266 | } | |
251 | 267 | } |
252 | 268 | return exp |
253 | } else { | |
254 | scope.SqlVars = append(scope.SqlVars, value) | |
255 | return scope.Dialect().BinVar(len(scope.SqlVars)) | |
256 | } | |
257 | } | |
258 | ||
259 | type tabler interface { | |
260 | TableName() string | |
261 | } | |
262 | ||
263 | type dbTabler interface { | |
264 | TableName(*DB) string | |
265 | } | |
266 | ||
267 | // TableName get table name | |
268 | func (scope *Scope) TableName() string { | |
269 | if scope.Search != nil && len(scope.Search.tableName) > 0 { | |
270 | return scope.Search.tableName | |
271 | } | |
272 | ||
273 | if tabler, ok := scope.Value.(tabler); ok { | |
274 | return tabler.TableName() | |
275 | } | |
276 | ||
277 | if tabler, ok := scope.Value.(dbTabler); ok { | |
278 | return tabler.TableName(scope.db) | |
279 | } | |
280 | ||
281 | return scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) | |
282 | } | |
283 | ||
284 | func (scope *Scope) QuotedTableName() (name string) { | |
285 | if scope.Search != nil && len(scope.Search.tableName) > 0 { | |
286 | if strings.Index(scope.Search.tableName, " ") != -1 { | |
287 | return scope.Search.tableName | |
288 | } | |
289 | return scope.Quote(scope.Search.tableName) | |
290 | } else { | |
291 | return scope.Quote(scope.TableName()) | |
292 | } | |
293 | } | |
294 | ||
295 | // CombinedConditionSql get combined condition sql | |
296 | func (scope *Scope) CombinedConditionSql() string { | |
297 | return scope.joinsSql() + scope.whereSql() + scope.groupSql() + | |
298 | scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql() | |
299 | } | |
300 | ||
301 | func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { | |
302 | for _, field := range scope.Fields() { | |
303 | if field.Name == name || field.DBName == name { | |
304 | return field, true | |
305 | } | |
306 | } | |
307 | return nil, false | |
308 | } | |
309 | ||
310 | // Raw set sql | |
311 | func (scope *Scope) Raw(sql string) *Scope { | |
312 | scope.Sql = strings.Replace(sql, "$$", "?", -1) | |
313 | return scope | |
314 | } | |
315 | ||
316 | // Exec invoke sql | |
317 | func (scope *Scope) Exec() *Scope { | |
318 | defer scope.Trace(NowFunc()) | |
319 | ||
320 | if !scope.HasError() { | |
321 | if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { | |
322 | if count, err := result.RowsAffected(); scope.Err(err) == nil { | |
323 | scope.db.RowsAffected = count | |
324 | } | |
325 | } | |
326 | } | |
327 | return scope | |
328 | } | |
329 | ||
330 | // Set set value by name | |
331 | func (scope *Scope) Set(name string, value interface{}) *Scope { | |
332 | scope.db.InstantSet(name, value) | |
333 | return scope | |
334 | } | |
335 | ||
336 | // Get get value by name | |
337 | func (scope *Scope) Get(name string) (interface{}, bool) { | |
338 | return scope.db.Get(name) | |
339 | } | |
340 | ||
341 | // InstanceId get InstanceId for scope | |
342 | func (scope *Scope) InstanceId() string { | |
343 | if scope.instanceId == "" { | |
344 | scope.instanceId = fmt.Sprintf("%v%v", &scope, &scope.db) | |
345 | } | |
346 | return scope.instanceId | |
347 | } | |
348 | ||
349 | func (scope *Scope) InstanceSet(name string, value interface{}) *Scope { | |
350 | return scope.Set(name+scope.InstanceId(), value) | |
351 | } | |
352 | ||
353 | func (scope *Scope) InstanceGet(name string) (interface{}, bool) { | |
354 | return scope.Get(name + scope.InstanceId()) | |
355 | } | |
356 | ||
357 | // Trace print sql log | |
358 | func (scope *Scope) Trace(t time.Time) { | |
359 | if len(scope.Sql) > 0 { | |
360 | scope.db.slog(scope.Sql, t, scope.SqlVars...) | |
361 | } | |
362 | } | |
363 | ||
364 | // Begin start a transaction | |
365 | func (scope *Scope) Begin() *Scope { | |
366 | if db, ok := scope.SqlDB().(sqlDb); ok { | |
367 | if tx, err := db.Begin(); err == nil { | |
368 | scope.db.db = interface{}(tx).(sqlCommon) | |
369 | scope.InstanceSet("gorm:started_transaction", true) | |
370 | } | |
371 | } | |
372 | return scope | |
373 | } | |
374 | ||
375 | // CommitOrRollback commit current transaction if there is no error, otherwise rollback it | |
376 | func (scope *Scope) CommitOrRollback() *Scope { | |
377 | if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { | |
378 | if db, ok := scope.db.db.(sqlTx); ok { | |
379 | if scope.HasError() { | |
380 | db.Rollback() | |
381 | } else { | |
382 | db.Commit() | |
383 | } | |
384 | scope.db.db = scope.db.parent.db | |
385 | } | |
386 | } | |
387 | return scope | |
388 | } | |
389 | ||
269 | } | |
270 | ||
271 | scope.SQLVars = append(scope.SQLVars, value) | |
272 | ||
273 | if skipBindVar { | |
274 | return "?" | |
275 | } | |
276 | return scope.Dialect().BindVar(len(scope.SQLVars)) | |
277 | } | |
278 | ||
279 | // SelectAttrs return selected attributes | |
390 | 280 | func (scope *Scope) SelectAttrs() []string { |
391 | 281 | if scope.selectAttrs == nil { |
392 | 282 | attrs := []string{} |
406 | 296 | return *scope.selectAttrs |
407 | 297 | } |
408 | 298 | |
299 | // OmitAttrs return omitted attributes | |
409 | 300 | func (scope *Scope) OmitAttrs() []string { |
410 | 301 | return scope.Search.omits |
411 | 302 | } |
412 | 303 | |
413 | func (scope *Scope) changeableDBColumn(column string) bool { | |
414 | selectAttrs := scope.SelectAttrs() | |
415 | omitAttrs := scope.OmitAttrs() | |
416 | ||
417 | if len(selectAttrs) > 0 { | |
418 | for _, attr := range selectAttrs { | |
419 | if column == ToDBName(attr) { | |
420 | return true | |
421 | } | |
422 | } | |
423 | return false | |
424 | } | |
425 | ||
426 | for _, attr := range omitAttrs { | |
427 | if column == ToDBName(attr) { | |
428 | return false | |
429 | } | |
430 | } | |
431 | return true | |
304 | type tabler interface { | |
305 | TableName() string | |
306 | } | |
307 | ||
308 | type dbTabler interface { | |
309 | TableName(*DB) string | |
310 | } | |
311 | ||
312 | // TableName return table name | |
313 | func (scope *Scope) TableName() string { | |
314 | if scope.Search != nil && len(scope.Search.tableName) > 0 { | |
315 | return scope.Search.tableName | |
316 | } | |
317 | ||
318 | if tabler, ok := scope.Value.(tabler); ok { | |
319 | return tabler.TableName() | |
320 | } | |
321 | ||
322 | if tabler, ok := scope.Value.(dbTabler); ok { | |
323 | return tabler.TableName(scope.db) | |
324 | } | |
325 | ||
326 | return scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) | |
327 | } | |
328 | ||
329 | // QuotedTableName return quoted table name | |
330 | func (scope *Scope) QuotedTableName() (name string) { | |
331 | if scope.Search != nil && len(scope.Search.tableName) > 0 { | |
332 | if strings.Index(scope.Search.tableName, " ") != -1 { | |
333 | return scope.Search.tableName | |
334 | } | |
335 | return scope.Quote(scope.Search.tableName) | |
336 | } | |
337 | ||
338 | return scope.Quote(scope.TableName()) | |
339 | } | |
340 | ||
341 | // CombinedConditionSql return combined condition sql | |
342 | func (scope *Scope) CombinedConditionSql() string { | |
343 | joinSQL := scope.joinsSQL() | |
344 | whereSQL := scope.whereSQL() | |
345 | if scope.Search.raw { | |
346 | whereSQL = strings.TrimSuffix(strings.TrimPrefix(whereSQL, "WHERE ("), ")") | |
347 | } | |
348 | return joinSQL + whereSQL + scope.groupSQL() + | |
349 | scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL() | |
350 | } | |
351 | ||
352 | // Raw set raw sql | |
353 | func (scope *Scope) Raw(sql string) *Scope { | |
354 | scope.SQL = strings.Replace(sql, "$$$", "?", -1) | |
355 | return scope | |
356 | } | |
357 | ||
358 | // Exec perform generated SQL | |
359 | func (scope *Scope) Exec() *Scope { | |
360 | defer scope.trace(NowFunc()) | |
361 | ||
362 | if !scope.HasError() { | |
363 | if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { | |
364 | if count, err := result.RowsAffected(); scope.Err(err) == nil { | |
365 | scope.db.RowsAffected = count | |
366 | } | |
367 | } | |
368 | } | |
369 | return scope | |
370 | } | |
371 | ||
372 | // Set set value by name | |
373 | func (scope *Scope) Set(name string, value interface{}) *Scope { | |
374 | scope.db.InstantSet(name, value) | |
375 | return scope | |
376 | } | |
377 | ||
378 | // Get get setting by name | |
379 | func (scope *Scope) Get(name string) (interface{}, bool) { | |
380 | return scope.db.Get(name) | |
381 | } | |
382 | ||
383 | // InstanceID get InstanceID for scope | |
384 | func (scope *Scope) InstanceID() string { | |
385 | if scope.instanceID == "" { | |
386 | scope.instanceID = fmt.Sprintf("%v%v", &scope, &scope.db) | |
387 | } | |
388 | return scope.instanceID | |
389 | } | |
390 | ||
391 | // InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback | |
392 | func (scope *Scope) InstanceSet(name string, value interface{}) *Scope { | |
393 | return scope.Set(name+scope.InstanceID(), value) | |
394 | } | |
395 | ||
396 | // InstanceGet get instance setting from current operation | |
397 | func (scope *Scope) InstanceGet(name string) (interface{}, bool) { | |
398 | return scope.Get(name + scope.InstanceID()) | |
399 | } | |
400 | ||
401 | // Begin start a transaction | |
402 | func (scope *Scope) Begin() *Scope { | |
403 | if db, ok := scope.SQLDB().(sqlDb); ok { | |
404 | if tx, err := db.Begin(); err == nil { | |
405 | scope.db.db = interface{}(tx).(SQLCommon) | |
406 | scope.InstanceSet("gorm:started_transaction", true) | |
407 | } | |
408 | } | |
409 | return scope | |
410 | } | |
411 | ||
412 | // CommitOrRollback commit current transaction if no error happened, otherwise will rollback it | |
413 | func (scope *Scope) CommitOrRollback() *Scope { | |
414 | if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { | |
415 | if db, ok := scope.db.db.(sqlTx); ok { | |
416 | if scope.HasError() { | |
417 | db.Rollback() | |
418 | } else { | |
419 | scope.Err(db.Commit()) | |
420 | } | |
421 | scope.db.db = scope.db.parent.db | |
422 | } | |
423 | } | |
424 | return scope | |
425 | } | |
426 | ||
427 | //////////////////////////////////////////////////////////////////////////////// | |
428 | // Private Methods For *gorm.Scope | |
429 | //////////////////////////////////////////////////////////////////////////////// | |
430 | ||
431 | func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { | |
432 | // Only get address from non-pointer | |
433 | if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr { | |
434 | reflectValue = reflectValue.Addr() | |
435 | } | |
436 | ||
437 | if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() { | |
438 | switch method := methodValue.Interface().(type) { | |
439 | case func(): | |
440 | method() | |
441 | case func(*Scope): | |
442 | method(scope) | |
443 | case func(*DB): | |
444 | newDB := scope.NewDB() | |
445 | method(newDB) | |
446 | scope.Err(newDB.Error) | |
447 | case func() error: | |
448 | scope.Err(method()) | |
449 | case func(*Scope) error: | |
450 | scope.Err(method(scope)) | |
451 | case func(*DB) error: | |
452 | newDB := scope.NewDB() | |
453 | scope.Err(method(newDB)) | |
454 | scope.Err(newDB.Error) | |
455 | default: | |
456 | scope.Err(fmt.Errorf("unsupported function %v", methodName)) | |
457 | } | |
458 | } | |
459 | } | |
460 | ||
461 | var ( | |
462 | columnRegexp = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name` | |
463 | isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number | |
464 | comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ") | |
465 | countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$") | |
466 | ) | |
467 | ||
468 | func (scope *Scope) quoteIfPossible(str string) string { | |
469 | if columnRegexp.MatchString(str) { | |
470 | return scope.Quote(str) | |
471 | } | |
472 | return str | |
473 | } | |
474 | ||
475 | func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { | |
476 | var ( | |
477 | ignored interface{} | |
478 | values = make([]interface{}, len(columns)) | |
479 | selectFields []*Field | |
480 | selectedColumnsMap = map[string]int{} | |
481 | resetFields = map[int]*Field{} | |
482 | ) | |
483 | ||
484 | for index, column := range columns { | |
485 | values[index] = &ignored | |
486 | ||
487 | selectFields = fields | |
488 | if idx, ok := selectedColumnsMap[column]; ok { | |
489 | selectFields = selectFields[idx+1:] | |
490 | } | |
491 | ||
492 | for fieldIndex, field := range selectFields { | |
493 | if field.DBName == column { | |
494 | if field.Field.Kind() == reflect.Ptr { | |
495 | values[index] = field.Field.Addr().Interface() | |
496 | } else { | |
497 | reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) | |
498 | reflectValue.Elem().Set(field.Field.Addr()) | |
499 | values[index] = reflectValue.Interface() | |
500 | resetFields[index] = field | |
501 | } | |
502 | ||
503 | selectedColumnsMap[column] = fieldIndex | |
504 | ||
505 | if field.IsNormal { | |
506 | break | |
507 | } | |
508 | } | |
509 | } | |
510 | } | |
511 | ||
512 | scope.Err(rows.Scan(values...)) | |
513 | ||
514 | for index, field := range resetFields { | |
515 | if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { | |
516 | field.Field.Set(v) | |
517 | } | |
518 | } | |
519 | } | |
520 | ||
521 | func (scope *Scope) primaryCondition(value interface{}) string { | |
522 | return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value) | |
523 | } | |
524 | ||
525 | func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) { | |
526 | var ( | |
527 | quotedTableName = scope.QuotedTableName() | |
528 | quotedPrimaryKey = scope.Quote(scope.PrimaryKey()) | |
529 | equalSQL = "=" | |
530 | inSQL = "IN" | |
531 | ) | |
532 | ||
533 | // If building not conditions | |
534 | if !include { | |
535 | equalSQL = "<>" | |
536 | inSQL = "NOT IN" | |
537 | } | |
538 | ||
539 | switch value := clause["query"].(type) { | |
540 | case sql.NullInt64: | |
541 | return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64) | |
542 | case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: | |
543 | return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value) | |
544 | case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: | |
545 | if !include && reflect.ValueOf(value).Len() == 0 { | |
546 | return | |
547 | } | |
548 | str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL) | |
549 | clause["args"] = []interface{}{value} | |
550 | case string: | |
551 | if isNumberRegexp.MatchString(value) { | |
552 | return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value)) | |
553 | } | |
554 | ||
555 | if value != "" { | |
556 | if !include { | |
557 | if comparisonRegexp.MatchString(value) { | |
558 | str = fmt.Sprintf("NOT (%v)", value) | |
559 | } else { | |
560 | str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value)) | |
561 | } | |
562 | } else { | |
563 | str = fmt.Sprintf("(%v)", value) | |
564 | } | |
565 | } | |
566 | case map[string]interface{}: | |
567 | var sqls []string | |
568 | for key, value := range value { | |
569 | if value != nil { | |
570 | sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value))) | |
571 | } else { | |
572 | if !include { | |
573 | sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key))) | |
574 | } else { | |
575 | sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key))) | |
576 | } | |
577 | } | |
578 | } | |
579 | return strings.Join(sqls, " AND ") | |
580 | case interface{}: | |
581 | var sqls []string | |
582 | newScope := scope.New(value) | |
583 | ||
584 | if len(newScope.Fields()) == 0 { | |
585 | scope.Err(fmt.Errorf("invalid query condition: %v", value)) | |
586 | return | |
587 | } | |
588 | ||
589 | for _, field := range newScope.Fields() { | |
590 | if !field.IsIgnored && !field.IsBlank { | |
591 | sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) | |
592 | } | |
593 | } | |
594 | return strings.Join(sqls, " AND ") | |
595 | default: | |
596 | scope.Err(fmt.Errorf("invalid query condition: %v", value)) | |
597 | return | |
598 | } | |
599 | ||
600 | replacements := []string{} | |
601 | args := clause["args"].([]interface{}) | |
602 | for _, arg := range args { | |
603 | var err error | |
604 | switch reflect.ValueOf(arg).Kind() { | |
605 | case reflect.Slice: // For where("id in (?)", []int64{1,2}) | |
606 | if scanner, ok := interface{}(arg).(driver.Valuer); ok { | |
607 | arg, err = scanner.Value() | |
608 | replacements = append(replacements, scope.AddToVars(arg)) | |
609 | } else if b, ok := arg.([]byte); ok { | |
610 | replacements = append(replacements, scope.AddToVars(b)) | |
611 | } else if as, ok := arg.([][]interface{}); ok { | |
612 | var tempMarks []string | |
613 | for _, a := range as { | |
614 | var arrayMarks []string | |
615 | for _, v := range a { | |
616 | arrayMarks = append(arrayMarks, scope.AddToVars(v)) | |
617 | } | |
618 | ||
619 | if len(arrayMarks) > 0 { | |
620 | tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ","))) | |
621 | } | |
622 | } | |
623 | ||
624 | if len(tempMarks) > 0 { | |
625 | replacements = append(replacements, strings.Join(tempMarks, ",")) | |
626 | } | |
627 | } else if values := reflect.ValueOf(arg); values.Len() > 0 { | |
628 | var tempMarks []string | |
629 | for i := 0; i < values.Len(); i++ { | |
630 | tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) | |
631 | } | |
632 | replacements = append(replacements, strings.Join(tempMarks, ",")) | |
633 | } else { | |
634 | replacements = append(replacements, scope.AddToVars(Expr("NULL"))) | |
635 | } | |
636 | default: | |
637 | if valuer, ok := interface{}(arg).(driver.Valuer); ok { | |
638 | arg, err = valuer.Value() | |
639 | } | |
640 | ||
641 | replacements = append(replacements, scope.AddToVars(arg)) | |
642 | } | |
643 | ||
644 | if err != nil { | |
645 | scope.Err(err) | |
646 | } | |
647 | } | |
648 | ||
649 | buff := bytes.NewBuffer([]byte{}) | |
650 | i := 0 | |
651 | for _, s := range str { | |
652 | if s == '?' { | |
653 | buff.WriteString(replacements[i]) | |
654 | i++ | |
655 | } else { | |
656 | buff.WriteRune(s) | |
657 | } | |
658 | } | |
659 | ||
660 | str = buff.String() | |
661 | ||
662 | return | |
663 | } | |
664 | ||
665 | func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { | |
666 | switch value := clause["query"].(type) { | |
667 | case string: | |
668 | str = value | |
669 | case []string: | |
670 | str = strings.Join(value, ", ") | |
671 | } | |
672 | ||
673 | args := clause["args"].([]interface{}) | |
674 | replacements := []string{} | |
675 | for _, arg := range args { | |
676 | switch reflect.ValueOf(arg).Kind() { | |
677 | case reflect.Slice: | |
678 | values := reflect.ValueOf(arg) | |
679 | var tempMarks []string | |
680 | for i := 0; i < values.Len(); i++ { | |
681 | tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) | |
682 | } | |
683 | replacements = append(replacements, strings.Join(tempMarks, ",")) | |
684 | default: | |
685 | if valuer, ok := interface{}(arg).(driver.Valuer); ok { | |
686 | arg, _ = valuer.Value() | |
687 | } | |
688 | replacements = append(replacements, scope.AddToVars(arg)) | |
689 | } | |
690 | } | |
691 | ||
692 | buff := bytes.NewBuffer([]byte{}) | |
693 | i := 0 | |
694 | for pos := range str { | |
695 | if str[pos] == '?' { | |
696 | buff.WriteString(replacements[i]) | |
697 | i++ | |
698 | } else { | |
699 | buff.WriteByte(str[pos]) | |
700 | } | |
701 | } | |
702 | ||
703 | str = buff.String() | |
704 | ||
705 | return | |
706 | } | |
707 | ||
708 | func (scope *Scope) whereSQL() (sql string) { | |
709 | var ( | |
710 | quotedTableName = scope.QuotedTableName() | |
711 | deletedAtField, hasDeletedAtField = scope.FieldByName("DeletedAt") | |
712 | primaryConditions, andConditions, orConditions []string | |
713 | ) | |
714 | ||
715 | if !scope.Search.Unscoped && hasDeletedAtField { | |
716 | sql := fmt.Sprintf("%v.%v IS NULL", quotedTableName, scope.Quote(deletedAtField.DBName)) | |
717 | primaryConditions = append(primaryConditions, sql) | |
718 | } | |
719 | ||
720 | if !scope.PrimaryKeyZero() { | |
721 | for _, field := range scope.PrimaryFields() { | |
722 | sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())) | |
723 | primaryConditions = append(primaryConditions, sql) | |
724 | } | |
725 | } | |
726 | ||
727 | for _, clause := range scope.Search.whereConditions { | |
728 | if sql := scope.buildCondition(clause, true); sql != "" { | |
729 | andConditions = append(andConditions, sql) | |
730 | } | |
731 | } | |
732 | ||
733 | for _, clause := range scope.Search.orConditions { | |
734 | if sql := scope.buildCondition(clause, true); sql != "" { | |
735 | orConditions = append(orConditions, sql) | |
736 | } | |
737 | } | |
738 | ||
739 | for _, clause := range scope.Search.notConditions { | |
740 | if sql := scope.buildCondition(clause, false); sql != "" { | |
741 | andConditions = append(andConditions, sql) | |
742 | } | |
743 | } | |
744 | ||
745 | orSQL := strings.Join(orConditions, " OR ") | |
746 | combinedSQL := strings.Join(andConditions, " AND ") | |
747 | if len(combinedSQL) > 0 { | |
748 | if len(orSQL) > 0 { | |
749 | combinedSQL = combinedSQL + " OR " + orSQL | |
750 | } | |
751 | } else { | |
752 | combinedSQL = orSQL | |
753 | } | |
754 | ||
755 | if len(primaryConditions) > 0 { | |
756 | sql = "WHERE " + strings.Join(primaryConditions, " AND ") | |
757 | if len(combinedSQL) > 0 { | |
758 | sql = sql + " AND (" + combinedSQL + ")" | |
759 | } | |
760 | } else if len(combinedSQL) > 0 { | |
761 | sql = "WHERE " + combinedSQL | |
762 | } | |
763 | return | |
764 | } | |
765 | ||
766 | func (scope *Scope) selectSQL() string { | |
767 | if len(scope.Search.selects) == 0 { | |
768 | if len(scope.Search.joinConditions) > 0 { | |
769 | return fmt.Sprintf("%v.*", scope.QuotedTableName()) | |
770 | } | |
771 | return "*" | |
772 | } | |
773 | return scope.buildSelectQuery(scope.Search.selects) | |
774 | } | |
775 | ||
776 | func (scope *Scope) orderSQL() string { | |
777 | if len(scope.Search.orders) == 0 || scope.Search.ignoreOrderQuery { | |
778 | return "" | |
779 | } | |
780 | ||
781 | var orders []string | |
782 | for _, order := range scope.Search.orders { | |
783 | if str, ok := order.(string); ok { | |
784 | orders = append(orders, scope.quoteIfPossible(str)) | |
785 | } else if expr, ok := order.(*expr); ok { | |
786 | exp := expr.expr | |
787 | for _, arg := range expr.args { | |
788 | exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) | |
789 | } | |
790 | orders = append(orders, exp) | |
791 | } | |
792 | } | |
793 | return " ORDER BY " + strings.Join(orders, ",") | |
794 | } | |
795 | ||
796 | func (scope *Scope) limitAndOffsetSQL() string { | |
797 | return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) | |
798 | } | |
799 | ||
800 | func (scope *Scope) groupSQL() string { | |
801 | if len(scope.Search.group) == 0 { | |
802 | return "" | |
803 | } | |
804 | return " GROUP BY " + scope.Search.group | |
805 | } | |
806 | ||
807 | func (scope *Scope) havingSQL() string { | |
808 | if len(scope.Search.havingConditions) == 0 { | |
809 | return "" | |
810 | } | |
811 | ||
812 | var andConditions []string | |
813 | for _, clause := range scope.Search.havingConditions { | |
814 | if sql := scope.buildCondition(clause, true); sql != "" { | |
815 | andConditions = append(andConditions, sql) | |
816 | } | |
817 | } | |
818 | ||
819 | combinedSQL := strings.Join(andConditions, " AND ") | |
820 | if len(combinedSQL) == 0 { | |
821 | return "" | |
822 | } | |
823 | ||
824 | return " HAVING " + combinedSQL | |
825 | } | |
826 | ||
827 | func (scope *Scope) joinsSQL() string { | |
828 | var joinConditions []string | |
829 | for _, clause := range scope.Search.joinConditions { | |
830 | if sql := scope.buildCondition(clause, true); sql != "" { | |
831 | joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) | |
832 | } | |
833 | } | |
834 | ||
835 | return strings.Join(joinConditions, " ") + " " | |
836 | } | |
837 | ||
838 | func (scope *Scope) prepareQuerySQL() { | |
839 | if scope.Search.raw { | |
840 | scope.Raw(scope.CombinedConditionSql()) | |
841 | } else { | |
842 | scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())) | |
843 | } | |
844 | return | |
845 | } | |
846 | ||
847 | func (scope *Scope) inlineCondition(values ...interface{}) *Scope { | |
848 | if len(values) > 0 { | |
849 | scope.Search.Where(values[0], values[1:]...) | |
850 | } | |
851 | return scope | |
852 | } | |
853 | ||
854 | func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { | |
855 | for _, f := range funcs { | |
856 | (*f)(scope) | |
857 | if scope.skipLeft { | |
858 | break | |
859 | } | |
860 | } | |
861 | return scope | |
862 | } | |
863 | ||
864 | func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string]interface{} { | |
865 | var attrs = map[string]interface{}{} | |
866 | ||
867 | switch value := values.(type) { | |
868 | case map[string]interface{}: | |
869 | return value | |
870 | case []interface{}: | |
871 | for _, v := range value { | |
872 | for key, value := range convertInterfaceToMap(v, withIgnoredField) { | |
873 | attrs[key] = value | |
874 | } | |
875 | } | |
876 | case interface{}: | |
877 | reflectValue := reflect.ValueOf(values) | |
878 | ||
879 | switch reflectValue.Kind() { | |
880 | case reflect.Map: | |
881 | for _, key := range reflectValue.MapKeys() { | |
882 | attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() | |
883 | } | |
884 | default: | |
885 | for _, field := range (&Scope{Value: values}).Fields() { | |
886 | if !field.IsBlank && (withIgnoredField || !field.IsIgnored) { | |
887 | attrs[field.DBName] = field.Field.Interface() | |
888 | } | |
889 | } | |
890 | } | |
891 | } | |
892 | return attrs | |
893 | } | |
894 | ||
895 | func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) { | |
896 | if scope.IndirectValue().Kind() != reflect.Struct { | |
897 | return convertInterfaceToMap(value, false), true | |
898 | } | |
899 | ||
900 | results = map[string]interface{}{} | |
901 | ||
902 | for key, value := range convertInterfaceToMap(value, true) { | |
903 | if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { | |
904 | if _, ok := value.(*expr); ok { | |
905 | hasUpdate = true | |
906 | results[field.DBName] = value | |
907 | } else { | |
908 | err := field.Set(value) | |
909 | if field.IsNormal { | |
910 | hasUpdate = true | |
911 | if err == ErrUnaddressable { | |
912 | results[field.DBName] = value | |
913 | } else { | |
914 | results[field.DBName] = field.Field.Interface() | |
915 | } | |
916 | } | |
917 | } | |
918 | } | |
919 | } | |
920 | return | |
921 | } | |
922 | ||
923 | func (scope *Scope) row() *sql.Row { | |
924 | defer scope.trace(NowFunc()) | |
925 | ||
926 | result := &RowQueryResult{} | |
927 | scope.InstanceSet("row_query_result", result) | |
928 | scope.callCallbacks(scope.db.parent.callbacks.rowQueries) | |
929 | ||
930 | return result.Row | |
931 | } | |
932 | ||
933 | func (scope *Scope) rows() (*sql.Rows, error) { | |
934 | defer scope.trace(NowFunc()) | |
935 | ||
936 | result := &RowsQueryResult{} | |
937 | scope.InstanceSet("row_query_result", result) | |
938 | scope.callCallbacks(scope.db.parent.callbacks.rowQueries) | |
939 | ||
940 | return result.Rows, result.Error | |
941 | } | |
942 | ||
943 | func (scope *Scope) initialize() *Scope { | |
944 | for _, clause := range scope.Search.whereConditions { | |
945 | scope.updatedAttrsWithValues(clause["query"]) | |
946 | } | |
947 | scope.updatedAttrsWithValues(scope.Search.initAttrs) | |
948 | scope.updatedAttrsWithValues(scope.Search.assignAttrs) | |
949 | return scope | |
950 | } | |
951 | ||
952 | func (scope *Scope) isQueryForColumn(query interface{}, column string) bool { | |
953 | queryStr := strings.ToLower(fmt.Sprint(query)) | |
954 | if queryStr == column { | |
955 | return true | |
956 | } | |
957 | ||
958 | if strings.HasSuffix(queryStr, "as "+column) { | |
959 | return true | |
960 | } | |
961 | ||
962 | if strings.HasSuffix(queryStr, "as "+scope.Quote(column)) { | |
963 | return true | |
964 | } | |
965 | ||
966 | return false | |
967 | } | |
968 | ||
969 | func (scope *Scope) pluck(column string, value interface{}) *Scope { | |
970 | dest := reflect.Indirect(reflect.ValueOf(value)) | |
971 | if dest.Kind() != reflect.Slice { | |
972 | scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) | |
973 | return scope | |
974 | } | |
975 | ||
976 | if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) { | |
977 | scope.Search.Select(column) | |
978 | } | |
979 | ||
980 | rows, err := scope.rows() | |
981 | if scope.Err(err) == nil { | |
982 | defer rows.Close() | |
983 | for rows.Next() { | |
984 | elem := reflect.New(dest.Type().Elem()).Interface() | |
985 | scope.Err(rows.Scan(elem)) | |
986 | dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) | |
987 | } | |
988 | ||
989 | if err := rows.Err(); err != nil { | |
990 | scope.Err(err) | |
991 | } | |
992 | } | |
993 | return scope | |
994 | } | |
995 | ||
996 | func (scope *Scope) count(value interface{}) *Scope { | |
997 | if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { | |
998 | if len(scope.Search.group) != 0 { | |
999 | scope.Search.Select("count(*) FROM ( SELECT count(*) as name ") | |
1000 | scope.Search.group += " ) AS count_table" | |
1001 | } else { | |
1002 | scope.Search.Select("count(*)") | |
1003 | } | |
1004 | } | |
1005 | scope.Search.ignoreOrderQuery = true | |
1006 | scope.Err(scope.row().Scan(value)) | |
1007 | return scope | |
1008 | } | |
1009 | ||
1010 | func (scope *Scope) typeName() string { | |
1011 | typ := scope.IndirectValue().Type() | |
1012 | ||
1013 | for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr { | |
1014 | typ = typ.Elem() | |
1015 | } | |
1016 | ||
1017 | return typ.Name() | |
1018 | } | |
1019 | ||
1020 | // trace print sql log | |
1021 | func (scope *Scope) trace(t time.Time) { | |
1022 | if len(scope.SQL) > 0 { | |
1023 | scope.db.slog(scope.SQL, t, scope.SQLVars...) | |
1024 | } | |
432 | 1025 | } |
433 | 1026 | |
434 | 1027 | func (scope *Scope) changeableField(field *Field) bool { |
435 | selectAttrs := scope.SelectAttrs() | |
436 | omitAttrs := scope.OmitAttrs() | |
437 | ||
438 | if len(selectAttrs) > 0 { | |
1028 | if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 { | |
439 | 1029 | for _, attr := range selectAttrs { |
440 | 1030 | if field.Name == attr || field.DBName == attr { |
441 | 1031 | return true |
444 | 1034 | return false |
445 | 1035 | } |
446 | 1036 | |
447 | for _, attr := range omitAttrs { | |
1037 | for _, attr := range scope.OmitAttrs() { | |
448 | 1038 | if field.Name == attr || field.DBName == attr { |
449 | 1039 | return false |
450 | 1040 | } |
451 | 1041 | } |
452 | 1042 | |
453 | return !field.IsIgnored | |
454 | } | |
455 | ||
456 | func (scope *Scope) shouldSaveAssociations() bool { | |
457 | saveAssociations, ok := scope.Get("gorm:save_associations") | |
458 | if ok && !saveAssociations.(bool) { | |
459 | return false | |
460 | } | |
461 | return true && !scope.HasError() | |
462 | } | |
1043 | return true | |
1044 | } | |
1045 | ||
1046 | func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { | |
1047 | toScope := scope.db.NewScope(value) | |
1048 | tx := scope.db.Set("gorm:association:source", scope.Value) | |
1049 | ||
1050 | for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { | |
1051 | fromField, _ := scope.FieldByName(foreignKey) | |
1052 | toField, _ := toScope.FieldByName(foreignKey) | |
1053 | ||
1054 | if fromField != nil { | |
1055 | if relationship := fromField.Relationship; relationship != nil { | |
1056 | if relationship.Kind == "many_to_many" { | |
1057 | joinTableHandler := relationship.JoinTableHandler | |
1058 | scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Find(value).Error) | |
1059 | } else if relationship.Kind == "belongs_to" { | |
1060 | for idx, foreignKey := range relationship.ForeignDBNames { | |
1061 | if field, ok := scope.FieldByName(foreignKey); ok { | |
1062 | tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) | |
1063 | } | |
1064 | } | |
1065 | scope.Err(tx.Find(value).Error) | |
1066 | } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { | |
1067 | for idx, foreignKey := range relationship.ForeignDBNames { | |
1068 | if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { | |
1069 | tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) | |
1070 | } | |
1071 | } | |
1072 | ||
1073 | if relationship.PolymorphicType != "" { | |
1074 | tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) | |
1075 | } | |
1076 | scope.Err(tx.Find(value).Error) | |
1077 | } | |
1078 | } else { | |
1079 | sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) | |
1080 | scope.Err(tx.Where(sql, fromField.Field.Interface()).Find(value).Error) | |
1081 | } | |
1082 | return scope | |
1083 | } else if toField != nil { | |
1084 | sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) | |
1085 | scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) | |
1086 | return scope | |
1087 | } | |
1088 | } | |
1089 | ||
1090 | scope.Err(fmt.Errorf("invalid association %v", foreignKeys)) | |
1091 | return scope | |
1092 | } | |
1093 | ||
1094 | // getTableOptions return the table options string or an empty string if the table options does not exist | |
1095 | func (scope *Scope) getTableOptions() string { | |
1096 | tableOptions, ok := scope.Get("gorm:table_options") | |
1097 | if !ok { | |
1098 | return "" | |
1099 | } | |
1100 | return " " + tableOptions.(string) | |
1101 | } | |
1102 | ||
1103 | func (scope *Scope) createJoinTable(field *StructField) { | |
1104 | if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { | |
1105 | joinTableHandler := relationship.JoinTableHandler | |
1106 | joinTable := joinTableHandler.Table(scope.db) | |
1107 | if !scope.Dialect().HasTable(joinTable) { | |
1108 | toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} | |
1109 | ||
1110 | var sqlTypes, primaryKeys []string | |
1111 | for idx, fieldName := range relationship.ForeignFieldNames { | |
1112 | if field, ok := scope.FieldByName(fieldName); ok { | |
1113 | foreignKeyStruct := field.clone() | |
1114 | foreignKeyStruct.IsPrimaryKey = false | |
1115 | foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" | |
1116 | delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT") | |
1117 | sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) | |
1118 | primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) | |
1119 | } | |
1120 | } | |
1121 | ||
1122 | for idx, fieldName := range relationship.AssociationForeignFieldNames { | |
1123 | if field, ok := toScope.FieldByName(fieldName); ok { | |
1124 | foreignKeyStruct := field.clone() | |
1125 | foreignKeyStruct.IsPrimaryKey = false | |
1126 | foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" | |
1127 | delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT") | |
1128 | sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) | |
1129 | primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) | |
1130 | } | |
1131 | } | |
1132 | ||
1133 | scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v))%s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) | |
1134 | } | |
1135 | scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) | |
1136 | } | |
1137 | } | |
1138 | ||
1139 | func (scope *Scope) createTable() *Scope { | |
1140 | var tags []string | |
1141 | var primaryKeys []string | |
1142 | var primaryKeyInColumnType = false | |
1143 | for _, field := range scope.GetModelStruct().StructFields { | |
1144 | if field.IsNormal { | |
1145 | sqlTag := scope.Dialect().DataTypeOf(field) | |
1146 | ||
1147 | // Check if the primary key constraint was specified as | |
1148 | // part of the column type. If so, we can only support | |
1149 | // one column as the primary key. | |
1150 | if strings.Contains(strings.ToLower(sqlTag), "primary key") { | |
1151 | primaryKeyInColumnType = true | |
1152 | } | |
1153 | ||
1154 | tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag) | |
1155 | } | |
1156 | ||
1157 | if field.IsPrimaryKey { | |
1158 | primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) | |
1159 | } | |
1160 | scope.createJoinTable(field) | |
1161 | } | |
1162 | ||
1163 | var primaryKeyStr string | |
1164 | if len(primaryKeys) > 0 && !primaryKeyInColumnType { | |
1165 | primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) | |
1166 | } | |
1167 | ||
1168 | scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() | |
1169 | ||
1170 | scope.autoIndex() | |
1171 | return scope | |
1172 | } | |
1173 | ||
1174 | func (scope *Scope) dropTable() *Scope { | |
1175 | scope.Raw(fmt.Sprintf("DROP TABLE %v%s", scope.QuotedTableName(), scope.getTableOptions())).Exec() | |
1176 | return scope | |
1177 | } | |
1178 | ||
1179 | func (scope *Scope) modifyColumn(column string, typ string) { | |
1180 | scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ)) | |
1181 | } | |
1182 | ||
1183 | func (scope *Scope) dropColumn(column string) { | |
1184 | scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() | |
1185 | } | |
1186 | ||
1187 | func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { | |
1188 | if scope.Dialect().HasIndex(scope.TableName(), indexName) { | |
1189 | return | |
1190 | } | |
1191 | ||
1192 | var columns []string | |
1193 | for _, name := range column { | |
1194 | columns = append(columns, scope.quoteIfPossible(name)) | |
1195 | } | |
1196 | ||
1197 | sqlCreate := "CREATE INDEX" | |
1198 | if unique { | |
1199 | sqlCreate = "CREATE UNIQUE INDEX" | |
1200 | } | |
1201 | ||
1202 | scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() | |
1203 | } | |
1204 | ||
1205 | func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { | |
1206 | // Compatible with old generated key | |
1207 | keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") | |
1208 | ||
1209 | if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { | |
1210 | return | |
1211 | } | |
1212 | var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` | |
1213 | scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() | |
1214 | } | |
1215 | ||
1216 | func (scope *Scope) removeForeignKey(field string, dest string) { | |
1217 | keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest) | |
1218 | ||
1219 | if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { | |
1220 | return | |
1221 | } | |
1222 | var query = `ALTER TABLE %s DROP CONSTRAINT %s;` | |
1223 | scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() | |
1224 | } | |
1225 | ||
1226 | func (scope *Scope) removeIndex(indexName string) { | |
1227 | scope.Dialect().RemoveIndex(scope.TableName(), indexName) | |
1228 | } | |
1229 | ||
1230 | func (scope *Scope) autoMigrate() *Scope { | |
1231 | tableName := scope.TableName() | |
1232 | quotedTableName := scope.QuotedTableName() | |
1233 | ||
1234 | if !scope.Dialect().HasTable(tableName) { | |
1235 | scope.createTable() | |
1236 | } else { | |
1237 | for _, field := range scope.GetModelStruct().StructFields { | |
1238 | if !scope.Dialect().HasColumn(tableName, field.DBName) { | |
1239 | if field.IsNormal { | |
1240 | sqlTag := scope.Dialect().DataTypeOf(field) | |
1241 | scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() | |
1242 | } | |
1243 | } | |
1244 | scope.createJoinTable(field) | |
1245 | } | |
1246 | scope.autoIndex() | |
1247 | } | |
1248 | return scope | |
1249 | } | |
1250 | ||
1251 | func (scope *Scope) autoIndex() *Scope { | |
1252 | var indexes = map[string][]string{} | |
1253 | var uniqueIndexes = map[string][]string{} | |
1254 | ||
1255 | for _, field := range scope.GetStructFields() { | |
1256 | if name, ok := field.TagSettings["INDEX"]; ok { | |
1257 | names := strings.Split(name, ",") | |
1258 | ||
1259 | for _, name := range names { | |
1260 | if name == "INDEX" || name == "" { | |
1261 | name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName) | |
1262 | } | |
1263 | indexes[name] = append(indexes[name], field.DBName) | |
1264 | } | |
1265 | } | |
1266 | ||
1267 | if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok { | |
1268 | names := strings.Split(name, ",") | |
1269 | ||
1270 | for _, name := range names { | |
1271 | if name == "UNIQUE_INDEX" || name == "" { | |
1272 | name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName) | |
1273 | } | |
1274 | uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) | |
1275 | } | |
1276 | } | |
1277 | } | |
1278 | ||
1279 | for name, columns := range indexes { | |
1280 | if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddIndex(name, columns...); db.Error != nil { | |
1281 | scope.db.AddError(db.Error) | |
1282 | } | |
1283 | } | |
1284 | ||
1285 | for name, columns := range uniqueIndexes { | |
1286 | if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil { | |
1287 | scope.db.AddError(db.Error) | |
1288 | } | |
1289 | } | |
1290 | ||
1291 | return scope | |
1292 | } | |
1293 | ||
1294 | func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { | |
1295 | for _, value := range values { | |
1296 | indirectValue := indirect(reflect.ValueOf(value)) | |
1297 | ||
1298 | switch indirectValue.Kind() { | |
1299 | case reflect.Slice: | |
1300 | for i := 0; i < indirectValue.Len(); i++ { | |
1301 | var result []interface{} | |
1302 | var object = indirect(indirectValue.Index(i)) | |
1303 | var hasValue = false | |
1304 | for _, column := range columns { | |
1305 | field := object.FieldByName(column) | |
1306 | if hasValue || !isBlank(field) { | |
1307 | hasValue = true | |
1308 | } | |
1309 | result = append(result, field.Interface()) | |
1310 | } | |
1311 | ||
1312 | if hasValue { | |
1313 | results = append(results, result) | |
1314 | } | |
1315 | } | |
1316 | case reflect.Struct: | |
1317 | var result []interface{} | |
1318 | var hasValue = false | |
1319 | for _, column := range columns { | |
1320 | field := indirectValue.FieldByName(column) | |
1321 | if hasValue || !isBlank(field) { | |
1322 | hasValue = true | |
1323 | } | |
1324 | result = append(result, field.Interface()) | |
1325 | } | |
1326 | ||
1327 | if hasValue { | |
1328 | results = append(results, result) | |
1329 | } | |
1330 | } | |
1331 | } | |
1332 | ||
1333 | return | |
1334 | } | |
1335 | ||
1336 | func (scope *Scope) getColumnAsScope(column string) *Scope { | |
1337 | indirectScopeValue := scope.IndirectValue() | |
1338 | ||
1339 | switch indirectScopeValue.Kind() { | |
1340 | case reflect.Slice: | |
1341 | if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok { | |
1342 | fieldType := fieldStruct.Type | |
1343 | if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr { | |
1344 | fieldType = fieldType.Elem() | |
1345 | } | |
1346 | ||
1347 | resultsMap := map[interface{}]bool{} | |
1348 | results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() | |
1349 | ||
1350 | for i := 0; i < indirectScopeValue.Len(); i++ { | |
1351 | result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column)) | |
1352 | ||
1353 | if result.Kind() == reflect.Slice { | |
1354 | for j := 0; j < result.Len(); j++ { | |
1355 | if elem := result.Index(j); elem.CanAddr() && resultsMap[elem.Addr()] != true { | |
1356 | resultsMap[elem.Addr()] = true | |
1357 | results = reflect.Append(results, elem.Addr()) | |
1358 | } | |
1359 | } | |
1360 | } else if result.CanAddr() && resultsMap[result.Addr()] != true { | |
1361 | resultsMap[result.Addr()] = true | |
1362 | results = reflect.Append(results, result.Addr()) | |
1363 | } | |
1364 | } | |
1365 | return scope.New(results.Interface()) | |
1366 | } | |
1367 | case reflect.Struct: | |
1368 | if field := indirectScopeValue.FieldByName(column); field.CanAddr() { | |
1369 | return scope.New(field.Addr().Interface()) | |
1370 | } | |
1371 | } | |
1372 | return nil | |
1373 | } | |
1374 | ||
1375 | func (scope *Scope) hasConditions() bool { | |
1376 | return !scope.PrimaryKeyZero() || | |
1377 | len(scope.Search.whereConditions) > 0 || | |
1378 | len(scope.Search.orConditions) > 0 || | |
1379 | len(scope.Search.notConditions) > 0 | |
1380 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "database/sql" | |
4 | "database/sql/driver" | |
5 | "fmt" | |
6 | "reflect" | |
7 | "regexp" | |
8 | "strconv" | |
9 | "strings" | |
10 | ) | |
11 | ||
12 | func (scope *Scope) primaryCondition(value interface{}) string { | |
13 | return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value) | |
14 | } | |
15 | ||
16 | func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { | |
17 | switch value := clause["query"].(type) { | |
18 | case string: | |
19 | // if string is number | |
20 | if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { | |
21 | id, _ := strconv.Atoi(value) | |
22 | return scope.primaryCondition(scope.AddToVars(id)) | |
23 | } else if value != "" { | |
24 | str = fmt.Sprintf("(%v)", value) | |
25 | } | |
26 | case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: | |
27 | return scope.primaryCondition(scope.AddToVars(value)) | |
28 | case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: | |
29 | str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey())) | |
30 | clause["args"] = []interface{}{value} | |
31 | case map[string]interface{}: | |
32 | var sqls []string | |
33 | for key, value := range value { | |
34 | sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value))) | |
35 | } | |
36 | return strings.Join(sqls, " AND ") | |
37 | case interface{}: | |
38 | var sqls []string | |
39 | for _, field := range scope.New(value).Fields() { | |
40 | if !field.IsIgnored && !field.IsBlank { | |
41 | sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) | |
42 | } | |
43 | } | |
44 | return strings.Join(sqls, " AND ") | |
45 | } | |
46 | ||
47 | args := clause["args"].([]interface{}) | |
48 | for _, arg := range args { | |
49 | switch reflect.ValueOf(arg).Kind() { | |
50 | case reflect.Slice: // For where("id in (?)", []int64{1,2}) | |
51 | values := reflect.ValueOf(arg) | |
52 | var tempMarks []string | |
53 | for i := 0; i < values.Len(); i++ { | |
54 | tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) | |
55 | } | |
56 | str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) | |
57 | default: | |
58 | if valuer, ok := interface{}(arg).(driver.Valuer); ok { | |
59 | arg, _ = valuer.Value() | |
60 | } | |
61 | ||
62 | str = strings.Replace(str, "?", scope.AddToVars(arg), 1) | |
63 | } | |
64 | } | |
65 | return | |
66 | } | |
67 | ||
68 | func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { | |
69 | var notEqualSql string | |
70 | var primaryKey = scope.PrimaryKey() | |
71 | ||
72 | switch value := clause["query"].(type) { | |
73 | case string: | |
74 | // is number | |
75 | if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { | |
76 | id, _ := strconv.Atoi(value) | |
77 | return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) | |
78 | } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) { | |
79 | str = fmt.Sprintf(" NOT (%v) ", value) | |
80 | notEqualSql = fmt.Sprintf("NOT (%v)", value) | |
81 | } else { | |
82 | str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value)) | |
83 | notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) | |
84 | } | |
85 | case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: | |
86 | return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value) | |
87 | case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: | |
88 | if reflect.ValueOf(value).Len() > 0 { | |
89 | str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey)) | |
90 | clause["args"] = []interface{}{value} | |
91 | } | |
92 | return "" | |
93 | case map[string]interface{}: | |
94 | var sqls []string | |
95 | for key, value := range value { | |
96 | sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value))) | |
97 | } | |
98 | return strings.Join(sqls, " AND ") | |
99 | case interface{}: | |
100 | var sqls []string | |
101 | for _, field := range scope.New(value).Fields() { | |
102 | if !field.IsBlank { | |
103 | sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) | |
104 | } | |
105 | } | |
106 | return strings.Join(sqls, " AND ") | |
107 | } | |
108 | ||
109 | args := clause["args"].([]interface{}) | |
110 | for _, arg := range args { | |
111 | switch reflect.ValueOf(arg).Kind() { | |
112 | case reflect.Slice: // For where("id in (?)", []int64{1,2}) | |
113 | values := reflect.ValueOf(arg) | |
114 | var tempMarks []string | |
115 | for i := 0; i < values.Len(); i++ { | |
116 | tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) | |
117 | } | |
118 | str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) | |
119 | default: | |
120 | if scanner, ok := interface{}(arg).(driver.Valuer); ok { | |
121 | arg, _ = scanner.Value() | |
122 | } | |
123 | str = strings.Replace(notEqualSql, "?", scope.AddToVars(arg), 1) | |
124 | } | |
125 | } | |
126 | return | |
127 | } | |
128 | ||
129 | func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { | |
130 | switch value := clause["query"].(type) { | |
131 | case string: | |
132 | str = value | |
133 | case []string: | |
134 | str = strings.Join(value, ", ") | |
135 | } | |
136 | ||
137 | args := clause["args"].([]interface{}) | |
138 | for _, arg := range args { | |
139 | switch reflect.ValueOf(arg).Kind() { | |
140 | case reflect.Slice: | |
141 | values := reflect.ValueOf(arg) | |
142 | var tempMarks []string | |
143 | for i := 0; i < values.Len(); i++ { | |
144 | tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) | |
145 | } | |
146 | str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) | |
147 | default: | |
148 | if valuer, ok := interface{}(arg).(driver.Valuer); ok { | |
149 | arg, _ = valuer.Value() | |
150 | } | |
151 | str = strings.Replace(str, "?", scope.AddToVars(arg), 1) | |
152 | } | |
153 | } | |
154 | return | |
155 | } | |
156 | ||
157 | func (scope *Scope) whereSql() (sql string) { | |
158 | var primaryConditions, andConditions, orConditions []string | |
159 | ||
160 | if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil { | |
161 | sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName()) | |
162 | primaryConditions = append(primaryConditions, sql) | |
163 | } | |
164 | ||
165 | if !scope.PrimaryKeyZero() { | |
166 | for _, field := range scope.PrimaryFields() { | |
167 | sql := fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())) | |
168 | primaryConditions = append(primaryConditions, sql) | |
169 | } | |
170 | } | |
171 | ||
172 | for _, clause := range scope.Search.whereConditions { | |
173 | if sql := scope.buildWhereCondition(clause); sql != "" { | |
174 | andConditions = append(andConditions, sql) | |
175 | } | |
176 | } | |
177 | ||
178 | for _, clause := range scope.Search.orConditions { | |
179 | if sql := scope.buildWhereCondition(clause); sql != "" { | |
180 | orConditions = append(orConditions, sql) | |
181 | } | |
182 | } | |
183 | ||
184 | for _, clause := range scope.Search.notConditions { | |
185 | if sql := scope.buildNotCondition(clause); sql != "" { | |
186 | andConditions = append(andConditions, sql) | |
187 | } | |
188 | } | |
189 | ||
190 | orSql := strings.Join(orConditions, " OR ") | |
191 | combinedSql := strings.Join(andConditions, " AND ") | |
192 | if len(combinedSql) > 0 { | |
193 | if len(orSql) > 0 { | |
194 | combinedSql = combinedSql + " OR " + orSql | |
195 | } | |
196 | } else { | |
197 | combinedSql = orSql | |
198 | } | |
199 | ||
200 | if len(primaryConditions) > 0 { | |
201 | sql = "WHERE " + strings.Join(primaryConditions, " AND ") | |
202 | if len(combinedSql) > 0 { | |
203 | sql = sql + " AND (" + combinedSql + ")" | |
204 | } | |
205 | } else if len(combinedSql) > 0 { | |
206 | sql = "WHERE " + combinedSql | |
207 | } | |
208 | return | |
209 | } | |
210 | ||
211 | var hasCountRegexp = regexp.MustCompile(`(?i)count(.+)`) | |
212 | ||
213 | func (scope *Scope) selectSql() string { | |
214 | if len(scope.Search.selects) == 0 { | |
215 | if scope.Search.joins != "" { | |
216 | return fmt.Sprintf("%v.*", scope.QuotedTableName()) | |
217 | } | |
218 | return "*" | |
219 | } | |
220 | sql := scope.buildSelectQuery(scope.Search.selects) | |
221 | scope.Search.countingQuery = (len(scope.Search.group) == 0) && hasCountRegexp.MatchString(sql) | |
222 | return sql | |
223 | } | |
224 | ||
225 | func (scope *Scope) orderSql() string { | |
226 | if len(scope.Search.orders) == 0 || scope.Search.countingQuery { | |
227 | return "" | |
228 | } | |
229 | return " ORDER BY " + strings.Join(scope.Search.orders, ",") | |
230 | } | |
231 | ||
232 | func (scope *Scope) limitSql() string { | |
233 | if !scope.Dialect().HasTop() { | |
234 | if len(scope.Search.limit) == 0 { | |
235 | return "" | |
236 | } | |
237 | return " LIMIT " + scope.Search.limit | |
238 | } | |
239 | ||
240 | return "" | |
241 | } | |
242 | ||
243 | func (scope *Scope) topSql() string { | |
244 | if scope.Dialect().HasTop() && len(scope.Search.offset) == 0 { | |
245 | if len(scope.Search.limit) == 0 { | |
246 | return "" | |
247 | } | |
248 | return " TOP(" + scope.Search.limit + ")" | |
249 | } | |
250 | ||
251 | return "" | |
252 | } | |
253 | ||
254 | func (scope *Scope) offsetSql() string { | |
255 | if len(scope.Search.offset) == 0 { | |
256 | return "" | |
257 | } | |
258 | ||
259 | if scope.Dialect().HasTop() { | |
260 | sql := " OFFSET " + scope.Search.offset + " ROW " | |
261 | if len(scope.Search.limit) > 0 { | |
262 | sql += "FETCH NEXT " + scope.Search.limit + " ROWS ONLY" | |
263 | } | |
264 | return sql | |
265 | } | |
266 | return " OFFSET " + scope.Search.offset | |
267 | } | |
268 | ||
269 | func (scope *Scope) groupSql() string { | |
270 | if len(scope.Search.group) == 0 { | |
271 | return "" | |
272 | } | |
273 | return " GROUP BY " + scope.Search.group | |
274 | } | |
275 | ||
276 | func (scope *Scope) havingSql() string { | |
277 | if scope.Search.havingConditions == nil { | |
278 | return "" | |
279 | } | |
280 | ||
281 | var andConditions []string | |
282 | ||
283 | for _, clause := range scope.Search.havingConditions { | |
284 | if sql := scope.buildWhereCondition(clause); sql != "" { | |
285 | andConditions = append(andConditions, sql) | |
286 | } | |
287 | } | |
288 | ||
289 | combinedSql := strings.Join(andConditions, " AND ") | |
290 | if len(combinedSql) == 0 { | |
291 | return "" | |
292 | } | |
293 | ||
294 | return " HAVING " + combinedSql | |
295 | } | |
296 | ||
297 | func (scope *Scope) joinsSql() string { | |
298 | return scope.Search.joins + " " | |
299 | } | |
300 | ||
301 | func (scope *Scope) prepareQuerySql() { | |
302 | if scope.Search.raw { | |
303 | scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")")) | |
304 | } else { | |
305 | scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql())) | |
306 | } | |
307 | return | |
308 | } | |
309 | ||
310 | func (scope *Scope) inlineCondition(values ...interface{}) *Scope { | |
311 | if len(values) > 0 { | |
312 | scope.Search.Where(values[0], values[1:]...) | |
313 | } | |
314 | return scope | |
315 | } | |
316 | ||
317 | func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { | |
318 | for _, f := range funcs { | |
319 | (*f)(scope) | |
320 | if scope.skipLeft { | |
321 | break | |
322 | } | |
323 | } | |
324 | return scope | |
325 | } | |
326 | ||
327 | func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) { | |
328 | if !scope.IndirectValue().CanAddr() { | |
329 | return values, true | |
330 | } | |
331 | ||
332 | var hasExpr bool | |
333 | fields := scope.Fields() | |
334 | for key, value := range values { | |
335 | if field, ok := fields[ToDBName(key)]; ok && field.Field.IsValid() { | |
336 | if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) { | |
337 | if _, ok := value.(*expr); ok { | |
338 | hasExpr = true | |
339 | } else if !equalAsString(field.Field.Interface(), value) { | |
340 | hasUpdate = true | |
341 | field.Set(value) | |
342 | } | |
343 | } | |
344 | } | |
345 | } | |
346 | if hasExpr { | |
347 | var updateMap = map[string]interface{}{} | |
348 | for key, value := range fields { | |
349 | if v, ok := values[key]; ok { | |
350 | updateMap[key] = v | |
351 | } else { | |
352 | updateMap[key] = value.Field.Interface() | |
353 | } | |
354 | } | |
355 | return updateMap, true | |
356 | } | |
357 | return | |
358 | } | |
359 | ||
360 | func (scope *Scope) row() *sql.Row { | |
361 | defer scope.Trace(NowFunc()) | |
362 | scope.callCallbacks(scope.db.parent.callback.rowQueries) | |
363 | scope.prepareQuerySql() | |
364 | return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...) | |
365 | } | |
366 | ||
367 | func (scope *Scope) rows() (*sql.Rows, error) { | |
368 | defer scope.Trace(NowFunc()) | |
369 | scope.callCallbacks(scope.db.parent.callback.rowQueries) | |
370 | scope.prepareQuerySql() | |
371 | return scope.SqlDB().Query(scope.Sql, scope.SqlVars...) | |
372 | } | |
373 | ||
374 | func (scope *Scope) initialize() *Scope { | |
375 | for _, clause := range scope.Search.whereConditions { | |
376 | scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false) | |
377 | } | |
378 | scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false) | |
379 | scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false) | |
380 | return scope | |
381 | } | |
382 | ||
383 | func (scope *Scope) pluck(column string, value interface{}) *Scope { | |
384 | dest := reflect.Indirect(reflect.ValueOf(value)) | |
385 | scope.Search.Select(column) | |
386 | if dest.Kind() != reflect.Slice { | |
387 | scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) | |
388 | return scope | |
389 | } | |
390 | ||
391 | rows, err := scope.rows() | |
392 | if scope.Err(err) == nil { | |
393 | defer rows.Close() | |
394 | for rows.Next() { | |
395 | elem := reflect.New(dest.Type().Elem()).Interface() | |
396 | scope.Err(rows.Scan(elem)) | |
397 | dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) | |
398 | } | |
399 | } | |
400 | return scope | |
401 | } | |
402 | ||
403 | func (scope *Scope) count(value interface{}) *Scope { | |
404 | scope.Search.Select("count(*)") | |
405 | scope.Err(scope.row().Scan(value)) | |
406 | return scope | |
407 | } | |
408 | ||
409 | func (scope *Scope) typeName() string { | |
410 | value := scope.IndirectValue() | |
411 | if value.Kind() == reflect.Slice { | |
412 | return value.Type().Elem().Name() | |
413 | } | |
414 | ||
415 | return value.Type().Name() | |
416 | } | |
417 | ||
418 | func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { | |
419 | toScope := scope.db.NewScope(value) | |
420 | fromFields := scope.Fields() | |
421 | toFields := toScope.Fields() | |
422 | for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { | |
423 | var fromField, toField *Field | |
424 | if field, ok := scope.FieldByName(foreignKey); ok { | |
425 | fromField = field | |
426 | } else { | |
427 | fromField = fromFields[ToDBName(foreignKey)] | |
428 | } | |
429 | if field, ok := toScope.FieldByName(foreignKey); ok { | |
430 | toField = field | |
431 | } else { | |
432 | toField = toFields[ToDBName(foreignKey)] | |
433 | } | |
434 | ||
435 | if fromField != nil { | |
436 | if relationship := fromField.Relationship; relationship != nil { | |
437 | if relationship.Kind == "many_to_many" { | |
438 | joinTableHandler := relationship.JoinTableHandler | |
439 | scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) | |
440 | } else if relationship.Kind == "belongs_to" { | |
441 | query := toScope.db | |
442 | for idx, foreignKey := range relationship.ForeignDBNames { | |
443 | if field, ok := scope.FieldByName(foreignKey); ok { | |
444 | query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) | |
445 | } | |
446 | } | |
447 | scope.Err(query.Find(value).Error) | |
448 | } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { | |
449 | query := toScope.db | |
450 | for idx, foreignKey := range relationship.ForeignDBNames { | |
451 | if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { | |
452 | query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) | |
453 | } | |
454 | } | |
455 | ||
456 | if relationship.PolymorphicType != "" { | |
457 | query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) | |
458 | } | |
459 | scope.Err(query.Find(value).Error) | |
460 | } | |
461 | } else { | |
462 | sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) | |
463 | scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error) | |
464 | } | |
465 | return scope | |
466 | } else if toField != nil { | |
467 | sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) | |
468 | scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) | |
469 | return scope | |
470 | } | |
471 | } | |
472 | ||
473 | scope.Err(fmt.Errorf("invalid association %v", foreignKeys)) | |
474 | return scope | |
475 | } | |
476 | ||
477 | /** | |
478 | Return the table options string or an empty string if the table options does not exist | |
479 | */ | |
480 | func (scope *Scope) getTableOptions() string { | |
481 | tableOptions, ok := scope.Get("gorm:table_options") | |
482 | if !ok { | |
483 | return "" | |
484 | } | |
485 | return tableOptions.(string) | |
486 | } | |
487 | ||
488 | func (scope *Scope) createJoinTable(field *StructField) { | |
489 | if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { | |
490 | joinTableHandler := relationship.JoinTableHandler | |
491 | joinTable := joinTableHandler.Table(scope.db) | |
492 | if !scope.Dialect().HasTable(scope, joinTable) { | |
493 | toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} | |
494 | ||
495 | var sqlTypes []string | |
496 | for idx, fieldName := range relationship.ForeignFieldNames { | |
497 | if field, ok := scope.Fields()[fieldName]; ok { | |
498 | value := reflect.Indirect(reflect.New(field.Struct.Type)) | |
499 | primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false) | |
500 | sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType) | |
501 | } | |
502 | } | |
503 | ||
504 | for idx, fieldName := range relationship.AssociationForeignFieldNames { | |
505 | if field, ok := toScope.Fields()[fieldName]; ok { | |
506 | value := reflect.Indirect(reflect.New(field.Struct.Type)) | |
507 | primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false) | |
508 | sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType) | |
509 | } | |
510 | } | |
511 | scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), scope.getTableOptions())).Error) | |
512 | } | |
513 | scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) | |
514 | } | |
515 | } | |
516 | ||
517 | func (scope *Scope) createTable() *Scope { | |
518 | var tags []string | |
519 | var primaryKeys []string | |
520 | var primaryKeyInColumnType bool = false | |
521 | for _, field := range scope.GetStructFields() { | |
522 | if field.IsNormal { | |
523 | sqlTag := scope.generateSqlTag(field) | |
524 | ||
525 | // Check if the primary key constraint was specified as | |
526 | // part of the column type. If so, we can only support | |
527 | // one column as the primary key. | |
528 | if strings.Contains(strings.ToLower(sqlTag), "primary key") { | |
529 | primaryKeyInColumnType = true | |
530 | } | |
531 | ||
532 | tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag) | |
533 | } | |
534 | ||
535 | if field.IsPrimaryKey { | |
536 | primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) | |
537 | } | |
538 | scope.createJoinTable(field) | |
539 | } | |
540 | ||
541 | var primaryKeyStr string | |
542 | if len(primaryKeys) > 0 && !primaryKeyInColumnType { | |
543 | primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) | |
544 | } | |
545 | scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() | |
546 | return scope | |
547 | } | |
548 | ||
549 | func (scope *Scope) dropTable() *Scope { | |
550 | scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() | |
551 | return scope | |
552 | } | |
553 | ||
554 | func (scope *Scope) dropTableIfExists() *Scope { | |
555 | if scope.Dialect().HasTable(scope, scope.TableName()) { | |
556 | scope.dropTable() | |
557 | } | |
558 | return scope | |
559 | } | |
560 | ||
561 | func (scope *Scope) modifyColumn(column string, typ string) { | |
562 | scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() | |
563 | } | |
564 | ||
565 | func (scope *Scope) dropColumn(column string) { | |
566 | scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() | |
567 | } | |
568 | ||
569 | func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { | |
570 | if scope.Dialect().HasIndex(scope, scope.TableName(), indexName) { | |
571 | return | |
572 | } | |
573 | ||
574 | var columns []string | |
575 | for _, name := range column { | |
576 | columns = append(columns, scope.QuoteIfPossible(name)) | |
577 | } | |
578 | ||
579 | sqlCreate := "CREATE INDEX" | |
580 | if unique { | |
581 | sqlCreate = "CREATE UNIQUE INDEX" | |
582 | } | |
583 | ||
584 | scope.Search.Unscoped = true | |
585 | scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSql())).Exec() | |
586 | scope.Search.Unscoped = false | |
587 | } | |
588 | ||
589 | func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { | |
590 | var table = scope.TableName() | |
591 | var keyName = fmt.Sprintf("%s_%s_%s_foreign", table, field, regexp.MustCompile("[^a-zA-Z]").ReplaceAllString(dest, "_")) | |
592 | keyName = regexp.MustCompile("_+").ReplaceAllString(keyName, "_") | |
593 | var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` | |
594 | scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.QuoteIfPossible(keyName), scope.QuoteIfPossible(field), dest, onDelete, onUpdate)).Exec() | |
595 | } | |
596 | ||
597 | func (scope *Scope) removeIndex(indexName string) { | |
598 | scope.Dialect().RemoveIndex(scope, indexName) | |
599 | } | |
600 | ||
601 | func (scope *Scope) autoMigrate() *Scope { | |
602 | tableName := scope.TableName() | |
603 | quotedTableName := scope.QuotedTableName() | |
604 | ||
605 | if !scope.Dialect().HasTable(scope, tableName) { | |
606 | scope.createTable() | |
607 | } else { | |
608 | for _, field := range scope.GetStructFields() { | |
609 | if !scope.Dialect().HasColumn(scope, tableName, field.DBName) { | |
610 | if field.IsNormal { | |
611 | sqlTag := scope.generateSqlTag(field) | |
612 | scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() | |
613 | } | |
614 | } | |
615 | scope.createJoinTable(field) | |
616 | } | |
617 | } | |
618 | ||
619 | scope.autoIndex() | |
620 | return scope | |
621 | } | |
622 | ||
623 | func (scope *Scope) autoIndex() *Scope { | |
624 | var indexes = map[string][]string{} | |
625 | var uniqueIndexes = map[string][]string{} | |
626 | ||
627 | for _, field := range scope.GetStructFields() { | |
628 | sqlSettings := parseTagSetting(field.Tag.Get("sql")) | |
629 | if name, ok := sqlSettings["INDEX"]; ok { | |
630 | if name == "INDEX" { | |
631 | name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) | |
632 | } | |
633 | indexes[name] = append(indexes[name], field.DBName) | |
634 | } | |
635 | ||
636 | if name, ok := sqlSettings["UNIQUE_INDEX"]; ok { | |
637 | if name == "UNIQUE_INDEX" { | |
638 | name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) | |
639 | } | |
640 | uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) | |
641 | } | |
642 | } | |
643 | ||
644 | for name, columns := range indexes { | |
645 | scope.addIndex(false, name, columns...) | |
646 | } | |
647 | ||
648 | for name, columns := range uniqueIndexes { | |
649 | scope.addIndex(true, name, columns...) | |
650 | } | |
651 | ||
652 | return scope | |
653 | } |
0 | 0 | package gorm_test |
1 | 1 | |
2 | 2 | import ( |
3 | "encoding/hex" | |
4 | "math/rand" | |
5 | "strings" | |
6 | "testing" | |
7 | ||
3 | 8 | "github.com/jinzhu/gorm" |
4 | "testing" | |
5 | 9 | ) |
6 | 10 | |
7 | 11 | func NameIn1And2(d *gorm.DB) *gorm.DB { |
40 | 44 | t.Errorf("Should found two users's name in 1, 3") |
41 | 45 | } |
42 | 46 | } |
47 | ||
48 | func randName() string { | |
49 | data := make([]byte, 8) | |
50 | rand.Read(data) | |
51 | ||
52 | return "n-" + hex.EncodeToString(data) | |
53 | } | |
54 | ||
55 | func TestValuer(t *testing.T) { | |
56 | name := randName() | |
57 | ||
58 | origUser := User{Name: name, Age: 1, Password: EncryptedData("pass1"), PasswordHash: []byte("abc")} | |
59 | if err := DB.Save(&origUser).Error; err != nil { | |
60 | t.Errorf("No error should happen when saving user, but got %v", err) | |
61 | } | |
62 | ||
63 | var user2 User | |
64 | if err := DB.Where("name = ? AND password = ? AND password_hash = ?", name, EncryptedData("pass1"), []byte("abc")).First(&user2).Error; err != nil { | |
65 | t.Errorf("No error should happen when querying user with valuer, but got %v", err) | |
66 | } | |
67 | } | |
68 | ||
69 | func TestFailedValuer(t *testing.T) { | |
70 | name := randName() | |
71 | ||
72 | err := DB.Exec("INSERT INTO users(name, password) VALUES(?, ?)", name, EncryptedData("xpass1")).Error | |
73 | ||
74 | if err == nil { | |
75 | t.Errorf("There should be an error should happen when insert data") | |
76 | } else if !strings.HasPrefix(err.Error(), "Should not start with") { | |
77 | t.Errorf("The error should be returned from Valuer, but get %v", err) | |
78 | } | |
79 | } |
0 | 0 | package gorm |
1 | 1 | |
2 | import "fmt" | |
2 | import ( | |
3 | "fmt" | |
4 | ) | |
3 | 5 | |
4 | 6 | type search struct { |
5 | 7 | db *DB |
7 | 9 | orConditions []map[string]interface{} |
8 | 10 | notConditions []map[string]interface{} |
9 | 11 | havingConditions []map[string]interface{} |
12 | joinConditions []map[string]interface{} | |
10 | 13 | initAttrs []interface{} |
11 | 14 | assignAttrs []interface{} |
12 | 15 | selects map[string]interface{} |
13 | 16 | omits []string |
14 | orders []string | |
15 | joins string | |
17 | orders []interface{} | |
16 | 18 | preload []searchPreload |
17 | offset string | |
18 | limit string | |
19 | offset interface{} | |
20 | limit interface{} | |
19 | 21 | group string |
20 | 22 | tableName string |
21 | 23 | raw bool |
22 | 24 | Unscoped bool |
23 | countingQuery bool | |
25 | ignoreOrderQuery bool | |
24 | 26 | } |
25 | 27 | |
26 | 28 | type searchPreload struct { |
58 | 60 | return s |
59 | 61 | } |
60 | 62 | |
61 | func (s *search) Order(value string, reorder ...bool) *search { | |
63 | func (s *search) Order(value interface{}, reorder ...bool) *search { | |
62 | 64 | if len(reorder) > 0 && reorder[0] { |
63 | if value != "" { | |
64 | s.orders = []string{value} | |
65 | } else { | |
66 | s.orders = []string{} | |
67 | } | |
68 | } else if value != "" { | |
65 | s.orders = []interface{}{} | |
66 | } | |
67 | ||
68 | if value != nil && value != "" { | |
69 | 69 | s.orders = append(s.orders, value) |
70 | 70 | } |
71 | 71 | return s |
81 | 81 | return s |
82 | 82 | } |
83 | 83 | |
84 | func (s *search) Limit(value interface{}) *search { | |
85 | s.limit = s.getInterfaceAsSql(value) | |
84 | func (s *search) Limit(limit interface{}) *search { | |
85 | s.limit = limit | |
86 | 86 | return s |
87 | 87 | } |
88 | 88 | |
89 | func (s *search) Offset(value interface{}) *search { | |
90 | s.offset = s.getInterfaceAsSql(value) | |
89 | func (s *search) Offset(offset interface{}) *search { | |
90 | s.offset = offset | |
91 | 91 | return s |
92 | 92 | } |
93 | 93 | |
94 | 94 | func (s *search) Group(query string) *search { |
95 | s.group = s.getInterfaceAsSql(query) | |
95 | s.group = s.getInterfaceAsSQL(query) | |
96 | 96 | return s |
97 | 97 | } |
98 | 98 | |
99 | func (s *search) Having(query string, values ...interface{}) *search { | |
100 | s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) | |
99 | func (s *search) Having(query interface{}, values ...interface{}) *search { | |
100 | if val, ok := query.(*expr); ok { | |
101 | s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args}) | |
102 | } else { | |
103 | s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) | |
104 | } | |
101 | 105 | return s |
102 | 106 | } |
103 | 107 | |
104 | func (s *search) Joins(query string) *search { | |
105 | s.joins = query | |
108 | func (s *search) Joins(query string, values ...interface{}) *search { | |
109 | s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values}) | |
106 | 110 | return s |
107 | 111 | } |
108 | 112 | |
133 | 137 | return s |
134 | 138 | } |
135 | 139 | |
136 | func (s *search) getInterfaceAsSql(value interface{}) (str string) { | |
140 | func (s *search) getInterfaceAsSQL(value interface{}) (str string) { | |
137 | 141 | switch value.(type) { |
138 | 142 | case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: |
139 | 143 | str = fmt.Sprintf("%v", value) |
140 | 144 | default: |
141 | s.db.AddError(InvalidSql) | |
145 | s.db.AddError(ErrInvalidSQL) | |
142 | 146 | } |
143 | 147 | |
144 | 148 | if str == "-1" { |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "database/sql/driver" | |
4 | "encoding/json" | |
5 | "testing" | |
6 | ) | |
7 | ||
8 | func TestScannableSlices(t *testing.T) { | |
9 | if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil { | |
10 | t.Errorf("Should create table with slice values correctly: %s", err) | |
11 | } | |
12 | ||
13 | r1 := RecordWithSlice{ | |
14 | Strings: ExampleStringSlice{"a", "b", "c"}, | |
15 | Structs: ExampleStructSlice{ | |
16 | {"name1", "value1"}, | |
17 | {"name2", "value2"}, | |
18 | }, | |
19 | } | |
20 | ||
21 | if err := DB.Save(&r1).Error; err != nil { | |
22 | t.Errorf("Should save record with slice values") | |
23 | } | |
24 | ||
25 | var r2 RecordWithSlice | |
26 | ||
27 | if err := DB.Find(&r2).Error; err != nil { | |
28 | t.Errorf("Should fetch record with slice values") | |
29 | } | |
30 | ||
31 | if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" { | |
32 | t.Errorf("Should have serialised and deserialised a string array") | |
33 | } | |
34 | ||
35 | if len(r2.Structs) != 2 || r2.Structs[0].Name != "name1" || r2.Structs[0].Value != "value1" || r2.Structs[1].Name != "name2" || r2.Structs[1].Value != "value2" { | |
36 | t.Errorf("Should have serialised and deserialised a struct array") | |
37 | } | |
38 | } | |
39 | ||
40 | type RecordWithSlice struct { | |
41 | ID uint64 | |
42 | Strings ExampleStringSlice `sql:"type:text"` | |
43 | Structs ExampleStructSlice `sql:"type:text"` | |
44 | } | |
45 | ||
46 | type ExampleStringSlice []string | |
47 | ||
48 | func (l ExampleStringSlice) Value() (driver.Value, error) { | |
49 | return json.Marshal(l) | |
50 | } | |
51 | ||
52 | func (l *ExampleStringSlice) Scan(input interface{}) error { | |
53 | return json.Unmarshal(input.([]byte), l) | |
54 | } | |
55 | ||
56 | type ExampleStruct struct { | |
57 | Name string | |
58 | Value string | |
59 | } | |
60 | ||
61 | type ExampleStructSlice []ExampleStruct | |
62 | ||
63 | func (l ExampleStructSlice) Value() (driver.Value, error) { | |
64 | return json.Marshal(l) | |
65 | } | |
66 | ||
67 | func (l *ExampleStructSlice) Scan(input interface{}) error { | |
68 | return json.Unmarshal(input.([]byte), l) | |
69 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "reflect" | |
5 | "time" | |
6 | ) | |
7 | ||
8 | type sqlite3 struct { | |
9 | commonDialect | |
10 | } | |
11 | ||
12 | func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string { | |
13 | switch value.Kind() { | |
14 | case reflect.Bool: | |
15 | return "bool" | |
16 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | |
17 | if autoIncrease { | |
18 | return "integer primary key autoincrement" | |
19 | } | |
20 | return "integer" | |
21 | case reflect.Int64, reflect.Uint64: | |
22 | if autoIncrease { | |
23 | return "integer primary key autoincrement" | |
24 | } | |
25 | return "bigint" | |
26 | case reflect.Float32, reflect.Float64: | |
27 | return "real" | |
28 | case reflect.String: | |
29 | if size > 0 && size < 65532 { | |
30 | return fmt.Sprintf("varchar(%d)", size) | |
31 | } | |
32 | return "text" | |
33 | case reflect.Struct: | |
34 | if _, ok := value.Interface().(time.Time); ok { | |
35 | return "datetime" | |
36 | } | |
37 | default: | |
38 | if _, ok := value.Interface().([]byte); ok { | |
39 | return "blob" | |
40 | } | |
41 | } | |
42 | panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String())) | |
43 | } | |
44 | ||
45 | func (s sqlite3) HasTable(scope *Scope, tableName string) bool { | |
46 | var count int | |
47 | s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName) | |
48 | return count > 0 | |
49 | } | |
50 | ||
51 | func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool { | |
52 | var count int | |
53 | s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", columnName, columnName, columnName, columnName), tableName) | |
54 | return count > 0 | |
55 | } | |
56 | ||
57 | func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { | |
58 | var count int | |
59 | s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName) | |
60 | return count > 0 | |
61 | } | |
62 | ||
63 | func (sqlite3) RemoveIndex(scope *Scope, indexName string) { | |
64 | scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) | |
65 | } | |
66 | ||
67 | func (sqlite3) CurrentDatabase(scope *Scope) (name string) { | |
68 | var ( | |
69 | ifaces = make([]interface{}, 3) | |
70 | pointers = make([]*string, 3) | |
71 | i int | |
72 | ) | |
73 | for i = 0; i < 3; i++ { | |
74 | ifaces[i] = &pointers[i] | |
75 | } | |
76 | if err := scope.NewDB().Raw("PRAGMA database_list").Row().Scan(ifaces...); scope.Err(err) != nil { | |
77 | return | |
78 | } | |
79 | if pointers[1] != nil { | |
80 | name = *pointers[1] | |
81 | } | |
82 | return | |
83 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "database/sql" | |
4 | "database/sql/driver" | |
5 | "errors" | |
6 | "fmt" | |
7 | ||
8 | "reflect" | |
9 | "time" | |
10 | ) | |
11 | ||
12 | type User struct { | |
13 | Id int64 | |
14 | Age int64 | |
15 | UserNum Num | |
16 | Name string `sql:"size:255"` | |
17 | Birthday time.Time // Time | |
18 | CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically | |
19 | UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically | |
20 | Emails []Email // Embedded structs | |
21 | BillingAddress Address // Embedded struct | |
22 | BillingAddressID sql.NullInt64 // Embedded struct's foreign key | |
23 | ShippingAddress Address // Embedded struct | |
24 | ShippingAddressId int64 // Embedded struct's foreign key | |
25 | CreditCard CreditCard | |
26 | Latitude float64 | |
27 | Languages []Language `gorm:"many2many:user_languages;"` | |
28 | CompanyID int64 | |
29 | Company Company | |
30 | Role | |
31 | PasswordHash []byte | |
32 | IgnoreMe int64 `sql:"-"` | |
33 | IgnoreStringSlice []string `sql:"-"` | |
34 | Ignored struct{ Name string } `sql:"-"` | |
35 | IgnoredPointer *User `sql:"-"` | |
36 | } | |
37 | ||
38 | type CreditCard struct { | |
39 | ID int8 | |
40 | Number string | |
41 | UserId sql.NullInt64 | |
42 | CreatedAt time.Time | |
43 | UpdatedAt time.Time | |
44 | DeletedAt time.Time | |
45 | } | |
46 | ||
47 | type Email struct { | |
48 | Id int16 | |
49 | UserId int | |
50 | Email string `sql:"type:varchar(100);"` | |
51 | CreatedAt time.Time | |
52 | UpdatedAt time.Time | |
53 | } | |
54 | ||
55 | type Address struct { | |
56 | ID int | |
57 | Address1 string | |
58 | Address2 string | |
59 | Post string | |
60 | CreatedAt time.Time | |
61 | UpdatedAt time.Time | |
62 | DeletedAt time.Time | |
63 | } | |
64 | ||
65 | type Language struct { | |
66 | Id int | |
67 | Name string | |
68 | Users []User `gorm:"many2many:user_languages;"` | |
69 | } | |
70 | ||
71 | type Product struct { | |
72 | Id int64 | |
73 | Code string | |
74 | Price int64 | |
75 | CreatedAt time.Time | |
76 | UpdatedAt time.Time | |
77 | AfterFindCallTimes int64 | |
78 | BeforeCreateCallTimes int64 | |
79 | AfterCreateCallTimes int64 | |
80 | BeforeUpdateCallTimes int64 | |
81 | AfterUpdateCallTimes int64 | |
82 | BeforeSaveCallTimes int64 | |
83 | AfterSaveCallTimes int64 | |
84 | BeforeDeleteCallTimes int64 | |
85 | AfterDeleteCallTimes int64 | |
86 | } | |
87 | ||
88 | type Company struct { | |
89 | Id int64 | |
90 | Name string | |
91 | Owner *User `sql:"-"` | |
92 | } | |
93 | ||
94 | type Role struct { | |
95 | Name string | |
96 | } | |
97 | ||
98 | func (role *Role) Scan(value interface{}) error { | |
99 | if b, ok := value.([]uint8); ok { | |
100 | role.Name = string(b) | |
101 | } else { | |
102 | role.Name = value.(string) | |
103 | } | |
104 | return nil | |
105 | } | |
106 | ||
107 | func (role Role) Value() (driver.Value, error) { | |
108 | return role.Name, nil | |
109 | } | |
110 | ||
111 | func (role Role) IsAdmin() bool { | |
112 | return role.Name == "admin" | |
113 | } | |
114 | ||
115 | type Num int64 | |
116 | ||
117 | func (i *Num) Scan(src interface{}) error { | |
118 | switch s := src.(type) { | |
119 | case []byte: | |
120 | case int64: | |
121 | *i = Num(s) | |
122 | default: | |
123 | return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) | |
124 | } | |
125 | return nil | |
126 | } | |
127 | ||
128 | type Animal struct { | |
129 | Counter uint64 `gorm:"primary_key:yes"` | |
130 | Name string `sql:"DEFAULT:'galeone'"` | |
131 | From string //test reserved sql keyword as field name | |
132 | Age time.Time `sql:"DEFAULT:current_timestamp"` | |
133 | unexported string // unexported value | |
134 | CreatedAt time.Time | |
135 | UpdatedAt time.Time | |
136 | } | |
137 | ||
138 | type JoinTable struct { | |
139 | From uint64 | |
140 | To uint64 | |
141 | Time time.Time `sql:"default: null"` | |
142 | } | |
143 | ||
144 | type Post struct { | |
145 | Id int64 | |
146 | CategoryId sql.NullInt64 | |
147 | MainCategoryId int64 | |
148 | Title string | |
149 | Body string | |
150 | Comments []*Comment | |
151 | Category Category | |
152 | MainCategory Category | |
153 | } | |
154 | ||
155 | type Category struct { | |
156 | Id int64 | |
157 | Name string | |
158 | } | |
159 | ||
160 | type Comment struct { | |
161 | Id int64 | |
162 | PostId int64 | |
163 | Content string | |
164 | Post Post | |
165 | } | |
166 | ||
167 | // Scanner | |
168 | type NullValue struct { | |
169 | Id int64 | |
170 | Name sql.NullString `sql:"not null"` | |
171 | Age sql.NullInt64 | |
172 | Male sql.NullBool | |
173 | Height sql.NullFloat64 | |
174 | AddedAt NullTime | |
175 | } | |
176 | ||
177 | type NullTime struct { | |
178 | Time time.Time | |
179 | Valid bool | |
180 | } | |
181 | ||
182 | func (nt *NullTime) Scan(value interface{}) error { | |
183 | if value == nil { | |
184 | nt.Valid = false | |
185 | return nil | |
186 | } | |
187 | nt.Time, nt.Valid = value.(time.Time), true | |
188 | return nil | |
189 | } | |
190 | ||
191 | func (nt NullTime) Value() (driver.Value, error) { | |
192 | if !nt.Valid { | |
193 | return nil, nil | |
194 | } | |
195 | return nt.Time, nil | |
196 | } | |
197 | ||
198 | func getPreparedUser(name string, role string) *User { | |
199 | var company Company | |
200 | DB.Where(Company{Name: role}).FirstOrCreate(&company) | |
201 | ||
202 | return &User{ | |
203 | Name: name, | |
204 | Age: 20, | |
205 | Role: Role{role}, | |
206 | BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)}, | |
207 | ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, | |
208 | CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)}, | |
209 | Emails: []Email{ | |
210 | {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, | |
211 | }, | |
212 | Company: company, | |
213 | Languages: []Language{ | |
214 | {Name: fmt.Sprintf("lang_1_%v", name)}, | |
215 | {Name: fmt.Sprintf("lang_2_%v", name)}, | |
216 | }, | |
217 | } | |
218 | } |
0 | dialects=("postgres" "mysql" "sqlite") | |
0 | dialects=("postgres" "mysql" "mssql" "sqlite") | |
1 | 1 | |
2 | 2 | for dialect in "${dialects[@]}" ; do |
3 | GORM_DIALECT=${dialect} go test | |
3 | DEBUG=false GORM_DIALECT=${dialect} go test | |
4 | 4 | done |
19 | 19 | DB.First(&product1, product1.Id) |
20 | 20 | DB.First(&product2, product2.Id) |
21 | 21 | updatedAt1 := product1.UpdatedAt |
22 | updatedAt2 := product2.UpdatedAt | |
23 | ||
24 | var product3 Product | |
25 | DB.First(&product3, product2.Id).Update("code", "product2newcode") | |
26 | if updatedAt2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) { | |
27 | t.Errorf("updatedAt should not be updated if nothing changed") | |
28 | } | |
29 | 22 | |
30 | 23 | if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() { |
31 | 24 | t.Errorf("Product1 should not be updated") |
70 | 63 | } |
71 | 64 | |
72 | 65 | DB.First(&product4, product4.Id) |
66 | updatedAt4 := product4.UpdatedAt | |
73 | 67 | DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50)) |
74 | 68 | var product5 Product |
75 | 69 | DB.First(&product5, product4.Id) |
76 | 70 | if product5.Price != product4.Price+100-50 { |
77 | 71 | t.Errorf("Update with expression") |
78 | 72 | } |
79 | if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) { | |
73 | if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { | |
80 | 74 | t.Errorf("Update with expression should update UpdatedAt") |
81 | 75 | } |
82 | 76 | } |
102 | 96 | DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched |
103 | 97 | DB.First(&animal, animal.Counter) |
104 | 98 | if animal.Name != "galeone" { |
105 | t.Errorf("Name fiels shouldn't be changed if untouched, but got %v", animal.Name) | |
99 | t.Errorf("Name fields shouldn't be changed if untouched, but got %v", animal.Name) | |
106 | 100 | } |
107 | 101 | |
108 | 102 | // When changing a field with a default value, the change must occur |
133 | 127 | |
134 | 128 | DB.First(&product1, product1.Id) |
135 | 129 | DB.First(&product2, product2.Id) |
136 | updatedAt1 := product1.UpdatedAt | |
137 | 130 | updatedAt2 := product2.UpdatedAt |
138 | ||
139 | var product3 Product | |
140 | DB.First(&product3, product1.Id).Updates(Product{Code: "product1newcode", Price: 100}) | |
141 | if product3.Code != "product1newcode" || product3.Price != 100 { | |
142 | t.Errorf("Record should be updated with struct") | |
143 | } | |
144 | ||
145 | if updatedAt1.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) { | |
146 | t.Errorf("updatedAt should not be updated if nothing changed") | |
147 | } | |
148 | 131 | |
149 | 132 | if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() { |
150 | 133 | t.Errorf("Product2 should not be updated") |
169 | 152 | t.Errorf("product2's code should be updated") |
170 | 153 | } |
171 | 154 | |
155 | updatedAt4 := product4.UpdatedAt | |
172 | 156 | DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)}) |
173 | 157 | var product5 Product |
174 | 158 | DB.First(&product5, product4.Id) |
175 | 159 | if product5.Price != product4.Price+100 { |
176 | 160 | t.Errorf("Updates with expression") |
177 | 161 | } |
178 | if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) { | |
162 | // product4's UpdatedAt will be reset when updating | |
163 | if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { | |
179 | 164 | t.Errorf("Updates with expression should update UpdatedAt") |
180 | 165 | } |
181 | 166 | } |
314 | 299 | queryUser.ShippingAddressId == user.ShippingAddressId || |
315 | 300 | queryUser.CreditCard.ID != user.CreditCard.ID || |
316 | 301 | len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { |
317 | t.Errorf("Should only update relationships that not omited") | |
302 | t.Errorf("Should only update relationships that not omitted") | |
318 | 303 | } |
319 | 304 | } |
320 | 305 | |
350 | 335 | queryUser.ShippingAddressId == user.ShippingAddressId || |
351 | 336 | queryUser.CreditCard.ID != user.CreditCard.ID || |
352 | 337 | len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { |
353 | t.Errorf("Should only update relationships not omited") | |
338 | t.Errorf("Should only update relationships not omitted") | |
354 | 339 | } |
355 | 340 | } |
356 | 341 | |
418 | 403 | t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1) |
419 | 404 | } |
420 | 405 | } |
406 | ||
407 | func TestUpdatesWithBlankValues(t *testing.T) { | |
408 | product := Product{Code: "product1", Price: 10} | |
409 | DB.Save(&product) | |
410 | ||
411 | DB.Model(&Product{Id: product.Id}).Updates(&Product{Price: 100}) | |
412 | ||
413 | var product1 Product | |
414 | DB.First(&product1, product.Id) | |
415 | ||
416 | if product1.Code != "product1" || product1.Price != 100 { | |
417 | t.Errorf("product's code should not be updated") | |
418 | } | |
419 | } | |
420 | ||
421 | type ElementWithIgnoredField struct { | |
422 | Id int64 | |
423 | Value string | |
424 | IgnoredField int64 `sql:"-"` | |
425 | } | |
426 | ||
427 | func (e ElementWithIgnoredField) TableName() string { | |
428 | return "element_with_ignored_field" | |
429 | } | |
430 | ||
431 | func TestUpdatesTableWithIgnoredValues(t *testing.T) { | |
432 | elem := ElementWithIgnoredField{Value: "foo", IgnoredField: 10} | |
433 | DB.Save(&elem) | |
434 | ||
435 | DB.Table(elem.TableName()). | |
436 | Where("id = ?", elem.Id). | |
437 | // DB.Model(&ElementWithIgnoredField{Id: elem.Id}). | |
438 | Updates(&ElementWithIgnoredField{Value: "bar", IgnoredField: 100}) | |
439 | ||
440 | var elem1 ElementWithIgnoredField | |
441 | err := DB.First(&elem1, elem.Id).Error | |
442 | if err != nil { | |
443 | t.Errorf("error getting an element from database: %s", err.Error()) | |
444 | } | |
445 | ||
446 | if elem1.IgnoredField != 0 { | |
447 | t.Errorf("element's ignored field should not be updated") | |
448 | } | |
449 | } | |
450 | ||
451 | func TestUpdateDecodeVirtualAttributes(t *testing.T) { | |
452 | var user = User{ | |
453 | Name: "jinzhu", | |
454 | IgnoreMe: 88, | |
455 | } | |
456 | ||
457 | DB.Save(&user) | |
458 | ||
459 | DB.Model(&user).Updates(User{Name: "jinzhu2", IgnoreMe: 100}) | |
460 | ||
461 | if user.IgnoreMe != 100 { | |
462 | t.Errorf("should decode virtual attributes to struct, so it could be used in callbacks") | |
463 | } | |
464 | } |
1 | 1 | |
2 | 2 | import ( |
3 | 3 | "bytes" |
4 | "database/sql/driver" | |
5 | "fmt" | |
6 | "reflect" | |
7 | "regexp" | |
8 | "runtime" | |
4 | 9 | "strings" |
5 | 10 | "sync" |
11 | "time" | |
6 | 12 | ) |
7 | 13 | |
14 | // NowFunc returns current time, this function is exported in order to be able | |
15 | // to give the flexibility to the developer to customize it according to their | |
16 | // needs, e.g: | |
17 | // gorm.NowFunc = func() time.Time { | |
18 | // return time.Now().UTC() | |
19 | // } | |
20 | var NowFunc = func() time.Time { | |
21 | return time.Now() | |
22 | } | |
23 | ||
8 | 24 | // Copied from golint |
9 | var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} | |
25 | var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} | |
10 | 26 | var commonInitialismsReplacer *strings.Replacer |
27 | ||
28 | var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) | |
29 | var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`) | |
11 | 30 | |
12 | 31 | func init() { |
13 | 32 | var commonInitialismsForReplacer []string |
40 | 59 | |
41 | 60 | var smap = newSafeMap() |
42 | 61 | |
62 | type strCase bool | |
63 | ||
64 | const ( | |
65 | lower strCase = false | |
66 | upper strCase = true | |
67 | ) | |
68 | ||
69 | // ToDBName convert string to db name | |
43 | 70 | func ToDBName(name string) string { |
44 | 71 | if v := smap.Get(name); v != "" { |
45 | 72 | return v |
46 | 73 | } |
47 | 74 | |
48 | value := commonInitialismsReplacer.Replace(name) | |
49 | buf := bytes.NewBufferString("") | |
50 | for i, v := range value { | |
51 | if i > 0 && v >= 'A' && v <= 'Z' { | |
52 | buf.WriteRune('_') | |
53 | } | |
54 | buf.WriteRune(v) | |
55 | } | |
75 | if name == "" { | |
76 | return "" | |
77 | } | |
78 | ||
79 | var ( | |
80 | value = commonInitialismsReplacer.Replace(name) | |
81 | buf = bytes.NewBufferString("") | |
82 | lastCase, currCase, nextCase strCase | |
83 | ) | |
84 | ||
85 | for i, v := range value[:len(value)-1] { | |
86 | nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') | |
87 | if i > 0 { | |
88 | if currCase == upper { | |
89 | if lastCase == upper && nextCase == upper { | |
90 | buf.WriteRune(v) | |
91 | } else { | |
92 | if value[i-1] != '_' && value[i+1] != '_' { | |
93 | buf.WriteRune('_') | |
94 | } | |
95 | buf.WriteRune(v) | |
96 | } | |
97 | } else { | |
98 | buf.WriteRune(v) | |
99 | if i == len(value)-2 && nextCase == upper { | |
100 | buf.WriteRune('_') | |
101 | } | |
102 | } | |
103 | } else { | |
104 | currCase = upper | |
105 | buf.WriteRune(v) | |
106 | } | |
107 | lastCase = currCase | |
108 | currCase = nextCase | |
109 | } | |
110 | ||
111 | buf.WriteByte(value[len(value)-1]) | |
56 | 112 | |
57 | 113 | s := strings.ToLower(buf.String()) |
58 | 114 | smap.Set(name, s) |
59 | 115 | return s |
60 | 116 | } |
61 | 117 | |
118 | // SQL expression | |
62 | 119 | type expr struct { |
63 | 120 | expr string |
64 | 121 | args []interface{} |
65 | 122 | } |
66 | 123 | |
124 | // Expr generate raw SQL expression, for example: | |
125 | // DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) | |
67 | 126 | func Expr(expression string, args ...interface{}) *expr { |
68 | 127 | return &expr{expr: expression, args: args} |
69 | 128 | } |
129 | ||
130 | func indirect(reflectValue reflect.Value) reflect.Value { | |
131 | for reflectValue.Kind() == reflect.Ptr { | |
132 | reflectValue = reflectValue.Elem() | |
133 | } | |
134 | return reflectValue | |
135 | } | |
136 | ||
137 | func toQueryMarks(primaryValues [][]interface{}) string { | |
138 | var results []string | |
139 | ||
140 | for _, primaryValue := range primaryValues { | |
141 | var marks []string | |
142 | for range primaryValue { | |
143 | marks = append(marks, "?") | |
144 | } | |
145 | ||
146 | if len(marks) > 1 { | |
147 | results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) | |
148 | } else { | |
149 | results = append(results, strings.Join(marks, "")) | |
150 | } | |
151 | } | |
152 | return strings.Join(results, ",") | |
153 | } | |
154 | ||
155 | func toQueryCondition(scope *Scope, columns []string) string { | |
156 | var newColumns []string | |
157 | for _, column := range columns { | |
158 | newColumns = append(newColumns, scope.Quote(column)) | |
159 | } | |
160 | ||
161 | if len(columns) > 1 { | |
162 | return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) | |
163 | } | |
164 | return strings.Join(newColumns, ",") | |
165 | } | |
166 | ||
167 | func toQueryValues(values [][]interface{}) (results []interface{}) { | |
168 | for _, value := range values { | |
169 | for _, v := range value { | |
170 | results = append(results, v) | |
171 | } | |
172 | } | |
173 | return | |
174 | } | |
175 | ||
176 | func fileWithLineNum() string { | |
177 | for i := 2; i < 15; i++ { | |
178 | _, file, line, ok := runtime.Caller(i) | |
179 | if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) { | |
180 | return fmt.Sprintf("%v:%v", file, line) | |
181 | } | |
182 | } | |
183 | return "" | |
184 | } | |
185 | ||
186 | func isBlank(value reflect.Value) bool { | |
187 | switch value.Kind() { | |
188 | case reflect.String: | |
189 | return value.Len() == 0 | |
190 | case reflect.Bool: | |
191 | return !value.Bool() | |
192 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |
193 | return value.Int() == 0 | |
194 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: | |
195 | return value.Uint() == 0 | |
196 | case reflect.Float32, reflect.Float64: | |
197 | return value.Float() == 0 | |
198 | case reflect.Interface, reflect.Ptr: | |
199 | return value.IsNil() | |
200 | } | |
201 | ||
202 | return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) | |
203 | } | |
204 | ||
205 | func toSearchableMap(attrs ...interface{}) (result interface{}) { | |
206 | if len(attrs) > 1 { | |
207 | if str, ok := attrs[0].(string); ok { | |
208 | result = map[string]interface{}{str: attrs[1]} | |
209 | } | |
210 | } else if len(attrs) == 1 { | |
211 | if attr, ok := attrs[0].(map[string]interface{}); ok { | |
212 | result = attr | |
213 | } | |
214 | ||
215 | if attr, ok := attrs[0].(interface{}); ok { | |
216 | result = attr | |
217 | } | |
218 | } | |
219 | return | |
220 | } | |
221 | ||
222 | func equalAsString(a interface{}, b interface{}) bool { | |
223 | return toString(a) == toString(b) | |
224 | } | |
225 | ||
226 | func toString(str interface{}) string { | |
227 | if values, ok := str.([]interface{}); ok { | |
228 | var results []string | |
229 | for _, value := range values { | |
230 | results = append(results, toString(value)) | |
231 | } | |
232 | return strings.Join(results, "_") | |
233 | } else if bytes, ok := str.([]byte); ok { | |
234 | return string(bytes) | |
235 | } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { | |
236 | return fmt.Sprintf("%v", reflectValue.Interface()) | |
237 | } | |
238 | return "" | |
239 | } | |
240 | ||
241 | func makeSlice(elemType reflect.Type) interface{} { | |
242 | if elemType.Kind() == reflect.Slice { | |
243 | elemType = elemType.Elem() | |
244 | } | |
245 | sliceType := reflect.SliceOf(elemType) | |
246 | slice := reflect.New(sliceType) | |
247 | slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) | |
248 | return slice.Interface() | |
249 | } | |
250 | ||
251 | func strInSlice(a string, list []string) bool { | |
252 | for _, b := range list { | |
253 | if b == a { | |
254 | return true | |
255 | } | |
256 | } | |
257 | return false | |
258 | } | |
259 | ||
260 | // getValueFromFields return given fields's value | |
261 | func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { | |
262 | // If value is a nil pointer, Indirect returns a zero Value! | |
263 | // Therefor we need to check for a zero value, | |
264 | // as FieldByName could panic | |
265 | if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { | |
266 | for _, fieldName := range fieldNames { | |
267 | if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { | |
268 | result := fieldValue.Interface() | |
269 | if r, ok := result.(driver.Valuer); ok { | |
270 | result, _ = r.Value() | |
271 | } | |
272 | results = append(results, result) | |
273 | } | |
274 | } | |
275 | } | |
276 | return | |
277 | } | |
278 | ||
279 | func addExtraSpaceIfExist(str string) string { | |
280 | if str != "" { | |
281 | return " " + str | |
282 | } | |
283 | return "" | |
284 | } |
0 | package gorm | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "reflect" | |
5 | "regexp" | |
6 | "runtime" | |
7 | "strings" | |
8 | ) | |
9 | ||
10 | func fileWithLineNum() string { | |
11 | for i := 2; i < 15; i++ { | |
12 | _, file, line, ok := runtime.Caller(i) | |
13 | if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { | |
14 | return fmt.Sprintf("%v:%v", file, line) | |
15 | } | |
16 | } | |
17 | return "" | |
18 | } | |
19 | ||
20 | func isBlank(value reflect.Value) bool { | |
21 | return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) | |
22 | } | |
23 | ||
24 | func toSearchableMap(attrs ...interface{}) (result interface{}) { | |
25 | if len(attrs) > 1 { | |
26 | if str, ok := attrs[0].(string); ok { | |
27 | result = map[string]interface{}{str: attrs[1]} | |
28 | } | |
29 | } else if len(attrs) == 1 { | |
30 | if attr, ok := attrs[0].(map[string]interface{}); ok { | |
31 | result = attr | |
32 | } | |
33 | ||
34 | if attr, ok := attrs[0].(interface{}); ok { | |
35 | result = attr | |
36 | } | |
37 | } | |
38 | return | |
39 | } | |
40 | ||
41 | func convertInterfaceToMap(values interface{}) map[string]interface{} { | |
42 | attrs := map[string]interface{}{} | |
43 | ||
44 | switch value := values.(type) { | |
45 | case map[string]interface{}: | |
46 | for k, v := range value { | |
47 | attrs[ToDBName(k)] = v | |
48 | } | |
49 | case []interface{}: | |
50 | for _, v := range value { | |
51 | for key, value := range convertInterfaceToMap(v) { | |
52 | attrs[key] = value | |
53 | } | |
54 | } | |
55 | case interface{}: | |
56 | reflectValue := reflect.ValueOf(values) | |
57 | ||
58 | switch reflectValue.Kind() { | |
59 | case reflect.Map: | |
60 | for _, key := range reflectValue.MapKeys() { | |
61 | attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() | |
62 | } | |
63 | default: | |
64 | scope := Scope{Value: values} | |
65 | for _, field := range scope.Fields() { | |
66 | if !field.IsBlank && !field.IsIgnored { | |
67 | attrs[field.DBName] = field.Field.Interface() | |
68 | } | |
69 | } | |
70 | } | |
71 | } | |
72 | return attrs | |
73 | } | |
74 | ||
75 | func toString(str interface{}) string { | |
76 | if values, ok := str.([]interface{}); ok { | |
77 | var results []string | |
78 | for _, value := range values { | |
79 | results = append(results, toString(value)) | |
80 | } | |
81 | return strings.Join(results, "_") | |
82 | } else if bytes, ok := str.([]byte); ok { | |
83 | return string(bytes) | |
84 | } else { | |
85 | return fmt.Sprintf("%v", str) | |
86 | } | |
87 | } | |
88 | ||
89 | func strInSlice(a string, list []string) bool { | |
90 | for _, b := range list { | |
91 | if b == a { | |
92 | return true | |
93 | } | |
94 | } | |
95 | return false | |
96 | } |
0 | package gorm_test | |
1 | ||
2 | import ( | |
3 | "testing" | |
4 | ||
5 | "github.com/jinzhu/gorm" | |
6 | ) | |
7 | ||
8 | func TestToDBNameGenerateFriendlyName(t *testing.T) { | |
9 | var maps = map[string]string{ | |
10 | "": "", | |
11 | "X": "x", | |
12 | "ThisIsATest": "this_is_a_test", | |
13 | "PFAndESI": "pf_and_esi", | |
14 | "AbcAndJkl": "abc_and_jkl", | |
15 | "EmployeeID": "employee_id", | |
16 | "SKU_ID": "sku_id", | |
17 | "FieldX": "field_x", | |
18 | "HTTPAndSMTP": "http_and_smtp", | |
19 | "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", | |
20 | "UUID": "uuid", | |
21 | "HTTPURL": "http_url", | |
22 | "HTTP_URL": "http_url", | |
23 | "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", | |
24 | } | |
25 | ||
26 | for key, value := range maps { | |
27 | if gorm.ToDBName(key) != value { | |
28 | t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key)) | |
29 | } | |
30 | } | |
31 | } |
0 | # use the default golang container from Docker Hub | |
1 | box: golang | |
2 | ||
3 | services: | |
4 | - name: mariadb | |
5 | id: mariadb:latest | |
6 | env: | |
7 | MYSQL_DATABASE: gorm | |
8 | MYSQL_USER: gorm | |
9 | MYSQL_PASSWORD: gorm | |
10 | MYSQL_RANDOM_ROOT_PASSWORD: "yes" | |
11 | - name: mysql | |
12 | id: mysql:8 | |
13 | env: | |
14 | MYSQL_DATABASE: gorm | |
15 | MYSQL_USER: gorm | |
16 | MYSQL_PASSWORD: gorm | |
17 | MYSQL_RANDOM_ROOT_PASSWORD: "yes" | |
18 | - name: mysql57 | |
19 | id: mysql:5.7 | |
20 | env: | |
21 | MYSQL_DATABASE: gorm | |
22 | MYSQL_USER: gorm | |
23 | MYSQL_PASSWORD: gorm | |
24 | MYSQL_RANDOM_ROOT_PASSWORD: "yes" | |
25 | - name: mysql56 | |
26 | id: mysql:5.6 | |
27 | env: | |
28 | MYSQL_DATABASE: gorm | |
29 | MYSQL_USER: gorm | |
30 | MYSQL_PASSWORD: gorm | |
31 | MYSQL_RANDOM_ROOT_PASSWORD: "yes" | |
32 | - name: mysql55 | |
33 | id: mysql:5.5 | |
34 | env: | |
35 | MYSQL_DATABASE: gorm | |
36 | MYSQL_USER: gorm | |
37 | MYSQL_PASSWORD: gorm | |
38 | MYSQL_RANDOM_ROOT_PASSWORD: "yes" | |
39 | - name: postgres | |
40 | id: postgres:latest | |
41 | env: | |
42 | POSTGRES_USER: gorm | |
43 | POSTGRES_PASSWORD: gorm | |
44 | POSTGRES_DB: gorm | |
45 | - name: postgres96 | |
46 | id: postgres:9.6 | |
47 | env: | |
48 | POSTGRES_USER: gorm | |
49 | POSTGRES_PASSWORD: gorm | |
50 | POSTGRES_DB: gorm | |
51 | - name: postgres95 | |
52 | id: postgres:9.5 | |
53 | env: | |
54 | POSTGRES_USER: gorm | |
55 | POSTGRES_PASSWORD: gorm | |
56 | POSTGRES_DB: gorm | |
57 | - name: postgres94 | |
58 | id: postgres:9.4 | |
59 | env: | |
60 | POSTGRES_USER: gorm | |
61 | POSTGRES_PASSWORD: gorm | |
62 | POSTGRES_DB: gorm | |
63 | - name: postgres93 | |
64 | id: postgres:9.3 | |
65 | env: | |
66 | POSTGRES_USER: gorm | |
67 | POSTGRES_PASSWORD: gorm | |
68 | POSTGRES_DB: gorm | |
69 | - name: mssql | |
70 | id: mcmoe/mssqldocker:latest | |
71 | env: | |
72 | ACCEPT_EULA: Y | |
73 | SA_PASSWORD: LoremIpsum86 | |
74 | MSSQL_DB: gorm | |
75 | MSSQL_USER: gorm | |
76 | MSSQL_PASSWORD: LoremIpsum86 | |
77 | ||
78 | # The steps that will be executed in the build pipeline | |
79 | build: | |
80 | # The steps that will be executed on build | |
81 | steps: | |
82 | # Sets the go workspace and places you package | |
83 | # at the right place in the workspace tree | |
84 | - setup-go-workspace | |
85 | ||
86 | # Gets the dependencies | |
87 | - script: | |
88 | name: go get | |
89 | code: | | |
90 | cd $WERCKER_SOURCE_DIR | |
91 | go version | |
92 | go get -t ./... | |
93 | ||
94 | # Build the project | |
95 | - script: | |
96 | name: go build | |
97 | code: | | |
98 | go build ./... | |
99 | ||
100 | # Test the project | |
101 | - script: | |
102 | name: test sqlite | |
103 | code: | | |
104 | go test ./... | |
105 | ||
106 | - script: | |
107 | name: test mariadb | |
108 | code: | | |
109 | GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./... | |
110 | ||
111 | - script: | |
112 | name: test mysql | |
113 | code: | | |
114 | GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test ./... | |
115 | ||
116 | - script: | |
117 | name: test mysql5.7 | |
118 | code: | | |
119 | GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test ./... | |
120 | ||
121 | - script: | |
122 | name: test mysql5.6 | |
123 | code: | | |
124 | GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test ./... | |
125 | ||
126 | - script: | |
127 | name: test mysql5.5 | |
128 | code: | | |
129 | GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test ./... | |
130 | ||
131 | - script: | |
132 | name: test postgres | |
133 | code: | | |
134 | GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... | |
135 | ||
136 | - script: | |
137 | name: test postgres96 | |
138 | code: | | |
139 | GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... | |
140 | ||
141 | - script: | |
142 | name: test postgres95 | |
143 | code: | | |
144 | GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... | |
145 | ||
146 | - script: | |
147 | name: test postgres94 | |
148 | code: | | |
149 | GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... | |
150 | ||
151 | - script: | |
152 | name: test postgres93 | |
153 | code: | | |
154 | GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... | |
155 | ||
156 | - script: | |
157 | name: test mssql | |
158 | code: | | |
159 | GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test ./... |