33"""
44
55# built-in
6- from contextlib import contextmanager
6+ from contextlib import contextmanager , suppress
77from copy import copy as _copy
88from typing import Iterator as _Iterator
99from typing import NamedTuple
@@ -57,17 +57,18 @@ def asdict(self) -> _JsonObject:
5757
5858
5959T = _TypeVar ("T" , bound = "ProtocolBase" )
60+ ProtocolBuild = list [_Union [int , FieldSpec , tuple [str , int ]]]
6061
6162
62- class ProtocolBase :
63+ class ProtocolBase ( PrimitiveArray ) :
6364 """A class for defining runtime communication protocols."""
6465
6566 def __init__ (
6667 self ,
6768 enum_registry : _EnumRegistry ,
6869 names : _NameRegistry = None ,
6970 fields : BitFieldsManager = None ,
70- build : list [ _Union [ int , FieldSpec , str ]] = None ,
71+ build : ProtocolBuild = None ,
7172 identifier : int = 1 ,
7273 byte_order : _Union [_ByteOrder , _RegistryKey ] = _DEFAULT_BYTE_ORDER ,
7374 serializables : SerializableMap = None ,
@@ -86,7 +87,8 @@ def __init__(
8687 byte_order = _ByteOrder (
8788 self ._enum_registry ["ByteOrder" ].get_int (byte_order )
8889 )
89- self .array = PrimitiveArray (byte_order = byte_order )
90+
91+ super ().__init__ (byte_order = byte_order )
9092
9193 if names is None :
9294 names = _NameRegistry ()
@@ -96,11 +98,11 @@ def __init__(
9698 fields = BitFieldsManager (self .names , self ._enum_registry )
9799 self ._fields = fields
98100
99- self ._regular_fields : dict [str , _AnyPrimitive ] = {}
101+ self ._regular_fields : dict [str , list [ _AnyPrimitive ] ] = {}
100102 self ._enum_fields : dict [str , _RuntimeEnum ] = {}
101103
102104 # Keep track of the order that the protocol was created.
103- self ._build : list [ _Union [ int , FieldSpec , str ]] = []
105+ self ._build : ProtocolBuild = []
104106
105107 # Keep track of named serializables.
106108 self .serializables : SerializableMap = {}
@@ -110,34 +112,38 @@ def __init__(
110112 build = []
111113 for item in build :
112114 if isinstance (item , int ):
113- self ._add_bit_fields (self ._fields .fields [item ], track = False )
114- elif isinstance (item , str ):
115- assert serializables , (item , serializables )
116- self .add_field (item , serializable = serializables [item ])
117- del serializables [item ]
118- else :
115+ self ._add_bit_fields (self ._fields .fields [item ])
116+ elif isinstance (item , FieldSpec ):
119117 self .add_field (
120118 item .name ,
121119 item .kind ,
122120 enum = item .enum ,
123- track = False ,
124121 array_length = item .array_length ,
125122 )
123+ else :
124+ assert serializables , (item , serializables )
125+ name = item [0 ]
126+ self .add_serializable (
127+ name ,
128+ serializables [name ][0 ],
129+ array_length = None if item [1 ] == 1 else item [1 ],
130+ )
131+ del serializables [name ]
126132
127133 # Ensure all serializables were handled via build.
128134 assert not serializables , serializables
129135
130- def __copy__ (self : T ) -> T :
136+ def _copy_impl (self : T ) -> T :
131137 """Create another protocol instance from this one."""
132138
133139 return self .__class__ (
134140 self ._enum_registry ,
135141 names = self .names ,
136142 fields = _copy (self ._fields ),
137143 build = self ._build ,
138- byte_order = self .array . byte_order ,
144+ byte_order = self .byte_order ,
139145 serializables = {
140- key : val .copy_without_chain ()
146+ key : [ val [ 0 ] .copy_without_chain ()]
141147 for key , val in self .serializables .items ()
142148 },
143149 )
@@ -151,31 +157,32 @@ def register_name(self, name: str) -> int:
151157
152158 def add_serializable (
153159 self , name : str , serializable : Serializable , array_length : int = None
154- ) -> int :
160+ ) -> None :
155161 """Add a serializable instance."""
156162
157163 self .register_name (name )
158- self .serializables [name ] = serializable
159- self ._build .append (name )
160- return self .array .add_to_end (serializable , array_length = array_length )
164+
165+ instances = self .add_to_end (serializable , array_length = array_length )
166+ self ._build .append ((name , len (instances )))
167+ self .serializables [name ] = instances
161168
162169 def add_field (
163170 self ,
164171 name : str ,
165172 kind : _Primitivelike = None ,
166173 enum : _RegistryKey = None ,
167174 serializable : Serializable = None ,
168- track : bool = True ,
169175 array_length : int = None ,
170- ) -> int :
176+ ) -> None :
171177 """Add a new field to the protocol."""
172178
173179 # Add the serializable to the end of this protocol.
174180 if serializable is not None :
175181 assert kind is None and enum is None
176- return self .add_serializable (
182+ self .add_serializable (
177183 name , serializable , array_length = array_length
178184 )
185+ return
179186
180187 self .register_name (name )
181188
@@ -189,25 +196,20 @@ def add_field(
189196 kind = runtime_enum .primitive
190197
191198 assert kind is not None
192- new = _create (kind )
193-
194- result = self .array .add (new , array_length = array_length )
195- self ._regular_fields [name ] = new
196199
197- if track :
198- self ._build .append (
199- FieldSpec (name , kind , enum , array_length = array_length )
200- )
200+ self ._regular_fields [name ] = self .add (
201+ _create (kind ), array_length = array_length
202+ )
201203
202- return result
204+ self ._build .append (
205+ FieldSpec (name , kind , enum , array_length = array_length )
206+ )
203207
204- def _add_bit_fields (self , fields : _BitFields , track : bool = True ) -> None :
208+ def _add_bit_fields (self , fields : _BitFields ) -> None :
205209 """Add a bit-fields instance."""
206210
207- idx = self ._fields .add (fields )
208- self .array .add (fields .raw )
209- if track :
210- self ._build .append (idx )
211+ self ._build .append (self ._fields .add (fields ))
212+ self .add (fields .raw )
211213
212214 @contextmanager
213215 def add_bit_fields (
@@ -219,55 +221,61 @@ def add_bit_fields(
219221 yield new
220222 self ._add_bit_fields (new )
221223
222- def value (self , name : str , resolve_enum : bool = True ) -> ProtocolPrimitive :
224+ def value (
225+ self , name : str , resolve_enum : bool = True , index : int = 0
226+ ) -> ProtocolPrimitive :
223227 """Get the value of a field belonging to the protocol."""
224228
225229 val : ProtocolPrimitive = 0
226230
227231 if name in self ._regular_fields :
228- val = self ._regular_fields [name ].value
232+ val = self ._regular_fields [name ][ index ] .value
229233
230234 # Resolve the enum value.
231235 if resolve_enum and name in self ._enum_fields :
232- val = self ._enum_fields [name ].get_str (val ) # type: ignore
236+ with suppress (KeyError ):
237+ val = self ._enum_fields [name ].get_str (val ) # type: ignore
233238
234239 return val
235240
236241 return self ._fields .get (name , resolve_enum = resolve_enum )
237242
238- @property
239- def size (self ) -> int :
240- """Get this protocol's size in bytes."""
241- return self .array .length ()
242-
243243 def trace_size (self , logger : LoggerType ) -> None :
244244 """Log a size trace."""
245- logger .info ("%s: %s" , self , self .array . length_trace ())
245+ logger .info ("%s: %s" , self , self .length_trace ())
246246
247247 def __str__ (self ) -> str :
248248 """Get this instance as a string."""
249249
250- return f"({ self .size } ) " + " " .join (
251- f"{ name } ={ self [name ]} " for name in self .names .registered_order
250+ return (
251+ self .length_trace ()
252+ + f" | ({ self .length ()} ) "
253+ + " " .join (
254+ f"{ name } ={ self [name ]} " for name in self .names .registered_order
255+ )
252256 )
253257
254- def __getitem__ (self , name : str ) -> ProtocolPrimitive :
258+ def __getitem__ (self , name : str ) -> ProtocolPrimitive : # type: ignore
255259 """Get the value of a protocol field."""
256260
257261 if name in self .serializables :
258- return str (self .serializables [name ])
262+ return str (self .serializables [name ][ 0 ] )
259263
260264 return self .value (name )
261265
262- def __setitem__ (self , name : str , val : ProtocolPrimitive ) -> None :
266+ def set (self , name : str , val : ProtocolPrimitive , index : int = 0 ) -> None :
263267 """Set a value of a field belonging to the protocol."""
264268
265269 if name in self ._regular_fields :
266270 # Resolve an enum value.
267271 if isinstance (val , str ):
268272 val = self ._enum_fields [name ].get_int (val )
269- self ._regular_fields [name ].value = val
273+ self ._regular_fields [name ][ index ] .value = val
270274 elif name in self .serializables and isinstance (val , str ):
271- self .serializables [name ].update_str (val )
275+ self .serializables [name ][ index ] .update_str (val )
272276 else :
273277 self ._fields .set (name , val ) # type: ignore
278+
279+ def __setitem__ (self , name : str , val : ProtocolPrimitive ) -> None :
280+ """Set a value of a field belonging to the protocol."""
281+ self .set (name , val )
0 commit comments