You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have been trying to refactor my physics-based Transformer code so that it all calls to numpy are removed and everything is implemented in MLX. I am almost there, a pure MLX implementation. 😀
In several places in the code I have been using np.random.shuffle() to randomize the generation of datasets (boundary conditions) and also during the loading of batches. My MLX solution has been the following , but I am wondering if this is the most efficient option or if I am missing a built-in option. Any insights from @awni would be appreciated.
if synchronized_shuffling:
# Stack the arrays along axis 1 (side by side, as columns)
combined = mx.stack([left_bcs, right_bcs, top_bcs, bottom_bcs, alphas], axis=1)
# Generate random indices for shuffling the samples
num_samples_comb = combined.shape[0] # Number of samples along the first axis
indices = mx.random.randint(low=0, high=num_samples_comb, shape=(num_samples_comb,))
# Reorder the combined array based on shuffled indices
combined = combined[indices, :] # Shuffle the rows (samples)
# Use split to "unstack" the combined array (5 arrays, one for each boundary condition)
left_bcs, right_bcs, top_bcs, bottom_bcs, alphas = mx.split(combined, 5, axis=1)
# Squeeze to remove the extra dimension
left_bcs = left_bcs.squeeze()
right_bcs = right_bcs.squeeze()
top_bcs = top_bcs.squeeze()
bottom_bcs = bottom_bcs.squeeze()
alphas = alphas.squeeze()
else:
# Generate random indices for shuffling each array independently
indicesL = mx.random.randint(low=0, high=num_samples, shape=(num_samples,))
indicesR = mx.random.randint(low=0, high=num_samples, shape=(num_samples,))
indicesT = mx.random.randint(low=0, high=num_samples, shape=(num_samples,))
indicesB = mx.random.randint(low=0, high=num_samples, shape=(num_samples,))
indicesA = mx.random.randint(low=0, high=num_samples, shape=(num_samples,))
# Reorder arrays based on the shuffled indices
left_bcs = left_bcs[indicesL]
right_bcs = right_bcs[indicesR]
top_bcs = top_bcs[indicesT]
bottom_bcs = bottom_bcs[indicesB]
alphas = alphas[indicesA]
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I have been trying to refactor my physics-based Transformer code so that it all calls to numpy are removed and everything is implemented in MLX. I am almost there, a pure MLX implementation. 😀
In several places in the code I have been using np.random.shuffle() to randomize the generation of datasets (boundary conditions) and also during the loading of batches. My MLX solution has been the following , but I am wondering if this is the most efficient option or if I am missing a built-in option. Any insights from @awni would be appreciated.
Numpy implementation
MLX Refactoring
Beta Was this translation helpful? Give feedback.
All reactions