Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pre-transpose constant MatMul operand #315

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixup! - Move strides field to end of table to preserve backwards c…
…ompatibility
robertknight committed Jan 26, 2025
commit 894f3f9c29f9b62864f8877d081fe48b6a263563
204 changes: 102 additions & 102 deletions rten-convert/rten_convert/schema_generated.py
Original file line number Diff line number Diff line change
@@ -205,91 +205,91 @@ def OperatorAttrsCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == OperatorAttrs().ArgMaxAttrs:
if unionType == OperatorAttrs.ArgMaxAttrs:
return ArgMaxAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().AveragePoolAttrs:
if unionType == OperatorAttrs.AveragePoolAttrs:
return AveragePoolAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().BatchNormalizationAttrs:
if unionType == OperatorAttrs.BatchNormalizationAttrs:
return BatchNormalizationAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().CastAttrs:
if unionType == OperatorAttrs.CastAttrs:
return CastAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ConcatAttrs:
if unionType == OperatorAttrs.ConcatAttrs:
return ConcatAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ConstantOfShapeAttrs:
if unionType == OperatorAttrs.ConstantOfShapeAttrs:
return ConstantOfShapeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ConvAttrs:
if unionType == OperatorAttrs.ConvAttrs:
return ConvAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ConvTransposeAttrs:
if unionType == OperatorAttrs.ConvTransposeAttrs:
return ConvTransposeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().FlattenAttrs:
if unionType == OperatorAttrs.FlattenAttrs:
return FlattenAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().GatherAttrs:
if unionType == OperatorAttrs.GatherAttrs:
return GatherAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().GemmAttrs:
if unionType == OperatorAttrs.GemmAttrs:
return GemmAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().GRUAttrs:
if unionType == OperatorAttrs.GRUAttrs:
return GRUAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().LeakyReluAttrs:
if unionType == OperatorAttrs.LeakyReluAttrs:
return LeakyReluAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().LSTMAttrs:
if unionType == OperatorAttrs.LSTMAttrs:
return LSTMAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().MaxPoolAttrs:
if unionType == OperatorAttrs.MaxPoolAttrs:
return MaxPoolAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ReduceMeanAttrs:
if unionType == OperatorAttrs.ReduceMeanAttrs:
return ReduceMeanAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ReshapeAttrs:
if unionType == OperatorAttrs.ReshapeAttrs:
return ReshapeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ResizeAttrs:
if unionType == OperatorAttrs.ResizeAttrs:
return ResizeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().SplitAttrs:
if unionType == OperatorAttrs.SplitAttrs:
return SplitAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().SoftmaxAttrs:
if unionType == OperatorAttrs.SoftmaxAttrs:
return SoftmaxAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().TransposeAttrs:
if unionType == OperatorAttrs.TransposeAttrs:
return TransposeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ModAttrs:
if unionType == OperatorAttrs.ModAttrs:
return ModAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ScatterElementsAttrs:
if unionType == OperatorAttrs.ScatterElementsAttrs:
return ScatterElementsAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().OneHotAttrs:
if unionType == OperatorAttrs.OneHotAttrs:
return OneHotAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().TopKAttrs:
if unionType == OperatorAttrs.TopKAttrs:
return TopKAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().HardSigmoidAttrs:
if unionType == OperatorAttrs.HardSigmoidAttrs:
return HardSigmoidAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().TriluAttrs:
if unionType == OperatorAttrs.TriluAttrs:
return TriluAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ScatterNDAttrs:
if unionType == OperatorAttrs.ScatterNDAttrs:
return ScatterNDAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().NonMaxSuppressionAttrs:
if unionType == OperatorAttrs.NonMaxSuppressionAttrs:
return NonMaxSuppressionAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().LayerNormalizationAttrs:
if unionType == OperatorAttrs.LayerNormalizationAttrs:
return LayerNormalizationAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().RandomUniformAttrs:
if unionType == OperatorAttrs.RandomUniformAttrs:
return RandomUniformAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().EluAttrs:
if unionType == OperatorAttrs.EluAttrs:
return EluAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().RandomUniformLikeAttrs:
if unionType == OperatorAttrs.RandomUniformLikeAttrs:
return RandomUniformLikeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().RandomNormalAttrs:
if unionType == OperatorAttrs.RandomNormalAttrs:
return RandomNormalAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().RandomNormalLikeAttrs:
if unionType == OperatorAttrs.RandomNormalLikeAttrs:
return RandomNormalLikeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().GatherNDAttrs:
if unionType == OperatorAttrs.GatherNDAttrs:
return GatherNDAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().GeluAttrs:
if unionType == OperatorAttrs.GeluAttrs:
return GeluAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().EinsumAttrs:
if unionType == OperatorAttrs.EinsumAttrs:
return EinsumAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().IfAttrs:
if unionType == OperatorAttrs.IfAttrs:
return IfAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().PadAttrs:
if unionType == OperatorAttrs.PadAttrs:
return PadAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().DequantizeLinearAttrs:
if unionType == OperatorAttrs.DequantizeLinearAttrs:
return DequantizeLinearAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().QuantizeLinearAttrs:
if unionType == OperatorAttrs.QuantizeLinearAttrs:
return QuantizeLinearAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().DepthToSpaceAttrs:
if unionType == OperatorAttrs.DepthToSpaceAttrs:
return DepthToSpaceAttrsT.InitFromBuf(table.Bytes, table.Pos)
return None

@@ -308,9 +308,9 @@ def ScalarCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == Scalar().IntScalar:
if unionType == Scalar.IntScalar:
return IntScalarT.InitFromBuf(table.Bytes, table.Pos)
if unionType == Scalar().FloatScalar:
if unionType == Scalar.FloatScalar:
return FloatScalarT.InitFromBuf(table.Bytes, table.Pos)
return None

@@ -343,11 +343,11 @@ def NodeKindCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == NodeKind().OperatorNode:
if unionType == NodeKind.OperatorNode:
return OperatorNodeT.InitFromBuf(table.Bytes, table.Pos)
if unionType == NodeKind().ConstantNode:
if unionType == NodeKind.ConstantNode:
return ConstantNodeT.InitFromBuf(table.Bytes, table.Pos)
if unionType == NodeKind().ValueNode:
if unionType == NodeKind.ValueNode:
return ValueNodeT.InitFromBuf(table.Bytes, table.Pos)
return None

@@ -363,13 +363,13 @@ def ConstantDataCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == ConstantData().FloatData:
if unionType == ConstantData.FloatData:
return FloatDataT.InitFromBuf(table.Bytes, table.Pos)
if unionType == ConstantData().Int32Data:
if unionType == ConstantData.Int32Data:
return Int32DataT.InitFromBuf(table.Bytes, table.Pos)
if unionType == ConstantData().Int8Data:
if unionType == ConstantData.Int8Data:
return Int8DataT.InitFromBuf(table.Bytes, table.Pos)
if unionType == ConstantData().UInt8Data:
if unionType == ConstantData.UInt8Data:
return UInt8DataT.InitFromBuf(table.Bytes, table.Pos)
return None

@@ -5753,43 +5753,16 @@ def ShapeIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
return o == 0

# ConstantNode
def Strides(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return 0

# ConstantNode
def StridesAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
return 0

# ConstantNode
def StridesLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.VectorLen(o)
return 0

# ConstantNode
def StridesIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
return o == 0

# ConstantNode
def DataType(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos)
return 0

# ConstantNode
def Data(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
from flatbuffers.table import Table
obj = Table(bytearray(), 0)
@@ -5799,18 +5772,45 @@ def Data(self):

# ConstantNode
def Dtype(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint16Flags, o + self._tab.Pos)
return None

# ConstantNode
def DataOffset(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos)
return None

# ConstantNode
def Strides(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return 0

# ConstantNode
def StridesAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
return 0

# ConstantNode
def StridesLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
if o != 0:
return self._tab.VectorLen(o)
return 0

# ConstantNode
def StridesIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
return o == 0

def ConstantNodeStart(builder):
builder.StartObject(6)

@@ -5820,23 +5820,23 @@ def ConstantNodeAddShape(builder, shape):
def ConstantNodeStartShapeVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def ConstantNodeAddStrides(builder, strides):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(strides), 0)

def ConstantNodeStartStridesVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def ConstantNodeAddDataType(builder, dataType):
builder.PrependUint8Slot(2, dataType, 0)
builder.PrependUint8Slot(1, dataType, 0)

def ConstantNodeAddData(builder, data):
builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(data), 0)
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(data), 0)

def ConstantNodeAddDtype(builder, dtype):
builder.PrependUint16Slot(4, dtype, None)
builder.PrependUint16Slot(3, dtype, None)

def ConstantNodeAddDataOffset(builder, dataOffset):
builder.PrependUint64Slot(5, dataOffset, None)
builder.PrependUint64Slot(4, dataOffset, None)

def ConstantNodeAddStrides(builder, strides):
builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(strides), 0)

def ConstantNodeStartStridesVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def ConstantNodeEnd(builder):
return builder.EndObject()
@@ -5852,11 +5852,11 @@ class ConstantNodeT(object):
# ConstantNodeT
def __init__(self):
self.shape = None # type: List[int]
self.strides = None # type: List[int]
self.dataType = 0 # type: int
self.data = None # type: Union[None, FloatDataT, Int32DataT, Int8DataT, UInt8DataT]
self.dtype = None # type: Optional[int]
self.dataOffset = None # type: Optional[int]
self.strides = None # type: List[int]

@classmethod
def InitFromBuf(cls, buf, pos):
@@ -5886,17 +5886,17 @@ def _UnPack(self, constantNode):
self.shape.append(constantNode.Shape(i))
else:
self.shape = constantNode.ShapeAsNumpy()
self.dataType = constantNode.DataType()
self.data = ConstantDataCreator(self.dataType, constantNode.Data())
self.dtype = constantNode.Dtype()
self.dataOffset = constantNode.DataOffset()
if not constantNode.StridesIsNone():
if np is None:
self.strides = []
for i in range(constantNode.StridesLength()):
self.strides.append(constantNode.Strides(i))
else:
self.strides = constantNode.StridesAsNumpy()
self.dataType = constantNode.DataType()
self.data = ConstantDataCreator(self.dataType, constantNode.Data())
self.dtype = constantNode.Dtype()
self.dataOffset = constantNode.DataOffset()

# ConstantNodeT
def Pack(self, builder):
@@ -5908,6 +5908,8 @@ def Pack(self, builder):
for i in reversed(range(len(self.shape))):
builder.PrependUint32(self.shape[i])
shape = builder.EndVector()
if self.data is not None:
data = self.data.Pack(builder)
if self.strides is not None:
if np is not None and type(self.strides) is np.ndarray:
strides = builder.CreateNumpyVector(self.strides)
@@ -5916,18 +5918,16 @@ def Pack(self, builder):
for i in reversed(range(len(self.strides))):
builder.PrependUint32(self.strides[i])
strides = builder.EndVector()
if self.data is not None:
data = self.data.Pack(builder)
ConstantNodeStart(builder)
if self.shape is not None:
ConstantNodeAddShape(builder, shape)
if self.strides is not None:
ConstantNodeAddStrides(builder, strides)
ConstantNodeAddDataType(builder, self.dataType)
if self.data is not None:
ConstantNodeAddData(builder, data)
ConstantNodeAddDtype(builder, self.dtype)
ConstantNodeAddDataOffset(builder, self.dataOffset)
if self.strides is not None:
ConstantNodeAddStrides(builder, strides)
constantNode = ConstantNodeEnd(builder)
return constantNode

5 changes: 4 additions & 1 deletion src/schema.fbs
Original file line number Diff line number Diff line change
@@ -537,7 +537,6 @@ enum ConstantDataType: ushort {
// Graph node for a constant tensor value, whose data is part of the model.
table ConstantNode {
shape:[uint] (required);
strides:[uint];

// Tensor data embedded within the model file.
data:ConstantData;
@@ -548,6 +547,10 @@ table ConstantNode {
// Offset of tensor data from the start of the tensor data segment in the
// model file. Null if the tensor data is stored inline.
data_offset:uint64 = null;

// Custom strides for each dimension. This enables pre-transposing weights.
// If not specified the strides default to contiguous.
strides:[uint];
}

// Dimension of a ValueNode's shape. This can be either a fixed value or a
Loading