Skip to content

Commit db91196

Browse files
Add test case that accumulates a stream as a memref using memref.realloc.
1 parent d49ed91 commit db91196

File tree

1 file changed

+53
-1
lines changed
  • experimental/iterators/test/Integration/Dialect/Iterators/CPU

1 file changed

+53
-1
lines changed

experimental/iterators/test/Integration/Dialect/Iterators/CPU/accumulate.mlir

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,15 @@
22
// RUN: -convert-iterators-to-llvm \
33
// RUN: -decompose-iterator-states \
44
// RUN: -decompose-tuples \
5+
// RUN: -convert-tabular-to-llvm \
6+
// RUN: -inline -canonicalize \
7+
// RUN: -arith-bufferize \
8+
// RUN: -expand-strided-metadata \
9+
// RUN: -finalize-memref-to-llvm \
510
// RUN: -convert-func-to-llvm \
6-
// RUN: -convert-scf-to-cf -convert-cf-to-llvm \
11+
// RUN: -convert-scf-to-cf \
12+
// RUN: -convert-cf-to-llvm \
13+
// RUN: -reconcile-unrealized-casts \
714
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
815
// RUN: | FileCheck %s
916

@@ -72,8 +79,53 @@ func.func @test_accumulate_avg_tuple() {
7279
return
7380
}
7481

82+
!memref_i32 = memref<?xi32>
83+
84+
func.func private @accumulate_realloc(
85+
%acc : !memref_i32, %val : tuple<i32>) -> !memref_i32 {
86+
%zero = arith.constant 0 : index
87+
%one = arith.constant 1 : index
88+
%dim = memref.dim %acc, %zero : !memref_i32
89+
%new_dim = arith.addi %one, %dim : index
90+
%realloced = memref.realloc %acc (%new_dim) : !memref_i32 to !memref_i32
91+
%vali = tuple.to_elements %val : tuple<i32>
92+
memref.store %vali, %realloced[%dim] : !memref_i32
93+
return %realloced : !memref_i32
94+
}
95+
96+
// CHECK-LABEL: test_accumulate_realloc
97+
// CHECK-NEXT: (9)
98+
// CHECK-NEXT: (8)
99+
// CHECK-NEXT: (7)
100+
// CHECK-NEXT: -
101+
func.func @test_accumulate_realloc() {
102+
iterators.print("test_accumulate_realloc")
103+
%tensor = arith.constant dense<[9, 8, 7]> : tensor<3xi32>
104+
%memref = bufferization.to_memref %tensor : memref<3xi32>
105+
%view = "tabular.view_as_tabular"(%memref)
106+
: (memref<3xi32>) -> !tabular.tabular_view<i32>
107+
%stream = iterators.tabular_view_to_stream %view
108+
to !iterators.stream<tuple<i32>>
109+
%zero = arith.constant 0 : index
110+
%alloced = memref.alloc (%zero) : !memref_i32
111+
%accumulated = iterators.accumulate(%stream, %alloced)
112+
with @accumulate_realloc
113+
: (!iterators.stream<tuple<i32>>) -> !iterators.stream<!memref_i32>
114+
%result:2 = iterators.stream_to_value %accumulated :
115+
!iterators.stream<!memref_i32>
116+
scf.if %result#1 {
117+
%result_view = "tabular.view_as_tabular"(%result#0)
118+
: (memref<?xi32>) -> !tabular.tabular_view<i32>
119+
%result_stream = iterators.tabular_view_to_stream %result_view
120+
to !iterators.stream<tuple<i32>>
121+
"iterators.sink"(%result_stream) : (!iterators.stream<tuple<i32>>) -> ()
122+
}
123+
return
124+
}
125+
75126
func.func @main() {
76127
call @test_accumulate_sum_tuple() : () -> ()
77128
call @test_accumulate_avg_tuple() : () -> ()
129+
call @test_accumulate_realloc() : () -> ()
78130
return
79131
}

0 commit comments

Comments
 (0)