diff --git a/checkers/testdata/_integration/check_main_only/linttest.golden b/checkers/testdata/_integration/check_main_only/linttest.golden index a9d5cd713..214b733c9 100644 --- a/checkers/testdata/_integration/check_main_only/linttest.golden +++ b/checkers/testdata/_integration/check_main_only/linttest.golden @@ -38,6 +38,7 @@ exit status 1 ./main.go:168:9: sloppyLen: len(xs) < 0 is always false ./main.go:173:5: sloppyReassign: re-assignment to `err` can be replaced with `err := (point{})` ./main.go:180:2: switchTrue: replace 'switch true {}' with 'switch {}' +./main.go:260:2: typeAssertChain: rewrite if-else to type switch statement ./main.go:189:2: typeSwitchVar: case 0 can benefit from type switch with assignment ./main.go:189:2: typeSwitchVar: case 1 can benefit from type switch with assignment ./main.go:200:8: typeUnparen: could simplify (func()) to func() diff --git a/checkers/testdata/_integration/check_main_only/main.go b/checkers/testdata/_integration/check_main_only/main.go index 8740f59af..5f4289f92 100644 --- a/checkers/testdata/_integration/check_main_only/main.go +++ b/checkers/testdata/_integration/check_main_only/main.go @@ -255,5 +255,14 @@ func exitAfterDefer() { log.Fatal(123) } +func typeAssertChain() { + var x interface{} + if v, ok := x.(int8); ok { + _ = v + } else if v, ok := x.(int16); ok { + _ = v + } +} + func main() { } diff --git a/checkers/testdata/typeAssertChain/negative_tests.go b/checkers/testdata/typeAssertChain/negative_tests.go new file mode 100644 index 000000000..47a77b0ec --- /dev/null +++ b/checkers/testdata/typeAssertChain/negative_tests.go @@ -0,0 +1,62 @@ +package checker_test + +func negativeTests() { + var x interface{} + + switch v := x.(type) { + case int8: + _ = v + case int16: + _ = v + } + + switch v := x.(type) { + case int8: + _ = v + case int16: + _ = v + case int32: + _ = v + } + + // Not a type assertion chain. + if true { + } else if true { + } else if false { + } + + // Only a single type assertion. + if v, ok := x.(int8); ok { + _ = v + } + + // Duplicated types. + if v, ok := x.(int8); ok { + _ = v + } else if v, ok := x.(int8); ok { + _ = v + } + + // Non-matching condition. + if v, ok := x.(int8); ok { + _ = v + } else if v, ok := x.(int8); true { + _ = v + _ = ok + } + if v, ok := x.(int8); ok { + _ = v + } else if v, ok := x.(int8); !ok { + _ = v + } + + var y interface{} + // Mixed type-asserted values. + if v1, ok := x.(int8); ok { + _ = v1 + } else if v2, ok := x.(int16); ok { + _ = v2 + } else if v3, ok := y.(int32); ok { + _ = v3 + } +} diff --git a/checkers/testdata/typeAssertChain/positive_tests.go b/checkers/testdata/typeAssertChain/positive_tests.go new file mode 100644 index 000000000..5887b0162 --- /dev/null +++ b/checkers/testdata/typeAssertChain/positive_tests.go @@ -0,0 +1,30 @@ +package checker_test + +func suggestTypeSwitch() { + var x interface{} + + /*! rewrite if-else to type switch statement */ + if v, ok := x.(int8); ok { + _ = v + } else if v, ok := x.(int16); ok { + _ = v + } + + /*! rewrite if-else to type switch statement */ + if v, ok := x.(int8); ok { + _ = v + } else if v, ok := x.(int16); ok { + _ = v + } else if v, ok := x.(int32); ok { + _ = v + } + + /*! rewrite if-else to type switch statement */ + if v1, ok := x.(int8); ok { + _ = v1 + } else if v2, ok := x.(int16); ok { + _ = v2 + } else if v3, ok := x.(int32); ok { + _ = v3 + } +} diff --git a/checkers/typeAssertChain_checker.go b/checkers/typeAssertChain_checker.go new file mode 100644 index 000000000..c0c42e351 --- /dev/null +++ b/checkers/typeAssertChain_checker.go @@ -0,0 +1,132 @@ +package checkers + +import ( + "go/ast" + "go/token" + + "github.com/go-critic/go-critic/checkers/internal/lintutil" + "github.com/go-lintpack/lintpack" + "github.com/go-lintpack/lintpack/astwalk" + "github.com/go-toolsmith/astcast" + "github.com/go-toolsmith/astequal" + "github.com/go-toolsmith/astp" +) + +func init() { + var info lintpack.CheckerInfo + info.Name = "typeAssertChain" + info.Tags = []string{"style", "experimental"} + info.Summary = "Detects repeated type assertions and suggests to replace them with type switch statement" + info.Before = ` +if x, ok := v.(T1); ok { + // Code A, uses x. +} else if x, ok := v.(T2); ok { + // Code B, uses x. +} else if x, ok := v.(T3); ok { + // Code C, uses x. +}` + info.After = ` +switch x := v.(T1) { +case cond1: + // Code A, uses x. +case cond2: + // Code B, uses x. +default: + // Code C, uses x. +}` + + collection.AddChecker(&info, func(ctx *lintpack.CheckerContext) lintpack.FileWalker { + return astwalk.WalkerForStmt(&typeAssertChainChecker{ctx: ctx}) + }) +} + +type typeAssertChainChecker struct { + astwalk.WalkHandler + ctx *lintpack.CheckerContext + + cause *ast.IfStmt + visited map[*ast.IfStmt]bool + typeSet lintutil.AstSet +} + +func (c *typeAssertChainChecker) EnterFunc(fn *ast.FuncDecl) bool { + if fn.Body == nil { + return false + } + c.visited = make(map[*ast.IfStmt]bool) + return true +} + +func (c *typeAssertChainChecker) VisitStmt(stmt ast.Stmt) { + ifstmt, ok := stmt.(*ast.IfStmt) + if !ok || c.visited[ifstmt] || ifstmt.Init == nil { + return + } + assertion := c.getTypeAssert(ifstmt) + if assertion == nil { + return + } + c.cause = ifstmt + c.checkIfStmt(ifstmt, assertion) +} + +func (c *typeAssertChainChecker) getTypeAssert(ifstmt *ast.IfStmt) *ast.TypeAssertExpr { + assign := astcast.ToAssignStmt(ifstmt.Init) + if len(assign.Lhs) != 2 || len(assign.Rhs) != 1 { + return nil + } + if !astp.IsIdent(assign.Lhs[0]) || assign.Tok != token.DEFINE { + return nil + } + if !astequal.Expr(assign.Lhs[1], ifstmt.Cond) { + return nil + } + + assertion, ok := assign.Rhs[0].(*ast.TypeAssertExpr) + if !ok { + return nil + } + return assertion +} + +func (c *typeAssertChainChecker) checkIfStmt(stmt *ast.IfStmt, assertion *ast.TypeAssertExpr) { + if c.countTypeAssertions(stmt, assertion) >= 2 { + c.warn() + } +} + +func (c *typeAssertChainChecker) countTypeAssertions(stmt *ast.IfStmt, assertion *ast.TypeAssertExpr) int { + c.typeSet.Clear() + + count := 1 + x := assertion.X + c.typeSet.Insert(assertion.Type) + for { + e, ok := stmt.Else.(*ast.IfStmt) + if !ok { + return count + } + assertion = c.getTypeAssert(e) + if assertion == nil { + return count + } + if !c.typeSet.Insert(assertion.Type) { + // Asserted type is duplicated. + // Type switch does not permit duplicate cases, + // so give up. + return 0 + } + if !astequal.Expr(x, assertion.X) { + // Mixed type asserting chain. + // Can't be easily translated to a type switch. + return 0 + } + stmt = e + count++ + c.visited[e] = true + } +} + +func (c *typeAssertChainChecker) warn() { + c.ctx.Warn(c.cause, "rewrite if-else to type switch statement") +}