From 6f7b1fcb349960703752abc56e9f8a6d4e94469b Mon Sep 17 00:00:00 2001 From: Radu Sora <254848+radusora@users.noreply.github.com> Date: Mon, 20 Jan 2025 10:48:33 +0000 Subject: [PATCH] feat: add support for databases enforcing strict data integrity through PKs. Fixes #13611 Signed-off-by: Radu Sora <254848+radusora@users.noreply.github.com> --- persist/sqldb/migrate.go | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/persist/sqldb/migrate.go b/persist/sqldb/migrate.go index 336c33e98304..0274d2bcda31 100644 --- a/persist/sqldb/migrate.go +++ b/persist/sqldb/migrate.go @@ -40,12 +40,43 @@ func ternary(condition bool, left, right change) change { } func (m migrate) Exec(ctx context.Context) (err error) { + + dbType := dbTypeFor(m.session) + { // poor mans SQL migration - _, err = m.session.SQL().Exec("create table if not exists schema_history(schema_version int not null)") + _, err = m.session.SQL().Exec("create table if not exists schema_history(schema_version int not null, primary key(schema_version))") + if err != nil { + return err + } + + // Ensure the schema_history table has a primary key, creating it if necessary + // This logic is implemented separately from regular migrations to improve compatibility with databases running in strict or HA modes + dbIdentifierColumn := "table_schema" + if dbType == Postgres { + dbIdentifierColumn = "table_catalog" + } + rows, err := m.session.SQL().Query( + "select 1 from information_schema.table_constraints where constraint_type = 'PRIMARY KEY' and table_name = 'schema_history' and "+dbIdentifierColumn+" = ?", + m.session.Name()) if err != nil { return err } + defer func() { + tmpErr := rows.Close() + if err == nil { + err = tmpErr + } + }() + if !rows.Next() { + _, err := m.session.SQL().Exec("alter table schema_history add primary key(schema_version)") + if err != nil { + return err + } + } else if err := rows.Err(); err != nil { + return err + } + rs, err := m.session.SQL().Query("select schema_version from schema_history") if err != nil { return err @@ -65,7 +96,6 @@ func (m migrate) Exec(ctx context.Context) (err error) { return err } } - dbType := dbTypeFor(m.session) log.WithFields(log.Fields{"clusterName": m.clusterName, "dbType": dbType}).Info("Migrating database schema")