Skip to content

Commit 3ca308c

Browse files
Allow custom drivers (go-gorm#11)
Much like the `mysql` driver's `DriverName`, this allows you to specify a custom driver for SQLite. This is important when creating custom functions, for example.
1 parent 459b36b commit 3ca308c

File tree

2 files changed

+142
-2
lines changed

2 files changed

+142
-2
lines changed

sqlite.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@ import (
1414
"gorm.io/gorm/schema"
1515
)
1616

17+
// DriverName is the default driver name for SQLite.
18+
const DriverName = "sqlite3"
19+
1720
type Dialector struct {
18-
DSN string
21+
DriverName string
22+
DSN string
1923
}
2024

2125
func Open(dsn string) gorm.Dialector {
@@ -27,11 +31,15 @@ func (dialector Dialector) Name() string {
2731
}
2832

2933
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
34+
if dialector.DriverName == "" {
35+
dialector.DriverName = DriverName
36+
}
37+
3038
// register callbacks
3139
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
3240
LastInsertIDReversed: true,
3341
})
34-
db.ConnPool, err = sql.Open("sqlite3", dialector.DSN)
42+
db.ConnPool, err = sql.Open(dialector.DriverName, dialector.DSN)
3543

3644
for k, v := range dialector.ClauseBuilders() {
3745
db.ClauseBuilders[k] = v

sqlite_test.go

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
package sqlite
2+
3+
import (
4+
"database/sql"
5+
"fmt"
6+
"testing"
7+
8+
"github.com/mattn/go-sqlite3"
9+
"gorm.io/gorm"
10+
)
11+
12+
func TestDialector(t *testing.T) {
13+
// This is the DSN of the in-memory SQLite database for these tests.
14+
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
15+
// This is the custom SQLite driver name.
16+
const CustomDriverName = "my_custom_driver"
17+
18+
// Register the custom SQlite3 driver.
19+
// It will have one custom function called "my_custom_function".
20+
sql.Register(CustomDriverName,
21+
&sqlite3.SQLiteDriver{
22+
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
23+
// Define the `concat` function, since we use this elsewhere.
24+
err := conn.RegisterFunc(
25+
"my_custom_function",
26+
func(arguments ...interface{}) (string, error) {
27+
return "my-result", nil // Return a string value.
28+
},
29+
true,
30+
)
31+
return err
32+
},
33+
},
34+
)
35+
36+
rows := []struct {
37+
description string
38+
dialector *Dialector
39+
openSuccess bool
40+
query string
41+
querySuccess bool
42+
}{
43+
{
44+
description: "Default driver",
45+
dialector: &Dialector{
46+
DSN: InMemoryDSN,
47+
},
48+
openSuccess: true,
49+
query: "SELECT 1",
50+
querySuccess: true,
51+
},
52+
{
53+
description: "Explicit default driver",
54+
dialector: &Dialector{
55+
DriverName: DriverName,
56+
DSN: InMemoryDSN,
57+
},
58+
openSuccess: true,
59+
query: "SELECT 1",
60+
querySuccess: true,
61+
},
62+
{
63+
description: "Bad driver",
64+
dialector: &Dialector{
65+
DriverName: "not-a-real-driver",
66+
DSN: InMemoryDSN,
67+
},
68+
openSuccess: false,
69+
},
70+
{
71+
description: "Explicit default driver, custom function",
72+
dialector: &Dialector{
73+
DriverName: DriverName,
74+
DSN: InMemoryDSN,
75+
},
76+
openSuccess: true,
77+
query: "SELECT my_custom_function()",
78+
querySuccess: false,
79+
},
80+
{
81+
description: "Custom driver",
82+
dialector: &Dialector{
83+
DriverName: CustomDriverName,
84+
DSN: InMemoryDSN,
85+
},
86+
openSuccess: true,
87+
query: "SELECT 1",
88+
querySuccess: true,
89+
},
90+
{
91+
description: "Custom driver, custom function",
92+
dialector: &Dialector{
93+
DriverName: CustomDriverName,
94+
DSN: InMemoryDSN,
95+
},
96+
openSuccess: true,
97+
query: "SELECT my_custom_function()",
98+
querySuccess: true,
99+
},
100+
}
101+
for rowIndex, row := range rows {
102+
t.Run(fmt.Sprintf("%d/%s", rowIndex, row.description), func(t *testing.T) {
103+
db, err := gorm.Open(row.dialector, &gorm.Config{})
104+
if !row.openSuccess {
105+
if err == nil {
106+
t.Errorf("Expected Open to fail.")
107+
}
108+
return
109+
}
110+
111+
if err != nil {
112+
t.Errorf("Expected Open to succeed; got error: %v", err)
113+
}
114+
if db == nil {
115+
t.Errorf("Expected db to be non-nil.")
116+
}
117+
if row.query != "" {
118+
err = db.Exec(row.query).Error
119+
if !row.querySuccess {
120+
if err == nil {
121+
t.Errorf("Expected query to fail.")
122+
}
123+
return
124+
}
125+
126+
if err != nil {
127+
t.Errorf("Expected query to succeed; got error: %v", err)
128+
}
129+
}
130+
})
131+
}
132+
}

0 commit comments

Comments
 (0)