Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] memory release action in stream(to reduce memory usage) #1481

Open
kaeru-shigure opened this issue Oct 12, 2024 · 3 comments
Open

Comments

@kaeru-shigure
Copy link

I want something like this:

import mlx.core as mx
import mlx.nn as nn

mx.metal.set_cache_limit(0)

def dload(x, *args, **kwargs):
  la = mx.load(*args, **kwargs)
  @mx.custom_function
  def dload(x):
    print(f"forward_load({args[0]})")
    return x
  @dload.vjp
  def dload(p, c, o):
    # la -> allocator::free
    print(f"backward_unload({args[0]})")
    return c

  @mx.custom_function
  def drelease(x):
    # la -> allocator::free
    print(f"forward_unload({args[0]})")
    return x
  @drelease.vjp
  def drelease(p, c, o):
    print(f"backward_load({args[0]})")
    # la = mx.load(*args, **kwargs)
    # (la -> allocator::malloc)
    return c
  return dload(x), la, drelease

def linear_down(x):
  x, la, release = dload(x, "down.safetensors")
  x = x @ la["w"] + la["b"]
  return release(x) # same as x, but run allocator::free in eval
def linear_up(x):
  x, la, release = dload(x, "up.safetensors")
  x = x @ la["w"] + la["b"]
  return release(x) # same as x, but run allocator::free in eval

def proc(x):
  for _ in range(10):
    x = linear_down(x)
    x = linear_up(x)
  return x.mean()

#init
mx.save_safetensors("down.safetensors",{"w": mx.random.normal([1, 1, 48, 64]), "b": mx.random.normal([1, 1, 1, 64])})
mx.save_safetensors("up.safetensors",{"w": mx.random.normal([1, 1, 64, 48]), "b": mx.random.normal([1, 1, 1, 48])})
mx.random.seed(3)
x = mx.random.normal([1, 3, 32, 48])

#run
r = mx.value_and_grad(proc)(x)
mx.eval(*r)
print(mx.metal.get_peak_memory())
# should only use 1x la["w"].nbytes + la["b"].nbytes + others
forward_load(down.safetensors)
forward_unload(down.safetensors)
forward_load(up.safetensors)
forward_unload(up.safetensors)
..
backward_load(up.safetensors)
backward_unload(up.safetensors)
backward_load(down.safetensors)
backward_unload(down.safetensors)
@kaeru-shigure
Copy link
Author

dynamically loaded weights are not required to be learnable (because using LoRA)

@awni
Copy link
Member

awni commented Oct 12, 2024

It looks like you already implemented what you want? Is it not working / broken?

@kaeru-shigure
Copy link
Author

no, look closely... this only implements just comments output.
there is no way to free up memory in stream for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants