Skip to content

Commit e669e5b

Browse files
committed
qbatchnormdense and permute
1 parent 5fb1d3a commit e669e5b

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

hls4ml/converters/keras_v3/squark/_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class SQConvHandler(SQLayerHandler, KV3ConvHandler):
139139

140140
@register
141141
class SQDenseHandler(SQLayerHandler, KV3DenseHandler):
142-
handles = ('squark.layers.core.dense.QDense',)
142+
handles = ('squark.layers.core.dense.QDense', 'squark.layers.core.dense.QBatchNormDense')
143143

144144
def handle(
145145
self,
@@ -148,6 +148,7 @@ def handle(
148148
out_tensors: Sequence['KerasTensor'],
149149
):
150150
conf = super().handle(layer, in_tensors, out_tensors)
151+
conf['class_name'] = 'Dense'
151152
in_shape: tuple[int, ...] = in_tensors[0].shape[1:] # type: ignore
152153
if len(in_shape) > 1:
153154
if hasattr(layer, 'parallelization_factor'):

hls4ml/model/optimizer/passes/bit_exact.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
Pooling2D,
2828
Reshape,
2929
Softmax,
30+
Transpose,
3031
)
3132
from hls4ml.model.optimizer import ModelOptimizerPass, OptimizerPass
3233
from hls4ml.model.optimizer.passes.hgq_proxy_model import FixedPointQuantizer, UnaryLUT
@@ -154,6 +155,17 @@ def _(layer: Concatenate):
154155
return ((k0, i0, f0), (k1, i1, f1))
155156

156157

158+
@_request_kif.register
159+
def _(layer: Transpose):
160+
k, i, f = requested_kif(layer)
161+
perm = layer.attributes['perm']
162+
inv_perm = np.argsort(perm)
163+
k = np.transpose(k, inv_perm)
164+
i = np.transpose(i, inv_perm)
165+
f = np.transpose(f, inv_perm)
166+
return ((k, i, f),)
167+
168+
157169
def requested_kif(layer: Layer) -> KIF_t:
158170
out_layers = get_output_layers(layer)
159171
out_shape = get_output_shape(layer)
@@ -319,6 +331,16 @@ def _(layer: Dense):
319331
return k.astype(np.int8), i, f
320332

321333

334+
@_produce_kif.register
335+
def _(layer: Transpose):
336+
k, i, f = get_input_kifs(layer)[0]
337+
perm = layer.attributes['perm']
338+
k = np.transpose(k, perm)
339+
i = np.transpose(i, perm)
340+
f = np.transpose(f, perm)
341+
return k, i, f
342+
343+
322344
def r_im2col(kernel_size: Sequence[int], arr: np.ndarray, buffer: np.ndarray, axis: int):
323345
w = kernel_size[0]
324346
if len(kernel_size) == 3: # 1D

0 commit comments

Comments
 (0)