Skip to content

Commit 0dd1227

Browse files
authored
Merge pull request #97 from pnnl/debug-3D-CSF-format
[3D] Fixed bug in lowering CSF format for 3D tensor (issue #80)
2 parents ccf03b9 + 4bfb87b commit 0dd1227

File tree

2 files changed

+34
-34
lines changed

2 files changed

+34
-34
lines changed

lib/Dialect/TensorAlgebra/Transforms/TensorDeclLowering.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,7 +1110,7 @@ Value insertSparseTensorDeclOp(PatternRewriter & rewriter,
11101110
Value sptensor;
11111111
if (rank_size == 2)
11121112
{
1113-
Value dims = rewriter.create<tensor::FromElementsOp>(loc, ValueRange{array_sizes[9], array_sizes[10]});
1113+
Value dims = rewriter.create<tensor::FromElementsOp>(loc, ValueRange{array_sizes[9], array_sizes[10]}); /// I, J
11141114
sptensor = rewriter.create<tensorAlgebra::SparseTensorConstructOp>(loc, ty,
11151115
dims, /// Dim sizes
11161116
ValueRange{
@@ -1134,7 +1134,7 @@ Value insertSparseTensorDeclOp(PatternRewriter & rewriter,
11341134
}
11351135
else if (rank_size == 3)
11361136
{
1137-
Value dims = rewriter.create<tensor::FromElementsOp>(loc, ValueRange{array_sizes[16], array_sizes[17], array_sizes[18]});
1137+
Value dims = rewriter.create<tensor::FromElementsOp>(loc, ValueRange{array_sizes[13], array_sizes[14], array_sizes[15]}); /// I, J, K
11381138

11391139
sptensor = rewriter.create<tensorAlgebra::SparseTensorConstructOp>(loc, ty, dims,
11401140
ValueRange {

lib/ExecutionEngine/SparseUtils.cpp

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2118,45 +2118,45 @@ void read_input_sizes_3D(int32_t fileID,
21182118
/// std::cout << "CSF format\n";
21192119
Csf3DTensor<T> csf_3dtensor(FileReader.coo_3dtensor);
21202120

2121-
desc_sizes->data[0] = csf_3dtensor.A1pos_size;
2122-
desc_sizes->data[1] = csf_3dtensor.A1crd_size;
2123-
desc_sizes->data[2] = 0;
2124-
desc_sizes->data[3] = 0;
2125-
desc_sizes->data[4] = csf_3dtensor.A2pos_size;
2126-
desc_sizes->data[5] = csf_3dtensor.A2crd_size;
2127-
desc_sizes->data[6] = 0;
2128-
desc_sizes->data[7] = 0;
2129-
desc_sizes->data[8] = csf_3dtensor.A3pos_size;
2130-
desc_sizes->data[9] = csf_3dtensor.A3crd_size;
2131-
desc_sizes->data[10] = 0;
2132-
desc_sizes->data[11] = 0;
2133-
desc_sizes->data[12] = csf_3dtensor.Aval_size;
2134-
desc_sizes->data[13] = csf_3dtensor.num_index_i;
2135-
desc_sizes->data[14] = csf_3dtensor.num_index_j;
2136-
desc_sizes->data[15] = csf_3dtensor.num_index_k;
2121+
desc_sizes->data[0] = csf_3dtensor.A1pos_size; /// A1pos
2122+
desc_sizes->data[1] = csf_3dtensor.A1crd_size; /// A1crd
2123+
desc_sizes->data[2] = 0; /// A1_tile_pos
2124+
desc_sizes->data[3] = 0; /// A1_tile_crd
2125+
desc_sizes->data[4] = csf_3dtensor.A2pos_size; /// A2pos
2126+
desc_sizes->data[5] = csf_3dtensor.A2crd_size; /// A2crd
2127+
desc_sizes->data[6] = 0; /// A2_tile_pos
2128+
desc_sizes->data[7] = 0; /// A2_tile_crd
2129+
desc_sizes->data[8] = csf_3dtensor.A3pos_size; /// A3pos
2130+
desc_sizes->data[9] = csf_3dtensor.A3crd_size; /// A3crd
2131+
desc_sizes->data[10] = 0; /// A3_tile_pos
2132+
desc_sizes->data[11] = 0; /// A3_tile_crd
2133+
desc_sizes->data[12] = csf_3dtensor.Aval_size; /// Aval
2134+
desc_sizes->data[13] = csf_3dtensor.num_index_i; /// I
2135+
desc_sizes->data[14] = csf_3dtensor.num_index_j; /// J
2136+
desc_sizes->data[15] = csf_3dtensor.num_index_k; /// K
21372137
}
21382138
/// Mode-Generic
21392139
else if (A1format == Compressed_nonunique && A2format == singleton && A3format == Dense)
21402140
{
21412141
/// std::cout << "Mode-Generic format\n";
21422142
Mg3DTensor<T> mg_3dtensor(FileReader.coo_3dtensor);
21432143

2144-
desc_sizes->data[0] = mg_3dtensor.A1pos_size;
2145-
desc_sizes->data[1] = mg_3dtensor.A1crd_size;
2146-
desc_sizes->data[2] = 0;
2147-
desc_sizes->data[3] = 0;
2148-
desc_sizes->data[4] = mg_3dtensor.A2pos_size;
2149-
desc_sizes->data[5] = mg_3dtensor.A2crd_size;
2150-
desc_sizes->data[6] = 0;
2151-
desc_sizes->data[7] = 0;
2152-
desc_sizes->data[8] = mg_3dtensor.A3pos_size;
2153-
desc_sizes->data[9] = mg_3dtensor.A3crd_size;
2154-
desc_sizes->data[10] = 0;
2155-
desc_sizes->data[11] = 0;
2156-
desc_sizes->data[12] = mg_3dtensor.Aval_size;
2157-
desc_sizes->data[13] = mg_3dtensor.num_index_i;
2158-
desc_sizes->data[14] = mg_3dtensor.num_index_j;
2159-
desc_sizes->data[15] = mg_3dtensor.num_index_k;
2144+
desc_sizes->data[0] = mg_3dtensor.A1pos_size; /// A1pos
2145+
desc_sizes->data[1] = mg_3dtensor.A1crd_size; /// A1crd
2146+
desc_sizes->data[2] = 0; /// A1_tile_pos
2147+
desc_sizes->data[3] = 0; /// A1_tile_crd
2148+
desc_sizes->data[4] = mg_3dtensor.A2pos_size; /// A2pos
2149+
desc_sizes->data[5] = mg_3dtensor.A2crd_size; /// A2crd
2150+
desc_sizes->data[6] = 0; /// A2_tile_pos
2151+
desc_sizes->data[7] = 0; /// A2_tile_crd
2152+
desc_sizes->data[8] = mg_3dtensor.A3pos_size; /// A3pos
2153+
desc_sizes->data[9] = mg_3dtensor.A3crd_size; /// A3crd
2154+
desc_sizes->data[10] = 0; /// A3_tile_pos
2155+
desc_sizes->data[11] = 0; /// A3_tile_crd
2156+
desc_sizes->data[12] = mg_3dtensor.Aval_size; /// Aval
2157+
desc_sizes->data[13] = mg_3dtensor.num_index_i; /// I
2158+
desc_sizes->data[14] = mg_3dtensor.num_index_j; /// J
2159+
desc_sizes->data[15] = mg_3dtensor.num_index_k; /// K
21602160
}
21612161
else
21622162
{

0 commit comments

Comments
 (0)