Skip to content

Commit 1afb3d7

Browse files
authored
fix(otgorm): hooks (#236)
* fix(otgorm): hooks * fix: remove debug prints
1 parent 0c0ae20 commit 1afb3d7

File tree

2 files changed

+73
-14
lines changed

2 files changed

+73
-14
lines changed

otgorm/otgorm.go

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ func AddGormCallbacks(db *gorm.DB, tracer opentracing.Tracer) {
1818
registerCallbacks(db, "query", callbacks)
1919
registerCallbacks(db, "update", callbacks)
2020
registerCallbacks(db, "delete", callbacks)
21-
registerCallbacks(db, "row_query", callbacks)
21+
registerCallbacks(db, "row", callbacks)
22+
registerCallbacks(db, "raw", callbacks)
2223
}
2324

2425
type callbacks struct {
@@ -29,16 +30,18 @@ func newCallbacks(tracer opentracing.Tracer) *callbacks {
2930
return &callbacks{tracer}
3031
}
3132

32-
func (c *callbacks) beforeCreate(scope *gorm.DB) { c.before(scope) }
33-
func (c *callbacks) afterCreate(scope *gorm.DB) { c.after(scope, "INSERT") }
34-
func (c *callbacks) beforeQuery(scope *gorm.DB) { c.before(scope) }
35-
func (c *callbacks) afterQuery(scope *gorm.DB) { c.after(scope, "SELECT") }
36-
func (c *callbacks) beforeUpdate(scope *gorm.DB) { c.before(scope) }
37-
func (c *callbacks) afterUpdate(scope *gorm.DB) { c.after(scope, "UPDATE") }
38-
func (c *callbacks) beforeDelete(scope *gorm.DB) { c.before(scope) }
39-
func (c *callbacks) afterDelete(scope *gorm.DB) { c.after(scope, "DELETE") }
40-
func (c *callbacks) beforeRowQuery(scope *gorm.DB) { c.before(scope) }
41-
func (c *callbacks) afterRowQuery(scope *gorm.DB) { c.after(scope, "") }
33+
func (c *callbacks) beforeCreate(scope *gorm.DB) { c.before(scope) }
34+
func (c *callbacks) afterCreate(scope *gorm.DB) { c.after(scope, "INSERT") }
35+
func (c *callbacks) beforeQuery(scope *gorm.DB) { c.before(scope) }
36+
func (c *callbacks) afterQuery(scope *gorm.DB) { c.after(scope, "SELECT") }
37+
func (c *callbacks) beforeUpdate(scope *gorm.DB) { c.before(scope) }
38+
func (c *callbacks) afterUpdate(scope *gorm.DB) { c.after(scope, "UPDATE") }
39+
func (c *callbacks) beforeDelete(scope *gorm.DB) { c.before(scope) }
40+
func (c *callbacks) afterDelete(scope *gorm.DB) { c.after(scope, "DELETE") }
41+
func (c *callbacks) beforeRow(scope *gorm.DB) { c.before(scope) }
42+
func (c *callbacks) afterRow(scope *gorm.DB) { c.after(scope, "") }
43+
func (c *callbacks) beforeRaw(scope *gorm.DB) { c.before(scope) }
44+
func (c *callbacks) afterRaw(scope *gorm.DB) { c.after(scope, "") }
4245

4346
func (c *callbacks) before(db *gorm.DB) {
4447
span, newCtx := opentracing.StartSpanFromContextWithTracer(db.Statement.Context, c.tracer, "sql")
@@ -48,6 +51,7 @@ func (c *callbacks) before(db *gorm.DB) {
4851
}
4952

5053
func (c *callbacks) after(db *gorm.DB, operation string) {
54+
5155
spanInterface, ok := db.Get("span")
5256
if !ok {
5357
return
@@ -85,8 +89,11 @@ func registerCallbacks(db *gorm.DB, name string, c *callbacks) {
8589
case "delete":
8690
db.Callback().Delete().Before(gormCallbackName).Register(beforeName, c.beforeDelete)
8791
db.Callback().Delete().After(gormCallbackName).Register(afterName, c.afterDelete)
88-
case "row_query":
89-
db.Callback().Row().Before(gormCallbackName).Register(beforeName, c.beforeRowQuery)
90-
db.Callback().Row().After(gormCallbackName).Register(afterName, c.afterRowQuery)
92+
case "row":
93+
db.Callback().Row().Before(gormCallbackName).Register(beforeName, c.beforeRow)
94+
db.Callback().Row().After(gormCallbackName).Register(afterName, c.afterRow)
95+
case "raw":
96+
db.Callback().Raw().Before(gormCallbackName).Register(beforeName, c.beforeRaw)
97+
db.Callback().Raw().After(gormCallbackName).Register(afterName, c.afterRaw)
9198
}
9299
}

otgorm/otgorm_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,55 @@ func TestHook(t *testing.T) {
5050

5151
assert.True(t, interceptorCalled)
5252
}
53+
54+
func TestHook_raw(t *testing.T) {
55+
tracer := mocktracer.New()
56+
out, cleanup, _ := provideDBFactory(&providersOption{
57+
drivers: map[string]func(dsn string) gorm.Dialector{"sqlite": sqlite.Open},
58+
})(factoryIn{
59+
Conf: config.MapAdapter{
60+
"gorm": map[string]interface{}{
61+
"default": map[string]interface{}{
62+
"database": "sqlite",
63+
"dsn": ":memory:",
64+
},
65+
},
66+
},
67+
Logger: log.NewNopLogger(),
68+
Tracer: tracer,
69+
})
70+
defer cleanup()
71+
72+
factory := out.Factory
73+
74+
db, err := factory.Make("default")
75+
assert.NoError(t, err)
76+
77+
_, ctx := opentracing.StartSpanFromContextWithTracer(context.Background(), tracer, "test")
78+
79+
err = db.WithContext(ctx).Exec("CREATE TABLE test (id uint)").Error
80+
assert.NoError(t, err)
81+
82+
err = db.WithContext(ctx).Exec("INSERT INTO test (id) VALUES (1)").Error
83+
assert.NoError(t, err)
84+
85+
err = db.WithContext(ctx).Exec("INSERT INTO test (id) VALUES (2)").Error
86+
assert.NoError(t, err)
87+
88+
rows, err := db.WithContext(ctx).Raw("SELECT * FROM test").Rows()
89+
assert.NoError(t, err)
90+
91+
var models []mockModel
92+
for rows.Next() {
93+
var m mockModel
94+
err = db.WithContext(ctx).ScanRows(rows, &m)
95+
assert.NoError(t, err)
96+
models = append(models, m)
97+
}
98+
t.Log(models)
99+
100+
db.WithContext(ctx).Raw("SELECT * FROM test").Scan(&models)
101+
t.Log(models)
102+
103+
assert.Len(t, tracer.FinishedSpans(), 5)
104+
}

0 commit comments

Comments
 (0)