Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batching over model parameters #1094

Open
LeanderK opened this issue Jan 4, 2023 · 2 comments
Open

batching over model parameters #1094

LeanderK opened this issue Jan 4, 2023 · 2 comments

Comments

@LeanderK
Copy link

LeanderK commented Jan 4, 2023

I have a use-case for functorch. I would like to check possible iterations of model parameters in a very efficient way (I want to eliminate the loop). Here's an example code for a simplified case I got it working:

linear = torch.nn.Linear(10,2)
default_weight = linear.weight.data
sample_input = torch.rand(3,10)
sample_add = torch.rand_like(default_weight)
def interpolate_weights(alpha):
    with torch.no_grad():
        res_weight = torch.nn.Parameter(default_weight + alpha*sample_add)
        linear.weight = res_weight
        return linear(sample_input)

now I could do for alpha in np.np.linspace(0.0, 1.0, 100) but I want to vectorise this loop since my code is prohibitively slow. Is functorch here applicable? Executing:

alphas = torch.linspace(0.0, 1.0, 100)
vmap(interpolate_weights)(alphas)

works, but how to do something similar for a simple resnet does not work. I've tried using load_state_dict but that's not working:

from torchvision import models
model_resnet = models.resnet18(pretrained=True)

named_params = list(model_resnet.named_parameters())
named_params_data = [(n,p.data.clone()) for (n,p) in named_params]

sample_data = torch.rand(10,3,224,244)

def test_resnet(new_params):
    def interpolate(alpha):
        with torch.no_grad():
            p_dict = {name:(old + alpha*new_params[i]) for i,(name, old) in enumerate(named_params_data)}
            model_resnet.load_state_dict(p_dict, strict=False)
            out = model_resnet(sample_data)
            return out
    return interpolate

rand_tensor = [torch.rand_like(p) for n,p in named_params_data]

to_vamp_resnet = test_thing(rand_tensor)
vmap(to_vamp_resnet)(alphas)

results in:

While copying the parameter named "fc.bias", whose dimensions in the model are torch.Size([1000]) and whose dimensions in the checkpoint are torch.Size([1000]), an exception occurred : ('vmap: inplace arithmetic(self, *extra_args) is not possible because there exists a Tensorotherin extra_args that has more elements thanself. This happened due to otherbeing vmapped over butselfnot being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.',).

@LeanderK
Copy link
Author

LeanderK commented Jan 4, 2023

is this a legal way to solve this? it doesn't give me an error but I am very unsure why this now works.

def test_resnet_2(new_params):
    def interpolate(alpha):
        with torch.no_grad():
            for i, (name, old_p) in enumerate(named_params_data):
                new_p = new_params[i]
                parame_names = name.split(".")
                current = model_resnet
                for p in parame_names[:-1]:
                    current = getattr(current, p)
                setattr(current, parame_names[-1], torch.nn.Parameter(old_p + alpha*new_p))
                
            out = model_resnet(sample_data)
            
            for i, (name, old_p) in enumerate(named_params_data):
                parame_names = name.split(".")
                current = model_resnet
                for p in parame_names[:-1]:
                    current = getattr(current, p)
                setattr(current, parame_names[-1], torch.nn.Parameter(old_p))
            return out
    return interpolate

model_resnet.eval()
to_vamp_resnet = test_thing2(rand_tensor)
test_out2 = vmap(to_vamp_resnet)(alphas)

EDIT: found an even simple solution. This is the correct approach, right?

def test_resnet_4(new_params, sample_data, model_resnet):
    func_model, params, buff = make_functional_with_buffers(model_resnet, disable_autograd_tracking=True)
    def interpolate(alpha):
        with torch.no_grad():
            interpol_params = [torch.nn.Parameter(old_p + alpha*new_params[i]) for i, old_p in enumerate(params)]
                
            out = func_model(interpol_params, buff, sample_data)
            return out
    return interpolate

model_resnet.eval()
to_vamp_resnet = test_resnet_4(rand_tensor, sample_data, model_resnet)
test_out2 = vmap(to_vamp_resnet)(alphas)

@samdow
Copy link
Contributor

samdow commented Jan 4, 2023

Hi @LeanderK! Thanks for the interesting issue! Since it sounds like this works, that's a totally fine way of doing it!

One thing that might come up is if you do N runs of this model (instead of 1), it will be faster to do something similar to the ensembling API since in your version you would be building the new parameters N times and this way you'll only build them once and then combine them. This is also useful if you want to train the model (batch norm should work with the ensemble)

For this use case, since it looks like you want to have very specific initializations, it this might be better to riff on the idea of the ensemble API

def test_resnet_4(func_model, buff, sample_data):
  def interpolate(interpol_params):
      with torch.no_grad():
          out = func_model(interpol_params, buff, sample_data)
          return out
  return interpolate

model_resnet.eval()

func_model, params, buff = make_functional_with_buffers(model_resnet, disable_autograd_tracking=True)
interpol_params = [[torch.nn.Parameter(old_p + alpha*rand_tensor[i]) for i, old_p in enumerate(params)] for alpha in alphas]
interpol_params = [torch.stack(i) for i in zip(*interpol_params)] # this is basically what the ensemble API is doing
to_vmap_resnet = test_resnet_4(func_model, buff, sample_data)
test_out2 = vmap(to_vmap_resnet)(interpol_params)

Then, if you want to train, you can also expand the buffers and vmap across them along with interpol_params so that batch norm works

Hope that helps! We are also looking at changing the module API to help rationalize some of the functorch API with the PyTorch API soon. If you're using the nightly build, I can point you to the new API if you're curious

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants