1
1
# RUN: %PYTHON %s | FileCheck %s
2
2
3
+ import argparse
3
4
import ctypes
5
+ import logging
4
6
import os
5
7
import sys
6
8
from tempfile import NamedTemporaryFile
18
20
from mlir_iterators .execution_engine import ExecutionEngine
19
21
from mlir_iterators .ir import Context , Module
20
22
23
+ # Set up logging.
24
+ LOGLEVELS = {
25
+ logging .getLevelName (l ): l
26
+ for l in (logging .CRITICAL , logging .ERROR , logging .WARNING , logging .INFO ,
27
+ logging .DEBUG )
28
+ }
29
+
30
+ # Parse command line arguments for interactive testing/debugging.
31
+ parser = argparse .ArgumentParser (
32
+ description = 'Integration tests for iterators related to Apache Arrow.' )
33
+ parser .add_argument ('--log-level' ,
34
+ type = LOGLEVELS .__getitem__ ,
35
+ default = logging .ERROR ,
36
+ help = 'Set the log level by name' )
37
+ parser .add_argument ('--enable-ir-printing' ,
38
+ action = 'store_true' ,
39
+ help = 'Enable printing IR after every pass' )
40
+ args = parser .parse_args ()
41
+
42
+ logging .getLogger ().setLevel (args .log_level )
43
+
44
+
45
+ def format_code (code : str ) -> str :
46
+ return '\n ' .join (
47
+ (f'{ i :>4} : { l } ' for i , l in enumerate (str (code ).splitlines ())))
48
+
21
49
22
50
def run (f ):
23
51
print ("\n TEST:" , f .__name__ )
@@ -85,7 +113,12 @@ def to_mlir_type(t: pa.DataType) -> str:
85
113
86
114
# Compiles the given code and wraps it into an execution engine.
87
115
def build_and_create_engine (code : str ) -> ExecutionEngine :
88
- mod = Module .parse (ARROW_STRUCT_DEFINITIONS_MLIR + code )
116
+ # Assemble, log, and parse input IR.
117
+ code = ARROW_STRUCT_DEFINITIONS_MLIR + code
118
+ logging .info ("Input IR:\n \n %s\n " , format_code (code ))
119
+ mod = Module .parse (code )
120
+
121
+ # Assemble and log pass pipeline.
89
122
pm = PassManager .parse ('builtin.module('
90
123
'convert-iterators-to-llvm,'
91
124
'convert-tabular-to-llvm,'
@@ -99,7 +132,17 @@ def build_and_create_engine(code: str) -> ExecutionEngine:
99
132
'reconcile-unrealized-casts,'
100
133
'convert-scf-to-cf,'
101
134
'convert-cf-to-llvm)' )
135
+ logging .info ("Pass pipeline:\n \n %s\n " , pm )
136
+
137
+ # Enable printing of intermediate IR if requested.
138
+ if args .enable_ir_printing :
139
+ mod .context .enable_multithreading (False )
140
+ pm .enable_ir_printing ()
141
+
142
+ # Run pipeline.
102
143
pm .run (mod .operation )
144
+
145
+ # Create and return engine.
103
146
runtime_lib = os .environ ['ITERATORS_RUNTIME_LIBRARY_PATH' ]
104
147
engine = ExecutionEngine (mod , shared_libs = [runtime_lib ])
105
148
return engine
0 commit comments