diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 26370d2..448d567 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -3,7 +3,7 @@ { "name": "ghidrecomp", // image from https://github.com/clearbluejar/ghidra-python - "image": "ghcr.io/clearbluejar/ghidra-python:latest", + "image": "ghcr.io/clearbluejar/ghidra-python:10.4ghidra3.11python-bookworm", // Configure tool-specific properties. "customizations": { // Configure properties specific to VS Code. diff --git a/.github/workflows/pytest-devcontainer.yml b/.github/workflows/pytest-devcontainer.yml index 891793d..b6b5652 100644 --- a/.github/workflows/pytest-devcontainer.yml +++ b/.github/workflows/pytest-devcontainer.yml @@ -13,20 +13,29 @@ on: jobs: test: - runs-on: ubuntu-latest + runs-on: ${{ matrix.runner }} strategy: fail-fast: false matrix: python-version: ["3.11"] + runner: [ ubuntu-latest ] steps: - uses: actions/checkout@v3 + # - name: Set up QEMU for multi-architecture builds + # uses: docker/setup-qemu-action@v2 + + # - name: Setup Docker buildx for multi-architecture builds + # uses: docker/setup-buildx-action@v2 + # with: + # use: true - name: Test with pytest on devcontainer uses: devcontainers/ci@v0.3 with: - imageName: ghcr.io/clearbluejar/ghidra-python cacheFrom: ghcr.io/clearbluejar/ghidra-python + push: never + # platform: linux/amd64,linux/arm64 runCmd: | pip install --upgrade pip if [ -f requirements.txt ]; then pip install -r requirements.txt; fi diff --git a/ghidrecomp/__init__.py b/ghidrecomp/__init__.py index 69d3a21..2f9f5a4 100644 --- a/ghidrecomp/__init__.py +++ b/ghidrecomp/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.4.2' +__version__ = '0.4.3' __author__ = 'clearbluejar' # Expose API diff --git a/ghidrecomp/callgraph.py b/ghidrecomp/callgraph.py index a04f56d..009b238 100644 --- a/ghidrecomp/callgraph.py +++ b/ghidrecomp/callgraph.py @@ -3,6 +3,7 @@ import zlib import json import sys +import re from typing import TYPE_CHECKING from functools import lru_cache @@ -124,6 +125,10 @@ def links_count(self) -> int: count += 1 return count + + @staticmethod + def remove_bad_mermaid_chars(text: str): + return re.sub(r'`','',text) def gen_mermaid_flow_graph(self, direction=None, shaded_nodes: list = None, shade_color='#339933', max_display_depth=None, endpoint_only=False, wrap_mermaid=False) -> str: """ @@ -200,7 +205,7 @@ def gen_mermaid_flow_graph(self, direction=None, shaded_nodes: list = None, shad depth = node[1] fname = node[0] - if max_display_depth and depth > max_display_depth: + if max_display_depth is not None and depth > max_display_depth: continue if shaded_nodes and fname in shaded_nodes: @@ -238,6 +243,8 @@ def gen_mermaid_flow_graph(self, direction=None, shaded_nodes: list = None, shad mermaid_chart = mermaid_flow.format(links='\n'.join(links.keys()), direction=direction, style=style) + mermaid_chart = self.remove_bad_mermaid_chars(mermaid_chart) + if wrap_mermaid: mermaid_chart = _wrap_mermaid(mermaid_chart) @@ -266,7 +273,7 @@ def gen_mermaid_mind_map(self, max_display_depth=None, wrap_mermaid=False) -> st depth = row[1] # skip root row - if depth < 2 or max_display_depth and depth > max_display_depth: + if depth < 2 or max_display_depth is not None and depth > max_display_depth: continue if depth < last_depth: @@ -281,6 +288,8 @@ def gen_mermaid_mind_map(self, max_display_depth=None, wrap_mermaid=False) -> st mermaid_chart = mermaid_mind.format(rows='\n'.join(rows), root=self.root) + mermaid_chart = self.remove_bad_mermaid_chars(mermaid_chart) + if wrap_mermaid: mermaid_chart = _wrap_mermaid(mermaid_chart) @@ -298,7 +307,7 @@ def get_called_funcs_memo(f: "ghidra.program.model.listing.Function"): # Recursively calling to build calling graph -def get_calling(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = CallGraph(), depth: int = 0, visited: tuple = None, verbose=False, include_ns=True, start_time=None, max_run_time=None): +def get_calling(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = CallGraph(), depth: int = 0, visited: tuple = None, verbose=False, include_ns=True, start_time=None, max_run_time=None, max_depth=MAX_DEPTH): """ Build a call graph of all calling functions Traverses depth first @@ -314,12 +323,15 @@ def get_calling(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = visited = tuple() start_time = time.time() - if depth > MAX_DEPTH: + if depth == MAX_DEPTH: cgraph.add_edge(f.getName(include_ns), f'MAX_DEPTH_HIT - {depth}', depth) return cgraph if (time.time() - start_time) > float(max_run_time): - raise TimeoutError(f'time expired for {f.getName(include_ns)}') + #raise TimeoutError(f'time expired for {clean_func(f,include_ns)}') + cgraph.add_edge(f.getName(include_ns), f'MAX_TIME_HIT - time: {max_run_time} depth: {depth}', depth) + print(f'\nWarn: cg : {cgraph.root} edges: {cgraph.links_count()} depth: {depth} name: {f.name} did not complete. max_run_time: {max_run_time} Increase timeout with --max-time-cg-gen MAX_TIME_CG_GEN') + return cgraph space = (depth+2)*' ' @@ -349,7 +361,7 @@ def get_calling(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = cgraph.add_edge(c.getName(include_ns), f.getName(include_ns), depth) # Parse further functions - cgraph = get_calling(c, cgraph, depth, visited=visited, start_time=start_time, max_run_time=max_run_time) + cgraph = get_calling(c, cgraph, depth, visited=visited, start_time=start_time, max_run_time=max_run_time, max_depth=max_depth) else: if verbose: print(f'{space} - END for {f.name}') @@ -380,13 +392,15 @@ def get_called(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = C visited = tuple() start_time = time.time() - if depth > max_depth: + if depth == max_depth: cgraph.add_edge(f.getName(include_ns), f'MAX_DEPTH_HIT - {depth}', depth) return cgraph if (time.time() - start_time) > float(max_run_time): - raise TimeoutError(f'time expired for {f.getName(include_ns)}') - + cgraph.add_edge(f.getName(include_ns), f'MAX_TIME_HIT - time: {max_run_time} depth: {depth}', depth) + print(f'\nWarn: cg : {cgraph.root} edges: {cgraph.links_count()} depth: {depth} name: {f.name} did not complete. max_run_time: {max_run_time} Increase timeout with --max-time-cg-gen MAX_TIME_CG_GEN') + return cgraph + space = (depth+2)*' ' # loop check @@ -427,7 +441,7 @@ def get_called(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = C # Parse further functions cgraph = get_called(c, cgraph, depth, visited=visited, - start_time=start_time, max_run_time=max_run_time) + start_time=start_time, max_run_time=max_run_time, max_depth=max_depth) else: if verbose: diff --git a/ghidrecomp/decompile.py b/ghidrecomp/decompile.py index 4f1c1e9..9d6f790 100644 --- a/ghidrecomp/decompile.py +++ b/ghidrecomp/decompile.py @@ -106,17 +106,12 @@ def gen_callgraph(func: 'ghidra.program.model.listing.Function', max_display_dep flow = '' callgraph = None - try: - if direction == 'calling': - callgraph = get_calling(func, max_run_time=max_run_time) - elif direction == 'called': - callgraph = get_called(func, max_run_time=max_run_time) - else: - raise Exception(f'Unsupported callgraph direction {direction}') - - except TimeoutError as error: - flow = flow_ends = mind = f'\nError: {error} func: {func.name}. max_run_time: {max_run_time} Increase timeout with --max-time-cg-gen MAX_TIME_CG_GEN' - print(flow) + if direction == 'calling': + callgraph = get_calling(func, max_run_time=max_run_time) + elif direction == 'called': + callgraph = get_called(func, max_run_time=max_run_time) + else: + raise Exception(f'Unsupported callgraph direction {direction}') if callgraph is not None: flow = callgraph.gen_mermaid_flow_graph( @@ -267,8 +262,12 @@ def decompile(args: Namespace): else: directions = [args.cg_direction] + max_display_depth = None + if args.max_display_depth is not None: + max_display_depth = int(args.max_display_depth) + with concurrent.futures.ThreadPoolExecutor(max_workers=thread_count) as executor: - futures = (executor.submit(gen_callgraph, func, args.max_display_depth, direction, args.max_time_cg_gen) + futures = (executor.submit(gen_callgraph, func, max_display_depth, direction, args.max_time_cg_gen) for direction in directions for func in all_funcs if args.skip_cache or get_filename(func) not in callgraphs_completed and re.search(args.callgraph_filter, func.name) is not None) for future in concurrent.futures.as_completed(futures): diff --git a/tests/test_callgraph.py b/tests/test_callgraph.py index e8c6936..53cc348 100644 --- a/tests/test_callgraph.py +++ b/tests/test_callgraph.py @@ -43,3 +43,23 @@ def test_decomplie_afd_callgraphs_cached(shared_datadir: Path): assert compiler == 'visualstudio:unknown' assert lang_id == 'x86:LE:64:default' assert len(callgraphs) == 0 + +def test_decomplie_afd_callgraphs_called_and_calling(shared_datadir: Path): + + parser = get_parser() + + bin_path = shared_datadir / 'afd.sys.10.0.22621.1415' + + args = parser.parse_args([f"{bin_path.absolute()}", "--callgraph-filter", "AfdRe", + "--filter", "AfdRe", "--callgraphs", "--skip-cache", "--cg-direction", "both"]) + + expected_output_path = Path(args.output_path) / bin_path.name + + all_funcs, decompilations, output_path, compiler, lang_id, callgraphs = decompile(args) + + assert len(all_funcs) == 73 + assert len(decompilations) == 73 + assert output_path == expected_output_path + assert compiler == 'visualstudio:unknown' + assert lang_id == 'x86:LE:64:default' + assert len(callgraphs) == 146 \ No newline at end of file diff --git a/tests/test_ghidrecomp.py b/tests/test_ghidrecomp.py index 71255fd..dc1cf40 100644 --- a/tests/test_ghidrecomp.py +++ b/tests/test_ghidrecomp.py @@ -71,8 +71,8 @@ def test_decomplie_afd(shared_datadir: Path): all_funcs, decompilations, output_path, compiler, lang_id, callgraphs = decompile(args) - assert len(all_funcs) == 1275 - assert len(decompilations) == 1275 + assert (len(all_funcs) == 1275 or len(all_funcs) == 1273) + assert (len(decompilations) == 1275 or len(decompilations) == 1273) assert output_path == expected_output_path assert compiler == 'visualstudio:unknown' assert lang_id == 'x86:LE:64:default' @@ -91,7 +91,7 @@ def test_decomplie_afd_cached(shared_datadir: Path): all_funcs, decompilations, output_path, compiler, lang_id, callgraphs = decompile(args) - assert len(all_funcs) == 1275 + assert (len(all_funcs) == 1275 or len(all_funcs) == 1273) assert len(decompilations) == 0 assert output_path == expected_output_path assert compiler == 'visualstudio:unknown'