Skip to content

Commit 5dd8a0c

Browse files
Extend test script with logging and debug output.
1 parent 333ebef commit 5dd8a0c

File tree

1 file changed

+44
-1
lines changed
  • experimental/iterators/test/python/dialects/iterators

1 file changed

+44
-1
lines changed

experimental/iterators/test/python/dialects/iterators/arrow.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

3+
import argparse
34
import ctypes
5+
import logging
46
import os
57
import sys
68
from tempfile import NamedTemporaryFile
@@ -18,6 +20,32 @@
1820
from mlir_iterators.execution_engine import ExecutionEngine
1921
from mlir_iterators.ir import Context, Module
2022

23+
# Set up logging.
24+
LOGLEVELS = {
25+
logging.getLevelName(l): l
26+
for l in (logging.CRITICAL, logging.ERROR, logging.WARNING, logging.INFO,
27+
logging.DEBUG)
28+
}
29+
30+
# Parse command line arguments for interactive testing/debugging.
31+
parser = argparse.ArgumentParser(
32+
description='Integration tests for iterators related to Apache Arrow.')
33+
parser.add_argument('--log-level',
34+
type=LOGLEVELS.__getitem__,
35+
default=logging.ERROR,
36+
help='Set the log level by name')
37+
parser.add_argument('--enable-ir-printing',
38+
action='store_true',
39+
help='Enable printing IR after every pass')
40+
args = parser.parse_args()
41+
42+
logging.getLogger().setLevel(args.log_level)
43+
44+
45+
def format_code(code: str) -> str:
46+
return '\n'.join(
47+
(f'{i:>4}: {l}' for i, l in enumerate(str(code).splitlines())))
48+
2149

2250
def run(f):
2351
print("\nTEST:", f.__name__)
@@ -85,7 +113,12 @@ def to_mlir_type(t: pa.DataType) -> str:
85113

86114
# Compiles the given code and wraps it into an execution engine.
87115
def build_and_create_engine(code: str) -> ExecutionEngine:
88-
mod = Module.parse(ARROW_STRUCT_DEFINITIONS_MLIR + code)
116+
# Assemble, log, and parse input IR.
117+
code = ARROW_STRUCT_DEFINITIONS_MLIR + code
118+
logging.info("Input IR:\n\n%s\n", format_code(code))
119+
mod = Module.parse(code)
120+
121+
# Assemble and log pass pipeline.
89122
pm = PassManager.parse('builtin.module('
90123
'convert-iterators-to-llvm,'
91124
'convert-tabular-to-llvm,'
@@ -99,7 +132,17 @@ def build_and_create_engine(code: str) -> ExecutionEngine:
99132
'reconcile-unrealized-casts,'
100133
'convert-scf-to-cf,'
101134
'convert-cf-to-llvm)')
135+
logging.info("Pass pipeline:\n\n%s\n", pm)
136+
137+
# Enable printing of intermediate IR if requested.
138+
if args.enable_ir_printing:
139+
mod.context.enable_multithreading(False)
140+
pm.enable_ir_printing()
141+
142+
# Run pipeline.
102143
pm.run(mod.operation)
144+
145+
# Create and return engine.
103146
runtime_lib = os.environ['ITERATORS_RUNTIME_LIBRARY_PATH']
104147
engine = ExecutionEngine(mod, shared_libs=[runtime_lib])
105148
return engine

0 commit comments

Comments
 (0)