Skip to content

Commit 68075ea

Browse files
authored
set default strides for outputs (#105)
1 parent c1477cd commit 68075ea

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

src/kernel/output.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,22 @@
2020
namespace mirage {
2121
namespace kernel {
2222

23+
std::vector<size_t> get_default_strides(DTensor const &A) {
24+
std::vector<size_t> strides(A.num_dims);
25+
size_t stride = 1;
26+
for (int i = A.num_dims - 1; i >= 0; --i) {
27+
strides[i] = stride;
28+
stride *= A.dim[i];
29+
}
30+
return strides;
31+
}
32+
2333
void Graph::mark_output(DTensor const &A) {
24-
std::vector<size_t> strides;
25-
return mark_output(A);
34+
return mark_output(A, get_default_strides(A));
2635
}
2736

2837
void Graph::mark_output(DTensor const *A) {
29-
std::vector<size_t> strides;
30-
return mark_output(A);
38+
return mark_output(A, get_default_strides(*A));
3139
}
3240

3341
void Graph::mark_output(DTensor const &A, std::vector<size_t> const &strides) {

0 commit comments

Comments
 (0)