Skip to content

Commit

Permalink
Remove useless code, implement to for tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
alessandropalla committed Jul 16, 2024
1 parent 9760d57 commit f06e806
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 23 deletions.
4 changes: 2 additions & 2 deletions intel_npu_acceleration_library/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#

from intel_npu_acceleration_library.nn.module import convert_to_npu_module
from intel_npu_acceleration_library.backend.tensor import RemoteTensor
from torch.overrides import TorchFunctionMode
from functools import lru_cache
from typing import Any, MutableMapping
Expand Down Expand Up @@ -165,8 +166,7 @@ def to(super_fn: Any, self: Any, *args: Any, **kwargs: Any):
"""
npu_device, args, kwargs = parse_to_arguments(*args, **kwargs)
if npu_device:
# None for now, once the remote tensor feature lands, it can be converted to a remote tensor
pass
return super_fn(RemoteTensor.from_torch(self), *args, **kwargs)
return super_fn(self, *args, **kwargs)


Expand Down
21 changes: 0 additions & 21 deletions intel_npu_acceleration_library/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,25 +67,6 @@ def compute_input_signature(
return "_".join(signature)


def patch_parameters(module: torch.nn.Module, model: NNFactory, recurse: bool = False):
"""Patch the parameters of a PyTorch module with constants.
Args:
module (torch.nn.Module): The PyTorch module.
model (NNFactory): The NNFactory instance.
recurse (bool, optional): Recurse over all submodules. Defaults to False.
"""
elements = list(module.named_parameters(recurse=recurse))
for name, param in elements:
del module._parameters[name]
setattr(module, name, model.constant(param.data.detach().numpy()))

buffers = list(module.named_buffers(recurse=recurse))
for name, param in buffers:
del module._buffers[name]
setattr(module, name, model.constant(param.data.detach().numpy()))


def patch_modules(module: torch.nn.Module, model: NNFactory):
"""Patch the modules of a PyTorch module with constants.
Expand All @@ -97,7 +78,6 @@ def patch_modules(module: torch.nn.Module, model: NNFactory):
for _, module in modules:
if isinstance(module, Module):
module.npu_top_level_module = False
# patch_parameters(module, model)
patch_modules(module, model)


Expand Down Expand Up @@ -224,7 +204,6 @@ def create_kwargs_from_list(
npu_kwargs = create_kwargs_from_list(kwargs)

patch_modules(self, model)
# patch_parameters(self, model)

_ = self.forward(*npu_args, **npu_kwargs)
model.compile()
Expand Down

0 comments on commit f06e806

Please sign in to comment.