Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 80 additions & 52 deletions clippy_lints/src/methods/unnecessary_fold.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use clippy_utils::diagnostics::span_lint_and_sugg;
use clippy_utils::res::{MaybeDef, MaybeResPath, MaybeTypeckRes};
use clippy_utils::res::{MaybeDef, MaybeQPath, MaybeResPath, MaybeTypeckRes};
use clippy_utils::source::snippet_with_applicability;
use clippy_utils::{peel_blocks, strip_pat_refs};
use rustc_ast::ast;
use rustc_data_structures::packed::Pu128;
use rustc_errors::Applicability;
use rustc_hir as hir;
use rustc_hir::PatKind;
use rustc_hir::def::{DefKind, Res};
use rustc_lint::LateContext;
use rustc_middle::ty;
use rustc_span::{Span, sym};
use rustc_span::{Span, Symbol, sym};

use super::UNNECESSARY_FOLD;

Expand Down Expand Up @@ -41,6 +42,18 @@ fn needs_turbofish(cx: &LateContext<'_>, expr: &hir::Expr<'_>) -> bool {
return false;
}

// - the final expression in the body of a function with a simple return type
if let hir::Node::Block(block) = parent
&& let mut parents = cx.tcx.hir_parent_iter(block.hir_id).map(|(_, def_id)| def_id)
&& let Some(hir::Node::Expr(_)) = parents.next()
&& let Some(hir::Node::Item(enclosing_item)) = parents.next()
&& let hir::ItemKind::Fn { sig, .. } = enclosing_item.kind
&& let hir::FnRetTy::Return(fn_return_ty) = sig.decl.output
&& matches!(fn_return_ty.kind, hir::TyKind::Path(..))
{
return false;
}
Comment on lines +45 to +55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the previous checks are a subset of what using expr_use_ctxt would get you. You want something like:

if let Some(use_cx) = expr_use_ctxt(cx, expr)
  && use_cx.is_same_ctxt
  && let Some(ty) = use_cx.use_node(cx).defined_ty(cx)
{
  // check ty
}


// if it's neither of those, stay on the safe side and suggest turbofish,
// even if it could work!
true
Expand All @@ -60,7 +73,7 @@ fn check_fold_with_op(
fold_span: Span,
op: hir::BinOpKind,
replacement: Replacement,
) {
) -> bool {
if let hir::ExprKind::Closure(&hir::Closure { body, .. }) = acc.kind
// Extract the body of the closure passed to fold
&& let closure_body = cx.tcx.hir_body(body)
Expand Down Expand Up @@ -93,7 +106,7 @@ fn check_fold_with_op(
r = snippet_with_applicability(cx, right_expr.span, "EXPR", &mut applicability),
)
} else {
format!("{method}{turbofish}()", method = replacement.method_name,)
format!("{method}{turbofish}()", method = replacement.method_name)
};

span_lint_and_sugg(
Expand All @@ -105,6 +118,41 @@ fn check_fold_with_op(
sugg,
applicability,
);
return true;
}
false
}

fn check_fold_with_method(
cx: &LateContext<'_>,
expr: &hir::Expr<'_>,
acc: &hir::Expr<'_>,
fold_span: Span,
method: Symbol,
replacement: Replacement,
) {
// Extract the name of the function passed to `fold`
if let Res::Def(DefKind::AssocFn, fn_did) = acc.res_if_named(cx, method)
// Check if the function belongs to the operator
&& cx.tcx.is_diagnostic_item(method, fn_did)
{
let applicability = Applicability::MachineApplicable;

let turbofish = if replacement.has_generic_return {
format!("::<{}>", cx.typeck_results().expr_ty(expr))
} else {
String::new()
};

span_lint_and_sugg(
cx,
UNNECESSARY_FOLD,
fold_span.with_hi(expr.span.hi()),
"this `.fold` can be written more succinctly using another method",
"try",
format!("{method}{turbofish}()", method = replacement.method_name),
applicability,
);
}
}

Expand All @@ -124,60 +172,40 @@ pub(super) fn check(
if let hir::ExprKind::Lit(lit) = init.kind {
match lit.node {
ast::LitKind::Bool(false) => {
check_fold_with_op(
cx,
expr,
acc,
fold_span,
hir::BinOpKind::Or,
Replacement {
method_name: "any",
has_args: true,
has_generic_return: false,
},
);
let replacement = Replacement {
method_name: "any",
has_args: true,
has_generic_return: false,
};
check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Or, replacement);
},
ast::LitKind::Bool(true) => {
check_fold_with_op(
cx,
expr,
acc,
fold_span,
hir::BinOpKind::And,
Replacement {
method_name: "all",
has_args: true,
has_generic_return: false,
},
);
let replacement = Replacement {
method_name: "all",
has_args: true,
has_generic_return: false,
};
check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::And, replacement);
},
ast::LitKind::Int(Pu128(0), _) => {
check_fold_with_op(
cx,
expr,
acc,
fold_span,
hir::BinOpKind::Add,
Replacement {
method_name: "sum",
has_args: false,
has_generic_return: needs_turbofish(cx, expr),
},
);
let replacement = Replacement {
method_name: "sum",
has_args: false,
has_generic_return: needs_turbofish(cx, expr),
};
if !check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Add, replacement) {
check_fold_with_method(cx, expr, acc, fold_span, sym::add, replacement);
}
},
ast::LitKind::Int(Pu128(1), _) => {
check_fold_with_op(
cx,
expr,
acc,
fold_span,
hir::BinOpKind::Mul,
Replacement {
method_name: "product",
has_args: false,
has_generic_return: needs_turbofish(cx, expr),
},
);
let replacement = Replacement {
method_name: "product",
has_args: false,
has_generic_return: needs_turbofish(cx, expr),
};
if !check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Mul, replacement) {
check_fold_with_method(cx, expr, acc, fold_span, sym::mul, replacement);
}
},
_ => (),
}
Expand Down
85 changes: 85 additions & 0 deletions tests/ui/unnecessary_fold.fixed
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,35 @@ fn is_any(acc: bool, x: usize) -> bool {

/// Calls which should trigger the `UNNECESSARY_FOLD` lint
fn unnecessary_fold() {
use std::ops::{Add, Mul};

// Can be replaced by .any
let _ = (0..3).any(|x| x > 2);
//~^ unnecessary_fold

// Can be replaced by .any (checking suggestion)
let _ = (0..3).fold(false, is_any);
//~^ redundant_closure

// Can be replaced by .all
let _ = (0..3).all(|x| x > 2);
//~^ unnecessary_fold

// Can be replaced by .sum
let _: i32 = (0..3).sum();
//~^ unnecessary_fold
let _: i32 = (0..3).sum();
//~^ unnecessary_fold
let _: i32 = (0..3).sum();
//~^ unnecessary_fold

// Can be replaced by .product
let _: i32 = (0..3).product();
//~^ unnecessary_fold
let _: i32 = (0..3).product();
//~^ unnecessary_fold
let _: i32 = (0..3).product();
//~^ unnecessary_fold
}

/// Should trigger the `UNNECESSARY_FOLD` lint, with an error span including exactly `.fold(...)`
Expand All @@ -37,6 +51,43 @@ fn unnecessary_fold_should_ignore() {
let _ = (0..3).fold(0, |acc, x| acc * x);
let _ = (0..3).fold(0, |acc, x| 1 + acc + x);

struct Adder;
impl Adder {
fn add(lhs: i32, rhs: i32) -> i32 {
unimplemented!()
}
fn mul(lhs: i32, rhs: i32) -> i32 {
unimplemented!()
}
}
// `add`/`mul` are inherent methods
let _: i32 = (0..3).fold(0, Adder::add);
let _: i32 = (0..3).fold(1, Adder::mul);

trait FakeAdd<Rhs = Self> {
type Output;
fn add(self, other: Rhs) -> Self::Output;
}
impl FakeAdd for i32 {
type Output = Self;
fn add(self, other: i32) -> Self::Output {
self + other
}
}
trait FakeMul<Rhs = Self> {
type Output;
fn mul(self, other: Rhs) -> Self::Output;
}
impl FakeMul for i32 {
type Output = Self;
fn mul(self, other: i32) -> Self::Output {
self * other
}
}
// `add`/`mul` come from an unrelated trait
let _: i32 = (0..3).fold(0, FakeAdd::add);
let _: i32 = (0..3).fold(1, FakeMul::mul);

// We only match against an accumulator on the left
// hand side. We could lint for .sum and .product when
// it's on the right, but don't for now (and this wouldn't
Expand All @@ -63,6 +114,7 @@ fn unnecessary_fold_over_multiple_lines() {
fn issue10000() {
use std::collections::HashMap;
use std::hash::BuildHasher;
use std::ops::{Add, Mul};

fn anything<T>(_: T) {}
fn num(_: i32) {}
Expand All @@ -74,23 +126,56 @@ fn issue10000() {
// more cases:
let _ = map.values().sum::<i32>();
//~^ unnecessary_fold
let _ = map.values().sum::<i32>();
//~^ unnecessary_fold
let _ = map.values().product::<i32>();
//~^ unnecessary_fold
let _ = map.values().product::<i32>();
//~^ unnecessary_fold
let _: i32 = map.values().sum();
//~^ unnecessary_fold
let _: i32 = map.values().sum();
//~^ unnecessary_fold
let _: i32 = map.values().product();
//~^ unnecessary_fold
let _: i32 = map.values().product();
//~^ unnecessary_fold
anything(map.values().sum::<i32>());
//~^ unnecessary_fold
anything(map.values().sum::<i32>());
//~^ unnecessary_fold
anything(map.values().product::<i32>());
//~^ unnecessary_fold
anything(map.values().product::<i32>());
//~^ unnecessary_fold
num(map.values().sum());
//~^ unnecessary_fold
num(map.values().sum());
//~^ unnecessary_fold
num(map.values().product());
//~^ unnecessary_fold
num(map.values().product());
//~^ unnecessary_fold
}

smoketest_map(HashMap::new());

fn add_turbofish_not_necessary() -> i32 {
(0..3).sum()
//~^ unnecessary_fold
}
fn mul_turbofish_not_necessary() -> i32 {
(0..3).product()
//~^ unnecessary_fold
}
fn add_turbofish_necessary() -> impl Add {
(0..3).sum::<i32>()
//~^ unnecessary_fold
}
fn mul_turbofish_necessary() -> impl Mul {
(0..3).product::<i32>()
//~^ unnecessary_fold
}
}

fn main() {}
Loading