Skip to content

Commit 19cad99

Browse files
committed
fix max-depth-display, test for called graphs, remove bad chars from mermaidjs, remove timeout exception
1 parent 6d3b04c commit 19cad99

File tree

4 files changed

+56
-23
lines changed

4 files changed

+56
-23
lines changed

ghidrecomp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = '0.4.2'
1+
__version__ = '0.4.3'
22
__author__ = 'clearbluejar'
33

44
# Expose API

ghidrecomp/callgraph.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import zlib
44
import json
55
import sys
6+
import re
67

78
from typing import TYPE_CHECKING
89
from functools import lru_cache
@@ -124,6 +125,10 @@ def links_count(self) -> int:
124125
count += 1
125126

126127
return count
128+
129+
@staticmethod
130+
def remove_bad_mermaid_chars(text: str):
131+
return re.sub(r'`','',text)
127132

128133
def gen_mermaid_flow_graph(self, direction=None, shaded_nodes: list = None, shade_color='#339933', max_display_depth=None, endpoint_only=False, wrap_mermaid=False) -> str:
129134
"""
@@ -200,7 +205,7 @@ def gen_mermaid_flow_graph(self, direction=None, shaded_nodes: list = None, shad
200205
depth = node[1]
201206
fname = node[0]
202207

203-
if max_display_depth and depth > max_display_depth:
208+
if max_display_depth is not None and depth > max_display_depth:
204209
continue
205210

206211
if shaded_nodes and fname in shaded_nodes:
@@ -238,6 +243,8 @@ def gen_mermaid_flow_graph(self, direction=None, shaded_nodes: list = None, shad
238243

239244
mermaid_chart = mermaid_flow.format(links='\n'.join(links.keys()), direction=direction, style=style)
240245

246+
mermaid_chart = self.remove_bad_mermaid_chars(mermaid_chart)
247+
241248
if wrap_mermaid:
242249
mermaid_chart = _wrap_mermaid(mermaid_chart)
243250

@@ -266,7 +273,7 @@ def gen_mermaid_mind_map(self, max_display_depth=None, wrap_mermaid=False) -> st
266273
depth = row[1]
267274

268275
# skip root row
269-
if depth < 2 or max_display_depth and depth > max_display_depth:
276+
if depth < 2 or max_display_depth is not None and depth > max_display_depth:
270277
continue
271278

272279
if depth < last_depth:
@@ -281,6 +288,8 @@ def gen_mermaid_mind_map(self, max_display_depth=None, wrap_mermaid=False) -> st
281288

282289
mermaid_chart = mermaid_mind.format(rows='\n'.join(rows), root=self.root)
283290

291+
mermaid_chart = self.remove_bad_mermaid_chars(mermaid_chart)
292+
284293
if wrap_mermaid:
285294
mermaid_chart = _wrap_mermaid(mermaid_chart)
286295

@@ -298,7 +307,7 @@ def get_called_funcs_memo(f: "ghidra.program.model.listing.Function"):
298307

299308

300309
# Recursively calling to build calling graph
301-
def get_calling(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = CallGraph(), depth: int = 0, visited: tuple = None, verbose=False, include_ns=True, start_time=None, max_run_time=None):
310+
def get_calling(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = CallGraph(), depth: int = 0, visited: tuple = None, verbose=False, include_ns=True, start_time=None, max_run_time=None, max_depth=MAX_DEPTH):
302311
"""
303312
Build a call graph of all calling functions
304313
Traverses depth first
@@ -314,12 +323,15 @@ def get_calling(f: "ghidra.program.model.listing.Function", cgraph: CallGraph =
314323
visited = tuple()
315324
start_time = time.time()
316325

317-
if depth > MAX_DEPTH:
326+
if depth == MAX_DEPTH:
318327
cgraph.add_edge(f.getName(include_ns), f'MAX_DEPTH_HIT - {depth}', depth)
319328
return cgraph
320329

321330
if (time.time() - start_time) > float(max_run_time):
322-
raise TimeoutError(f'time expired for {f.getName(include_ns)}')
331+
#raise TimeoutError(f'time expired for {clean_func(f,include_ns)}')
332+
cgraph.add_edge(f.getName(include_ns), f'MAX_TIME_HIT - time: {max_run_time} depth: {depth}', depth)
333+
print(f'\nWarn: cg : {cgraph.root} edges: {cgraph.links_count()} depth: {depth} name: {f.name} did not complete. max_run_time: {max_run_time} Increase timeout with --max-time-cg-gen MAX_TIME_CG_GEN')
334+
return cgraph
323335

324336
space = (depth+2)*' '
325337

@@ -349,7 +361,7 @@ def get_calling(f: "ghidra.program.model.listing.Function", cgraph: CallGraph =
349361
cgraph.add_edge(c.getName(include_ns), f.getName(include_ns), depth)
350362

351363
# Parse further functions
352-
cgraph = get_calling(c, cgraph, depth, visited=visited, start_time=start_time, max_run_time=max_run_time)
364+
cgraph = get_calling(c, cgraph, depth, visited=visited, start_time=start_time, max_run_time=max_run_time, max_depth=max_depth)
353365
else:
354366
if verbose:
355367
print(f'{space} - END for {f.name}')
@@ -380,13 +392,15 @@ def get_called(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = C
380392
visited = tuple()
381393
start_time = time.time()
382394

383-
if depth > max_depth:
395+
if depth == max_depth:
384396
cgraph.add_edge(f.getName(include_ns), f'MAX_DEPTH_HIT - {depth}', depth)
385397
return cgraph
386398

387399
if (time.time() - start_time) > float(max_run_time):
388-
raise TimeoutError(f'time expired for {f.getName(include_ns)}')
389-
400+
cgraph.add_edge(f.getName(include_ns), f'MAX_TIME_HIT - time: {max_run_time} depth: {depth}', depth)
401+
print(f'\nWarn: cg : {cgraph.root} edges: {cgraph.links_count()} depth: {depth} name: {f.name} did not complete. max_run_time: {max_run_time} Increase timeout with --max-time-cg-gen MAX_TIME_CG_GEN')
402+
return cgraph
403+
390404
space = (depth+2)*' '
391405

392406
# loop check
@@ -427,7 +441,7 @@ def get_called(f: "ghidra.program.model.listing.Function", cgraph: CallGraph = C
427441

428442
# Parse further functions
429443
cgraph = get_called(c, cgraph, depth, visited=visited,
430-
start_time=start_time, max_run_time=max_run_time)
444+
start_time=start_time, max_run_time=max_run_time, max_depth=max_depth)
431445

432446
else:
433447
if verbose:

ghidrecomp/decompile.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,17 +106,12 @@ def gen_callgraph(func: 'ghidra.program.model.listing.Function', max_display_dep
106106
flow = ''
107107
callgraph = None
108108

109-
try:
110-
if direction == 'calling':
111-
callgraph = get_calling(func, max_run_time=max_run_time)
112-
elif direction == 'called':
113-
callgraph = get_called(func, max_run_time=max_run_time)
114-
else:
115-
raise Exception(f'Unsupported callgraph direction {direction}')
116-
117-
except TimeoutError as error:
118-
flow = flow_ends = mind = f'\nError: {error} func: {func.name}. max_run_time: {max_run_time} Increase timeout with --max-time-cg-gen MAX_TIME_CG_GEN'
119-
print(flow)
109+
if direction == 'calling':
110+
callgraph = get_calling(func, max_run_time=max_run_time)
111+
elif direction == 'called':
112+
callgraph = get_called(func, max_run_time=max_run_time)
113+
else:
114+
raise Exception(f'Unsupported callgraph direction {direction}')
120115

121116
if callgraph is not None:
122117
flow = callgraph.gen_mermaid_flow_graph(
@@ -267,8 +262,12 @@ def decompile(args: Namespace):
267262
else:
268263
directions = [args.cg_direction]
269264

265+
max_display_depth = None
266+
if args.max_display_depth is not None:
267+
max_display_depth = int(args.max_display_depth)
268+
270269
with concurrent.futures.ThreadPoolExecutor(max_workers=thread_count) as executor:
271-
futures = (executor.submit(gen_callgraph, func, args.max_display_depth, direction, args.max_time_cg_gen)
270+
futures = (executor.submit(gen_callgraph, func, max_display_depth, direction, args.max_time_cg_gen)
272271
for direction in directions for func in all_funcs if args.skip_cache or get_filename(func) not in callgraphs_completed and re.search(args.callgraph_filter, func.name) is not None)
273272

274273
for future in concurrent.futures.as_completed(futures):

tests/test_callgraph.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,23 @@ def test_decomplie_afd_callgraphs_cached(shared_datadir: Path):
4343
assert compiler == 'visualstudio:unknown'
4444
assert lang_id == 'x86:LE:64:default'
4545
assert len(callgraphs) == 0
46+
47+
def test_decomplie_afd_callgraphs_called_and_calling(shared_datadir: Path):
48+
49+
parser = get_parser()
50+
51+
bin_path = shared_datadir / 'afd.sys.10.0.22621.1415'
52+
53+
args = parser.parse_args([f"{bin_path.absolute()}", "--callgraph-filter", "AfdRe",
54+
"--filter", "AfdRe", "--callgraphs", "--skip-cache", "--cg-direction", "both"])
55+
56+
expected_output_path = Path(args.output_path) / bin_path.name
57+
58+
all_funcs, decompilations, output_path, compiler, lang_id, callgraphs = decompile(args)
59+
60+
assert len(all_funcs) == 73
61+
assert len(decompilations) == 73
62+
assert output_path == expected_output_path
63+
assert compiler == 'visualstudio:unknown'
64+
assert lang_id == 'x86:LE:64:default'
65+
assert len(callgraphs) == 146

0 commit comments

Comments
 (0)