Skip to content

Commit

Permalink
space some code apart
Browse files Browse the repository at this point in the history
  • Loading branch information
dc-dc-dc committed Jan 22, 2024
1 parent c9b6cf3 commit 7873b39
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions mlx/onnx/ops/op_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

def Conv(x: mx.array, weight: mx.array, bias: Optional[mx.array]=None, dilations:Optional[mx.array]=None, group=1, auto_pad="NOTSET", kernel_shape:Optional[mx.array]=None, pads:Optional[mx.array]=None, strides:Optional[mx.array]=None):
assert group == 1, f"mlx only supports 1 group, got {group}"
if dilations is not None:
assert all(x == 1 for x in dilations.tolist()), "mlx only supports dilation 1"

if isinstance(kernel_shape, mx.array):
kernel_shape = kernel_shape.tolist()
if isinstance(strides, mx.array):
Expand All @@ -14,13 +17,14 @@ def Conv(x: mx.array, weight: mx.array, bias: Optional[mx.array]=None, dilations
pads = pads.tolist()
if pads is None:
pads = [0] * len(kernel_shape)

if x.ndim < weight.ndim:
x = mx.expand_dims(x, 0)

if auto_pad != "NOTSET":
padding = convert_pad(ap(x.shape, auto_pad, strides, kernel_shape))
x = mx.pad(x, pad_width=[(0,0), (0,0)] + padding, constant_values=0)
if dilations is not None:
assert all(x == 1 for x in dilations.tolist()), "mlx only supports dilation 1"

if x.ndim == 3:
c = mx.conv1d(x.transpose(0, 2, 1), weight.transpose(0, 2, 1), padding=pads[0] if pads is not None else 0, stride=strides[0] if strides is not None else 1)
c = c + bias if bias is not None else c
Expand Down

0 comments on commit 7873b39

Please sign in to comment.