diff --git a/ghidrecomp/__init__.py b/ghidrecomp/__init__.py index 69d3a21..2f9f5a4 100644 --- a/ghidrecomp/__init__.py +++ b/ghidrecomp/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.4.2' +__version__ = '0.4.3' __author__ = 'clearbluejar' # Expose API diff --git a/ghidrecomp/callgraph.py b/ghidrecomp/callgraph.py index a04f56d..009b238 100644 --- a/ghidrecomp/callgraph.py +++ b/ghidrecomp/callgraph.py @@ -3,6 +3,7 @@ import zlib import json import sys +import re from typing import TYPE_CHECKING from functools import lru_cache @@ -124,6 +125,10 @@ def links_count(self) -> int: count += 1 return count + + @staticmethod + def remove_bad_mermaid_chars(text: str): + return re.sub(r'`','',text) def gen_mermaid_flow_graph(self, direction=None, shaded_nodes: list = None, shade_color='#339933', max_display_depth=None, endpoint_only=False, wrap_mermaid=False) -> str: """ @@ -200,7 +205,7 @@ def gen_mermaid_flow_graph(self, direction=None, shaded_nodes: list = None, shad depth = node[1] fname = node[0] - if max_display_depth and depth > max_display_depth: + if max_display_depth is not None and depth > max_display_depth: continue if shaded_nodes and fname in shaded_nodes: @@ -238,6 +243,8 @@ def gen_mermaid_flow_graph(self, direction=None, shaded_nodes: list = None, shad mermaid_chart = mermaid_flow.format(links='\n'.join(links.keys()), direction=direction, style=style) + mermaid_chart = self.remove_bad_mermaid_chars(mermaid_chart) + if wrap_mermaid: mermaid_chart = _wrap_mermaid(mermaid_chart) @@ -266,7 +273,7 @@ def gen_mermaid_mind_map(self, max_display_depth=None, wrap_mermaid=False) -> st depth = row[1] # skip root row - if depth < 2 or max_display_depth and depth > max_display_depth: + if depth < 2 or max_display_depth is not None and depth > max_display_depth: continue if depth < last_depth: @@ -281,6 +288,8 @@ def gen_mermaid_mind_map(self, max_display_depth=None, wrap_mermaid=False) -> st mermaid_chart = mermaid_mind.format(rows='\n'.join(rows), root=self.root) + mermaid_chart = self.remove_bad_mermaid_chars(mermaid_chart) + if wrap_mermaid: mermaid_chart = _wrap_mermaid(mermaid_chart) @@ -298,7 +307,7 @@ def get_called_funcs_memo(f: "ghidra.program.model.listing.Function"): # Recursively calling to build calling graph -def get_calling(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = CallGraph(), depth: int = 0, visited: tuple = None, verbose=False, include_ns=True, start_time=None, max_run_time=None): +def get_calling(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = CallGraph(), depth: int = 0, visited: tuple = None, verbose=False, include_ns=True, start_time=None, max_run_time=None, max_depth=MAX_DEPTH): """ Build a call graph of all calling functions Traverses depth first @@ -314,12 +323,15 @@ def get_calling(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = visited = tuple() start_time = time.time() - if depth > MAX_DEPTH: + if depth == MAX_DEPTH: cgraph.add_edge(f.getName(include_ns), f'MAX_DEPTH_HIT - {depth}', depth) return cgraph if (time.time() - start_time) > float(max_run_time): - raise TimeoutError(f'time expired for {f.getName(include_ns)}') + #raise TimeoutError(f'time expired for {clean_func(f,include_ns)}') + cgraph.add_edge(f.getName(include_ns), f'MAX_TIME_HIT - time: {max_run_time} depth: {depth}', depth) + print(f'\nWarn: cg : {cgraph.root} edges: {cgraph.links_count()} depth: {depth} name: {f.name} did not complete. max_run_time: {max_run_time} Increase timeout with --max-time-cg-gen MAX_TIME_CG_GEN') + return cgraph space = (depth+2)*' ' @@ -349,7 +361,7 @@ def get_calling(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = cgraph.add_edge(c.getName(include_ns), f.getName(include_ns), depth) # Parse further functions - cgraph = get_calling(c, cgraph, depth, visited=visited, start_time=start_time, max_run_time=max_run_time) + cgraph = get_calling(c, cgraph, depth, visited=visited, start_time=start_time, max_run_time=max_run_time, max_depth=max_depth) else: if verbose: print(f'{space} - END for {f.name}') @@ -380,13 +392,15 @@ def get_called(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = C visited = tuple() start_time = time.time() - if depth > max_depth: + if depth == max_depth: cgraph.add_edge(f.getName(include_ns), f'MAX_DEPTH_HIT - {depth}', depth) return cgraph if (time.time() - start_time) > float(max_run_time): - raise TimeoutError(f'time expired for {f.getName(include_ns)}') - + cgraph.add_edge(f.getName(include_ns), f'MAX_TIME_HIT - time: {max_run_time} depth: {depth}', depth) + print(f'\nWarn: cg : {cgraph.root} edges: {cgraph.links_count()} depth: {depth} name: {f.name} did not complete. max_run_time: {max_run_time} Increase timeout with --max-time-cg-gen MAX_TIME_CG_GEN') + return cgraph + space = (depth+2)*' ' # loop check @@ -427,7 +441,7 @@ def get_called(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = C # Parse further functions cgraph = get_called(c, cgraph, depth, visited=visited, - start_time=start_time, max_run_time=max_run_time) + start_time=start_time, max_run_time=max_run_time, max_depth=max_depth) else: if verbose: diff --git a/ghidrecomp/decompile.py b/ghidrecomp/decompile.py index 4f1c1e9..9d6f790 100644 --- a/ghidrecomp/decompile.py +++ b/ghidrecomp/decompile.py @@ -106,17 +106,12 @@ def gen_callgraph(func: 'ghidra.program.model.listing.Function', max_display_dep flow = '' callgraph = None - try: - if direction == 'calling': - callgraph = get_calling(func, max_run_time=max_run_time) - elif direction == 'called': - callgraph = get_called(func, max_run_time=max_run_time) - else: - raise Exception(f'Unsupported callgraph direction {direction}') - - except TimeoutError as error: - flow = flow_ends = mind = f'\nError: {error} func: {func.name}. max_run_time: {max_run_time} Increase timeout with --max-time-cg-gen MAX_TIME_CG_GEN' - print(flow) + if direction == 'calling': + callgraph = get_calling(func, max_run_time=max_run_time) + elif direction == 'called': + callgraph = get_called(func, max_run_time=max_run_time) + else: + raise Exception(f'Unsupported callgraph direction {direction}') if callgraph is not None: flow = callgraph.gen_mermaid_flow_graph( @@ -267,8 +262,12 @@ def decompile(args: Namespace): else: directions = [args.cg_direction] + max_display_depth = None + if args.max_display_depth is not None: + max_display_depth = int(args.max_display_depth) + with concurrent.futures.ThreadPoolExecutor(max_workers=thread_count) as executor: - futures = (executor.submit(gen_callgraph, func, args.max_display_depth, direction, args.max_time_cg_gen) + futures = (executor.submit(gen_callgraph, func, max_display_depth, direction, args.max_time_cg_gen) for direction in directions for func in all_funcs if args.skip_cache or get_filename(func) not in callgraphs_completed and re.search(args.callgraph_filter, func.name) is not None) for future in concurrent.futures.as_completed(futures): diff --git a/tests/test_callgraph.py b/tests/test_callgraph.py index e8c6936..53cc348 100644 --- a/tests/test_callgraph.py +++ b/tests/test_callgraph.py @@ -43,3 +43,23 @@ def test_decomplie_afd_callgraphs_cached(shared_datadir: Path): assert compiler == 'visualstudio:unknown' assert lang_id == 'x86:LE:64:default' assert len(callgraphs) == 0 + +def test_decomplie_afd_callgraphs_called_and_calling(shared_datadir: Path): + + parser = get_parser() + + bin_path = shared_datadir / 'afd.sys.10.0.22621.1415' + + args = parser.parse_args([f"{bin_path.absolute()}", "--callgraph-filter", "AfdRe", + "--filter", "AfdRe", "--callgraphs", "--skip-cache", "--cg-direction", "both"]) + + expected_output_path = Path(args.output_path) / bin_path.name + + all_funcs, decompilations, output_path, compiler, lang_id, callgraphs = decompile(args) + + assert len(all_funcs) == 73 + assert len(decompilations) == 73 + assert output_path == expected_output_path + assert compiler == 'visualstudio:unknown' + assert lang_id == 'x86:LE:64:default' + assert len(callgraphs) == 146 \ No newline at end of file