diff --git a/package-lock.json b/package-lock.json index 8fb9792..dfcf7dc 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,5 +1,5 @@ { - "name": "slitherin", + "name": "custom_detectors", "lockfileVersion": 2, "requires": true, "packages": { diff --git a/slither_pess/__init__.py b/slither_pess/__init__.py new file mode 100644 index 0000000..f54d75d --- /dev/null +++ b/slither_pess/__init__.py @@ -0,0 +1,60 @@ +from slither_pess.detectors.arbitrary_call import ArbitraryCall +from slither_pess.detectors.double_entry_token_possibility import ( + DoubleEntryTokenPossiblity, +) +from slither_pess.detectors.dubious_typecast import DubiousTypecast +from slither_pess.detectors.falsy_only_eoa_modifier import OnlyEOACheck +from slither_pess.detectors.magic_number import MagicNumber +from slither_pess.detectors.strange_setter import StrangeSetter +from slither_pess.detectors.unprotected_setter import UnprotectedSetter +from slither_pess.detectors.nft_approve_warning import NftApproveWarning +from slither_pess.detectors.inconsistent_nonreentrant import InconsistentNonreentrant +from slither_pess.detectors.call_forward_to_protected import CallForwardToProtected +from slither_pess.detectors.multiple_storage_read import MultipleStorageRead +from slither_pess.detectors.timelock_controller import TimelockController +from slither_pess.detectors.tx_gasprice_warning import TxGaspriceWarning +from slither_pess.detectors.unprotected_initialize import UnprotectedInitialize +from slither_pess.detectors.readonly_reentrancy.read_only_reentrancy import ( + ReadOnlyReentrancy, +) +from slither_pess.detectors.event_setter import EventSetter +from slither_pess.detectors.before_token_transfer import BeforeTokenTransfer +from slither_pess.detectors.uni_v2 import UniswapV2 +from slither_pess.detectors.token_fallback import TokenFallback +from slither_pess.detectors.for_continue_increment import ForContinueIncrement +from slither_pess.detectors.ecrecover import Ecrecover +from slither_pess.detectors.public_vs_external import PublicVsExternal +from slither_pess.detectors.readonly_reentrancy.balancer_readonly_reentrancy import ( + BalancerReadonlyReentrancy, +) + + +def make_plugin(): + plugin_detectors = [ + DoubleEntryTokenPossiblity, + UnprotectedSetter, + NftApproveWarning, + InconsistentNonreentrant, + StrangeSetter, + OnlyEOACheck, + MagicNumber, + DubiousTypecast, + CallForwardToProtected, + MultipleStorageRead, + TimelockController, + TxGaspriceWarning, + UnprotectedInitialize, + ReadOnlyReentrancy, + EventSetter, + BeforeTokenTransfer, + UniswapV2, + TokenFallback, + ForContinueIncrement, + ArbitraryCall, + Ecrecover, + PublicVsExternal, + BalancerReadonlyReentrancy, + ] + plugin_printers = [] + + return plugin_detectors, plugin_printers diff --git a/slither_pess/detectors/readonly_reentrancy/__init__.py b/slither_pess/detectors/readonly_reentrancy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/slither_pess/detectors/readonly_reentrancy/balancer_readonly_reentrancy.py b/slither_pess/detectors/readonly_reentrancy/balancer_readonly_reentrancy.py new file mode 100644 index 0000000..97f548e --- /dev/null +++ b/slither_pess/detectors/readonly_reentrancy/balancer_readonly_reentrancy.py @@ -0,0 +1,96 @@ +from typing import List +from slither.utils.output import Output +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.core.declarations import Function, Contract +from slither.core.cfg.node import Node + + +class BalancerReadonlyReentrancy(AbstractDetector): + """ + Sees if a contract has a beforeTokenTransfer function. + """ + + ARGUMENT = "pess-balancer-readonly-reentrancy" # slither will launch the detector with slither.py --detect mydetector + HELP = "beforeTokenTransfer function does not follow OZ documentation" + IMPACT = DetectorClassification.LOW + CONFIDENCE = DetectorClassification.HIGH + + WIKI = ( + "https://docs.openzeppelin.com/contracts/4.x/extending-contracts#rules_of_hooks" + ) + WIKI_TITLE = "Before Token Transfer" + WIKI_DESCRIPTION = "Follow OZ documentation using their contracts" + WIKI_EXPLOIT_SCENARIO = "-" + WIKI_RECOMMENDATION = ( + "Make sure that beforeTokenTransfer function is used in the correct way." + ) + + VULNERABLE_FUNCTION_CALLS = ["getRate", "getPoolTokens"] + visited = [] + contains_reentrancy_check = {} + + def is_balancer_integration(self, c: Contract) -> bool: + """ + Iterates over all external function calls, and checks the interface/contract name + for a specific keywords to decide if the contract integrates with balancer + """ + for ( + fcontract, + _, + ) in c.all_high_level_calls: + contract_name = fcontract.name.lower() + if any(map(lambda x: x in contract_name, ["balancer", "ivault", "pool"])): + return True + + def _has_reentrancy_check(self, node: Node) -> bool: + if node in self.visited: + return self.contains_reentrancy_check[node] + + self.visited.append(node) + self.contains_reentrancy_check[node] = False + + for c, n in node.high_level_calls: + if isinstance(n, Function): + if ( + n.name == "ensureNotInVaultContext" + and c.name == "VaultReentrancyLib" + ) or ( + n.name == "manageUserBalance" + ): # TODO check if errors out + self.contains_reentrancy_check[node] = True + return True + + has_check = False + for internal_call in node.internal_calls: + if isinstance(internal_call, Function): + has_check |= self._has_reentrancy_check(internal_call) + # self.contains_reentrancy_check[internal_call] |= has_check + + self.contains_reentrancy_check[node] = has_check + return has_check + + def _check_function(self, function: Function) -> list: + has_dangerous_call = False + dangerous_call = None + for n in function.nodes: + for c, fc in n.high_level_calls: + if isinstance(fc, Function): + if fc.name in self.VULNERABLE_FUNCTION_CALLS: + dangerous_call = fc + has_dangerous_call = True + break + + if has_dangerous_call and not any( + [self._has_reentrancy_check(node) for node in function.nodes] + ): + print("READONLY_REENTRANCY!!!") + + def _detect(self) -> List[Output]: + """Main function""" + res = [] + for contract in self.compilation_unit.contracts_derived: + if not self.is_balancer_integration(contract): + continue + for f in contract.functions_and_modifiers_declared: + self._check_function(f) + return res diff --git a/slither_pess/detectors/readonly_reentrancy/read_only_reentrancy.py b/slither_pess/detectors/readonly_reentrancy/read_only_reentrancy.py new file mode 100644 index 0000000..7b17c61 --- /dev/null +++ b/slither_pess/detectors/readonly_reentrancy/read_only_reentrancy.py @@ -0,0 +1,434 @@ +"""" + Re-entrancy detection + Based on heuristics, it may lead to FP and FN + Iterate over all the nodes of the graph until reaching a fixpoint +""" +from collections import namedtuple, defaultdict +from typing import Dict, List, Set +from slither.core.variables.variable import Variable +from slither.core.declarations import Function +from slither.core.cfg.node import NodeType, Node, Contract +from slither.detectors.abstract_detector import DetectorClassification +from .reentrancy.reentrancy import ( + Reentrancy, + to_hashable, + AbstractState, + union_dict, + _filter_if, + is_subset, +) +from slither.slithir.operations import EventCall + +FindingKey = namedtuple("FindingKey", ["function", "calls"]) +FindingValue = namedtuple("FindingValue", ["variable", "written_at", "node", "nodes"]) + + +def are_same_contract(a: Contract, b: Contract) -> bool: + """ + Checks if A==B or A inherits from B or otherwise + """ + return a == b or (b in a.inheritance) or (b in a.derived_contracts) + + +class ReadOnlyReentrancyState(AbstractState): + def __init__(self): + super().__init__() + self._reads_external: Dict[Variable, Set[Node]] = defaultdict(set) + self._reads_external_contract_list: Dict[Variable, Set[Contract]] = defaultdict( + set + ) + self._written_external: Dict[Variable, Set[Node]] = defaultdict(set) + self._written: Dict[Variable, Set[Node]] = defaultdict(set) + + @property + def reads_external(self) -> Dict[Variable, Set[Node]]: + return self._reads_external + + @property + def reads_external_contract_list(self) -> Dict[Variable, Set[Contract]]: + return self._reads_external_contract_list + + @property + def written_external(self) -> Dict[Variable, Set[Node]]: + return self._written_external + + @property + def written(self) -> Dict[Variable, Set[Node]]: + return self._written + + def add(self, fathers): + super().add(fathers) + self._reads_external = union_dict(self._reads_external, fathers.reads_external) + self._reads_external_contract_list = union_dict( + self._reads_external_contract_list, fathers.reads_external_contract_list + ) + + def does_not_bring_new_info(self, new_info): + return ( + super().does_not_bring_new_info(new_info) + and is_subset(new_info.reads_external, self._reads_external) + and is_subset( + new_info.reads_external_contract_list, + self._reads_external_contract_list, + ) + ) + + def merge_fathers(self, node, skip_father, detector): + for father in node.fathers: + if detector.KEY in father.context: + self._send_eth = union_dict( + self._send_eth, + { + key: values + for key, values in father.context[detector.KEY].send_eth.items() + if key != skip_father + }, + ) + self._calls = union_dict( + self._calls, + { + key: values + for key, values in father.context[detector.KEY].calls.items() + if key != skip_father + }, + ) + self._reads = union_dict( + self._reads, father.context[detector.KEY].reads + ) + self._reads_external = union_dict( + self._reads_external, father.context[detector.KEY].reads + ) + self._written_external = union_dict( + self._written_external, father.context[detector.KEY].reads + ) + + def analyze_node(self, node: Node, detector): + state_vars_read: Dict[Variable, Set[Node]] = defaultdict( + set, {v: {node} for v in node.state_variables_read} + ) + + # All the state variables written + state_vars_written: Dict[Variable, Set[Node]] = defaultdict( + set, {v: {node} for v in node.state_variables_written} + ) + + external_state_vars_read: Dict[Variable, Set[Node]] = defaultdict(set) + external_state_vars_written: Dict[Variable, Set[Node]] = defaultdict(set) + external_state_vars_read_contract_list: Dict[ + Variable, Set[Contract] + ] = defaultdict(set) + + slithir_operations = [] + # Add the state variables written in internal calls + for internal_call in node.internal_calls: + # Filter to Function, as internal_call can be a solidity call + if isinstance(internal_call, Function): + for internal_node in internal_call.all_nodes(): + for read in internal_node.state_variables_read: + state_vars_read[read].add(internal_node) + for write in internal_node.state_variables_written: + state_vars_written[write].add(internal_node) + slithir_operations += internal_call.all_slithir_operations() + + for contract, v in node.high_level_calls: + if isinstance(v, Function): + for internal_node in v.all_nodes(): + for read in internal_node.state_variables_read: + external_state_vars_read[read].add(internal_node) + external_state_vars_read_contract_list[read].add(contract) + + if internal_node.context.get(detector.KEY): + for r in internal_node.context[detector.KEY].reads_external: + external_state_vars_read[r].add(internal_node) + external_state_vars_read_contract_list[r].add(contract) + for write in internal_node.state_variables_written: + external_state_vars_written[write].add(internal_node) + + contains_call = False + + self._written = state_vars_written + self._written_external = external_state_vars_written + for ir in node.irs + slithir_operations: + if detector.can_callback(ir): + self._calls[node] |= {ir.node} + contains_call = True + + if detector.can_send_eth(ir): + self._send_eth[node] |= {ir.node} + + if isinstance(ir, EventCall): + self._events[ir] |= {ir.node, node} + + self._reads = union_dict(self._reads, state_vars_read) + self._reads_external = union_dict( + self._reads_external, external_state_vars_read + ) + self._reads_external_contract_list = union_dict( + self._reads_external_contract_list, external_state_vars_read_contract_list + ) + + return contains_call + + +class ReadOnlyReentrancy(Reentrancy): + ARGUMENT = "pess-readonly-reentrancy" + HELP = "Read-only reentrancy vulnerabilities" + IMPACT = DetectorClassification.HIGH + CONFIDENCE = DetectorClassification.LOW + + WIKI = "https://github.com/pessimistic-io/slitherin/blob/master/docs/readonly_reentrancy.md" + WIKI_TITLE = "Read-only reentrancy vulnerabilities" + WIKI_DESCRIPTION = "Check docs" + STANDARD_JSON = False + KEY = "readonly_reentrancy" + + contracts_read_variable: Dict[Variable, Set[Contract]] = defaultdict(set) + contracts_written_variable_after_reentrancy: Dict[ + Variable, Set[Contract] + ] = defaultdict(set) + + def _explore(self, node, visited, skip_father=None): + if node in visited: + return + + visited = visited + [node] + + fathers_context = ReadOnlyReentrancyState() + fathers_context.merge_fathers(node, skip_father, self) + + # Exclude path that dont bring further information + if node in self.visited_all_paths: + if self.visited_all_paths[node].does_not_bring_new_info(fathers_context): + return + else: + self.visited_all_paths[node] = ReadOnlyReentrancyState() + + self.visited_all_paths[node].add(fathers_context) + + node.context[self.KEY] = fathers_context + + contains_call = fathers_context.analyze_node(node, self) + node.context[self.KEY] = fathers_context + + sons = node.sons + if contains_call and node.type in [NodeType.IF, NodeType.IFLOOP]: + if _filter_if(node): + son = sons[0] + self._explore(son, visited, node) + sons = sons[1:] + else: + son = sons[1] + self._explore(son, visited, node) + sons = [sons[0]] + + for son in sons: + self._explore(son, visited) + + def find_writes_after_reentrancy(self): + written_after_reentrancy: Dict[Variable, Set[Node]] = defaultdict(set) + written_after_reentrancy_external: Dict[Variable, Set[Node]] = defaultdict(set) + for contract in self.contracts: + for f in contract.functions_and_modifiers_declared: + for node in f.nodes: + # dead code + if self.KEY not in node.context: + continue + if node.context[self.KEY].calls: + if not any(n != node for n in node.context[self.KEY].calls): + continue + # TODO: check if written items exist + for v, nodes in node.context[self.KEY].written.items(): + written_after_reentrancy[v].add(node) + self.contracts_written_variable_after_reentrancy[v].add( + contract + ) + for v, nodes in node.context[self.KEY].written_external.items(): + written_after_reentrancy_external[v].add(node) + self.contracts_written_variable_after_reentrancy[v].add( + contract + ) + + return written_after_reentrancy, written_after_reentrancy_external + + # IMPORTANT: + # FOR the external reads, that var should be external written in the same contract + def get_readonly_reentrancies(self): + ( + written_after_reentrancy, + written_after_reentrancy_external, + ) = self.find_writes_after_reentrancy() + result = defaultdict(set) + + warnings = defaultdict(set) + + for contract in self.contracts: + for f in contract.functions_and_modifiers_declared: + for node in f.nodes: + + if self.KEY not in node.context: + continue + vulnerable_variables = set() + warning_variables = set() + for r, nodes in node.context[self.KEY].reads.items(): + + if r in written_after_reentrancy: + finding_value = FindingValue( + r, + tuple( + sorted( + list(written_after_reentrancy[r]), + key=lambda x: x.node_id, + ) + ), + node, + tuple(sorted(nodes, key=lambda x: x.node_id)), + ) + if are_same_contract(r.contract, f.contract): + if f.view and f.visibility in ["public", "external"]: + warning_variables.add(finding_value) + else: + vulnerable_variables.add(finding_value) + + for r, nodes in node.context[self.KEY].reads_external.items(): + if are_same_contract(r.contract, f.contract): + # TODO(yhtiyar): In case f.view we can notify the user that the given + # method could be vulnerable if other contract will use it + continue + if r in written_after_reentrancy_external: + isVulnerable = any( + c in self.contracts_written_variable_after_reentrancy[r] + for c in node.context[ + self.KEY + ].reads_external_contract_list[r] + ) + if isVulnerable: + vulnerable_variables.add( + FindingValue( + r, + tuple( + sorted( + list( + written_after_reentrancy_external[r] + ), + key=lambda x: x.node_id, + ) + ), + node, + tuple(sorted(nodes, key=lambda x: x.node_id)), + ) + ) + + if r in written_after_reentrancy: + vulnerable_variables.add( + FindingValue( + r, + tuple( + sorted( + list(written_after_reentrancy[r]), + key=lambda x: x.node_id, + ) + ), + node, + tuple(sorted(nodes, key=lambda x: x.node_id)), + ) + ) + + if vulnerable_variables: + finding_key = FindingKey( + function=f, calls=to_hashable(node.context[self.KEY].calls) + ) + result[finding_key] |= vulnerable_variables + if warning_variables: + finding_key = FindingKey( + function=f, calls=to_hashable(node.context[self.KEY].calls) + ) + warnings[finding_key] |= warning_variables + return result, warnings + + def _gen_results(self, raw_results, info_text): + results = [] + + result_sorted = sorted( + list(raw_results.items()), key=lambda x: x[0].function.name + ) + + varsRead: List[FindingValue] + for (func, calls), varsRead in result_sorted: + + varsRead = sorted(varsRead, key=lambda x: (x.variable.name, x.node.node_id)) + + info = [f"{info_text} ", func, ":\n"] + + info += [ + "\tState variables read that were written after the external call(s):\n" + ] + for finding_value in varsRead: + info += [ + "\t- ", + finding_value.variable, + " was read at ", + finding_value.node, + "\n", + ] + # info += ["\t- ", finding_value.node, "\n"] + + # for other_node in finding_value.nodes: + # if other_node != finding_value.node: + # info += ["\t\t- ", other_node, "\n"] + + # TODO: currently we are not printing the whole call-stack of variable + # it wasn't working properly, so I am removing it for now to avoid confusion + + info += ["\t\t This variable was written at (after external call):\n"] + for other_node in finding_value.written_at: + # info += ["\t- ", call_info, "\n"] + if other_node != finding_value.node: + info += ["\t\t\t- ", other_node, "\n"] + + # Create our JSON result + res = self.generate_result(info) + + res.add(func) + + # Add all variables written via nodes which write them. + for finding_value in varsRead: + res.add( + finding_value.node, + { + "underlying_type": "variables_written", + "variable_name": finding_value.variable.name, + }, + ) + for other_node in finding_value.nodes: + if other_node != finding_value.node: + res.add( + other_node, + { + "underlying_type": "variables_written", + "variable_name": finding_value.variable.name, + }, + ) + + # Append our result + results.append(res) + + return results + + def _detect(self): # pylint: disable=too-many-branches + results = [] + try: + super()._detect() + reentrancies, warnings = self.get_readonly_reentrancies() + results += self._gen_results(reentrancies, "Readonly-reentrancy in ") + results += self._gen_results( + warnings, + "Potential vulnerable to readonly-reentrancy function (if read in other function)", + ) + except Exception as e: + info = [ + "Error during detection of readonly-reentrancy:\n", + "Please inform this to Yhtyyar\n", + f"error details:", + e, + ] + return results diff --git a/slither_pess/detectors/readonly_reentrancy/reentrancy/__init__.py b/slither_pess/detectors/readonly_reentrancy/reentrancy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/slither_pess/detectors/readonly_reentrancy/reentrancy/reentrancy.py b/slither_pess/detectors/readonly_reentrancy/reentrancy/reentrancy.py new file mode 100644 index 0000000..5a40ffc --- /dev/null +++ b/slither_pess/detectors/readonly_reentrancy/reentrancy/reentrancy.py @@ -0,0 +1,314 @@ +"""" + This is a copy paste from the original slither reentrancy v0.9.1 + + + + Re-entrancy detection + Based on heuristics, it may lead to FP and FN + Iterate over all the nodes of the graph until reaching a fixpoint +""" +from collections import defaultdict +from typing import Set, Dict, Union + +from slither.core.cfg.node import NodeType, Node +from slither.core.declarations import Function +from slither.core.expressions import UnaryOperation, UnaryOperationType +from slither.core.variables.variable import Variable +from slither.detectors.abstract_detector import AbstractDetector +from slither.slithir.operations import Call, EventCall + + +def union_dict(d1, d2): + d3 = { + k: d1.get(k, set()) | d2.get(k, set()) + for k in set(list(d1.keys()) + list(d2.keys())) + } + return defaultdict(set, d3) + + +def dict_are_equal(d1, d2): + if set(list(d1.keys())) != set(list(d2.keys())): + return False + return all(set(d1[k]) == set(d2[k]) for k in d1.keys()) + + +def is_subset( + new_info: Dict[Union[Variable, Node], Set[Node]], + old_info: Dict[Union[Variable, Node], Set[Node]], +): + for k in new_info.keys(): + if k not in old_info: + return False + if not new_info[k].issubset(old_info[k]): + return False + return True + + +def to_hashable(d: Dict[Node, Set[Node]]): + list_tuple = list( + tuple((k, tuple(sorted(values, key=lambda x: x.node_id)))) + for k, values in d.items() + ) + return tuple(sorted(list_tuple, key=lambda x: x[0].node_id)) + + +class AbstractState: + def __init__(self): + # send_eth returns the list of calls sending value + # calls returns the list of calls that can callback + # read returns the variable read + # read_prior_calls returns the variable read prior a call + self._send_eth: Dict[Node, Set[Node]] = defaultdict(set) + self._calls: Dict[Node, Set[Node]] = defaultdict(set) + self._reads: Dict[Variable, Set[Node]] = defaultdict(set) + self._reads_prior_calls: Dict[Node, Set[Variable]] = defaultdict(set) + self._events: Dict[EventCall, Set[Node]] = defaultdict(set) + self._written: Dict[Variable, Set[Node]] = defaultdict(set) + + @property + def send_eth(self) -> Dict[Node, Set[Node]]: + """ + Return the list of calls sending value + :return: + """ + return self._send_eth + + @property + def calls(self) -> Dict[Node, Set[Node]]: + """ + Return the list of calls that can callback + :return: + """ + return self._calls + + @property + def reads(self) -> Dict[Variable, Set[Node]]: + """ + Return of variables that are read + :return: + """ + return self._reads + + @property + def written(self) -> Dict[Variable, Set[Node]]: + """ + Return of variables that are written + :return: + """ + return self._written + + @property + def reads_prior_calls(self) -> Dict[Node, Set[Variable]]: + """ + Return the dictionary node -> variables read before any call + :return: + """ + return self._reads_prior_calls + + @property + def events(self) -> Dict[EventCall, Set[Node]]: + """ + Return the list of events + :return: + """ + return self._events + + def merge_fathers(self, node, skip_father, detector): + for father in node.fathers: + if detector.KEY in father.context: + self._send_eth = union_dict( + self._send_eth, + { + key: values + for key, values in father.context[detector.KEY].send_eth.items() + if key != skip_father + }, + ) + self._calls = union_dict( + self._calls, + { + key: values + for key, values in father.context[detector.KEY].calls.items() + if key != skip_father + }, + ) + self._reads = union_dict( + self._reads, father.context[detector.KEY].reads + ) + self._reads_prior_calls = union_dict( + self.reads_prior_calls, + father.context[detector.KEY].reads_prior_calls, + ) + + def analyze_node(self, node, detector): + state_vars_read: Dict[Variable, Set[Node]] = defaultdict( + set, {v: {node} for v in node.state_variables_read} + ) + + # All the state variables written + state_vars_written: Dict[Variable, Set[Node]] = defaultdict( + set, {v: {node} for v in node.state_variables_written} + ) + slithir_operations = [] + # Add the state variables written in internal calls + for internal_call in node.internal_calls: + # Filter to Function, as internal_call can be a solidity call + if isinstance(internal_call, Function): + for internal_node in internal_call.all_nodes(): + for read in internal_node.state_variables_read: + state_vars_read[read].add(internal_node) + for write in internal_node.state_variables_written: + state_vars_written[write].add(internal_node) + slithir_operations += internal_call.all_slithir_operations() + + contains_call = False + + self._written = state_vars_written + for ir in node.irs + slithir_operations: + if detector.can_callback(ir): + self._calls[node] |= {ir.node} + self._reads_prior_calls[node] = set( + self._reads_prior_calls.get(node, set()) + | set(node.context[detector.KEY].reads.keys()) + | set(state_vars_read.keys()) + ) + contains_call = True + + if detector.can_send_eth(ir): + self._send_eth[node] |= {ir.node} + + if isinstance(ir, EventCall): + self._events[ir] |= {ir.node, node} + + self._reads = union_dict(self._reads, state_vars_read) + + return contains_call + + def add(self, fathers): + self._send_eth = union_dict(self._send_eth, fathers.send_eth) + self._calls = union_dict(self._calls, fathers.calls) + self._reads = union_dict(self._reads, fathers.reads) + self._reads_prior_calls = union_dict( + self._reads_prior_calls, fathers.reads_prior_calls + ) + + def does_not_bring_new_info(self, new_info): + if is_subset(new_info.calls, self.calls): + if is_subset(new_info.send_eth, self.send_eth): + if is_subset(new_info.reads, self.reads): + if dict_are_equal( + new_info.reads_prior_calls, self.reads_prior_calls + ): + return True + return False + + +def _filter_if(node): + """ + Check if the node is a condtional node where + there is an external call checked + Heuristic: + - The call is a IF node + - It contains a, external call + - The condition is the negation (!) + This will work only on naive implementation + """ + return ( + isinstance(node.expression, UnaryOperation) + and node.expression.type == UnaryOperationType.BANG + ) + + +class Reentrancy(AbstractDetector): + KEY = "REENTRANCY" + WIKI_EXPLOIT_SCENARIO = "Check original reentrancy" + WIKI_RECOMMENDATION = "Check original reentrancy" + + # can_callback and can_send_eth are static method + # allowing inherited classes to define different behaviors + # For example reentrancy_no_gas consider Send and Transfer as reentrant functions + @staticmethod + def can_callback(ir): + """ + Detect if the node contains a call that can + be used to re-entrance + Consider as valid target: + - low level call + - high level call + """ + return isinstance(ir, Call) and ir.can_reenter() + + @staticmethod + def can_send_eth(ir): + """ + Detect if the node can send eth + """ + return isinstance(ir, Call) and ir.can_send_eth() + + def _explore(self, node, visited, skip_father=None): + """ + Explore the CFG and look for re-entrancy + Heuristic: There is a re-entrancy if a state variable is written + after an external call + node.context will contains the external calls executed + It contains the calls executed in father nodes + if node.context is not empty, and variables are written, a re-entrancy is possible + """ + if node in visited: + return + + visited = visited + [node] + + fathers_context = AbstractState() + fathers_context.merge_fathers(node, skip_father, self) + + # Exclude path that dont bring further information + if node in self.visited_all_paths: + if self.visited_all_paths[node].does_not_bring_new_info(fathers_context): + return + else: + self.visited_all_paths[node] = AbstractState() + + self.visited_all_paths[node].add(fathers_context) + + node.context[self.KEY] = fathers_context + + contains_call = fathers_context.analyze_node(node, self) + node.context[self.KEY] = fathers_context + + sons = node.sons + if contains_call and node.type in [NodeType.IF, NodeType.IFLOOP]: + if _filter_if(node): + son = sons[0] + self._explore(son, visited, node) + sons = sons[1:] + else: + son = sons[1] + self._explore(son, visited, node) + sons = [sons[0]] + + for son in sons: + self._explore(son, visited) + + def detect_reentrancy(self, contract): + for function in contract.functions_and_modifiers_declared: + if not function.is_constructor: + if function.is_implemented: + if self.KEY in function.context: + continue + self._explore(function.entry_point, []) + function.context[self.KEY] = True + + def _detect(self): + """""" + # if a node was already visited by another path + # we will only explore it if the traversal brings + # new variables written + # This speedup the exploration through a light fixpoint + # Its particular useful on 'complex' functions with several loops and conditions + self.visited_all_paths = {} # pylint: disable=attribute-defined-outside-init + + for c in self.contracts: + self.detect_reentrancy(c) + + return []