diff --git a/pkg/database/mysql/dump.go b/pkg/database/mysql/dump.go index 236883d8..fa6b64ea 100644 --- a/pkg/database/mysql/dump.go +++ b/pkg/database/mysql/dump.go @@ -19,10 +19,7 @@ import ( "context" "database/sql" "errors" - "fmt" "io" - "reflect" - "strings" "text/template" "time" ) @@ -49,21 +46,10 @@ type Data struct { tx *sql.Tx headerTmpl *template.Template - tableTmpl *template.Template footerTmpl *template.Template err error } -type table struct { - Name string - Err error - - cols []string - data *Data - rows *sql.Rows - values []interface{} -} - type metaData struct { DumpVersion string ServerVersion string @@ -122,41 +108,6 @@ const footerTmpl = `/*!40103 SET TIME_ZONE=@OLD_TIME_ZONE */; const footerTmplCompact = `` -// Takes a *table -const tableTmpl = ` --- --- Table structure for table {{ .NameEsc }} --- - -DROP TABLE IF EXISTS {{ .NameEsc }}; -/*!40101 SET @saved_cs_client = @@character_set_client */; -/*!50503 SET character_set_client = utf8mb4 */; -{{ .CreateSQL }}; -/*!40101 SET character_set_client = @saved_cs_client */; - --- --- Dumping data for table {{ .NameEsc }} --- - -LOCK TABLES {{ .NameEsc }} WRITE; -/*!40000 ALTER TABLE {{ .NameEsc }} DISABLE KEYS */; -{{ range $value := .Stream }} -{{- $value }} -{{ end -}} -/*!40000 ALTER TABLE {{ .NameEsc }} ENABLE KEYS */; -UNLOCK TABLES; -` - -const tableTmplCompact = ` -/*!40101 SET @saved_cs_client = @@character_set_client */; -/*!50503 SET character_set_client = utf8mb4 */; -{{ .CreateSQL }}; -/*!40101 SET character_set_client = @saved_cs_client */; -{{ range $value := .Stream }} -{{- $value }} -{{ end -}} -` - const nullType = "NULL" // Dump data using struct @@ -206,11 +157,11 @@ func (data *Data) Dump() error { if data.LockTables && len(tables) > 0 { var b bytes.Buffer b.WriteString("LOCK TABLES ") - for index, name := range tables { + for index, table := range tables { if index != 0 { b.WriteString(",") } - b.WriteString("`" + name + "` READ /*!32311 LOCAL */") + b.WriteString("`" + table.Name() + "` READ /*!32311 LOCAL */") } if _, err := data.Connection.Exec(b.String()); err != nil { @@ -263,19 +214,14 @@ func (data *Data) rollback() error { // MARK: writter methods -func (data *Data) dumpTable(name string) error { +func (data *Data) dumpTable(table Table) error { if data.err != nil { return data.err } - table := data.createTable(name) - return data.writeTable(table) -} - -func (data *Data) writeTable(table *table) error { - if err := data.tableTmpl.Execute(data.Out, table); err != nil { + if err := table.Init(); err != nil { return err } - return table.Err + return table.Execute(data.Out, data.Compact) } // MARK: get methods @@ -284,10 +230,8 @@ func (data *Data) writeTable(table *table) error { func (data *Data) getTemplates() (err error) { var hTmpl string fTmpl := footerTmpl - tTmpl := tableTmpl if data.Compact { fTmpl = footerTmplCompact - tTmpl = tableTmplCompact } else { hTmpl = headerTmpl } @@ -304,34 +248,42 @@ func (data *Data) getTemplates() (err error) { return } - data.tableTmpl, err = template.New("mysqldumpTable").Parse(tTmpl) - if err != nil { - return - } - - data.footerTmpl, err = template.New("mysqldumpTable").Parse(fTmpl) + data.footerTmpl, err = template.New("mysqldumpFooter").Parse(fTmpl) if err != nil { return } return } -func (data *Data) getTables() ([]string, error) { - tables := make([]string, 0) +func (data *Data) getTables() ([]Table, error) { + tables := make([]Table, 0) - rows, err := data.tx.Query("SHOW TABLES") + rows, err := data.tx.Query("SHOW FULL TABLES") if err != nil { - return tables, err + return nil, err } defer rows.Close() for rows.Next() { - var table sql.NullString - if err := rows.Scan(&table); err != nil { - return tables, err + var tableName, tableType sql.NullString + if err := rows.Scan(&tableName, &tableType); err != nil { + return nil, err + } + if !tableName.Valid || data.isIgnoredTable(tableName.String) { + continue } - if table.Valid && !data.isIgnoredTable(table.String) { - tables = append(tables, table.String) + table := baseTable{ + name: tableName.String, + data: data, + database: data.Schema, + } + switch tableType.String { + case "VIEW": + tables = append(tables, &view{baseTable: table}) + case "BASE TABLE": + tables = append(tables, &table) + default: + return nil, errors.New("unknown table type: " + tableType.String) } } return tables, rows.Err() @@ -353,266 +305,10 @@ func (meta *metaData) updateServerVersion(data *Data) (err error) { return } -// MARK: create methods - -func (data *Data) createTable(name string) *table { - return &table{ - Name: name, - data: data, - } -} - -func (table *table) NameEsc() string { - return "`" + table.Name + "`" -} - -func (table *table) CreateSQL() (string, error) { - var tableReturn, tableSQL sql.NullString - if err := table.data.tx.QueryRow("SHOW CREATE TABLE "+table.NameEsc()).Scan(&tableReturn, &tableSQL); err != nil { - return "", err - } - - if tableReturn.String != table.Name { - return "", errors.New("Returned table is not the same as requested table") - } - - return tableSQL.String, nil -} - -func (table *table) initColumnData() error { - colInfo, err := table.data.tx.Query("SHOW COLUMNS FROM " + table.NameEsc()) - if err != nil { - return err - } - defer colInfo.Close() - - cols, err := colInfo.Columns() - if err != nil { - return err - } - - fieldIndex, extraIndex := -1, -1 - for i, col := range cols { - switch col { - case "Field", "field": - fieldIndex = i - case "Extra", "extra": - extraIndex = i - } - if fieldIndex >= 0 && extraIndex >= 0 { - break - } - } - if fieldIndex < 0 || extraIndex < 0 { - return errors.New("database column information is malformed") - } - - info := make([]sql.NullString, len(cols)) - scans := make([]interface{}, len(cols)) - for i := range info { - scans[i] = &info[i] - } - - var result []string - for colInfo.Next() { - // Read into the pointers to the info marker - if err := colInfo.Scan(scans...); err != nil { - return err - } - - // Ignore the virtual columns and generated columns - // if there is an Extra column and it is a valid string, then only include this column if - // the column is not marked as VIRTUAL or GENERATED - if !info[extraIndex].Valid || (!strings.Contains(info[extraIndex].String, "VIRTUAL") && !strings.Contains(info[extraIndex].String, "GENERATED")) { - result = append(result, info[fieldIndex].String) - } - } - table.cols = result - return nil -} - -func (table *table) columnsList() string { - return "`" + strings.Join(table.cols, "`, `") + "`" -} - -func (table *table) Init() error { - if len(table.values) != 0 { - return errors.New("can't init twice") - } - - if err := table.initColumnData(); err != nil { - return err - } - - if len(table.cols) == 0 { - // No data to dump since this is a virtual table - return nil - } - - var err error - table.rows, err = table.data.tx.Query("SELECT " + table.columnsList() + " FROM " + table.NameEsc()) - if err != nil { - return err - } - - tt, err := table.rows.ColumnTypes() - if err != nil { - return err - } - - table.values = make([]interface{}, len(tt)) - for i, tp := range tt { - table.values[i] = reflect.New(reflectColumnType(tp)).Interface() - } - return nil -} - -func reflectColumnType(tp *sql.ColumnType) reflect.Type { - // reflect for scanable - switch tp.ScanType().Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return reflect.TypeOf(sql.NullInt64{}) - case reflect.Float32, reflect.Float64: - return reflect.TypeOf(sql.NullFloat64{}) - case reflect.String: - return reflect.TypeOf(sql.NullString{}) - } - - // determine by name - switch tp.DatabaseTypeName() { - case "BLOB", "BINARY": - return reflect.TypeOf(sql.RawBytes{}) - case "VARCHAR", "TEXT", "DECIMAL": - return reflect.TypeOf(sql.NullString{}) - case "BIGINT", "TINYINT", "INT": - return reflect.TypeOf(sql.NullInt64{}) - case "DOUBLE": - return reflect.TypeOf(sql.NullFloat64{}) - case "TIMESTAMP", "DATETIME": - return reflect.TypeOf(sql.NullTime{}) - case "DATE": - return reflect.TypeOf(NullDate{}) - case "TIME": - return reflect.TypeOf(sql.NullString{}) - } - - // unknown datatype - return tp.ScanType() +func sub(a, b int) int { + return a - b } -func (table *table) Next() bool { - if table.rows == nil { - if err := table.Init(); err != nil { - table.Err = err - return false - } - } - // Fallthrough - if table.rows.Next() { - if err := table.rows.Scan(table.values...); err != nil { - table.Err = err - return false - } else if err := table.rows.Err(); err != nil { - table.Err = err - return false - } - } else { - table.rows.Close() - table.rows = nil - return false - } - return true -} - -func (table *table) RowValues() string { - return table.RowBuffer().String() -} - -func (table *table) RowBuffer() *bytes.Buffer { - var b bytes.Buffer - b.WriteString("(") - - for key, value := range table.values { - if key != 0 { - b.WriteString(",") - } - switch s := value.(type) { - case nil: - b.WriteString(nullType) - case *sql.NullString: - if s.Valid { - fmt.Fprintf(&b, "'%s'", sanitize(s.String)) - } else { - b.WriteString(nullType) - } - case *sql.NullInt64: - if s.Valid { - fmt.Fprintf(&b, "%d", s.Int64) - } else { - b.WriteString(nullType) - } - case *sql.NullFloat64: - if s.Valid { - fmt.Fprintf(&b, "%f", s.Float64) - } else { - b.WriteString(nullType) - } - case *sql.RawBytes: - if len(*s) == 0 { - b.WriteString(nullType) - } else { - fmt.Fprintf(&b, "_binary '%s'", sanitize(string(*s))) - } - case *NullDate: - if s.Valid { - fmt.Fprintf(&b, "'%s'", sanitize(s.Date.Format("2006-01-02"))) - } else { - b.WriteString(nullType) - } - case *sql.NullTime: - if s.Valid { - fmt.Fprintf(&b, "'%s'", sanitize(s.Time.Format("2006-01-02 15:04:05"))) - } else { - b.WriteString(nullType) - } - default: - fmt.Fprintf(&b, "'%s'", value) - } - } - b.WriteString(")") - - return &b -} - -func (table *table) Stream() <-chan string { - valueOut := make(chan string, 1) - go func() { - defer close(valueOut) - var insert bytes.Buffer - - for table.Next() { - b := table.RowBuffer() - // Truncate our insert if it won't fit - if insert.Len() != 0 && insert.Len()+b.Len() > table.data.MaxAllowedPacket-1 { - _, _ = insert.WriteString(";") - valueOut <- insert.String() - insert.Reset() - } - - if insert.Len() == 0 { - _, _ = fmt.Fprint(&insert, strings.Join( - // extra "" at the end so we get an extra whitespace as needed - []string{"INSERT", "INTO", table.NameEsc(), "(" + table.columnsList() + ")", "VALUES", ""}, - " ")) - } else { - _, _ = insert.WriteString(",") - } - _, _ = b.WriteTo(&insert) - } - if insert.Len() != 0 { - _, _ = insert.WriteString(";") - valueOut <- insert.String() - } - }() - return valueOut +func esc(in string) string { + return "`" + in + "`" } diff --git a/pkg/database/mysql/table.go b/pkg/database/mysql/table.go new file mode 100644 index 00000000..de3a4ff6 --- /dev/null +++ b/pkg/database/mysql/table.go @@ -0,0 +1,366 @@ +package mysql + +import ( + "bytes" + "database/sql" + "errors" + "fmt" + "io" + "reflect" + "strings" + "text/template" +) + +var tableFullTemplate, tableCompactTemplate *template.Template + +func init() { + tmpl, err := template.New("mysqldumpTable").Funcs(template.FuncMap{ + "sub": sub, + "esc": esc, + }).Parse(tableTmpl) + if err != nil { + panic(fmt.Errorf("could not parse table template: %w", err)) + } + tableFullTemplate = tmpl + + tmpl, err = template.New("mysqldumpTableCompact").Funcs(template.FuncMap{ + "sub": sub, + "esc": esc, + }).Parse(tableTmplCompact) + if err != nil { + panic(fmt.Errorf("could not parse table compact template: %w", err)) + } + tableCompactTemplate = tmpl +} + +type Table interface { + Name() string + Err() error + Database() string + Columns() []string + Init() error + Start() error + Next() bool + RowValues() string + RowBuffer() *bytes.Buffer + Execute(io.Writer, bool) error + Stream() <-chan string +} + +type baseTable struct { + name string + err error + + cols []string + data *Data + rows *sql.Rows + database string + values []interface{} +} + +func (table *baseTable) Name() string { + return table.name +} + +func (table *baseTable) Err() error { + return table.err +} + +func (table *baseTable) Columns() []string { + return table.cols +} +func (table *baseTable) Database() string { + return table.database +} + +func (table *baseTable) CreateSQL() ([]string, error) { + var tableReturn, tableSQL sql.NullString + if err := table.data.tx.QueryRow("SHOW CREATE TABLE "+esc(table.Name())).Scan(&tableReturn, &tableSQL); err != nil { + return nil, err + } + + if tableReturn.String != table.name { + return nil, errors.New("returned table is not the same as requested table") + } + + return []string{strings.TrimSpace(tableSQL.String)}, nil +} + +func (table *baseTable) initColumnData() error { + colInfo, err := table.data.tx.Query("SHOW COLUMNS FROM " + esc(table.Name())) + if err != nil { + return err + } + defer colInfo.Close() + + cols, err := colInfo.Columns() + if err != nil { + return err + } + + fieldIndex, extraIndex := -1, -1 + for i, col := range cols { + switch col { + case "Field", "field": + fieldIndex = i + case "Extra", "extra": + extraIndex = i + } + if fieldIndex >= 0 && extraIndex >= 0 { + break + } + } + if fieldIndex < 0 || extraIndex < 0 { + return errors.New("database column information is malformed") + } + + info := make([]sql.NullString, len(cols)) + scans := make([]interface{}, len(cols)) + for i := range info { + scans[i] = &info[i] + } + + var result []string + for colInfo.Next() { + // Read into the pointers to the info marker + if err := colInfo.Scan(scans...); err != nil { + return err + } + + // Ignore the virtual columns and generated columns + // if there is an Extra column and it is a valid string, then only include this column if + // the column is not marked as VIRTUAL or GENERATED + if !info[extraIndex].Valid || (!strings.Contains(info[extraIndex].String, "VIRTUAL") && !strings.Contains(info[extraIndex].String, "GENERATED")) { + result = append(result, info[fieldIndex].String) + } + } + table.cols = result + return nil +} + +func (table *baseTable) columnsList() string { + return "`" + strings.Join(table.cols, "`, `") + "`" +} + +func (table *baseTable) Init() error { + return table.initColumnData() +} + +func (table *baseTable) Start() error { + if table.rows != nil { + return errors.New("can't start twice") + } + + if len(table.cols) == 0 { + // No data to dump since this is a virtual table + return nil + } + + var err error + table.rows, err = table.data.tx.Query("SELECT " + table.columnsList() + " FROM " + esc(table.Name())) + if err != nil { + return err + } + tt, err := table.rows.ColumnTypes() + if err != nil { + return err + } + + table.values = make([]interface{}, len(tt)) + for i, tp := range tt { + table.values[i] = reflect.New(reflectColumnType(tp)).Interface() + } + + return nil +} + +func (table *baseTable) Next() bool { + if table.rows == nil { + if err := table.Start(); err != nil { + table.err = err + return false + } + } + // Fallthrough + if table.rows.Next() { + if err := table.rows.Scan(table.values...); err != nil { + table.err = err + return false + } else if err := table.rows.Err(); err != nil { + table.err = err + return false + } + } else { + table.rows.Close() + table.rows = nil + return false + } + return true +} + +func (table *baseTable) RowValues() string { + return table.RowBuffer().String() +} + +func (table *baseTable) RowBuffer() *bytes.Buffer { + var b bytes.Buffer + b.WriteString("(") + + for key, value := range table.values { + if key != 0 { + b.WriteString(",") + } + switch s := value.(type) { + case nil: + b.WriteString(nullType) + case *sql.NullString: + if s.Valid { + fmt.Fprintf(&b, "'%s'", sanitize(s.String)) + } else { + b.WriteString(nullType) + } + case *sql.NullInt64: + if s.Valid { + fmt.Fprintf(&b, "%d", s.Int64) + } else { + b.WriteString(nullType) + } + case *sql.NullFloat64: + if s.Valid { + fmt.Fprintf(&b, "%f", s.Float64) + } else { + b.WriteString(nullType) + } + case *sql.RawBytes: + if len(*s) == 0 { + b.WriteString(nullType) + } else { + fmt.Fprintf(&b, "_binary '%s'", sanitize(string(*s))) + } + case *NullDate: + if s.Valid { + fmt.Fprintf(&b, "'%s'", sanitize(s.Date.Format("2006-01-02"))) + } else { + b.WriteString(nullType) + } + case *sql.NullTime: + if s.Valid { + fmt.Fprintf(&b, "'%s'", sanitize(s.Time.Format("2006-01-02 15:04:05"))) + } else { + b.WriteString(nullType) + } + default: + fmt.Fprintf(&b, "'%s'", value) + } + } + b.WriteString(")") + + return &b +} + +func (table *baseTable) Stream() <-chan string { + valueOut := make(chan string, 1) + go func() { + defer close(valueOut) + var insert bytes.Buffer + + for table.Next() { + b := table.RowBuffer() + // Truncate our insert if it won't fit + if insert.Len() != 0 && insert.Len()+b.Len() > table.data.MaxAllowedPacket-1 { + _, _ = insert.WriteString(";") + valueOut <- insert.String() + insert.Reset() + } + + if insert.Len() == 0 { + _, _ = fmt.Fprint(&insert, strings.Join( + // extra "" at the end so we get an extra whitespace as needed + []string{"INSERT", "INTO", esc(table.Name()), "(" + table.columnsList() + ")", "VALUES", ""}, + " ")) + } else { + _, _ = insert.WriteString(",") + } + _, _ = b.WriteTo(&insert) + } + if insert.Len() != 0 { + _, _ = insert.WriteString(";") + valueOut <- insert.String() + } + }() + return valueOut +} + +func (table *baseTable) Execute(out io.Writer, compact bool) error { + tmpl := tableFullTemplate + if compact { + tmpl = tableCompactTemplate + } + return tmpl.Execute(out, table) +} + +func reflectColumnType(tp *sql.ColumnType) reflect.Type { + // reflect for scanable + switch tp.ScanType().Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return reflect.TypeOf(sql.NullInt64{}) + case reflect.Float32, reflect.Float64: + return reflect.TypeOf(sql.NullFloat64{}) + case reflect.String: + return reflect.TypeOf(sql.NullString{}) + } + + // determine by name + switch tp.DatabaseTypeName() { + case "BLOB", "BINARY": + return reflect.TypeOf(sql.RawBytes{}) + case "VARCHAR", "TEXT", "DECIMAL": + return reflect.TypeOf(sql.NullString{}) + case "BIGINT", "TINYINT", "INT": + return reflect.TypeOf(sql.NullInt64{}) + case "DOUBLE": + return reflect.TypeOf(sql.NullFloat64{}) + case "TIMESTAMP", "DATETIME": + return reflect.TypeOf(sql.NullTime{}) + case "DATE": + return reflect.TypeOf(NullDate{}) + case "TIME": + return reflect.TypeOf(sql.NullString{}) + } + + // unknown datatype + return tp.ScanType() +} + +// Takes a Table, but is a baseTable +const tableTmpl = ` +-- +-- Table structure for table {{ esc .Name }} +-- + +DROP TABLE IF EXISTS {{ esc .Name }}; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!50503 SET character_set_client = utf8mb4 */; +{{ index .CreateSQL 0 }}; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table {{ esc .Name }} +-- + +LOCK TABLES {{ esc .Name }} WRITE; +/*!40000 ALTER TABLE {{ esc .Name }} DISABLE KEYS */; +{{ range $value := .Stream }} +{{- $value }} +{{ end -}} +/*!40000 ALTER TABLE {{ esc .Name }} ENABLE KEYS */; +UNLOCK TABLES; +` + +const tableTmplCompact = ` +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!50503 SET character_set_client = utf8mb4 */; +{{ index .CreateSQL 0 }}; +/*!40101 SET character_set_client = @saved_cs_client */; +{{ range $value := .Stream }}{{- $value }}{{ end -}} +` diff --git a/pkg/database/mysql/view.go b/pkg/database/mysql/view.go new file mode 100644 index 00000000..cda1b5c4 --- /dev/null +++ b/pkg/database/mysql/view.go @@ -0,0 +1,170 @@ +package mysql + +import ( + "database/sql" + "errors" + "fmt" + "io" + "strings" + "text/template" +) + +type view struct { + baseTable + charset string + collation string +} + +var viewFullTemplate, viewCompactTemplate *template.Template + +func init() { + tmpl, err := template.New("mysqldumpView").Funcs(template.FuncMap{ + "sub": sub, + "esc": esc, + }).Parse(viewTmpl) + if err != nil { + panic(fmt.Errorf("could not parse view template: %w", err)) + } + viewFullTemplate = tmpl + + tmpl, err = template.New("mysqldumpViewCompact").Funcs(template.FuncMap{ + "sub": sub, + "esc": esc, + }).Parse(viewTmplCompact) + if err != nil { + panic(fmt.Errorf("could not parse view compact template: %w", err)) + } + viewCompactTemplate = tmpl +} + +func (v *view) CreateSQL() ([]string, error) { + var tableReturn, tableSQL, charSetClient, collationConnection sql.NullString + if err := v.data.tx.QueryRow("SHOW CREATE VIEW "+esc(v.Name())).Scan(&tableReturn, &tableSQL, &charSetClient, &collationConnection); err != nil { + return nil, err + } + + if tableReturn.String != v.Name() { + return nil, errors.New("returned view is not the same as requested view") + } + + // this comes in one string, which we need to break down into 3 parts for the template + // CREATE ALGORITHM=UNDEFINED DEFINER=`testadmin`@`%` SQL SECURITY DEFINER VIEW `view1` AS select `t1`.`id` AS `id`,`t1`.`name` AS `name` from `t1` + // becomes: + // CREATE ALGORITHM=UNDEFINED + // DEFINER=`testadmin`@`%` SQL SECURITY DEFINER + // VIEW `view1` AS select `t1`.`id` AS `id`,`t1`.`name` AS `name` from `t1` + in := tableSQL.String + indexDefiner := strings.Index(in, "DEFINER") + indexView := strings.Index(in, "VIEW") + + parts := make([]string, 3) + parts[0] = strings.TrimSpace(in[:indexDefiner]) + parts[1] = strings.TrimSpace(in[indexDefiner:indexView]) + parts[2] = strings.TrimSpace(in[indexView:]) + + v.charset = charSetClient.String + v.collation = collationConnection.String + + return parts, nil +} + +// SELECT TABLE_NAME,CHARACTER_SET_CLIENT,COLLATION_CONNECTION FROM INFORMATION_SCHEMA.VIEWS; +func (v *view) Init() error { + if err := v.initColumnData(); err != nil { + return fmt.Errorf("failed to initialize column data for view %s: %w", v.name, err) + } + var tableName, charSetClient, collationConnection sql.NullString + + if err := v.data.tx.QueryRow("SELECT TABLE_NAME,CHARACTER_SET_CLIENT,COLLATION_CONNECTION FROM INFORMATION_SCHEMA.VIEWS WHERE table_name = '"+v.name+"'").Scan(&tableName, &charSetClient, &collationConnection); err != nil { + return fmt.Errorf("failed to get view information schema for view %s: %w", v.name, err) + } + if tableName.String != v.name { + return fmt.Errorf("returned view name %s is not the same as requested view %s", tableName.String, v.name) + } + if !charSetClient.Valid { + return fmt.Errorf("returned charset is not valid for view %s", v.name) + } + if !collationConnection.Valid { + return fmt.Errorf("returned collation is not valid for view %s", v.name) + } + v.charset = charSetClient.String + v.collation = collationConnection.String + return nil +} + +func (v *view) Execute(out io.Writer, compact bool) error { + tmpl := viewFullTemplate + if compact { + tmpl = viewCompactTemplate + } + return tmpl.Execute(out, v) +} + +func (v *view) Charset() string { + return v.charset +} + +func (v *view) Collation() string { + return v.collation +} + +// takes a Table, but is a view +const viewTmpl = ` +-- +-- Temporary view structure for view {{ esc .Name }} +-- + +DROP TABLE IF EXISTS {{ esc .Name }}; +/*!50001 DROP VIEW IF EXISTS {{ esc .Name }}*/; +SET @saved_cs_client = @@character_set_client; +/*!50503 SET character_set_client = utf8mb4 */; +/*!50001 CREATE VIEW {{ esc .Name }} AS SELECT +{{ $columns := .Columns }}{{ range $index, $column := .Columns }} 1 AS {{ esc $column }}{{ if ne $index (sub (len $columns) 1) }},{{ printf "%c" 10 }}{{ else }}*/;{{ end }}{{ end }} +SET character_set_client = @saved_cs_client; + +-- +-- Current Database: {{ esc .Database }} +-- + +USE {{ esc .Database }}; + +-- +-- Final view structure for view {{ esc .Name }} +-- + +/*!50001 DROP VIEW IF EXISTS {{ esc .Name }}*/; +/*!50001 SET @saved_cs_client = @@character_set_client */; +/*!50001 SET @saved_cs_results = @@character_set_results */; +/*!50001 SET @saved_col_connection = @@collation_connection */; +/*!50001 SET character_set_client = {{ .Charset }} */; +/*!50001 SET character_set_results = {{ .Charset }} */; +/*!50001 SET collation_connection = {{ .Collation }} */; +/*!50001 {{ $sql := .CreateSQL }}{{ index $sql 0 }} */ +/*!50013 {{ index $sql 1 }} */ +/*!50001 {{ index $sql 2 }} */; +/*!50001 SET character_set_client = @saved_cs_client */; +/*!50001 SET character_set_results = @saved_cs_results */; +/*!50001 SET collation_connection = @saved_col_connection */; +` +const viewTmplCompact = ` +SET @saved_cs_client = @@character_set_client; +/*!50503 SET character_set_client = utf8mb4 */; +/*!50001 CREATE VIEW {{ esc .Name }} AS SELECT +{{ $columns := .Columns }}{{ range $index, $column := .Columns }} 1 AS {{ esc $column }}{{ if ne $index (sub (len $columns) 1) }},{{ printf "%c" 10 }}{{ else }}*/;{{ end }}{{ end }} +SET character_set_client = @saved_cs_client; + +USE {{ esc .Database }}; +/*!50001 DROP VIEW IF EXISTS {{ esc .Name }}*/; +/*!50001 SET @saved_cs_client = @@character_set_client */; +/*!50001 SET @saved_cs_results = @@character_set_results */; +/*!50001 SET @saved_col_connection = @@collation_connection */; +/*!50001 SET character_set_client = {{ .Charset }} */; +/*!50001 SET character_set_results = {{ .Charset }} */; +/*!50001 SET collation_connection = {{ .Collation }} */; +/*!50001 {{ $sql := .CreateSQL }}{{ index $sql 0 }} */ +/*!50013 {{ index $sql 1 }} */ +/*!50001 {{ index $sql 2 }} */; +/*!50001 SET character_set_client = @saved_cs_client */; +/*!50001 SET character_set_results = @saved_cs_results */; +/*!50001 SET collation_connection = @saved_col_connection */; +` diff --git a/test/backup_test.go b/test/backup_test.go index 6b771f6c..956f7482 100644 --- a/test/backup_test.go +++ b/test/backup_test.go @@ -313,6 +313,7 @@ func (d *dockerContext) createBackupFile(mysqlCID, mysqlUser, mysqlPass, outfile (2, "Jill", "2012-11-02", "00:16:00", "2012-11-02 00:16:00", "2012-11-02 00:16:00"), (3, "Sam", "2012-11-03", "00:17:00", "2012-11-03 00:17:00", "2012-11-03 00:17:00"), (4, "Sarah", "2012-11-04", "00:18:00", "2012-11-04 00:18:00", "2012-11-04 00:18:00"); + create view view1 as select id, name from t1; `} attachResp, exitCode, err := d.execInContainer(ctx, mysqlCID, mysqlCreateCmd) if err != nil {