Skip to content

Commit

Permalink
fix max-depth-display, test for called graphs, remove bad chars from …
Browse files Browse the repository at this point in the history
…mermaidjs, remove timeout exception
  • Loading branch information
clearbluejar committed Nov 7, 2023
1 parent 6d3b04c commit 19cad99
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 23 deletions.
2 changes: 1 addition & 1 deletion ghidrecomp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.4.2'
__version__ = '0.4.3'
__author__ = 'clearbluejar'

# Expose API
Expand Down
34 changes: 24 additions & 10 deletions ghidrecomp/callgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import zlib
import json
import sys
import re

from typing import TYPE_CHECKING
from functools import lru_cache
Expand Down Expand Up @@ -124,6 +125,10 @@ def links_count(self) -> int:
count += 1

return count

@staticmethod
def remove_bad_mermaid_chars(text: str):
return re.sub(r'`','',text)

def gen_mermaid_flow_graph(self, direction=None, shaded_nodes: list = None, shade_color='#339933', max_display_depth=None, endpoint_only=False, wrap_mermaid=False) -> str:
"""
Expand Down Expand Up @@ -200,7 +205,7 @@ def gen_mermaid_flow_graph(self, direction=None, shaded_nodes: list = None, shad
depth = node[1]
fname = node[0]

if max_display_depth and depth > max_display_depth:
if max_display_depth is not None and depth > max_display_depth:
continue

if shaded_nodes and fname in shaded_nodes:
Expand Down Expand Up @@ -238,6 +243,8 @@ def gen_mermaid_flow_graph(self, direction=None, shaded_nodes: list = None, shad

mermaid_chart = mermaid_flow.format(links='\n'.join(links.keys()), direction=direction, style=style)

mermaid_chart = self.remove_bad_mermaid_chars(mermaid_chart)

if wrap_mermaid:
mermaid_chart = _wrap_mermaid(mermaid_chart)

Expand Down Expand Up @@ -266,7 +273,7 @@ def gen_mermaid_mind_map(self, max_display_depth=None, wrap_mermaid=False) -> st
depth = row[1]

# skip root row
if depth < 2 or max_display_depth and depth > max_display_depth:
if depth < 2 or max_display_depth is not None and depth > max_display_depth:
continue

if depth < last_depth:
Expand All @@ -281,6 +288,8 @@ def gen_mermaid_mind_map(self, max_display_depth=None, wrap_mermaid=False) -> st

mermaid_chart = mermaid_mind.format(rows='\n'.join(rows), root=self.root)

mermaid_chart = self.remove_bad_mermaid_chars(mermaid_chart)

if wrap_mermaid:
mermaid_chart = _wrap_mermaid(mermaid_chart)

Expand All @@ -298,7 +307,7 @@ def get_called_funcs_memo(f: "ghidra.program.model.listing.Function"):


# Recursively calling to build calling graph
def get_calling(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = CallGraph(), depth: int = 0, visited: tuple = None, verbose=False, include_ns=True, start_time=None, max_run_time=None):
def get_calling(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = CallGraph(), depth: int = 0, visited: tuple = None, verbose=False, include_ns=True, start_time=None, max_run_time=None, max_depth=MAX_DEPTH):
"""
Build a call graph of all calling functions
Traverses depth first
Expand All @@ -314,12 +323,15 @@ def get_calling(f: "ghidra.program.model.listing.Function", cgraph: CallGraph =
visited = tuple()
start_time = time.time()

if depth > MAX_DEPTH:
if depth == MAX_DEPTH:
cgraph.add_edge(f.getName(include_ns), f'MAX_DEPTH_HIT - {depth}', depth)
return cgraph

if (time.time() - start_time) > float(max_run_time):
raise TimeoutError(f'time expired for {f.getName(include_ns)}')
#raise TimeoutError(f'time expired for {clean_func(f,include_ns)}')
cgraph.add_edge(f.getName(include_ns), f'MAX_TIME_HIT - time: {max_run_time} depth: {depth}', depth)
print(f'\nWarn: cg : {cgraph.root} edges: {cgraph.links_count()} depth: {depth} name: {f.name} did not complete. max_run_time: {max_run_time} Increase timeout with --max-time-cg-gen MAX_TIME_CG_GEN')
return cgraph

space = (depth+2)*' '

Expand Down Expand Up @@ -349,7 +361,7 @@ def get_calling(f: "ghidra.program.model.listing.Function", cgraph: CallGraph =
cgraph.add_edge(c.getName(include_ns), f.getName(include_ns), depth)

# Parse further functions
cgraph = get_calling(c, cgraph, depth, visited=visited, start_time=start_time, max_run_time=max_run_time)
cgraph = get_calling(c, cgraph, depth, visited=visited, start_time=start_time, max_run_time=max_run_time, max_depth=max_depth)
else:
if verbose:
print(f'{space} - END for {f.name}')
Expand Down Expand Up @@ -380,13 +392,15 @@ def get_called(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = C
visited = tuple()
start_time = time.time()

if depth > max_depth:
if depth == max_depth:
cgraph.add_edge(f.getName(include_ns), f'MAX_DEPTH_HIT - {depth}', depth)
return cgraph

if (time.time() - start_time) > float(max_run_time):
raise TimeoutError(f'time expired for {f.getName(include_ns)}')

cgraph.add_edge(f.getName(include_ns), f'MAX_TIME_HIT - time: {max_run_time} depth: {depth}', depth)
print(f'\nWarn: cg : {cgraph.root} edges: {cgraph.links_count()} depth: {depth} name: {f.name} did not complete. max_run_time: {max_run_time} Increase timeout with --max-time-cg-gen MAX_TIME_CG_GEN')
return cgraph

space = (depth+2)*' '

# loop check
Expand Down Expand Up @@ -427,7 +441,7 @@ def get_called(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = C

# Parse further functions
cgraph = get_called(c, cgraph, depth, visited=visited,
start_time=start_time, max_run_time=max_run_time)
start_time=start_time, max_run_time=max_run_time, max_depth=max_depth)

else:
if verbose:
Expand Down
23 changes: 11 additions & 12 deletions ghidrecomp/decompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,12 @@ def gen_callgraph(func: 'ghidra.program.model.listing.Function', max_display_dep
flow = ''
callgraph = None

try:
if direction == 'calling':
callgraph = get_calling(func, max_run_time=max_run_time)
elif direction == 'called':
callgraph = get_called(func, max_run_time=max_run_time)
else:
raise Exception(f'Unsupported callgraph direction {direction}')

except TimeoutError as error:
flow = flow_ends = mind = f'\nError: {error} func: {func.name}. max_run_time: {max_run_time} Increase timeout with --max-time-cg-gen MAX_TIME_CG_GEN'
print(flow)
if direction == 'calling':
callgraph = get_calling(func, max_run_time=max_run_time)
elif direction == 'called':
callgraph = get_called(func, max_run_time=max_run_time)
else:
raise Exception(f'Unsupported callgraph direction {direction}')

if callgraph is not None:
flow = callgraph.gen_mermaid_flow_graph(
Expand Down Expand Up @@ -267,8 +262,12 @@ def decompile(args: Namespace):
else:
directions = [args.cg_direction]

max_display_depth = None
if args.max_display_depth is not None:
max_display_depth = int(args.max_display_depth)

with concurrent.futures.ThreadPoolExecutor(max_workers=thread_count) as executor:
futures = (executor.submit(gen_callgraph, func, args.max_display_depth, direction, args.max_time_cg_gen)
futures = (executor.submit(gen_callgraph, func, max_display_depth, direction, args.max_time_cg_gen)
for direction in directions for func in all_funcs if args.skip_cache or get_filename(func) not in callgraphs_completed and re.search(args.callgraph_filter, func.name) is not None)

for future in concurrent.futures.as_completed(futures):
Expand Down
20 changes: 20 additions & 0 deletions tests/test_callgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,23 @@ def test_decomplie_afd_callgraphs_cached(shared_datadir: Path):
assert compiler == 'visualstudio:unknown'
assert lang_id == 'x86:LE:64:default'
assert len(callgraphs) == 0

def test_decomplie_afd_callgraphs_called_and_calling(shared_datadir: Path):

parser = get_parser()

bin_path = shared_datadir / 'afd.sys.10.0.22621.1415'

args = parser.parse_args([f"{bin_path.absolute()}", "--callgraph-filter", "AfdRe",
"--filter", "AfdRe", "--callgraphs", "--skip-cache", "--cg-direction", "both"])

expected_output_path = Path(args.output_path) / bin_path.name

all_funcs, decompilations, output_path, compiler, lang_id, callgraphs = decompile(args)

assert len(all_funcs) == 73
assert len(decompilations) == 73
assert output_path == expected_output_path
assert compiler == 'visualstudio:unknown'
assert lang_id == 'x86:LE:64:default'
assert len(callgraphs) == 146

0 comments on commit 19cad99

Please sign in to comment.