Skip to content

Commit

Permalink
fix(mariadb): Fix restore when db name has special characters
Browse files Browse the repository at this point in the history
  • Loading branch information
gabe565 committed Nov 22, 2023
1 parent b721420 commit 72ae32d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
9 changes: 8 additions & 1 deletion internal/database/dialect/mariadb.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ func (MariaDB) DefaultUser() string {
return "root"
}

func (MariaDB) DropDatabaseQuery(database string) string {
func (db MariaDB) DropDatabaseQuery(database string) string {
database = db.quoteIdentifier(database)
return "set FOREIGN_KEY_CHECKS=0; create or replace database " + database + "; set FOREIGN_KEY_CHECKS=1; use " + database + ";"
}

Expand Down Expand Up @@ -160,3 +161,9 @@ func (db MariaDB) DumpExtension(format sqlformat.Format) string {
}
return ""
}

func (db MariaDB) quoteIdentifier(param string) string {
param = strings.ReplaceAll(param, "`", "``")
param = "`" + param + "`"
return param
}
22 changes: 21 additions & 1 deletion internal/database/dialect/mariadb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestMariaDB_DropDatabaseQuery(t *testing.T) {
args args
want string
}{
{"database", args{"database"}, "set FOREIGN_KEY_CHECKS=0; create or replace database database; set FOREIGN_KEY_CHECKS=1; use database;"},
{"database", args{"database"}, "set FOREIGN_KEY_CHECKS=0; create or replace database `database`; set FOREIGN_KEY_CHECKS=1; use `database`;"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -264,3 +264,23 @@ func TestMariaDB_FormatFromFilename(t *testing.T) {
})
}
}

func TestMariaDB_quoteIdentifier(t *testing.T) {
type args struct {
param string
}
tests := []struct {
name string
args args
want string
}{
{"simple", args{"table"}, "`table`"},
{"escaped", args{"T`able"}, "`T``able`"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := MariaDB{}
assert.Equal(t, tt.want, db.quoteIdentifier(tt.args.param))
})
}
}

0 comments on commit 72ae32d

Please sign in to comment.