Skip to content

Commit b6a3dc3

Browse files
Return stream of tuples instead of LLVMStructs in ConstantStreamOp.
1 parent 405b0f8 commit b6a3dc3

File tree

19 files changed

+296
-292
lines changed

19 files changed

+296
-292
lines changed

experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -79,46 +79,38 @@ def Iterators_PrintOp : Iterators_Base_Op<"print", [
7979
// High-level iterators
8080
//===----------------------------------------------------------------------===//
8181

82-
/// Verifies that the element types of nested arrays in the $value array
83-
/// correspond to the types of the LLVM-struct element type of the $result
84-
/// Stream.
85-
def Iterators_ValueMatchesElementTypePred
86-
: CPred<[{$value.dyn_cast<ArrayAttr>().size() == 0 ||
87-
$result.getType().dyn_cast<StreamType>().getElementType() ==
88-
::mlir::LLVM::LLVMStructType::getLiteral(
89-
$result.getType().getContext(),
90-
::llvm::SmallVector<Type>(
91-
::llvm::map_range(
92-
$value.dyn_cast<::mlir::ArrayAttr>().begin()->dyn_cast<::mlir::ArrayAttr>(),
93-
[](Attribute attr) { return attr.cast<TypedAttr>().getType(); }
94-
)
95-
)
96-
)}]>;
97-
def Iterators_ValueMatchesElementType
98-
: PredOpTrait<"value type matches return type",
99-
Iterators_ValueMatchesElementTypePred>;
100-
101-
def Iterators_ConstantStreamOp : Iterators_Op<"constantstream",
102-
[Iterators_ValueMatchesElementType,
103-
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
82+
def Iterators_ConstantStreamOp : Iterators_Op<"constantstream", [
83+
PredOpTrait<"element type of return type must be tuple with matching types",
84+
CPred<[{
85+
$value.cast<::mlir::ArrayAttr>().size () == 0 ||
86+
TupleType::get(
87+
$value.getContext(),
88+
::llvm::SmallVector<Type>(
89+
::llvm::map_range(
90+
$value.cast<::mlir::ArrayAttr>().begin()->cast<::mlir::ArrayAttr>(),
91+
[](Attribute attr) { return attr.cast<TypedAttr>().getType(); }
92+
))) ==
93+
$result.getType().cast<StreamType>().getElementType()}]>>,
94+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
10495
let summary = "Produce a statically defined stream of elements";
10596
let description = [{
106-
Produces a stream of LLVM structs given in the array of arrays attribute
107-
(each inner array being returned as a literal LLVM struct with the values
108-
and types of the elements of that array). The inner arrays all have to have
109-
matching types, i.e., the element at position i has to be the same for all
110-
inner arrays, and the element type of the return Stream has to be the
111-
corresponding literal LLVM struct. An empty array is allowed (in which case
112-
the return Stream does not need to match anything).
97+
Produces a stream of tuples given in the array of arrays attribute (each
98+
inner array being returned as a built-in tuple with the values and types of
99+
the elements of that array). The inner arrays all have to have matching
100+
types, i.e., the element at position i has to be the same for all inner
101+
arrays, and the element type of the return Stream has to be the
102+
corresponding tuple tpye. An empty array is allowed (in which case the
103+
return Stream does not need to match anything).
113104

114105
Example:
115106
```mlir
116107
%constantstream = "iterators.constantstream"() { value = [[42 : i32]] } :
117-
() -> (!iterators.stream<!llvm.struct<(i32)>>)
108+
() -> (!iterators.stream<tuple<i32>>)
118109
```
119110
}];
111+
// TODO(ingomueller): Devise a lowering that allows to return non-LLVM types.
120112
let arguments = (ins Iterators_HomogeneouslyTypedLLVMNumericArrayArrayAttr:$value);
121-
let results = (outs Iterators_StreamOfLLVMStructOfNumerics:$result);
113+
let results = (outs Iterators_StreamOfPrintableTuples:$result);
122114
let extraClassDefinition = [{
123115
/// Implement OpAsmOpInterface.
124116
void $cppClass::getAsmResultNames(

experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsTypes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ class Iterators_StreamOf<Type elementType>
230230
def Iterators_StreamOfLLVMStructOfNumerics
231231
: Iterators_StreamOf<Iterators_LLVMStructOfNumerics>;
232232

233+
/// An Iterators stream of tuples of printable types.
234+
def Iterators_StreamOfPrintableTuples
235+
: Iterators_StreamOf<Iterators_TupleOfPrintableTypes>;
236+
233237
/// An Iterators stream of printable elements.
234238
def Iterators_StreamOfPrintableElements
235239
: Iterators_StreamOf<Iterators_PrintableType>;

experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -408,20 +408,22 @@ static GlobalOp buildGlobalData(ConstantStreamOp op, OpBuilder &builder,
408408
/// %0 = iterators.extractvalue %arg0[0] : !iterators.state<i32>
409409
/// %c4_i32 = arith.constant 4 : i32
410410
/// %1 = arith.cmpi slt, %0, %c4_i32 : i32
411-
/// %2:2 = scf.if %1 -> (!iterators.state<i32>, !element_type) {
411+
/// %2:2 = scf.if %1 -> (!iterators.state<i32>, !struct_tpe) {
412412
/// %c1_i32 = arith.constant 1 : i32
413413
/// %3 = arith.addi %0, %c1_i32 : i32
414414
/// %4 = iterators.insertvalue %3 into %arg0[0] : !iterators.state<i32>
415415
/// %5 = llvm.mlir.addressof @iterators.constant_stream_data.0 : !llvm.ptr
416416
/// %6 = llvm.getelementptr %5[%0, 0] :
417-
/// (!llvm.ptr<array<4 x !element_type>>, i32, i32)
418-
/// -> !llvm.ptr, !element_type
419-
/// %7 = llvm.load %6 : !llvm.ptr -> !element_type
420-
/// scf.yield %4, %7 : !iterators.state<i32>, !element_type
417+
/// (!llvm.ptr<array<4 x !struct_type>>, i32, i32)
418+
/// -> !llvm.ptr, !struct_type
419+
/// %7 = llvm.load %6 : !llvm.ptr -> !struct_type
420+
/// scf.yield %4, %7 : !iterators.state<i32>, !struct_type
421421
/// } else {
422-
/// %3 = llvm.mlir.undef : !element_type
423-
/// scf.yield %arg0, %3 : !iterators.state<i32>, !element_type
422+
/// %4 = llvm.mlir.undef : !struct_tpe
423+
/// scf.yield %arg0, %3 : !iterators.state<i32>, !struct_tpe
424424
/// }
425+
/// %3 = llvm.extractvalue %2#1[0] : !llvm.struct<(i32)>
426+
/// %tuple = tuple.from_elements %3 : tuple<i32>
425427
static llvm::SmallVector<Value, 4>
426428
buildNextBody(ConstantStreamOp op, OpBuilder &builder, Value initialState,
427429
ArrayRef<IteratorInfo> upstreamInfos, Type elementType) {
@@ -430,6 +432,8 @@ buildNextBody(ConstantStreamOp op, OpBuilder &builder, Value initialState,
430432
MLIRContext *context = builder.getContext();
431433
Type i32 = b.getI32Type();
432434
Type opaquePtrType = LLVMPointerType::get(context);
435+
auto tupleType = elementType.cast<TupleType>();
436+
auto structType = LLVMStructType::getLiteral(context, tupleType.getTypes());
433437

434438
// Extract current index.
435439
Value currentIndex = b.create<iterators::ExtractValueOp>(
@@ -456,26 +460,34 @@ buildNextBody(ConstantStreamOp op, OpBuilder &builder, Value initialState,
456460
initialState, b.getIndexAttr(0), updatedCurrentIndex);
457461

458462
// Load element from global data at current index.
459-
GlobalOp globalArray = buildGlobalData(op, b, elementType);
463+
GlobalOp globalArray = buildGlobalData(op, b, structType);
460464
Value globalPtr =
461465
b.create<AddressOfOp>(opaquePtrType, globalArray.getName());
462-
Value gep = b.create<GEPOp>(opaquePtrType, elementType, globalPtr,
466+
Value gep = b.create<GEPOp>(opaquePtrType, structType, globalPtr,
463467
ArrayRef<GEPArg>{currentIndex, 0});
464-
Value nextElement = b.create<LoadOp>(elementType, gep);
468+
Value nextStruct = b.create<LoadOp>(structType, gep);
465469

466-
b.create<scf::YieldOp>(ValueRange{updatedState, nextElement});
470+
b.create<scf::YieldOp>(ValueRange{updatedState, nextStruct});
467471
},
468472
/*elseBuilder=*/
469473
[&](OpBuilder &builder, Location loc) {
470474
ImplicitLocOpBuilder b(loc, builder);
471475

472476
// Don't modify state; return undef element.
473-
Value nextElement = b.create<UndefOp>(elementType);
474-
b.create<scf::YieldOp>(ValueRange{initialState, nextElement});
477+
Value nextStruct = b.create<UndefOp>(structType);
478+
b.create<scf::YieldOp>(ValueRange{initialState, nextStruct});
475479
});
476480

481+
// Convert LLVM struct to tuple.
482+
Value nextStruct = ifOp.getResult(1);
483+
SmallVector<Value> elements;
484+
for (auto [i, fieldType] : llvm::enumerate(structType.getBody())) {
485+
auto element = b.create<LLVM::ExtractValueOp>(fieldType, nextStruct, i);
486+
elements.push_back(element);
487+
}
488+
Value nextElement = b.create<tuple::FromElementsOp>(elementType, elements);
489+
477490
Value finalState = ifOp->getResult(0);
478-
Value nextElement = ifOp.getResult(1);
479491
return {finalState, hasNext, nextElement};
480492
}
481493

@@ -781,8 +793,22 @@ buildNextBody(MapOp op, OpBuilder &builder, Value initialState,
781793
[&](OpBuilder &builder, Location loc) {
782794
// Return undefined value.
783795
ImplicitLocOpBuilder b(loc, builder);
784-
Value undef = b.create<LLVM::UndefOp>(elementType);
785-
b.create<scf::YieldOp>(undef);
796+
// TODO(ingomueller): Find a more extensible design.
797+
Value defaultElement;
798+
if (auto tupleType = elementType.dyn_cast<TupleType>()) {
799+
// Special case for tuples: hope that field types are undef'able.
800+
SmallVector<Value> fieldValues;
801+
for (Type fieldType : tupleType.getTypes()) {
802+
auto fieldValue = b.create<LLVM::UndefOp>(fieldType);
803+
fieldValues.push_back(fieldValue);
804+
}
805+
defaultElement =
806+
b.create<tuple::FromElementsOp>(tupleType, fieldValues);
807+
} else {
808+
// Default case: hope that type is undef'able.
809+
defaultElement = b.create<LLVM::UndefOp>(elementType);
810+
}
811+
b.create<scf::YieldOp>(defaultElement);
786812
});
787813
Value mappedElement = ifOp.getResult(0);
788814

experimental/iterators/test/Conversion/IteratorsToLLVM/constant-stream.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
// RUN: iterators-opt %s -convert-iterators-to-llvm \
22
// RUN: | FileCheck --enable-var-scope %s
33

4-
!element_type = !llvm.struct<(i32)>
5-
64
// CHECK-LABEL: func private @iterators.constantstream.close.{{[0-9]+}}(%{{.*}}: !iterators.state<i32>) -> !iterators.state<i32>
75
// CHECK-NEXT: return %[[arg0:.*]] : !iterators.state<i32>
86
// CHECK-NEXT: }
@@ -28,7 +26,7 @@
2826
// CHECK-NEXT: llvm.return %[[V16]] : !llvm.array<4 x struct<(i32)>>
2927
// CHECK-NEXT: }
3028

31-
// CHECK-LABEL: func private @iterators.constantstream.next.{{[0-9]+}}(%{{.*}}: !iterators.state<i32>) -> (!iterators.state<i32>, i1, !llvm.struct<(i32)>)
29+
// CHECK-LABEL: func private @iterators.constantstream.next.{{[0-9]+}}(%{{.*}}: !iterators.state<i32>) -> (!iterators.state<i32>, i1, tuple<i32>)
3230
// CHECK-NEXT: %[[V0:.*]] = iterators.extractvalue %[[arg0:.*]][0] : !iterators.state<i32>
3331
// CHECK-NEXT: %[[V1:.*]] = arith.constant 4 : i32
3432
// CHECK-NEXT: %[[V2:.*]] = arith.cmpi slt, %[[V0]], %[[V1]] : i32
@@ -44,7 +42,9 @@
4442
// CHECK-NEXT: %[[Vb:.*]] = llvm.mlir.undef : !llvm.struct<(i32)>
4543
// CHECK-NEXT: scf.yield %[[arg0]], %[[Vb]] : !iterators.state<i32>, !llvm.struct<(i32)>
4644
// CHECK-NEXT: }
47-
// CHECK-NEXT: return %[[V3]]#0, %[[V2]], %[[V3]]#1 : !iterators.state<i32>, i1, !llvm.struct<(i32)>
45+
// CHECK-NEXT: %[[Vc:.*]] = llvm.extractvalue %[[V3]]#1[0] : !llvm.struct<(i32)>
46+
// CHECK-NEXT: %[[Vd:.*]] = tuple.from_elements %[[Vc]] : tuple<i32>
47+
// CHECK-NEXT: return %[[V3]]#0, %[[V2]], %[[Vd]] : !iterators.state<i32>, i1, tuple<i32>
4848
// CHECK-NEXT: }
4949

5050
// CHECK-LABEL: func private @iterators.constantstream.open.{{[0-9]+}}(%{{.*}}: !iterators.state<i32>) -> !iterators.state<i32>
@@ -57,7 +57,7 @@ func.func @main() {
5757
// CHECK-LABEL: func.func @main()
5858
%input = "iterators.constantstream"()
5959
{ value = [[0 : i32], [1 : i32], [2 : i32], [3 : i32]] }
60-
: () -> (!iterators.stream<!element_type>)
60+
: () -> (!iterators.stream<tuple<i32>>)
6161
// CHECK-NEXT: %[[V0:.*]] = arith.constant 0 : i32
6262
// CHECK-NEXT: %[[V1:.*]] = iterators.createstate(%[[V0]]) : !iterators.state<i32>
6363
return

experimental/iterators/test/Conversion/IteratorsToLLVM/filter.mlir

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,33 @@
11
// RUN: iterators-opt %s -convert-iterators-to-llvm \
22
// RUN: | FileCheck --enable-var-scope %s
33

4-
!element_type = !llvm.struct<(i32)>
5-
64
// CHECK-LABEL: func.func private @iterators.filter.close.{{[0-9]+}}(%{{.*}}: !iterators.state<!iterators.state<i32>>) -> !iterators.state<!iterators.state<i32>> {
75
// CHECK-NEXT: %[[V0:.*]] = iterators.extractvalue %[[arg0:.*]][0] : !iterators.state<!iterators.state<i32>>
86
// CHECK-NEXT: %[[V1:.*]] = call @iterators.{{[a-zA-Z]+}}.close.{{[0-9]+}}(%[[V0]]) : ([[upstreamStateType:.*]]) -> [[upstreamStateType]]
97
// CHECK-NEXT: %[[V2:.*]] = iterators.insertvalue %[[V1]] into %[[arg0]][0] : !iterators.state<!iterators.state<i32>>
108
// CHECK-NEXT: return %[[V2]] : !iterators.state<!iterators.state<i32>>
119
// CHECK-NEXT: }
1210

13-
// CHECK-LABEL: func.func private @iterators.filter.next.{{[0-9]+}}(%{{.*}}: !iterators.state<!iterators.state<i32>>) -> (!iterators.state<!iterators.state<i32>>, i1, !llvm.struct<(i32)>)
11+
// CHECK-LABEL: func.func private @iterators.filter.next.{{[0-9]+}}(%{{.*}}: !iterators.state<!iterators.state<i32>>) -> (!iterators.state<!iterators.state<i32>>, i1, tuple<i32>)
1412
// CHECK-NEXT: %[[V0:.*]] = iterators.extractvalue %[[arg0:.*]][0] : !iterators.state<!iterators.state<i32>>
15-
// CHECK-NEXT: %[[V1:.*]]:3 = scf.while (%[[arg1:.*]] = %[[V0]]) : ([[upstreamStateType:.*]]) -> ([[upstreamStateType]], i1, !llvm.struct<(i32)>) {
16-
// CHECK-NEXT: %[[V3:.*]]:3 = func.call @iterators.{{[a-zA-Z]+}}.next.0(%[[arg1]]) : ([[upstreamStateType]]) -> ([[upstreamStateType]], i1, !llvm.struct<(i32)>)
13+
// CHECK-NEXT: %[[V1:.*]]:3 = scf.while (%[[arg1:.*]] = %[[V0]]) : ([[upstreamStateType:.*]]) -> ([[upstreamStateType]], i1, tuple<i32>) {
14+
// CHECK-NEXT: %[[V3:.*]]:3 = func.call @iterators.{{[a-zA-Z]+}}.next.0(%[[arg1]]) : ([[upstreamStateType]]) -> ([[upstreamStateType]], i1, tuple<i32>)
1715
// CHECK-NEXT: %[[V4:.*]] = scf.if %[[V3]]#1 -> (i1) {
18-
// CHECK-NEXT: %[[V7:.*]] = func.call @is_positive_struct(%[[V3]]#2) : (!llvm.struct<(i32)>) -> i1
16+
// CHECK-NEXT: %[[V7:.*]] = func.call @is_positive_tuple(%[[V3]]#2) : (tuple<i32>) -> i1
1917
// CHECK-NEXT: scf.yield %[[V7]] : i1
2018
// CHECK-NEXT: } else {
2119
// CHECK-NEXT: scf.yield %[[V3]]#1 : i1
2220
// CHECK-NEXT: }
2321
// CHECK-NEXT: %[[Vtrue:.*]] = arith.constant true
2422
// CHECK-NEXT: %[[V5:.*]] = arith.xori %[[V4]], %[[Vtrue]] : i1
2523
// CHECK-NEXT: %[[V6:.*]] = arith.andi %[[V3]]#1, %[[V5]] : i1
26-
// CHECK-NEXT: scf.condition(%[[V6]]) %[[V3]]#0, %[[V3]]#1, %[[V3]]#2 : [[upstreamStateType]], i1, !llvm.struct<(i32)>
24+
// CHECK-NEXT: scf.condition(%[[V6]]) %[[V3]]#0, %[[V3]]#1, %[[V3]]#2 : [[upstreamStateType]], i1, tuple<i32>
2725
// CHECK-NEXT: } do {
28-
// CHECK-NEXT: ^bb0(%[[arg2:.*]]: [[upstreamStateType]], %arg2: i1, %arg3: !llvm.struct<(i32)>):
26+
// CHECK-NEXT: ^bb0(%[[arg2:.*]]: [[upstreamStateType]], %arg2: i1, %arg3: tuple<i32>):
2927
// CHECK-NEXT: scf.yield %[[arg2]] : [[upstreamStateType]]
3028
// CHECK-NEXT: }
3129
// CHECK-NEXT: %[[V2:.*]] = iterators.insertvalue %[[V1]]#0 into %[[arg0]][0] : !iterators.state<!iterators.state<i32>>
32-
// CHECK-NEXT: return %[[V2]], %[[V1]]#1, %[[V1]]#2 : !iterators.state<!iterators.state<i32>>, i1, !llvm.struct<(i32)>
30+
// CHECK-NEXT: return %[[V2]], %[[V1]]#1, %[[V1]]#2 : !iterators.state<!iterators.state<i32>>, i1, tuple<i32>
3331
// CHECK-NEXT: }
3432

3533
// CHECK-LABEL: func.func private @iterators.filter.open.{{[0-9]+}}(%{{.*}}: !iterators.state<!iterators.state<i32>>) -> !iterators.state<!iterators.state<i32>>
@@ -39,25 +37,25 @@
3937
// CHECK-NEXT: return %[[V2]] : !iterators.state<!iterators.state<i32>>
4038
// CHECK-NEXT: }
4139

42-
func.func private @is_positive_struct(%struct : !element_type) -> i1 {
43-
// CHECK-LABEL: func.func private @is_positive_struct(%{{.*}}: !llvm.struct<(i32)>) -> i1 {
44-
%i = llvm.extractvalue %struct[0] : !element_type
45-
// CHECK-NEXT: %[[i:.*]] = llvm.extractvalue %[[struct:.*]][0] : !llvm.struct<(i32)>
40+
// CHECK-LABEL: func.func private @is_positive_tuple(
41+
// CHECK-SAME: %[[ARG0:.*]]: tuple<i32>) -> i1 {
42+
// CHECK-DAG: %[[V0:.*]] = tuple.to_elements %[[ARG0]] : tuple<i32>
43+
// CHECK-DAG: %[[V1:.*]] = arith.constant 0 : i32
44+
// CHECK-NEXT: %[[V2:.*]] = arith.cmpi sgt, %[[V0]], %[[V1]] : i32
45+
// CHECK-NEXT: return %[[V2]] : i1
46+
func.func private @is_positive_tuple(%tuple : tuple<i32>) -> i1 {
47+
%i = tuple.to_elements %tuple : tuple<i32>
4648
%zero = arith.constant 0 : i32
47-
// CHECK-NEXT: %[[zero:.*]] = arith.constant 0 : i32
4849
%cmp = arith.cmpi "sgt", %i, %zero : i32
49-
// CHECK-NEXT: %[[cmp:.*]] = arith.cmpi sgt, %[[i]], %[[zero]] : i32
5050
return %cmp : i1
51-
// CHECK-NEXT: return %[[cmp]] : i1
5251
}
53-
// CHECK-NEXT: }
5452

5553
func.func @main() {
5654
// CHECK-LABEL: func.func @main()
57-
%input = "iterators.constantstream"() { value = [] } : () -> (!iterators.stream<!element_type>)
55+
%input = "iterators.constantstream"() { value = [] } : () -> (!iterators.stream<tuple<i32>>)
5856
// CHECK: %[[V0:.*]] = iterators.createstate({{.*}}) : [[upstreamStateType:.*]]
59-
%filter = "iterators.filter"(%input) {predicateRef = @is_positive_struct}
60-
: (!iterators.stream<!element_type>) -> (!iterators.stream<!element_type>)
57+
%filter = "iterators.filter"(%input) {predicateRef = @is_positive_tuple}
58+
: (!iterators.stream<tuple<i32>>) -> (!iterators.stream<tuple<i32>>)
6159
// CHECK-NEXT: %[[V1:.*]] = iterators.createstate(%[[V0]]) : !iterators.state<[[upstreamStateType]]>
6260
return
6361
// CHECK-NEXT: return

0 commit comments

Comments
 (0)