diff --git a/mlx/onnx/ops/__init__.py b/mlx/onnx/ops/__init__.py index 46e2429..3f89927 100644 --- a/mlx/onnx/ops/__init__.py +++ b/mlx/onnx/ops/__init__.py @@ -13,6 +13,7 @@ from .op_pool import MaxPool, AveragePool from .op_conv import Conv from .op_slice import Slice +from .op_topk import TopK # Reference Docs: https://onnx.ai/onnx/operators/ diff --git a/mlx/onnx/ops/op_topk.py b/mlx/onnx/ops/op_topk.py new file mode 100644 index 0000000..efb650a --- /dev/null +++ b/mlx/onnx/ops/op_topk.py @@ -0,0 +1,28 @@ +import mlx.core as mx + +def TopK(x: mx.array, k: mx.array, axis=-1, largest=1, sorted=1): + if isinstance(k, mx.array): + k = k.item() + if x.ndim == 2 and axis == 1: + sample = mx.arange(x.shape[0])[:, None] + if largest == 0: + sorted_indices = mx.argpartition(x, kth=k - 1, axis=axis) + sorted_indices = sorted_indices[:, :k] + sorted_indices = sorted_indices[sample, mx.argsort(x[sample, sorted_indices])] + else: + sorted_indices = mx.argpartition(-x, kth=k-1, axis=axis) + sorted_indices = sorted_indices[:, :k] + sorted_indices = sorted_indices[sample, mx.argsort(-x[sample, sorted_indices])] + sorted_distances = x[sample, sorted_indices] + return (sorted_distances, sorted_indices.astype(mx.int64)) + + if largest == 0: + sorted_indices = mx.argsort(x, axis=axis) + sorted_values = mx.sort(x, axis=axis) + else: + sorted_indices = mx.argsort(-x, axis=axis) + sorted_values = -mx.sort(-x, axis=axis) + ark = mx.arange(k) + topk_sorted_indices = mx.take(sorted_indices, ark, axis=axis) + topk_sorted_values = mx.take(sorted_values, ark, axis=axis) + return topk_sorted_values, topk_sorted_indices.astype(mx.int64) diff --git a/tests/test_onnx.py b/tests/test_onnx.py index 09314e5..2d54092 100644 --- a/tests/test_onnx.py +++ b/tests/test_onnx.py @@ -65,7 +65,6 @@ def supports_device(cls, device: str) -> bool: btest.exclude("test_convtranspose_*") btest.exclude("test_PReLU_*") -btest.exclude("test_topk*") # TODO: Implement dilations / col format btest.exclude("test_averagepool_2d_dilations_cpu")