Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync v2 #2180

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
3 changes: 2 additions & 1 deletion spelling_dict.txt
Original file line number Diff line number Diff line change
Expand Up @@ -456,4 +456,5 @@ multipoint
multilinestring
multipolygon
geometrycollection
charlength
charlength
xmls
4 changes: 2 additions & 2 deletions sqle/api/controller/v1/sql_audit_record.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@

var task *model.Task
if req.InstanceName != "" {
task, err = buildOnlineTaskForAudit(c, s, uint64(user.ID), req.InstanceName, req.InstanceSchema, projectUid, sqls)

Check failure on line 114 in sqle/api/controller/v1/sql_audit_record.go

View workflow job for this annotation

GitHub Actions / lint

cannot use uint64(user.ID) (value of type uint64) as uint value in argument to buildOnlineTaskForAudit (typecheck)
if err != nil {
return controller.JSONBaseErrorReq(c, err)
}
Expand Down Expand Up @@ -196,15 +196,15 @@
return nil
}

func buildOnlineTaskForAudit(c echo.Context, s *model.Storage, userId uint64, instanceName, instanceSchema, projectUid string, sqls getSQLFromFileResp) (*model.Task, error) {
instance, exist, err := dms.GetInstanceInProjectByName(c.Request().Context(), projectUid, instanceName)
func buildOnlineTaskForAudit(c echo.Context, s *model.Storage, userId uint, instanceName, instanceSchema, projectName string, sqls getSQLFromFileResp) (*model.Task, error) {
instance, exist, err := s.GetInstanceByNameAndProjectName(instanceName, projectName)

Check failure on line 200 in sqle/api/controller/v1/sql_audit_record.go

View workflow job for this annotation

GitHub Actions / lint

s.GetInstanceByNameAndProjectName undefined (type *"github.com/actiontech/sqle/sqle/model".Storage has no field or method GetInstanceByNameAndProjectName) (typecheck)
if err != nil {
return nil, err
}
if !exist {
return nil, ErrInstanceNoAccess
}
can, err := CheckCurrentUserCanAccessInstances(c.Request().Context(), projectUid, controller.GetUserID(c), []*model.Instance{instance})

Check failure on line 207 in sqle/api/controller/v1/sql_audit_record.go

View workflow job for this annotation

GitHub Actions / lint

undeclared name: `projectUid` (typecheck)
if err != nil {
return nil, err
}
Expand All @@ -226,7 +226,7 @@
Schema: instanceSchema,
InstanceId: instance.ID,
Instance: instance,
CreateUserId: userId,

Check failure on line 229 in sqle/api/controller/v1/sql_audit_record.go

View workflow job for this annotation

GitHub Actions / lint

cannot use userId (variable of type uint) as uint64 value in struct literal (typecheck)
ExecuteSQLs: []*model.ExecuteSQL{},
SQLSource: sqls.SourceType,
DBType: instance.DbType,
Expand Down
32 changes: 5 additions & 27 deletions sqle/driver/mysql/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,19 +493,19 @@ func getCreateTableAndOnCondition(input *RuleHandlerInput) (map[string]*ast.Crea
if stmt.From == nil {
return nil, nil
}
tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.From.TableRefs)
tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.From.TableRefs)
onConditions = util.GetTableFromOnCondition(stmt.From.TableRefs)
case *ast.UpdateStmt:
if stmt.TableRefs == nil {
return nil, nil
}
tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.TableRefs.TableRefs)
tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.TableRefs.TableRefs)
onConditions = util.GetTableFromOnCondition(stmt.TableRefs.TableRefs)
case *ast.DeleteStmt:
if stmt.TableRefs == nil {
return nil, nil
}
tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.TableRefs.TableRefs)
tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.TableRefs.TableRefs)
onConditions = util.GetTableFromOnCondition(stmt.TableRefs.TableRefs)
default:
return nil, nil
Expand Down Expand Up @@ -696,28 +696,6 @@ func getTableNameCreateTableStmtMapForJoinType(sessionContext *session.Context,
return tableNameCreateTableStmtMap
}

func getTableNameCreateTableStmtMap(sessionContext *session.Context, joinStmt *ast.Join) map[string] /*table name or alias table name*/ *ast.CreateTableStmt {
tableNameCreateTableStmtMap := make(map[string]*ast.CreateTableStmt)
tableSources := util.GetTableSources(joinStmt)
for _, tableSource := range tableSources {
if tableNameStmt, ok := tableSource.Source.(*ast.TableName); ok {
tableName := tableNameStmt.Name.L
if tableSource.AsName.L != "" {
// 如果使用别名,则需要用别名引用
tableName = tableSource.AsName.L
}

createTableStmt, exist, err := sessionContext.GetCreateTableStmt(tableNameStmt)
if err != nil || !exist {
continue
}
// TODO: 跨库的 JOIN 无法区分
tableNameCreateTableStmtMap[tableName] = createTableStmt
}
}
return tableNameCreateTableStmtMap
}

func getOnConditionLeftAndRightType(onCondition *ast.OnCondition, createTableStmtMap map[string]*ast.CreateTableStmt) (byte, byte) {
var leftType, rightType byte
// onCondition在中的ColumnNameExpr.Refer为nil无法索引到原表名和表别名
Expand Down Expand Up @@ -3259,7 +3237,7 @@ func checkWhereConditionUseIndex(ctx *session.Context, whereVisitor *util.WhereW
continue
}

tableNameCreateTableStmtMap := getTableNameCreateTableStmtMap(ctx, whereExpr.TableRef)
tableNameCreateTableStmtMap := ctx.GetTableNameCreateTableStmtMap(whereExpr.TableRef)
util.ScanWhereStmt(func(expr ast.ExprNode) (skip bool) {
switch x := expr.(type) {
case *ast.ColumnNameExpr:
Expand Down Expand Up @@ -5465,7 +5443,7 @@ func judgeJoinFieldUseIndex(input *RuleHandlerInput) (bool, error) {
// 如果SQL没有JOIN多表,则不需要审核
return true, fmt.Errorf("sql have not join node")
}
tableNameCreateTableStmtMap := getTableNameCreateTableStmtMap(input.Ctx, joinNode)
tableNameCreateTableStmtMap := input.Ctx.GetTableNameCreateTableStmtMap(joinNode)
tableIndexes := make(map[string][]*ast.Constraint, len(tableNameCreateTableStmtMap))
for tableName, createTableStmt := range tableNameCreateTableStmtMap {
tableIndexes[tableName] = createTableStmt.Constraints
Expand Down
22 changes: 22 additions & 0 deletions sqle/driver/mysql/session/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -1121,3 +1121,25 @@ func (c *Context) GetExecutor() *executor.Executor {
func (c *Context) GetTableIndexesInfo(schema, tableName string) ([]*executor.TableIndexesInfo, error) {
return c.e.GetTableIndexesInfo(utils.SupplementalQuotationMarks(schema), utils.SupplementalQuotationMarks(tableName))
}

func (c *Context) GetTableNameCreateTableStmtMap(joinStmt *ast.Join) map[string] /*table name or alias table name*/ *ast.CreateTableStmt {
tableNameCreateTableStmtMap := make(map[string]*ast.CreateTableStmt)
tableSources := util.GetTableSources(joinStmt)
for _, tableSource := range tableSources {
if tableNameStmt, ok := tableSource.Source.(*ast.TableName); ok {
tableName := tableNameStmt.Name.L
if tableSource.AsName.L != "" {
// 如果使用别名,则需要用别名引用
tableName = tableSource.AsName.L
}

createTableStmt, exist, err := c.GetCreateTableStmt(tableNameStmt)
if err != nil || !exist {
continue
}
// TODO: 跨库的 JOIN 无法区分
tableNameCreateTableStmtMap[tableName] = createTableStmt
}
}
return tableNameCreateTableStmtMap
}