From f06e80608e396cc5f025f403d84de29fc691349f Mon Sep 17 00:00:00 2001 From: Alessandro Palla Date: Tue, 16 Jul 2024 13:41:07 +0100 Subject: [PATCH] Remove useless code, implement to for tensors --- intel_npu_acceleration_library/device.py | 4 ++-- intel_npu_acceleration_library/nn/module.py | 21 --------------------- 2 files changed, 2 insertions(+), 23 deletions(-) diff --git a/intel_npu_acceleration_library/device.py b/intel_npu_acceleration_library/device.py index 988c315..28e8484 100644 --- a/intel_npu_acceleration_library/device.py +++ b/intel_npu_acceleration_library/device.py @@ -4,6 +4,7 @@ # from intel_npu_acceleration_library.nn.module import convert_to_npu_module +from intel_npu_acceleration_library.backend.tensor import RemoteTensor from torch.overrides import TorchFunctionMode from functools import lru_cache from typing import Any, MutableMapping @@ -165,8 +166,7 @@ def to(super_fn: Any, self: Any, *args: Any, **kwargs: Any): """ npu_device, args, kwargs = parse_to_arguments(*args, **kwargs) if npu_device: - # None for now, once the remote tensor feature lands, it can be converted to a remote tensor - pass + return super_fn(RemoteTensor.from_torch(self), *args, **kwargs) return super_fn(self, *args, **kwargs) diff --git a/intel_npu_acceleration_library/nn/module.py b/intel_npu_acceleration_library/nn/module.py index ef23c8e..9c5bf6d 100644 --- a/intel_npu_acceleration_library/nn/module.py +++ b/intel_npu_acceleration_library/nn/module.py @@ -67,25 +67,6 @@ def compute_input_signature( return "_".join(signature) -def patch_parameters(module: torch.nn.Module, model: NNFactory, recurse: bool = False): - """Patch the parameters of a PyTorch module with constants. - - Args: - module (torch.nn.Module): The PyTorch module. - model (NNFactory): The NNFactory instance. - recurse (bool, optional): Recurse over all submodules. Defaults to False. - """ - elements = list(module.named_parameters(recurse=recurse)) - for name, param in elements: - del module._parameters[name] - setattr(module, name, model.constant(param.data.detach().numpy())) - - buffers = list(module.named_buffers(recurse=recurse)) - for name, param in buffers: - del module._buffers[name] - setattr(module, name, model.constant(param.data.detach().numpy())) - - def patch_modules(module: torch.nn.Module, model: NNFactory): """Patch the modules of a PyTorch module with constants. @@ -97,7 +78,6 @@ def patch_modules(module: torch.nn.Module, model: NNFactory): for _, module in modules: if isinstance(module, Module): module.npu_top_level_module = False - # patch_parameters(module, model) patch_modules(module, model) @@ -224,7 +204,6 @@ def create_kwargs_from_list( npu_kwargs = create_kwargs_from_list(kwargs) patch_modules(self, model) - # patch_parameters(self, model) _ = self.forward(*npu_args, **npu_kwargs) model.compile()