From 6d484c63b426d0aabe9b7adf0ff47c3b6d6a938a Mon Sep 17 00:00:00 2001 From: hsfzxjy Date: Thu, 22 Aug 2024 20:13:21 +0800 Subject: [PATCH 1/3] Pre-transpose MatMul's RHS operands --- rten-convert/rten_convert/converter.py | 15 ++++ rten-convert/rten_convert/schema_generated.py | 69 ++++++++++++++++--- src/model.rs | 46 +++++++++++-- src/model_builder.rs | 2 + src/schema.fbs | 1 + src/schema_generated.rs | 38 ++++++++-- 6 files changed, 152 insertions(+), 19 deletions(-) diff --git a/rten-convert/rten_convert/converter.py b/rten-convert/rten_convert/converter.py index 8f606e55..d452bd69 100644 --- a/rten-convert/rten_convert/converter.py +++ b/rten-convert/rten_convert/converter.py @@ -55,6 +55,7 @@ class ConstantNode(Node): """ shape: list[int] + strides: Optional[list[int]] data: np.ndarray def __init__(self, name: str, shape: list[int], data: np.ndarray): @@ -861,6 +862,12 @@ def op_node_from_onnx_operator( op_reader.check_attr("input_forget", "int", 0) op_reader.check_attr("layout", "int", 0) + case "MatMul": + b = constant_nodes.get(onnx_op.input[-1]) + if b and len(b.shape) == 2 and b.shape[-1] > 1: + b.data = np.ascontiguousarray(b.data.transpose()) + b.strides = [1, b.shape[0]] + case "MaxPool": attrs = sg.MaxPoolAttrsT() kernel_shape = op_reader.require_attr("kernel_shape", "ints") @@ -1202,6 +1209,12 @@ def build_constant_node( shape_vec = write_vec( builder, sg.ConstantNodeStartShapeVector, constant.shape, "u32" ) + if getattr(constant, "strides", None): + strides_vec = write_vec( + builder, sg.ConstantNodeStartStridesVector, constant.strides, "u32" + ) + else: + strides_vec = None n_elems = reduce(mul, constant.shape, 1) assert n_elems == constant.data.size, "constant shape does not match element count" @@ -1261,6 +1274,8 @@ def build_constant_node( sg.ConstantNodeStart(builder) sg.ConstantNodeAddShape(builder, shape_vec) sg.ConstantNodeAddDtype(builder, dtype) + if strides_vec: + sg.ConstantNodeAddStrides(builder, strides_vec) if inline_data: sg.ConstantNodeAddDataType(builder, inline_data_type) diff --git a/rten-convert/rten_convert/schema_generated.py b/rten-convert/rten_convert/schema_generated.py index bed08145..145e54a0 100644 --- a/rten-convert/rten_convert/schema_generated.py +++ b/rten-convert/rten_convert/schema_generated.py @@ -5754,15 +5754,42 @@ def ShapeIsNone(self): return o == 0 # ConstantNode - def DataType(self): + def Strides(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # ConstantNode + def StridesAsNumpy(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o) + return 0 + + # ConstantNode + def StridesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # ConstantNode + def StridesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # ConstantNode + def DataType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) if o != 0: return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos) return 0 # ConstantNode def Data(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) if o != 0: from flatbuffers.table import Table obj = Table(bytearray(), 0) @@ -5772,20 +5799,20 @@ def Data(self): # ConstantNode def Dtype(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) if o != 0: return self._tab.Get(flatbuffers.number_types.Uint16Flags, o + self._tab.Pos) return None # ConstantNode def DataOffset(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) if o != 0: return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos) return None def ConstantNodeStart(builder): - builder.StartObject(5) + builder.StartObject(6) def ConstantNodeAddShape(builder, shape): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0) @@ -5793,17 +5820,23 @@ def ConstantNodeAddShape(builder, shape): def ConstantNodeStartShapeVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def ConstantNodeAddStrides(builder, strides): + builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(strides), 0) + +def ConstantNodeStartStridesVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + def ConstantNodeAddDataType(builder, dataType): - builder.PrependUint8Slot(1, dataType, 0) + builder.PrependUint8Slot(2, dataType, 0) def ConstantNodeAddData(builder, data): - builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(data), 0) + builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(data), 0) def ConstantNodeAddDtype(builder, dtype): - builder.PrependUint16Slot(3, dtype, None) + builder.PrependUint16Slot(4, dtype, None) def ConstantNodeAddDataOffset(builder, dataOffset): - builder.PrependUint64Slot(4, dataOffset, None) + builder.PrependUint64Slot(5, dataOffset, None) def ConstantNodeEnd(builder): return builder.EndObject() @@ -5819,6 +5852,7 @@ class ConstantNodeT(object): # ConstantNodeT def __init__(self): self.shape = None # type: List[int] + self.strides = None # type: List[int] self.dataType = 0 # type: int self.data = None # type: Union[None, FloatDataT, Int32DataT, Int8DataT, UInt8DataT] self.dtype = None # type: Optional[int] @@ -5852,6 +5886,13 @@ def _UnPack(self, constantNode): self.shape.append(constantNode.Shape(i)) else: self.shape = constantNode.ShapeAsNumpy() + if not constantNode.StridesIsNone(): + if np is None: + self.strides = [] + for i in range(constantNode.StridesLength()): + self.strides.append(constantNode.Strides(i)) + else: + self.strides = constantNode.StridesAsNumpy() self.dataType = constantNode.DataType() self.data = ConstantDataCreator(self.dataType, constantNode.Data()) self.dtype = constantNode.Dtype() @@ -5867,11 +5908,21 @@ def Pack(self, builder): for i in reversed(range(len(self.shape))): builder.PrependUint32(self.shape[i]) shape = builder.EndVector() + if self.strides is not None: + if np is not None and type(self.strides) is np.ndarray: + strides = builder.CreateNumpyVector(self.strides) + else: + ConstantNodeStartStridesVector(builder, len(self.strides)) + for i in reversed(range(len(self.strides))): + builder.PrependUint32(self.strides[i]) + strides = builder.EndVector() if self.data is not None: data = self.data.Pack(builder) ConstantNodeStart(builder) if self.shape is not None: ConstantNodeAddShape(builder, shape) + if self.strides is not None: + ConstantNodeAddStrides(builder, strides) ConstantNodeAddDataType(builder, self.dataType) if self.data is not None: ConstantNodeAddData(builder, data) diff --git a/src/model.rs b/src/model.rs index 398048da..af47febe 100644 --- a/src/model.rs +++ b/src/model.rs @@ -578,6 +578,9 @@ impl Model { tensor_data_offset: Option, ) -> Result { let shape: Vec = constant.shape().iter().map(|x| x as usize).collect(); + let strides: Option> = constant + .strides() + .map(|strides| strides.iter().map(|x| x as usize).collect()); if let Some(data_offset) = constant.data_offset() { // Constant data is stored outside the model buffer, in the same file. @@ -591,13 +594,21 @@ impl Model { let graph_node = match constant.dtype() { Some(sg::ConstantDataType::Int32) => { - let const_data = - constant_data_from_storage_offset::(storage, &shape, data_offset)?; + let const_data = constant_data_from_storage_offset::( + storage, + &shape, + strides.as_deref(), + data_offset, + )?; graph.add_constant(name, const_data) } Some(sg::ConstantDataType::Float32) => { - let const_data = - constant_data_from_storage_offset::(storage, &shape, data_offset)?; + let const_data = constant_data_from_storage_offset::( + storage, + &shape, + strides.as_deref(), + data_offset, + )?; graph.add_constant(name, const_data) } Some(sg::ConstantDataType::Int8) => { @@ -830,6 +841,7 @@ fn cast_le_bytes(bytes: &[u8]) -> Option<&[T]> { fn constant_data_from_storage_offset( storage: &Arc, shape: &[usize], + strides: Option<&[usize]>, offset: usize, ) -> Result, ModelLoadError> { let n_elements: usize = shape.iter().product(); @@ -844,14 +856,36 @@ fn constant_data_from_storage_offset( if let Some(elements) = cast_le_bytes(bytes) { let storage = ArcSlice::new(storage.clone(), elements).expect("storage does not contain data"); - let const_data: ConstantNodeData = ArcTensorView::from_data(shape, storage).into(); + let const_data: ConstantNodeData = if let Some(strides) = strides { + ArcTensorView::from_data_with_strides(shape, storage, strides) + .map_err(|_| { + ModelLoadError::GraphError(format!( + "bad strides = {:?}, shape = {:?}", + strides, shape + )) + })? + .into() + } else { + ArcTensorView::from_data(shape, storage).into() + }; Ok(const_data) } else { let data: Vec = bytes .chunks(std::mem::size_of::()) .map(|chunk| T::from_le_bytes(chunk.try_into().unwrap())) .collect(); - Ok(Tensor::from_data(shape, data).into()) + Ok(if let Some(strides) = strides { + Tensor::from_data_with_strides(shape, data, strides) + .map_err(|_| { + ModelLoadError::GraphError(format!( + "bad strides = {:?}, shape = {:?}", + strides, shape + )) + })? + .into() + } else { + Tensor::from_data(shape, data).into() + }) } } diff --git a/src/model_builder.rs b/src/model_builder.rs index b66b7251..fe6b8251 100644 --- a/src/model_builder.rs +++ b/src/model_builder.rs @@ -301,6 +301,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { sg::ConstantNodeArgs { shape: Some(shape_vec), + strides: None, data_type: sg::ConstantData::NONE, data: None, data_offset: Some(offset), @@ -312,6 +313,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { sg::ConstantNodeArgs { shape: Some(shape_vec), + strides: None, data_type: inline_dtype, data: Some(data), data_offset: None, diff --git a/src/schema.fbs b/src/schema.fbs index 44929602..49903039 100644 --- a/src/schema.fbs +++ b/src/schema.fbs @@ -537,6 +537,7 @@ enum ConstantDataType: ushort { // Graph node for a constant tensor value, whose data is part of the model. table ConstantNode { shape:[uint] (required); + strides:[uint]; // Tensor data embedded within the model file. data:ConstantData; diff --git a/src/schema_generated.rs b/src/schema_generated.rs index 4880368f..71c26459 100644 --- a/src/schema_generated.rs +++ b/src/schema_generated.rs @@ -9778,10 +9778,11 @@ impl<'a> flatbuffers::Follow<'a> for ConstantNode<'a> { impl<'a> ConstantNode<'a> { pub const VT_SHAPE: flatbuffers::VOffsetT = 4; - pub const VT_DATA_TYPE: flatbuffers::VOffsetT = 6; - pub const VT_DATA: flatbuffers::VOffsetT = 8; - pub const VT_DTYPE: flatbuffers::VOffsetT = 10; - pub const VT_DATA_OFFSET: flatbuffers::VOffsetT = 12; + pub const VT_STRIDES: flatbuffers::VOffsetT = 6; + pub const VT_DATA_TYPE: flatbuffers::VOffsetT = 8; + pub const VT_DATA: flatbuffers::VOffsetT = 10; + pub const VT_DTYPE: flatbuffers::VOffsetT = 12; + pub const VT_DATA_OFFSET: flatbuffers::VOffsetT = 14; #[inline] pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { @@ -9799,6 +9800,9 @@ impl<'a> ConstantNode<'a> { if let Some(x) = args.data { builder.add_data(x); } + if let Some(x) = args.strides { + builder.add_strides(x); + } if let Some(x) = args.shape { builder.add_shape(x); } @@ -9824,6 +9828,19 @@ impl<'a> ConstantNode<'a> { } } #[inline] + pub fn strides(&self) -> Option> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>>( + ConstantNode::VT_STRIDES, + None, + ) + } + } + #[inline] pub fn data_type(&self) -> ConstantData { // Safety: // Created from valid Table for this object @@ -9938,6 +9955,11 @@ impl flatbuffers::Verifiable for ConstantNode<'_> { Self::VT_SHAPE, true, )? + .visit_field::>>( + "strides", + Self::VT_STRIDES, + false, + )? .visit_union::( "data_type", Self::VT_DATA_TYPE, @@ -9976,6 +9998,7 @@ impl flatbuffers::Verifiable for ConstantNode<'_> { } pub struct ConstantNodeArgs<'a> { pub shape: Option>>, + pub strides: Option>>, pub data_type: ConstantData, pub data: Option>, pub dtype: Option, @@ -9986,6 +10009,7 @@ impl<'a> Default for ConstantNodeArgs<'a> { fn default() -> Self { ConstantNodeArgs { shape: None, // required field + strides: None, data_type: ConstantData::NONE, data: None, dtype: None, @@ -10005,6 +10029,11 @@ impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> ConstantNodeBuilder<'a, 'b, A> .push_slot_always::>(ConstantNode::VT_SHAPE, shape); } #[inline] + pub fn add_strides(&mut self, strides: flatbuffers::WIPOffset>) { + self.fbb_ + .push_slot_always::>(ConstantNode::VT_STRIDES, strides); + } + #[inline] pub fn add_data_type(&mut self, data_type: ConstantData) { self.fbb_.push_slot::( ConstantNode::VT_DATA_TYPE, @@ -10049,6 +10078,7 @@ impl core::fmt::Debug for ConstantNode<'_> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let mut ds = f.debug_struct("ConstantNode"); ds.field("shape", &self.shape()); + ds.field("strides", &self.strides()); ds.field("data_type", &self.data_type()); match self.data_type() { ConstantData::FloatData => { From 894f3f9c29f9b62864f8877d081fe48b6a263563 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sun, 26 Jan 2025 07:05:57 +0000 Subject: [PATCH 2/3] fixup! - Move `strides` field to end of table to preserve backwards compatibility --- rten-convert/rten_convert/schema_generated.py | 204 +++++++++--------- src/schema.fbs | 5 +- src/schema_generated.rs | 68 +++--- 3 files changed, 140 insertions(+), 137 deletions(-) diff --git a/rten-convert/rten_convert/schema_generated.py b/rten-convert/rten_convert/schema_generated.py index 145e54a0..be60d78b 100644 --- a/rten-convert/rten_convert/schema_generated.py +++ b/rten-convert/rten_convert/schema_generated.py @@ -205,91 +205,91 @@ def OperatorAttrsCreator(unionType, table): from flatbuffers.table import Table if not isinstance(table, Table): return None - if unionType == OperatorAttrs().ArgMaxAttrs: + if unionType == OperatorAttrs.ArgMaxAttrs: return ArgMaxAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().AveragePoolAttrs: + if unionType == OperatorAttrs.AveragePoolAttrs: return AveragePoolAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().BatchNormalizationAttrs: + if unionType == OperatorAttrs.BatchNormalizationAttrs: return BatchNormalizationAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().CastAttrs: + if unionType == OperatorAttrs.CastAttrs: return CastAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().ConcatAttrs: + if unionType == OperatorAttrs.ConcatAttrs: return ConcatAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().ConstantOfShapeAttrs: + if unionType == OperatorAttrs.ConstantOfShapeAttrs: return ConstantOfShapeAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().ConvAttrs: + if unionType == OperatorAttrs.ConvAttrs: return ConvAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().ConvTransposeAttrs: + if unionType == OperatorAttrs.ConvTransposeAttrs: return ConvTransposeAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().FlattenAttrs: + if unionType == OperatorAttrs.FlattenAttrs: return FlattenAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().GatherAttrs: + if unionType == OperatorAttrs.GatherAttrs: return GatherAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().GemmAttrs: + if unionType == OperatorAttrs.GemmAttrs: return GemmAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().GRUAttrs: + if unionType == OperatorAttrs.GRUAttrs: return GRUAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().LeakyReluAttrs: + if unionType == OperatorAttrs.LeakyReluAttrs: return LeakyReluAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().LSTMAttrs: + if unionType == OperatorAttrs.LSTMAttrs: return LSTMAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().MaxPoolAttrs: + if unionType == OperatorAttrs.MaxPoolAttrs: return MaxPoolAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().ReduceMeanAttrs: + if unionType == OperatorAttrs.ReduceMeanAttrs: return ReduceMeanAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().ReshapeAttrs: + if unionType == OperatorAttrs.ReshapeAttrs: return ReshapeAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().ResizeAttrs: + if unionType == OperatorAttrs.ResizeAttrs: return ResizeAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().SplitAttrs: + if unionType == OperatorAttrs.SplitAttrs: return SplitAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().SoftmaxAttrs: + if unionType == OperatorAttrs.SoftmaxAttrs: return SoftmaxAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().TransposeAttrs: + if unionType == OperatorAttrs.TransposeAttrs: return TransposeAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().ModAttrs: + if unionType == OperatorAttrs.ModAttrs: return ModAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().ScatterElementsAttrs: + if unionType == OperatorAttrs.ScatterElementsAttrs: return ScatterElementsAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().OneHotAttrs: + if unionType == OperatorAttrs.OneHotAttrs: return OneHotAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().TopKAttrs: + if unionType == OperatorAttrs.TopKAttrs: return TopKAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().HardSigmoidAttrs: + if unionType == OperatorAttrs.HardSigmoidAttrs: return HardSigmoidAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().TriluAttrs: + if unionType == OperatorAttrs.TriluAttrs: return TriluAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().ScatterNDAttrs: + if unionType == OperatorAttrs.ScatterNDAttrs: return ScatterNDAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().NonMaxSuppressionAttrs: + if unionType == OperatorAttrs.NonMaxSuppressionAttrs: return NonMaxSuppressionAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().LayerNormalizationAttrs: + if unionType == OperatorAttrs.LayerNormalizationAttrs: return LayerNormalizationAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().RandomUniformAttrs: + if unionType == OperatorAttrs.RandomUniformAttrs: return RandomUniformAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().EluAttrs: + if unionType == OperatorAttrs.EluAttrs: return EluAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().RandomUniformLikeAttrs: + if unionType == OperatorAttrs.RandomUniformLikeAttrs: return RandomUniformLikeAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().RandomNormalAttrs: + if unionType == OperatorAttrs.RandomNormalAttrs: return RandomNormalAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().RandomNormalLikeAttrs: + if unionType == OperatorAttrs.RandomNormalLikeAttrs: return RandomNormalLikeAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().GatherNDAttrs: + if unionType == OperatorAttrs.GatherNDAttrs: return GatherNDAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().GeluAttrs: + if unionType == OperatorAttrs.GeluAttrs: return GeluAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().EinsumAttrs: + if unionType == OperatorAttrs.EinsumAttrs: return EinsumAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().IfAttrs: + if unionType == OperatorAttrs.IfAttrs: return IfAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().PadAttrs: + if unionType == OperatorAttrs.PadAttrs: return PadAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().DequantizeLinearAttrs: + if unionType == OperatorAttrs.DequantizeLinearAttrs: return DequantizeLinearAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().QuantizeLinearAttrs: + if unionType == OperatorAttrs.QuantizeLinearAttrs: return QuantizeLinearAttrsT.InitFromBuf(table.Bytes, table.Pos) - if unionType == OperatorAttrs().DepthToSpaceAttrs: + if unionType == OperatorAttrs.DepthToSpaceAttrs: return DepthToSpaceAttrsT.InitFromBuf(table.Bytes, table.Pos) return None @@ -308,9 +308,9 @@ def ScalarCreator(unionType, table): from flatbuffers.table import Table if not isinstance(table, Table): return None - if unionType == Scalar().IntScalar: + if unionType == Scalar.IntScalar: return IntScalarT.InitFromBuf(table.Bytes, table.Pos) - if unionType == Scalar().FloatScalar: + if unionType == Scalar.FloatScalar: return FloatScalarT.InitFromBuf(table.Bytes, table.Pos) return None @@ -343,11 +343,11 @@ def NodeKindCreator(unionType, table): from flatbuffers.table import Table if not isinstance(table, Table): return None - if unionType == NodeKind().OperatorNode: + if unionType == NodeKind.OperatorNode: return OperatorNodeT.InitFromBuf(table.Bytes, table.Pos) - if unionType == NodeKind().ConstantNode: + if unionType == NodeKind.ConstantNode: return ConstantNodeT.InitFromBuf(table.Bytes, table.Pos) - if unionType == NodeKind().ValueNode: + if unionType == NodeKind.ValueNode: return ValueNodeT.InitFromBuf(table.Bytes, table.Pos) return None @@ -363,13 +363,13 @@ def ConstantDataCreator(unionType, table): from flatbuffers.table import Table if not isinstance(table, Table): return None - if unionType == ConstantData().FloatData: + if unionType == ConstantData.FloatData: return FloatDataT.InitFromBuf(table.Bytes, table.Pos) - if unionType == ConstantData().Int32Data: + if unionType == ConstantData.Int32Data: return Int32DataT.InitFromBuf(table.Bytes, table.Pos) - if unionType == ConstantData().Int8Data: + if unionType == ConstantData.Int8Data: return Int8DataT.InitFromBuf(table.Bytes, table.Pos) - if unionType == ConstantData().UInt8Data: + if unionType == ConstantData.UInt8Data: return UInt8DataT.InitFromBuf(table.Bytes, table.Pos) return None @@ -5753,43 +5753,16 @@ def ShapeIsNone(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) return o == 0 - # ConstantNode - def Strides(self, j): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) - if o != 0: - a = self._tab.Vector(o) - return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) - return 0 - - # ConstantNode - def StridesAsNumpy(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) - if o != 0: - return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o) - return 0 - - # ConstantNode - def StridesLength(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) - if o != 0: - return self._tab.VectorLen(o) - return 0 - - # ConstantNode - def StridesIsNone(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) - return o == 0 - # ConstantNode def DataType(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) if o != 0: return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos) return 0 # ConstantNode def Data(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) if o != 0: from flatbuffers.table import Table obj = Table(bytearray(), 0) @@ -5799,18 +5772,45 @@ def Data(self): # ConstantNode def Dtype(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) if o != 0: return self._tab.Get(flatbuffers.number_types.Uint16Flags, o + self._tab.Pos) return None # ConstantNode def DataOffset(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) if o != 0: return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos) return None + # ConstantNode + def Strides(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # ConstantNode + def StridesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o) + return 0 + + # ConstantNode + def StridesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # ConstantNode + def StridesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + return o == 0 + def ConstantNodeStart(builder): builder.StartObject(6) @@ -5820,23 +5820,23 @@ def ConstantNodeAddShape(builder, shape): def ConstantNodeStartShapeVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def ConstantNodeAddStrides(builder, strides): - builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(strides), 0) - -def ConstantNodeStartStridesVector(builder, numElems): - return builder.StartVector(4, numElems, 4) - def ConstantNodeAddDataType(builder, dataType): - builder.PrependUint8Slot(2, dataType, 0) + builder.PrependUint8Slot(1, dataType, 0) def ConstantNodeAddData(builder, data): - builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(data), 0) + builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(data), 0) def ConstantNodeAddDtype(builder, dtype): - builder.PrependUint16Slot(4, dtype, None) + builder.PrependUint16Slot(3, dtype, None) def ConstantNodeAddDataOffset(builder, dataOffset): - builder.PrependUint64Slot(5, dataOffset, None) + builder.PrependUint64Slot(4, dataOffset, None) + +def ConstantNodeAddStrides(builder, strides): + builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(strides), 0) + +def ConstantNodeStartStridesVector(builder, numElems): + return builder.StartVector(4, numElems, 4) def ConstantNodeEnd(builder): return builder.EndObject() @@ -5852,11 +5852,11 @@ class ConstantNodeT(object): # ConstantNodeT def __init__(self): self.shape = None # type: List[int] - self.strides = None # type: List[int] self.dataType = 0 # type: int self.data = None # type: Union[None, FloatDataT, Int32DataT, Int8DataT, UInt8DataT] self.dtype = None # type: Optional[int] self.dataOffset = None # type: Optional[int] + self.strides = None # type: List[int] @classmethod def InitFromBuf(cls, buf, pos): @@ -5886,6 +5886,10 @@ def _UnPack(self, constantNode): self.shape.append(constantNode.Shape(i)) else: self.shape = constantNode.ShapeAsNumpy() + self.dataType = constantNode.DataType() + self.data = ConstantDataCreator(self.dataType, constantNode.Data()) + self.dtype = constantNode.Dtype() + self.dataOffset = constantNode.DataOffset() if not constantNode.StridesIsNone(): if np is None: self.strides = [] @@ -5893,10 +5897,6 @@ def _UnPack(self, constantNode): self.strides.append(constantNode.Strides(i)) else: self.strides = constantNode.StridesAsNumpy() - self.dataType = constantNode.DataType() - self.data = ConstantDataCreator(self.dataType, constantNode.Data()) - self.dtype = constantNode.Dtype() - self.dataOffset = constantNode.DataOffset() # ConstantNodeT def Pack(self, builder): @@ -5908,6 +5908,8 @@ def Pack(self, builder): for i in reversed(range(len(self.shape))): builder.PrependUint32(self.shape[i]) shape = builder.EndVector() + if self.data is not None: + data = self.data.Pack(builder) if self.strides is not None: if np is not None and type(self.strides) is np.ndarray: strides = builder.CreateNumpyVector(self.strides) @@ -5916,18 +5918,16 @@ def Pack(self, builder): for i in reversed(range(len(self.strides))): builder.PrependUint32(self.strides[i]) strides = builder.EndVector() - if self.data is not None: - data = self.data.Pack(builder) ConstantNodeStart(builder) if self.shape is not None: ConstantNodeAddShape(builder, shape) - if self.strides is not None: - ConstantNodeAddStrides(builder, strides) ConstantNodeAddDataType(builder, self.dataType) if self.data is not None: ConstantNodeAddData(builder, data) ConstantNodeAddDtype(builder, self.dtype) ConstantNodeAddDataOffset(builder, self.dataOffset) + if self.strides is not None: + ConstantNodeAddStrides(builder, strides) constantNode = ConstantNodeEnd(builder) return constantNode diff --git a/src/schema.fbs b/src/schema.fbs index 49903039..678e1b2b 100644 --- a/src/schema.fbs +++ b/src/schema.fbs @@ -537,7 +537,6 @@ enum ConstantDataType: ushort { // Graph node for a constant tensor value, whose data is part of the model. table ConstantNode { shape:[uint] (required); - strides:[uint]; // Tensor data embedded within the model file. data:ConstantData; @@ -548,6 +547,10 @@ table ConstantNode { // Offset of tensor data from the start of the tensor data segment in the // model file. Null if the tensor data is stored inline. data_offset:uint64 = null; + + // Custom strides for each dimension. This enables pre-transposing weights. + // If not specified the strides default to contiguous. + strides:[uint]; } // Dimension of a ValueNode's shape. This can be either a fixed value or a diff --git a/src/schema_generated.rs b/src/schema_generated.rs index 71c26459..3b141ba3 100644 --- a/src/schema_generated.rs +++ b/src/schema_generated.rs @@ -9778,11 +9778,11 @@ impl<'a> flatbuffers::Follow<'a> for ConstantNode<'a> { impl<'a> ConstantNode<'a> { pub const VT_SHAPE: flatbuffers::VOffsetT = 4; - pub const VT_STRIDES: flatbuffers::VOffsetT = 6; - pub const VT_DATA_TYPE: flatbuffers::VOffsetT = 8; - pub const VT_DATA: flatbuffers::VOffsetT = 10; - pub const VT_DTYPE: flatbuffers::VOffsetT = 12; - pub const VT_DATA_OFFSET: flatbuffers::VOffsetT = 14; + pub const VT_DATA_TYPE: flatbuffers::VOffsetT = 6; + pub const VT_DATA: flatbuffers::VOffsetT = 8; + pub const VT_DTYPE: flatbuffers::VOffsetT = 10; + pub const VT_DATA_OFFSET: flatbuffers::VOffsetT = 12; + pub const VT_STRIDES: flatbuffers::VOffsetT = 14; #[inline] pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { @@ -9797,12 +9797,12 @@ impl<'a> ConstantNode<'a> { if let Some(x) = args.data_offset { builder.add_data_offset(x); } - if let Some(x) = args.data { - builder.add_data(x); - } if let Some(x) = args.strides { builder.add_strides(x); } + if let Some(x) = args.data { + builder.add_data(x); + } if let Some(x) = args.shape { builder.add_shape(x); } @@ -9828,19 +9828,6 @@ impl<'a> ConstantNode<'a> { } } #[inline] - pub fn strides(&self) -> Option> { - // Safety: - // Created from valid Table for this object - // which contains a valid value in this slot - unsafe { - self._tab - .get::>>( - ConstantNode::VT_STRIDES, - None, - ) - } - } - #[inline] pub fn data_type(&self) -> ConstantData { // Safety: // Created from valid Table for this object @@ -9882,6 +9869,19 @@ impl<'a> ConstantNode<'a> { unsafe { self._tab.get::(ConstantNode::VT_DATA_OFFSET, None) } } #[inline] + pub fn strides(&self) -> Option> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>>( + ConstantNode::VT_STRIDES, + None, + ) + } + } + #[inline] #[allow(non_snake_case)] pub fn data_as_float_data(&self) -> Option> { if self.data_type() == ConstantData::FloatData { @@ -9955,11 +9955,6 @@ impl flatbuffers::Verifiable for ConstantNode<'_> { Self::VT_SHAPE, true, )? - .visit_field::>>( - "strides", - Self::VT_STRIDES, - false, - )? .visit_union::( "data_type", Self::VT_DATA_TYPE, @@ -9992,28 +9987,33 @@ impl flatbuffers::Verifiable for ConstantNode<'_> { )? .visit_field::("dtype", Self::VT_DTYPE, false)? .visit_field::("data_offset", Self::VT_DATA_OFFSET, false)? + .visit_field::>>( + "strides", + Self::VT_STRIDES, + false, + )? .finish(); Ok(()) } } pub struct ConstantNodeArgs<'a> { pub shape: Option>>, - pub strides: Option>>, pub data_type: ConstantData, pub data: Option>, pub dtype: Option, pub data_offset: Option, + pub strides: Option>>, } impl<'a> Default for ConstantNodeArgs<'a> { #[inline] fn default() -> Self { ConstantNodeArgs { shape: None, // required field - strides: None, data_type: ConstantData::NONE, data: None, dtype: None, data_offset: None, + strides: None, } } } @@ -10029,11 +10029,6 @@ impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> ConstantNodeBuilder<'a, 'b, A> .push_slot_always::>(ConstantNode::VT_SHAPE, shape); } #[inline] - pub fn add_strides(&mut self, strides: flatbuffers::WIPOffset>) { - self.fbb_ - .push_slot_always::>(ConstantNode::VT_STRIDES, strides); - } - #[inline] pub fn add_data_type(&mut self, data_type: ConstantData) { self.fbb_.push_slot::( ConstantNode::VT_DATA_TYPE, @@ -10057,6 +10052,11 @@ impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> ConstantNodeBuilder<'a, 'b, A> .push_slot_always::(ConstantNode::VT_DATA_OFFSET, data_offset); } #[inline] + pub fn add_strides(&mut self, strides: flatbuffers::WIPOffset>) { + self.fbb_ + .push_slot_always::>(ConstantNode::VT_STRIDES, strides); + } + #[inline] pub fn new( _fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>, ) -> ConstantNodeBuilder<'a, 'b, A> { @@ -10078,7 +10078,6 @@ impl core::fmt::Debug for ConstantNode<'_> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let mut ds = f.debug_struct("ConstantNode"); ds.field("shape", &self.shape()); - ds.field("strides", &self.strides()); ds.field("data_type", &self.data_type()); match self.data_type() { ConstantData::FloatData => { @@ -10128,6 +10127,7 @@ impl core::fmt::Debug for ConstantNode<'_> { }; ds.field("dtype", &self.dtype()); ds.field("data_offset", &self.data_offset()); + ds.field("strides", &self.strides()); ds.finish() } } From 0ab23a36bfcc2c3a4b040a482dd78f76664f405f Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sun, 26 Jan 2025 07:06:32 +0000 Subject: [PATCH 3/3] fixup! - Update for API changes --- src/model.rs | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/model.rs b/src/model.rs index af47febe..705aaca1 100644 --- a/src/model.rs +++ b/src/model.rs @@ -612,13 +612,21 @@ impl Model { graph.add_constant(name, const_data) } Some(sg::ConstantDataType::Int8) => { - let const_data = - constant_data_from_storage_offset::(storage, &shape, data_offset)?; + let const_data = constant_data_from_storage_offset::( + storage, + &shape, + strides.as_deref(), + data_offset, + )?; graph.add_constant(name, const_data) } Some(sg::ConstantDataType::UInt8) => { - let const_data = - constant_data_from_storage_offset::(storage, &shape, data_offset)?; + let const_data = constant_data_from_storage_offset::( + storage, + &shape, + strides.as_deref(), + data_offset, + )?; graph.add_constant(name, const_data) } _ => {