Skip to content

Commit 275cdbb

Browse files
committed
OOC Seq Instruction
1 parent cd4c828 commit 275cdbb

File tree

5 files changed

+237
-1
lines changed

5 files changed

+237
-1
lines changed

src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction;
2828
import org.apache.sysds.runtime.instructions.ooc.CentralMomentOOCInstruction;
2929
import org.apache.sysds.runtime.instructions.ooc.CtableOOCInstruction;
30+
import org.apache.sysds.runtime.instructions.ooc.DataGenOOCInstruction;
3031
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
3132
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
3233
import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction;
@@ -75,6 +76,8 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
7576
return CentralMomentOOCInstruction.parseInstruction(str);
7677
case Ctable:
7778
return CtableOOCInstruction.parseInstruction(str);
79+
case Rand:
80+
return DataGenOOCInstruction.parseInstruction(str);
7881

7982
default:
8083
throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype);
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package org.apache.sysds.runtime.instructions.ooc;
2+
3+
import org.apache.commons.lang3.NotImplementedException;
4+
import org.apache.sysds.common.Opcodes;
5+
import org.apache.sysds.common.Types;
6+
import org.apache.sysds.runtime.DMLRuntimeException;
7+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
8+
import org.apache.sysds.runtime.instructions.InstructionUtils;
9+
import org.apache.sysds.runtime.instructions.cp.CPOperand;
10+
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
11+
import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen;
12+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
13+
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
14+
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
15+
import org.apache.sysds.runtime.util.UtilFunctions;
16+
17+
public class DataGenOOCInstruction extends UnaryOOCInstruction {
18+
19+
private final int blen;
20+
private Types.OpOpDG method;
21+
22+
// sequence specific attributes
23+
private final CPOperand seq_from, seq_to, seq_incr;
24+
25+
public DataGenOOCInstruction(UnaryOperator op, Types.OpOpDG mthd, CPOperand in, CPOperand out, int blen, CPOperand seqFrom,
26+
CPOperand seqTo, CPOperand seqIncr, String opcode, String istr) {
27+
super(OOCType.Rand, op, in, out, opcode, istr);
28+
this.blen = blen;
29+
this.method = mthd;
30+
this.seq_from = seqFrom;
31+
this.seq_to = seqTo;
32+
this.seq_incr = seqIncr;
33+
}
34+
35+
public static DataGenOOCInstruction parseInstruction(String str) {
36+
Types.OpOpDG method = null;
37+
String[] s = InstructionUtils.getInstructionPartsWithValueType(str);
38+
String opcode = s[0];
39+
40+
if(opcode.equalsIgnoreCase(Opcodes.SEQUENCE.toString())) {
41+
method = Types.OpOpDG.SEQ;
42+
// 8 operands: rows, cols, blen, from, to, incr, outvar
43+
InstructionUtils.checkNumFields(s, 7);
44+
}
45+
else
46+
throw new NotImplementedException(); // TODO
47+
48+
CPOperand out = new CPOperand(s[s.length - 1]);
49+
UnaryOperator op = null;
50+
51+
if(method == Types.OpOpDG.SEQ) {
52+
int blen = Integer.parseInt(s[3]);
53+
CPOperand from = new CPOperand(s[4]);
54+
CPOperand to = new CPOperand(s[5]);
55+
CPOperand incr = new CPOperand(s[6]);
56+
57+
return new DataGenOOCInstruction(op, method, null, out, blen, from, to, incr, opcode, str);
58+
}
59+
else
60+
throw new NotImplementedException();
61+
}
62+
63+
@Override
64+
public void processInstruction(ExecutionContext ec) {
65+
final OOCStream<IndexedMatrixValue> qOut = createWritableStream();
66+
67+
// process specific datagen operator
68+
if(method == Types.OpOpDG.SEQ) {
69+
double lfrom = ec.getScalarInput(seq_from).getDoubleValue();
70+
double lto = ec.getScalarInput(seq_to).getDoubleValue();
71+
double lincr = ec.getScalarInput(seq_incr).getDoubleValue();
72+
73+
// handle default 1 to -1 for special case of from>to
74+
lincr = LibMatrixDatagen.updateSeqIncr(lfrom, lto, lincr);
75+
76+
if(LOG.isTraceEnabled())
77+
LOG.trace(
78+
"Process DataGenOOCInstruction seq with seqFrom=" + lfrom + ", seqTo=" + lto + ", seqIncr" + lincr);
79+
80+
final int maxK = (int) UtilFunctions.getSeqLength(lfrom, lto, lincr);
81+
final double finalLincr = lincr;
82+
83+
84+
submitOOCTask(() -> {
85+
int k = 0;
86+
double curFrom = lfrom;
87+
double curTo;
88+
MatrixBlock mb;
89+
90+
while (k < maxK) {
91+
long desiredLen = Math.min(blen, maxK - k);
92+
curTo = curFrom + (desiredLen - 1) * finalLincr;
93+
long actualLen = UtilFunctions.getSeqLength(curFrom, curTo, finalLincr);
94+
95+
if (actualLen != desiredLen) {
96+
// Then we add / subtract a small correction term
97+
curTo += (actualLen < desiredLen) ? finalLincr / 2 : -finalLincr / 2;
98+
99+
if (UtilFunctions.getSeqLength(curFrom, curTo, finalLincr) != desiredLen)
100+
throw new DMLRuntimeException("OOC seq could not construct the right number of elements.");
101+
}
102+
103+
mb = MatrixBlock.seqOperations(curFrom, curTo, finalLincr);
104+
qOut.enqueue(new IndexedMatrixValue(new MatrixIndexes(1 + k / blen, 1), mb));
105+
curFrom = mb.get(mb.getNumRows() - 1, 0) + finalLincr;
106+
k += blen;
107+
}
108+
109+
qOut.closeInput();
110+
}, qOut);
111+
}
112+
else
113+
throw new NotImplementedException();
114+
115+
ec.getMatrixObject(output).setStreamHandle(qOut);
116+
}
117+
}

src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public abstract class OOCInstruction extends Instruction {
5353
private static final AtomicInteger nextStreamId = new AtomicInteger(0);
5454

5555
public enum OOCType {
56-
Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM, Ctable, MatrixIndexing
56+
Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM, Ctable, MatrixIndexing, Rand
5757
}
5858

5959
protected final OOCInstruction.OOCType _ooctype;
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.ooc;
21+
22+
import org.apache.sysds.common.Opcodes;
23+
import org.apache.sysds.common.Types;
24+
import org.apache.sysds.runtime.instructions.Instruction;
25+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
26+
import org.apache.sysds.test.AutomatedTestBase;
27+
import org.apache.sysds.test.TestConfiguration;
28+
import org.apache.sysds.test.TestUtils;
29+
import org.junit.Assert;
30+
import org.junit.Test;
31+
32+
public class SeqTest extends AutomatedTestBase {
33+
private final static String TEST_NAME1 = "Seq";
34+
private final static String TEST_DIR = "functions/ooc/";
35+
private final static String TEST_CLASS_DIR = TEST_DIR + SeqTest.class.getSimpleName() + "/";
36+
private final static double eps = 1e-8;
37+
private static final String OUTPUT_NAME = "res";
38+
39+
@Override
40+
public void setUp() {
41+
TestUtils.clearAssertionInformation();
42+
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
43+
addTestConfiguration(TEST_NAME1, config);
44+
}
45+
46+
@Test
47+
public void testSeq1() {
48+
runSeqTest(0, 10, 0.1);
49+
}
50+
51+
@Test
52+
public void testSeq2() {
53+
runSeqTest(0, 15.9, 0.01);
54+
}
55+
56+
private void runSeqTest(double from, double to, double incr) {
57+
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
58+
59+
try {
60+
getAndLoadTestConfiguration(TEST_NAME1);
61+
62+
String HOME = SCRIPT_DIR + TEST_DIR;
63+
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
64+
programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", Double.toString(from), Double.toString(to), Double.toString(incr), output(OUTPUT_NAME)};
65+
66+
runTest(true, false, null, -1);
67+
68+
//check seq OOC
69+
Assert.assertTrue("OOC wasn't used for seq",
70+
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.SEQUENCE));
71+
//compare results
72+
73+
// rerun without ooc flag
74+
programArgs = new String[] {"-explain", "-stats", "-args", Double.toString(from), Double.toString(to), Double.toString(incr), output(OUTPUT_NAME + "_target")};
75+
runTest(true, false, null, -1);
76+
77+
// compare matrices
78+
MatrixBlock ret1 = TestUtils.readBinary(output(OUTPUT_NAME));
79+
MatrixBlock ret2 = TestUtils.readBinary(output(OUTPUT_NAME + "_target"));
80+
81+
TestUtils.compareMatrices(ret1, ret2, eps);
82+
}
83+
finally {
84+
resetExecMode(platformOld);
85+
}
86+
}
87+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# Read the input matrix as a stream
23+
from = $1;
24+
to = $2;
25+
incr = $3;
26+
27+
res = seq(from, to, incr);
28+
29+
write(res, $4, format="binary");

0 commit comments

Comments
 (0)