From 5fd7504fc69c8c0624a429f59f1051e749c541e3 Mon Sep 17 00:00:00 2001
From: Wenqi Mou <wenqimou@gmail.com>
Date: Thu, 2 Jan 2025 18:02:52 -0500
Subject: [PATCH] add one more unit test, and some minor fixes

Signed-off-by: Wenqi Mou <wenqimou@gmail.com>
---
 .../log_client/batch_meta_processor.go        |  2 +-
 br/pkg/restore/log_client/client.go           |  4 +-
 br/pkg/stream/table_mapping.go                |  2 +-
 br/pkg/stream/table_mapping_test.go           | 22 +++----
 br/pkg/task/common.go                         |  4 +-
 br/pkg/task/restore.go                        |  6 +-
 br/pkg/task/restore_test.go                   |  2 -
 br/pkg/task/stream.go                         | 10 ++-
 br/pkg/utils/BUILD.bazel                      |  1 +
 br/pkg/utils/filter.go                        | 61 ++++++++++---------
 br/pkg/utils/filter_test.go                   | 53 ++++++++++++++++
 11 files changed, 112 insertions(+), 55 deletions(-)
 create mode 100644 br/pkg/utils/filter_test.go

diff --git a/br/pkg/restore/log_client/batch_meta_processor.go b/br/pkg/restore/log_client/batch_meta_processor.go
index b0e908b2c5d787..10032b8f2b1c18 100644
--- a/br/pkg/restore/log_client/batch_meta_processor.go
+++ b/br/pkg/restore/log_client/batch_meta_processor.go
@@ -84,7 +84,7 @@ func (rp *RestoreMetaKVProcessor) RestoreAndRewriteMetaKVFiles(
 		return errors.Trace(err)
 	}
 
-	// UpdateTable global schema version to trigger a full reload so every TiDB node in the cluster will get synced with
+	// AddTable global schema version to trigger a full reload so every TiDB node in the cluster will get synced with
 	// the latest schema update.
 	if err := rp.client.UpdateSchemaVersionFullReload(ctx); err != nil {
 		return errors.Trace(err)
diff --git a/br/pkg/restore/log_client/client.go b/br/pkg/restore/log_client/client.go
index 3475b8acf5e60e..8d1eda42a9c621 100644
--- a/br/pkg/restore/log_client/client.go
+++ b/br/pkg/restore/log_client/client.go
@@ -857,7 +857,7 @@ func readFilteredFullBackupTables(
 	ctx context.Context,
 	s storage.ExternalStorage,
 	tableFilter filter.Filter,
-	piTRTableFilter *utils.PiTRTableFilter,
+	piTRTableFilter *utils.PiTRTableTracker,
 	cipherInfo *backuppb.CipherInfo,
 ) (map[int64]*metautil.Table, error) {
 	metaData, err := s.ReadFile(ctx, metautil.MetaFile)
@@ -934,7 +934,7 @@ type GetIDMapConfig struct {
 	// optional
 	FullBackupStorage *FullBackupStorageConfig
 	CipherInfo        *backuppb.CipherInfo
-	PiTRTableFilter   *utils.PiTRTableFilter // generated table filter that contain all the table id that needs to restore
+	PiTRTableFilter   *utils.PiTRTableTracker // generated table filter that contain all the table id that needs to restore
 }
 
 const UnsafePITRLogRestoreStartBeforeAnyUpstreamUserDDL = "UNSAFE_PITR_LOG_RESTORE_START_BEFORE_ANY_UPSTREAM_USER_DDL"
diff --git a/br/pkg/stream/table_mapping.go b/br/pkg/stream/table_mapping.go
index bc71865575aa44..705c4dbf74985a 100644
--- a/br/pkg/stream/table_mapping.go
+++ b/br/pkg/stream/table_mapping.go
@@ -276,7 +276,7 @@ func (tm *TableMappingManager) ReplaceTemporaryIDs(
 	return nil
 }
 
-func (tm *TableMappingManager) FilterDBReplaceMap(filter *utils.PiTRTableFilter) {
+func (tm *TableMappingManager) FilterDBReplaceMap(filter *utils.PiTRTableTracker) {
 	// collect all IDs that should be kept
 	keepIDs := make(map[UpstreamID]struct{})
 
diff --git a/br/pkg/stream/table_mapping_test.go b/br/pkg/stream/table_mapping_test.go
index 451ced524a6161..bf75efea2ed847 100644
--- a/br/pkg/stream/table_mapping_test.go
+++ b/br/pkg/stream/table_mapping_test.go
@@ -364,7 +364,7 @@ func TestFilterDBReplaceMap(t *testing.T) {
 	tests := []struct {
 		name     string
 		initial  map[UpstreamID]*DBReplace
-		filter   *utils.PiTRTableFilter
+		filter   *utils.PiTRTableTracker
 		expected map[UpstreamID]*DBReplace
 	}{
 		{
@@ -378,8 +378,8 @@ func TestFilterDBReplaceMap(t *testing.T) {
 					},
 				},
 			},
-			filter: &utils.PiTRTableFilter{
-				DbIdToTable: map[int64]map[int64]struct{}{},
+			filter: &utils.PiTRTableTracker{
+				DBIdToTable: map[int64]map[int64]struct{}{},
 			},
 			expected: map[UpstreamID]*DBReplace{},
 		},
@@ -401,8 +401,8 @@ func TestFilterDBReplaceMap(t *testing.T) {
 					},
 				},
 			},
-			filter: &utils.PiTRTableFilter{
-				DbIdToTable: map[int64]map[int64]struct{}{
+			filter: &utils.PiTRTableTracker{
+				DBIdToTable: map[int64]map[int64]struct{}{
 					1: {10: struct{}{}},
 				},
 			},
@@ -429,8 +429,8 @@ func TestFilterDBReplaceMap(t *testing.T) {
 					},
 				},
 			},
-			filter: &utils.PiTRTableFilter{
-				DbIdToTable: map[int64]map[int64]struct{}{
+			filter: &utils.PiTRTableTracker{
+				DBIdToTable: map[int64]map[int64]struct{}{
 					1: {
 						10: struct{}{},
 						12: struct{}{},
@@ -474,8 +474,8 @@ func TestFilterDBReplaceMap(t *testing.T) {
 					},
 				},
 			},
-			filter: &utils.PiTRTableFilter{
-				DbIdToTable: map[int64]map[int64]struct{}{
+			filter: &utils.PiTRTableTracker{
+				DBIdToTable: map[int64]map[int64]struct{}{
 					1: {10: struct{}{}},
 				},
 			},
@@ -523,8 +523,8 @@ func TestFilterDBReplaceMap(t *testing.T) {
 					},
 				},
 			},
-			filter: &utils.PiTRTableFilter{
-				DbIdToTable: map[int64]map[int64]struct{}{
+			filter: &utils.PiTRTableTracker{
+				DBIdToTable: map[int64]map[int64]struct{}{
 					1: {10: struct{}{}},
 					2: {
 						20: struct{}{},
diff --git a/br/pkg/task/common.go b/br/pkg/task/common.go
index 426911c08da496..79c9eaef8c5ac5 100644
--- a/br/pkg/task/common.go
+++ b/br/pkg/task/common.go
@@ -256,8 +256,8 @@ type Config struct {
 	TableFilter filter.Filter `json:"-" toml:"-"`
 	// PiTRTableFilter generated from TableFilter during snapshot restore, it has all the db id and table id that needs
 	// to be restored
-	PiTRTableFilter    *utils.PiTRTableFilter `json:"-" toml:"-"`
-	SwitchModeInterval time.Duration          `json:"switch-mode-interval" toml:"switch-mode-interval"`
+	PiTRTableFilter    *utils.PiTRTableTracker `json:"-" toml:"-"`
+	SwitchModeInterval time.Duration           `json:"switch-mode-interval" toml:"switch-mode-interval"`
 	// Schemas is a database name set, to check whether the restore database has been backup
 	Schemas map[string]struct{}
 	// Tables is a table name set, to check whether the restore table has been backup
diff --git a/br/pkg/task/restore.go b/br/pkg/task/restore.go
index cc2a1e962e9562..a8c073fdee8382 100644
--- a/br/pkg/task/restore.go
+++ b/br/pkg/task/restore.go
@@ -1468,7 +1468,7 @@ func adjustTablesToRestoreAndCreateFilter(
 	newlyCreatedDBs := logBackupTableHistory.GetNewlyCreatedDBHistory()
 	for dbId, dbName := range newlyCreatedDBs {
 		if utils.MatchSchema(cfg.TableFilter, dbName) {
-			piTRTableFilter.UpdateDB(dbId)
+			piTRTableFilter.AddDB(dbId)
 		}
 	}
 
@@ -1505,7 +1505,7 @@ func adjustTablesToRestoreAndCreateFilter(
 			// put this db/table id into pitr filter as it matches with user's filter
 			// have to update filter here since table might be empty or not in snapshot so nothing will be returned .
 			// but we still need to capture this table id to restore during log restore.
-			piTRTableFilter.UpdateTable(end.DbID, tableID)
+			piTRTableFilter.AddTable(end.DbID, tableID)
 
 			// check if snapshot contains the original db/table
 			originalDB, exists := snapshotDBMap[start.DbID]
@@ -1564,7 +1564,7 @@ func adjustTablesToRestoreAndCreateFilter(
 
 func UpdatePiTRFilter(cfg *RestoreConfig, tableMap map[int64]*metautil.Table) {
 	for _, table := range tableMap {
-		cfg.PiTRTableFilter.UpdateTable(table.DB.ID, table.Info.ID)
+		cfg.PiTRTableFilter.AddTable(table.DB.ID, table.Info.ID)
 	}
 }
 
diff --git a/br/pkg/task/restore_test.go b/br/pkg/task/restore_test.go
index 5cb3c8a67abe72..86ceb3755ee097 100644
--- a/br/pkg/task/restore_test.go
+++ b/br/pkg/task/restore_test.go
@@ -308,8 +308,6 @@ func TestFilterDDLJobs(t *testing.T) {
 	ddlJobs := task.FilterDDLJobs(allDDLJobs, tables)
 	for _, job := range ddlJobs {
 		t.Logf("get ddl job: %s", job.Query)
-		t.Logf("table name: %s", job.TableName)
-		t.Logf("dbid: %s", job.SchemaName)
 	}
 	require.Equal(t, 7, len(ddlJobs))
 }
diff --git a/br/pkg/task/stream.go b/br/pkg/task/stream.go
index 62d818c989b87c..bae493fcd293ef 100644
--- a/br/pkg/task/stream.go
+++ b/br/pkg/task/stream.go
@@ -56,6 +56,7 @@ import (
 	"github.com/pingcap/tidb/br/pkg/utils"
 	"github.com/pingcap/tidb/pkg/kv"
 	"github.com/pingcap/tidb/pkg/meta/model"
+	"github.com/pingcap/tidb/pkg/parser/mysql"
 	"github.com/pingcap/tidb/pkg/util/cdcutil"
 	"github.com/spf13/pflag"
 	"github.com/tikv/client-go/v2/oracle"
@@ -1301,9 +1302,12 @@ func RunStreamRestore(
 	if err != nil {
 		return errors.Trace(err)
 	}
-	// TODO: pitr filtered restore doesn't support restore system table yet, hacky way to override the sys filter here
+	// TODO: pitr filtered restore doesn't support restore system table yet
 	if cfg.ExplicitFilter {
-		// add some check
+		if cfg.TableFilter.MatchSchema(mysql.SystemDB) {
+			return errors.Annotatef(berrors.ErrInvalidArgument,
+				"PiTR doesn't support custom filter to include system db, consider to exclude system db")
+		}
 	}
 	metaInfoProcessor := logclient.NewMetaKVInfoProcessor(logClient)
 	// only doesn't need to build if id map has been saved during log restore
@@ -1318,7 +1322,7 @@ func RunStreamRestore(
 			return errors.Trace(err)
 		}
 		dbReplace := metaInfoProcessor.GetTableMappingManager().DBReplaceMap
-		stream.LogDBReplaceMap("scanning log meta kv before snapshot restore", dbReplace)
+		stream.LogDBReplaceMap("scanned log meta kv before snapshot restore", dbReplace)
 	}
 
 	// restore full snapshot.
diff --git a/br/pkg/utils/BUILD.bazel b/br/pkg/utils/BUILD.bazel
index 06827b9435a2ad..d708078e76abfd 100644
--- a/br/pkg/utils/BUILD.bazel
+++ b/br/pkg/utils/BUILD.bazel
@@ -79,6 +79,7 @@ go_test(
         "backoff_test.go",
         "db_test.go",
         "error_handling_test.go",
+        "filter_test.go",
         "json_test.go",
         "key_test.go",
         "main_test.go",
diff --git a/br/pkg/utils/filter.go b/br/pkg/utils/filter.go
index 2d09cead17c9bf..bfe278c07556c6 100644
--- a/br/pkg/utils/filter.go
+++ b/br/pkg/utils/filter.go
@@ -22,44 +22,45 @@ import (
 	filter "github.com/pingcap/tidb/pkg/util/table-filter"
 )
 
-type PiTRTableFilter struct {
-	DbIdToTable map[int64]map[int64]struct{}
+// PiTRTableTracker tracks all the DB and table ids that need to restore in a PiTR
+type PiTRTableTracker struct {
+	DBIdToTable map[int64]map[int64]struct{}
 }
 
-func NewPiTRTableFilter() *PiTRTableFilter {
-	return &PiTRTableFilter{
-		DbIdToTable: make(map[int64]map[int64]struct{}),
+func NewPiTRTableFilter() *PiTRTableTracker {
+	return &PiTRTableTracker{
+		DBIdToTable: make(map[int64]map[int64]struct{}),
 	}
 }
 
-// UpdateTable adds a table ID to the filter for the given database ID
-func (f *PiTRTableFilter) UpdateTable(dbID, tableID int64) {
-	if f.DbIdToTable == nil {
-		f.DbIdToTable = make(map[int64]map[int64]struct{})
+// AddTable adds a table ID to the filter for the given database ID
+func (f *PiTRTableTracker) AddTable(dbID, tableID int64) {
+	if f.DBIdToTable == nil {
+		f.DBIdToTable = make(map[int64]map[int64]struct{})
 	}
 
-	if _, ok := f.DbIdToTable[dbID]; !ok {
-		f.DbIdToTable[dbID] = make(map[int64]struct{})
+	if _, ok := f.DBIdToTable[dbID]; !ok {
+		f.DBIdToTable[dbID] = make(map[int64]struct{})
 	}
 
-	f.DbIdToTable[dbID][tableID] = struct{}{}
+	f.DBIdToTable[dbID][tableID] = struct{}{}
 }
 
-// UpdateDB adds the database id
-func (f *PiTRTableFilter) UpdateDB(dbID int64) {
-	if f.DbIdToTable == nil {
-		f.DbIdToTable = make(map[int64]map[int64]struct{})
+// AddDB adds the database id
+func (f *PiTRTableTracker) AddDB(dbID int64) {
+	if f.DBIdToTable == nil {
+		f.DBIdToTable = make(map[int64]map[int64]struct{})
 	}
 
-	if _, ok := f.DbIdToTable[dbID]; !ok {
-		f.DbIdToTable[dbID] = make(map[int64]struct{})
+	if _, ok := f.DBIdToTable[dbID]; !ok {
+		f.DBIdToTable[dbID] = make(map[int64]struct{})
 	}
 }
 
 // Remove removes a table ID from the filter for the given database ID.
 // Returns true if the table was found and removed, false otherwise.
-func (f *PiTRTableFilter) Remove(dbID, tableID int64) bool {
-	if tables, ok := f.DbIdToTable[dbID]; ok {
+func (f *PiTRTableTracker) Remove(dbID, tableID int64) bool {
+	if tables, ok := f.DBIdToTable[dbID]; ok {
 		if _, exists := tables[tableID]; exists {
 			delete(tables, tableID)
 			return true
@@ -69,8 +70,8 @@ func (f *PiTRTableFilter) Remove(dbID, tableID int64) bool {
 }
 
 // ContainsTable checks if the given database ID and table ID combination exists in the filter
-func (f *PiTRTableFilter) ContainsTable(dbID, tableID int64) bool {
-	if tables, ok := f.DbIdToTable[dbID]; ok {
+func (f *PiTRTableTracker) ContainsTable(dbID, tableID int64) bool {
+	if tables, ok := f.DBIdToTable[dbID]; ok {
 		_, exists := tables[tableID]
 		return exists
 	}
@@ -78,20 +79,20 @@ func (f *PiTRTableFilter) ContainsTable(dbID, tableID int64) bool {
 }
 
 // ContainsDB checks if the given database ID exists in the filter
-func (f *PiTRTableFilter) ContainsDB(dbID int64) bool {
-	_, ok := f.DbIdToTable[dbID]
+func (f *PiTRTableTracker) ContainsDB(dbID int64) bool {
+	_, ok := f.DBIdToTable[dbID]
 	return ok
 }
 
-// String returns a string representation of the PiTRTableFilter for debugging
-func (f *PiTRTableFilter) String() string {
-	if f == nil || f.DbIdToTable == nil {
-		return "PiTRTableFilter{nil}"
+// String returns a string representation of the PiTRTableTracker for debugging
+func (f *PiTRTableTracker) String() string {
+	if f == nil || f.DBIdToTable == nil {
+		return "PiTRTableTracker{nil}"
 	}
 
 	var result strings.Builder
-	result.WriteString("PiTRTableFilter{\n")
-	for dbID, tables := range f.DbIdToTable {
+	result.WriteString("PiTRTableTracker{\n")
+	for dbID, tables := range f.DBIdToTable {
 		result.WriteString(fmt.Sprintf("  DB[%d]: {", dbID))
 		tableIDs := make([]int64, 0, len(tables))
 		for tableID := range tables {
diff --git a/br/pkg/utils/filter_test.go b/br/pkg/utils/filter_test.go
new file mode 100644
index 00000000000000..adac7e7d50da46
--- /dev/null
+++ b/br/pkg/utils/filter_test.go
@@ -0,0 +1,53 @@
+package utils
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/require"
+)
+
+func TestPiTRTableTracker(t *testing.T) {
+	t.Run("test new tracker", func(t *testing.T) {
+		tracker := NewPiTRTableFilter()
+		require.NotNil(t, tracker)
+		require.NotNil(t, tracker.DBIdToTable)
+		require.Empty(t, tracker.DBIdToTable)
+	})
+
+	t.Run("test update and contains table", func(t *testing.T) {
+		tracker := NewPiTRTableFilter()
+
+		tracker.AddDB(1)
+		tracker.AddTable(1, 100)
+		tracker.AddDB(2)
+		require.True(t, tracker.ContainsDB(1))
+		require.True(t, tracker.ContainsDB(2))
+		require.True(t, tracker.ContainsTable(1, 100))
+		require.False(t, tracker.ContainsTable(1, 101))
+		require.False(t, tracker.ContainsTable(2, 100))
+
+		tracker.AddTable(1, 101)
+		tracker.AddTable(2, 200)
+		require.True(t, tracker.ContainsTable(1, 100))
+		require.True(t, tracker.ContainsTable(1, 101))
+		require.True(t, tracker.ContainsTable(2, 200))
+
+		tracker.AddTable(3, 300)
+		require.True(t, tracker.ContainsDB(3))
+		require.True(t, tracker.ContainsTable(3, 300))
+	})
+
+	t.Run("test remove table", func(t *testing.T) {
+		tracker := NewPiTRTableFilter()
+
+		tracker.AddTable(1, 100)
+		tracker.AddTable(1, 101)
+
+		require.True(t, tracker.Remove(1, 100))
+		require.False(t, tracker.ContainsTable(1, 100))
+		require.True(t, tracker.ContainsTable(1, 101))
+
+		require.False(t, tracker.Remove(1, 102))
+		require.False(t, tracker.Remove(2, 100))
+	})
+}