Skip to content

Commit

Permalink
add topk
Browse files Browse the repository at this point in the history
  • Loading branch information
dc-dc-dc committed Jan 22, 2024
1 parent 7873b39 commit 9e393dc
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
1 change: 1 addition & 0 deletions mlx/onnx/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .op_pool import MaxPool, AveragePool
from .op_conv import Conv
from .op_slice import Slice
from .op_topk import TopK

# Reference Docs: https://onnx.ai/onnx/operators/

Expand Down
28 changes: 28 additions & 0 deletions mlx/onnx/ops/op_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import mlx.core as mx

def TopK(x: mx.array, k: mx.array, axis=-1, largest=1, sorted=1):
if isinstance(k, mx.array):
k = k.item()
if x.ndim == 2 and axis == 1:
sample = mx.arange(x.shape[0])[:, None]
if largest == 0:
sorted_indices = mx.argpartition(x, kth=k - 1, axis=axis)
sorted_indices = sorted_indices[:, :k]
sorted_indices = sorted_indices[sample, mx.argsort(x[sample, sorted_indices])]
else:
sorted_indices = mx.argpartition(-x, kth=k-1, axis=axis)
sorted_indices = sorted_indices[:, :k]
sorted_indices = sorted_indices[sample, mx.argsort(-x[sample, sorted_indices])]
sorted_distances = x[sample, sorted_indices]
return (sorted_distances, sorted_indices.astype(mx.int64))

if largest == 0:
sorted_indices = mx.argsort(x, axis=axis)
sorted_values = mx.sort(x, axis=axis)
else:
sorted_indices = mx.argsort(-x, axis=axis)
sorted_values = -mx.sort(-x, axis=axis)
ark = mx.arange(k)
topk_sorted_indices = mx.take(sorted_indices, ark, axis=axis)
topk_sorted_values = mx.take(sorted_values, ark, axis=axis)
return topk_sorted_values, topk_sorted_indices.astype(mx.int64)
1 change: 0 additions & 1 deletion tests/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def supports_device(cls, device: str) -> bool:
btest.exclude("test_convtranspose_*")

btest.exclude("test_PReLU_*")
btest.exclude("test_topk*")

# TODO: Implement dilations / col format
btest.exclude("test_averagepool_2d_dilations_cpu")
Expand Down

0 comments on commit 9e393dc

Please sign in to comment.