diff --git a/callback.go b/callback.go index b020fe3..d305691 100644 --- a/callback.go +++ b/callback.go @@ -353,6 +353,20 @@ return nil } +func callbackRetGeneric(ctx *C.sqlite3_context, v reflect.Value) error { + if v.IsNil() { + C.sqlite3_result_null(ctx) + return nil + } + + cb, err := callbackRet(v.Elem().Type()) + if err != nil { + return err + } + + return cb(ctx, v.Elem()) +} + func callbackRet(typ reflect.Type) (callbackRetConverter, error) { switch typ.Kind() { case reflect.Interface: @@ -360,6 +374,11 @@ if typ.Implements(errorInterface) { return callbackRetNil, nil } + + if typ.NumMethod() == 0 { + return callbackRetGeneric, nil + } + fallthrough case reflect.Slice: if typ.Elem().Kind() != reflect.Uint8 { diff --git a/callback_test.go b/callback_test.go index 714ed60..b09122a 100644 --- a/callback_test.go +++ b/callback_test.go @@ -102,3 +102,15 @@ } } } + +func TestCallbackReturnAny(t *testing.T) { + udf := func() interface{} { + return 1 + } + + typ := reflect.TypeOf(udf) + _, err := callbackRet(typ.Out(0)) + if err != nil { + t.Errorf("Expected valid callback for any return type, got: %s", err) + } +} diff --git a/sqlite3_test.go b/sqlite3_test.go index c86aba4..9ee87e7 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1446,6 +1446,63 @@ if ret != test.sum { t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum) } + } +} + +type mode struct { + counts map[interface{}]int + top interface{} + topCount int +} + +func newMode() *mode { + return &mode{ + counts: map[interface{}]int{}, + } +} + +func (m *mode) Step(x interface{}) { + m.counts[x]++ + c := m.counts[x] + if c > m.topCount { + m.top = x + m.topCount = c + } +} + +func (m *mode) Done() interface{} { + return m.top +} + +func TestAggregatorRegistration_GenericReturn(t *testing.T) { + sql.Register("sqlite3_AggregatorRegistration_GenericReturn", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + return conn.RegisterAggregator("mode", newMode, true) + }, + }) + db, err := sql.Open("sqlite3_AggregatorRegistration_GenericReturn", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("create table foo (department integer, profits integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + _, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115), (2, 20)") + if err != nil { + t.Fatal("Failed to insert records:", err) + } + + var mode int + err = db.QueryRow("select mode(profits) from foo").Scan(&mode) + if err != nil { + t.Fatal("MODE query error:", err) + } + + if mode != 20 { + t.Fatal("Got incorrect mode. Wanted 20, got: ", mode) } }