Skip to content

Commit

Permalink
query statement filters by table name
Browse files Browse the repository at this point in the history
  • Loading branch information
Hongyu Zhou committed Feb 23, 2024
1 parent b9fd7e8 commit 34b1599
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 15 deletions.
14 changes: 7 additions & 7 deletions pkg/reflector/dml_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,21 +131,21 @@ func (source *sqlDmlSource) Next(ctx context.Context) (statement schema.DMLState
// Helper function to generate the SQL query
func generateSQLQuery(ledgerTableName, shardingFamily, shardingTable string, blocksize int) string {
if shardingFamily != "" {
familiesStr := prepareFamilyString(shardingFamily)
return sqlgen.SqlSprintf("SELECT seq, leader_ts, statement, family_name, table_name FROM $1 WHERE seq > ? AND family_name IN $2 ORDER BY seq LIMIT $4",
familiesStr := prepareString(shardingFamily)
tablesStr := prepareString(shardingTable)
return sqlgen.SqlSprintf("SELECT seq, leader_ts, statement, family_name, table_name FROM $1 WHERE seq > ? AND family_name IN $2 AND CONCAT(family_name,'___',table_name) IN $3 ORDER BY seq LIMIT $4",
ledgerTableName,
familiesStr,
shardingTable,
tablesStr,
fmt.Sprintf("%d", blocksize))
} else {
return sqlgen.SqlSprintf("SELECT seq, leader_ts, statement, family_name, table_name FROM $1 WHERE seq > ? ORDER BY seq LIMIT $3",
return sqlgen.SqlSprintf("SELECT seq, leader_ts, statement, family_name, table_name FROM $1 WHERE seq > ? ORDER BY seq LIMIT $2",
ledgerTableName,
shardingTable,
fmt.Sprintf("%d", blocksize))
}
}

// Helper function to prepare the family string for SQL query
func prepareFamilyString(families string) string {
return "(\"" + strings.ReplaceAll(families, ",", "\", \"") + "\")"
func prepareString(str string) string {
return "(\"" + strings.ReplaceAll(str, ",", "\", \"") + "\")"
}
29 changes: 21 additions & 8 deletions pkg/reflector/dml_source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func TestSqlDmlSourceWithSharding(t *testing.T) {
db: db,
ledgerTableName: "ctlstore_dml_ledger",
shardingFamily: "foo",
shardingTable: "bar",
shardingTable: "foo___bar",
queryBlockSize: queryBlockSize,
}

Expand All @@ -145,10 +145,12 @@ func TestSqlDmlSourceWithSharding(t *testing.T) {

var ststr string
var ststr1 string
var ststr2 string
// Add statements for two different families
for i := 0; i < queryBlockSize*2; i++ {
ststr = srcutil.AddStatementWithFamilyAndTable("INSERT INTO foo___bar VALUES('hi mom')", "foo", "bar")
ststr1 = srcutil.AddStatementWithFamilyAndTable("INSERT INTO foo1___bar1 VALUES('hi mom')", "foo1", "bar1")
ststr1 = srcutil.AddStatementWithFamilyAndTable("INSERT INTO foo___bar1 VALUES('hi mom')", "foo", "bar1")
ststr2 = srcutil.AddStatementWithFamilyAndTable("INSERT INTO foo1___bar1 VALUES('hi mom')", "foo1", "bar1")
}

var lastSeq int64
Expand All @@ -157,11 +159,12 @@ func TestSqlDmlSourceWithSharding(t *testing.T) {
require.NoError(t, err)
require.Equal(t, ststr, st.Statement)
require.NotEqual(t, ststr1, st.Statement)
require.NotEqual(t, ststr2, st.Statement)
require.True(t, st.Sequence.Int() > lastSeq)
lastSeq = st.Sequence.Int()
// The initial sequence number is 1,and the addition of the statements takes turn between 'foo' and 'foo1',
// so the sequence number should be even for 'foo' and odd for 'foo1'
require.Equal(t, lastSeq%2, int64(0))
// The initial sequence number is 1,and the addition of the statements takes turn among 'foo___bar',
// 'foo___bar1' and 'foo1___bar1', so the sequence number should be 2, 5, 8, 11... for 'foo___bar'
require.Equal(t, (lastSeq-2)%3, int64(0))
}

_, err = src.Next(ctx)
Expand Down Expand Up @@ -201,7 +204,7 @@ func TestSqlDmlSourceWithSharding(t *testing.T) {
}
}

func TestPrepareFamilyString(t *testing.T) {
func TestPrepareString(t *testing.T) {
tests := []struct {
name string
input string
Expand All @@ -222,13 +225,23 @@ func TestPrepareFamilyString(t *testing.T) {
input: "",
expected: "(\"\")",
},
{
name: "Sharding table",
input: "foo___bar",
expected: "(\"foo___bar\")",
},
{
name: "Multiple sharding tables",
input: "foo___bar,foo1___bar1",
expected: "(\"foo___bar\", \"foo1___bar1\")",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
output := prepareFamilyString(tt.input)
output := prepareString(tt.input)
if output != tt.expected {
t.Errorf("prepareFamilyString(%q) = %q, want %q", tt.input, output, tt.expected)
t.Errorf("prepareString(%q) = %q, want %q", tt.input, output, tt.expected)
}
})
}
Expand Down

0 comments on commit 34b1599

Please sign in to comment.