diff --git a/command/ddl_runner.go b/command/ddl_runner.go index bfdc14ca..1924443e 100644 --- a/command/ddl_runner.go +++ b/command/ddl_runner.go @@ -77,8 +77,9 @@ const ( func NewDDLCommandRunner(ce *Executor) *DDLCommandRunner { return &DDLCommandRunner{ - ce: ce, - idSeq: -1, + ce: ce, + idSeq: -1, + cancellationSequences: make(map[cancelledCommandsKey]int64), } } @@ -106,6 +107,16 @@ type DDLCommandRunner struct { ce *Executor commands sync.Map idSeq int64 + + handleCancelLock sync.Mutex + // Cancel can be received on a node during or *before* the command has been received + // so we keep a map of the sequences for each schema and originating node before which commands should be cancelled + cancellationSequences map[cancelledCommandsKey]int64 +} + +type cancelledCommandsKey struct { + schemaName string + originatingNodeID int64 } func (d *DDLCommandRunner) generateCommandKey(origNodeID uint64, commandID uint64) string { @@ -144,10 +155,12 @@ func (d *DDLCommandRunner) HandleCancelMessage(clusterMsg remoting.ClusterMessag if !ok { panic("not a cancel msg") } - return d.cancelCommandsForSchema(cancelMsg.SchemaName) + return d.cancelCommandsForSchema(cancelMsg.SchemaName, cancelMsg.CommandId, cancelMsg.OriginatingNodeId) } -func (d *DDLCommandRunner) cancelCommandsForSchema(schemaName string) error { +// cancel commands for the schema up to and including the commandID +func (d *DDLCommandRunner) cancelCommandsForSchema(schemaName string, commandID int64, originatingNodeID int64) error { + found := false d.commands.Range(func(key, value interface{}) bool { command, ok := value.(DDLCommand) if !ok { @@ -157,9 +170,21 @@ func (d *DDLCommandRunner) cancelCommandsForSchema(schemaName string) error { d.commands.Delete(key) command.Cancel() command.Cleanup() + found = true } return true }) + if !found { + d.handleCancelLock.Lock() + defer d.handleCancelLock.Unlock() + // If we didn't find the command to delete it's possible that the cancel has arrived before the original command + // so we add it to a map so we know to ignore the command if it arrives later + key := cancelledCommandsKey{ + schemaName: schemaName, + originatingNodeID: originatingNodeID, + } + d.cancellationSequences[key] = commandID + } return nil } @@ -219,7 +244,9 @@ func (d *DDLCommandRunner) HandleDdlMessage(ddlMsg remoting.ClusterMessage) erro } com = NewDDLCommand(d.ce, DDLCommandType(ddlInfo.CommandType), ddlInfo.GetSchemaName(), ddlInfo.GetSql(), ddlInfo.GetTableSequences(), ddlInfo.GetExtraData()) - d.commands.Store(skey, com) + if !d.storeIfNotCancelled(skey, ddlInfo.CommandId, ddlInfo.GetOriginatingNodeId(), com) { + return nil + } } } else if !ok { // This can happen if ddlMsg comes in after commands are cancelled @@ -232,13 +259,32 @@ func (d *DDLCommandRunner) HandleDdlMessage(ddlMsg remoting.ClusterMessage) erro com.Cleanup() } log.Debugf("Running phase %d for DDL message %d %s returned err %v", phase, com.CommandType(), ddlInfo.Sql, err) - if phase == int32(com.NumPhases()-1) { - // Final phase so delete the command + if phase == int32(com.NumPhases()-1) || err != nil { + // Final phase or err so delete the command d.commands.Delete(skey) } return err } +func (d *DDLCommandRunner) storeIfNotCancelled(skey string, commandID int64, originatingNodeID int64, com DDLCommand) bool { + d.handleCancelLock.Lock() + defer d.handleCancelLock.Unlock() + // We first check if we have already received a cancel for commands up to this id - cancels can come in before the + // original command was fielded + key := cancelledCommandsKey{ + schemaName: com.SchemaName(), + originatingNodeID: originatingNodeID, + } + cid, ok := d.cancellationSequences[key] + if ok && cid >= commandID { + log.Debugf("ddl command arrived after cancellation, it will be ignored command id %d cid %d", commandID, cid) + return false + } + + d.commands.Store(skey, com) + return true +} + func (d *DDLCommandRunner) RunCommand(ctx context.Context, command DDLCommand) error { log.Debugf("Attempting to run DDL command %d", command.CommandType()) lockName := getLockName(command.SchemaName()) @@ -297,6 +343,7 @@ func (d *DDLCommandRunner) RunWithLock(commandKey string, command DDLCommand, dd } if err != nil { log.Debugf("Error return from broadcasting phase %d for DDL command %d %s %v cancel will be broadcast", phase, command.CommandType(), ddlInfo.Sql, err) + d.commands.Delete(commandKey) // Broadcast a cancel to clean up command state across the cluster if err2 := d.broadcastCancel(command.SchemaName()); err2 != nil { // Ignore @@ -309,7 +356,11 @@ func (d *DDLCommandRunner) RunWithLock(commandKey string, command DDLCommand, dd } func (d *DDLCommandRunner) broadcastCancel(schemaName string) error { - return d.ce.ddlResetClient.Broadcast(&clustermsgs.DDLCancelMessage{SchemaName: schemaName}) + return d.ce.ddlResetClient.Broadcast(&clustermsgs.DDLCancelMessage{ + SchemaName: schemaName, + CommandId: atomic.LoadInt64(&d.idSeq), + OriginatingNodeId: int64(d.ce.config.NodeID), + }) } func (d *DDLCommandRunner) broadcastDDL(phase int32, ddlInfo *clustermsgs.DDLStatementInfo) error { @@ -340,8 +391,10 @@ func (d *DDLCommandRunner) getLock(lockName string) error { } func (d *DDLCommandRunner) empty() bool { + log.Debug("DDLCommand Runner state:") count := 0 d.commands.Range(func(key, value interface{}) bool { + log.Debugf("DDLCommand runner has command: %s", value.(DDLCommand).SQL()) //nolint:forcetypeassert count++ return true }) diff --git a/kafkatest/kafka_integration_test.go b/kafkatest/kafka_integration_test.go index 6c58449f..5694e32c 100644 --- a/kafkatest/kafka_integration_test.go +++ b/kafkatest/kafka_integration_test.go @@ -1,6 +1,3 @@ -//go:build integration -// +build integration - package kafkatest import ( @@ -194,6 +191,7 @@ func stopPranaCluster(t *testing.T, cluster []*server.Server) { } func waitUntilRowsInPayments(t *testing.T, numRows int, cli *client.Client) { + t.Helper() ok, err := commontest.WaitUntilWithError(func() (bool, error) { ch, err := cli.ExecuteStatement("select * from payments order by payment_id", nil, nil) require.NoError(t, err) diff --git a/msggen/generators.go b/msggen/generators.go index 391f6611..56821e52 100644 --- a/msggen/generators.go +++ b/msggen/generators.go @@ -26,9 +26,7 @@ func (p *PaymentGenerator) GenerateMessage(_ int32, index int64, rnd *rand.Rand) paymentTypes := []string{"btc", "p2p", "other"} currencies := []string{"gbp", "usd", "eur", "aud"} - // timestamp needs to be in the future - otherwise, if it's in the past Kafka might start deleting log entries - // thinking they're past log retention time. - timestamp := time.Date(2100, time.Month(4), 12, 9, 0, 0, 0, time.UTC) + timestamp := time.Now() m := make(map[string]interface{}) paymentID := fmt.Sprintf("payment%06d", index) diff --git a/protos/descriptors/squareup/cash/pranadb/clustermsgs/v1/clustermsgs.bin b/protos/descriptors/squareup/cash/pranadb/clustermsgs/v1/clustermsgs.bin index 25018dd6..8d6ad4c8 100644 --- a/protos/descriptors/squareup/cash/pranadb/clustermsgs/v1/clustermsgs.bin +++ b/protos/descriptors/squareup/cash/pranadb/clustermsgs/v1/clustermsgs.bin @@ -1,5 +1,5 @@ - + 6squareup/cash/pranadb/clustermsgs/v1/clustermsgs.proto$squareup.cash.pranadb.clustermsgs.v1" DDLStatementInfo. @@ -13,10 +13,13 @@ schemaName sql ( Rsql' table_sequences (RtableSequences -extra_data ( R extraData"3 -DDLCancelMessage - schema_name ( R -schemaName" +extra_data ( R extraData" +DDLCancelMessage. +originating_node_id (RoriginatingNodeId + schema_name ( R +schemaName + +command_id (R commandId" ReloadProtobuf"U ClusterProposeRequest shard_id (RshardId! diff --git a/protos/squareup/cash/pranadb/clustermsgs/v1/clustermsgs.proto b/protos/squareup/cash/pranadb/clustermsgs/v1/clustermsgs.proto index 0f09a216..aa1501b5 100644 --- a/protos/squareup/cash/pranadb/clustermsgs/v1/clustermsgs.proto +++ b/protos/squareup/cash/pranadb/clustermsgs/v1/clustermsgs.proto @@ -16,7 +16,9 @@ message DDLStatementInfo { } message DDLCancelMessage { - string schema_name = 1; + int64 originating_node_id = 1; + string schema_name = 2; + int64 command_id = 3; } message ReloadProtobuf { diff --git a/protos/squareup/cash/pranadb/v1/clustermsgs/clustermsgs.pb.go b/protos/squareup/cash/pranadb/v1/clustermsgs/clustermsgs.pb.go index 9698338b..fd0c96df 100644 --- a/protos/squareup/cash/pranadb/v1/clustermsgs/clustermsgs.pb.go +++ b/protos/squareup/cash/pranadb/v1/clustermsgs/clustermsgs.pb.go @@ -128,7 +128,9 @@ type DDLCancelMessage struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - SchemaName string `protobuf:"bytes,1,opt,name=schema_name,json=schemaName,proto3" json:"schema_name,omitempty"` + OriginatingNodeId int64 `protobuf:"varint,1,opt,name=originating_node_id,json=originatingNodeId,proto3" json:"originating_node_id,omitempty"` + SchemaName string `protobuf:"bytes,2,opt,name=schema_name,json=schemaName,proto3" json:"schema_name,omitempty"` + CommandId int64 `protobuf:"varint,3,opt,name=command_id,json=commandId,proto3" json:"command_id,omitempty"` } func (x *DDLCancelMessage) Reset() { @@ -163,6 +165,13 @@ func (*DDLCancelMessage) Descriptor() ([]byte, []int) { return file_squareup_cash_pranadb_clustermsgs_v1_clustermsgs_proto_rawDescGZIP(), []int{1} } +func (x *DDLCancelMessage) GetOriginatingNodeId() int64 { + if x != nil { + return x.OriginatingNodeId + } + return 0 +} + func (x *DDLCancelMessage) GetSchemaName() string { if x != nil { return x.SchemaName @@ -170,6 +179,13 @@ func (x *DDLCancelMessage) GetSchemaName() string { return "" } +func (x *DDLCancelMessage) GetCommandId() int64 { + if x != nil { + return x.CommandId + } + return 0 +} + type ReloadProtobuf struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -751,10 +767,15 @@ var file_squareup_cash_pranadb_clustermsgs_v1_clustermsgs_proto_rawDesc = []byte 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x04, 0x52, 0x0e, 0x74, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x65, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x65, 0x78, 0x74, 0x72, 0x61, 0x5f, 0x64, 0x61, 0x74, 0x61, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x65, 0x78, 0x74, - 0x72, 0x61, 0x44, 0x61, 0x74, 0x61, 0x22, 0x33, 0x0a, 0x10, 0x44, 0x44, 0x4c, 0x43, 0x61, 0x6e, - 0x63, 0x65, 0x6c, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x63, - 0x68, 0x65, 0x6d, 0x61, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0a, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x4e, 0x61, 0x6d, 0x65, 0x22, 0x10, 0x0a, 0x0e, 0x52, + 0x72, 0x61, 0x44, 0x61, 0x74, 0x61, 0x22, 0x82, 0x01, 0x0a, 0x10, 0x44, 0x44, 0x4c, 0x43, 0x61, + 0x6e, 0x63, 0x65, 0x6c, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x2e, 0x0a, 0x13, 0x6f, + 0x72, 0x69, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6e, 0x67, 0x5f, 0x6e, 0x6f, 0x64, 0x65, 0x5f, + 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x11, 0x6f, 0x72, 0x69, 0x67, 0x69, 0x6e, + 0x61, 0x74, 0x69, 0x6e, 0x67, 0x4e, 0x6f, 0x64, 0x65, 0x49, 0x64, 0x12, 0x1f, 0x0a, 0x0b, 0x73, + 0x63, 0x68, 0x65, 0x6d, 0x61, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0a, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1d, 0x0a, 0x0a, + 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x09, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x49, 0x64, 0x22, 0x10, 0x0a, 0x0e, 0x52, 0x65, 0x6c, 0x6f, 0x61, 0x64, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x22, 0x55, 0x0a, 0x15, 0x43, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x50, 0x72, 0x6f, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x73, 0x68, 0x61, 0x72, 0x64, 0x5f, diff --git a/sqltest/sql_test.go b/sqltest/sql_test.go index 1ab75099..f4cf5b00 100644 --- a/sqltest/sql_test.go +++ b/sqltest/sql_test.go @@ -24,5 +24,5 @@ func TestSQLClusteredThreeNodes(t *testing.T) { t.Skip("-short: skipped") } log.Info("Running TestSQLClusteredThreeNodes") - testSQL(t, false, 3, 3, false, false, tlsKeysInfo) + testSQL(t, false, 3, 3, false, true, tlsKeysInfo) }