Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP,POC] Faster functional modules #983

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

vmoens
Copy link

@vmoens vmoens commented Jul 24, 2022

Proposes a new method to load weights in FunctionalModule and FunctionalModuleWithBuffers.

A map module <-> param_name <-> param_value is created and used to set attributes.

Test:
The following test runs twice as fast on CPU than current implementation:

if __name__ == "__main__":
    # module with high param allocation cost but few operations
    net = torch.nn.Sequential(
        torch.nn.Linear(1, 1),
        torch.nn.Linear(1, 1),
        torch.nn.Sequential(
            torch.nn.Linear(1, 1),
            torch.nn.Linear(1, 1),
            torch.nn.Linear(1, 1),
            torch.nn.Linear(1, 1),
        )
    )

    fnet, params = make_functional(net)
    x = torch.randn(1)
    print(timeit.timeit("fnet(params, x)", globals={"fnet": fnet, "x": x, "params": params}, number=10000))
    # 1.7 sec with new, 3.8 with old
    

    # the implementation supports serialization
    import tempfile
    with tempfile.NamedTemporaryFile() as file:
        torch.save(fnet2, file.name)
        loaded_fnet = torch.load(file.name)
        assert torch.isclose(fnet2(params, x), loaded_fnet(params, x))

Other metrics:
On torchrl's DDPG, the new in a full forward-backard pass, the old implementation of _swap_state takes approx. 20% of the runtime with small neural nets (2 layers MLP with 256 cells) on CPU. The new implementation takes approx. 6% of runtime.

@zou3519 zou3519 self-requested a review July 25, 2022 21:39
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the approach -- it's awesome that it speeds up small neural nets a lot. I'm curious about your thoughts on parameter tying

Comment on lines +59 to +63
for module_name, m in model.named_modules():
for param_name, p in list(m.named_parameters(recurse=False)):
delattr(m, param_name)
setattr(m, param_name, None)
yield (module_name, m, param_name, p)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we previously used create_names_map was for parameter tying. If someone creates a module that looks like:

class Foo(nn.Module):
   def __init__(self):
        super().__init__()
        self.bias = nn.Parameter(torch.randn(3))
        self.linear = nn.Linear(3, 3)
        self.linear.bias = self.bias

then fmodel, params = make_functional(Foo()) returns 2 Tensors (self.linear.weight and self.bias) instead of 3 Tensors. When the user calls fmodel([w, b], x), then b gets loaded to self.bias and self.linear.bias and w gets loaded to self.linear.weight.

Under the new strategy, it seems like params would have 3 tensors: [self.bias, self.linear.weight, self.linear.bias].

In general I'm not really sure what the interaction between parameter tying and make_functional should be. Thoughts?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point. If we want to keep things as they are we could link params to a list of modules and a list of names (instead of a single module and a single name). That will come with a slight overhead though...

It's the kind of design choice where you will always make someone unhappy (there will be someone out there that wants multiple copies of the same param), but it's probably not the majority of users.

Comment on lines +255 to +262
old_states = _swap_state(
self.param_modules + self.buffer_modules,
self.param_names + self.buffer_names,
list(params) + list(buffers)
)
old_params = old_states[:len(self.param_modules)]
old_buffers = old_states[len(self.param_modules):]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I understand why this is faster: is it because we no longer need to traverse through the module to find the submodules; we've already made the submodules directly available to swap their parameters out?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly!
Instead of going through a tree of param names, we just flatten it and go through a single list of modules, one-level names and values.

Comment on lines +229 to +231
param_module_names, param_modules, param_names, params = zip(*param_container)
else:
param_module_names, param_modules, param_names, params = tuple(), tuple(), tuple(), tuple()
Copy link
Contributor

@zou3519 zou3519 Jul 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, we guaranteed that params is returned in the same order as what gets returned by original_model.parameters(). After this change, is that still true?

(Side note) To be honest, we've been thinking of changing the API so that params isn't returned as a flat list; instead we probably want to return some sort of dictionary or object so that one can easily figure out which params corresponds to which parameters on the original module. This is something that a couple of users have asked us for. If we returned a dictionary then it doesn't matter that params isn't the same as what gets returned by original_module.parameters()

Copy link
Author

@vmoens vmoens Aug 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh man this would be so great! I could definitely use that feature.

To be honest, I was thinking about using TensorDict from torchrl to pass params to functorch stateless modules. We could nest the dicts (eg d["module"]["param"] to d["module.param"]), expand the params, change device or whatever, in batch and with little or no effort since all those ops are built-in tensordict methods. I think there's a good synergy that we could get from TensorDict functorch. At the moment, TensorDict isn't torchscriptable though, I don't know how much trouble it is for you.
@nairbv @shagunsodhani

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants