From 58dfeea84bbb1d8ce7a081847851d411a83ef818 Mon Sep 17 00:00:00 2001 From: Alex Yun Date: Tue, 10 Dec 2024 14:56:33 -0600 Subject: [PATCH 1/4] add dtypes --- src/adapters/configuration/adapter_config.py | 7 +++++++ src/adapters/methods/lora.py | 4 ++-- src/adapters/methods/reft.py | 6 ++++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 0f2eec2162..20aec1cf67 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -3,6 +3,8 @@ from dataclasses import FrozenInstanceError, asdict, dataclass, field, replace from typing import List, Literal, Optional, Union +import torch + from ..utils import resolve_adapter_config @@ -499,6 +501,7 @@ class LoRAConfig(AdapterConfig): composition_mode: str = "add" init_weights: str = "lora" use_gating: bool = False + dtype: torch.dtype = torch.float32 @dataclass(eq=False) @@ -521,6 +524,7 @@ class IA3Config(LoRAConfig): composition_mode: str = "scale" init_weights: str = "ia3" use_gating: bool = False + dtype: torch.dtype = torch.float32 @dataclass(eq=False) @@ -551,6 +555,7 @@ class ReftConfig(AdapterConfig): subtract_projection = True dropout: float = 0.05 non_linearity: Optional[str] = None + dtype: torch.dtype = torch.float32 architecture: str = "reft" @@ -583,6 +588,7 @@ class NoReftConfig(ReftConfig): r: int = 1 orthogonality: bool = False tied_weights: bool = False + dtype: torch.dtype = torch.float32 @dataclass(eq=False) @@ -598,6 +604,7 @@ class DiReftConfig(ReftConfig): orthogonality: bool = False tied_weights: bool = False subtract_projection = False + dtype: torch.dtype = torch.float32 class ConfigUnion(AdapterConfig): diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index c62a94f265..ffd98c3700 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -52,8 +52,8 @@ def __init__( self.lora_dropout = lambda x: x # Actual trainable parameters - self.lora_A = nn.Parameter(torch.zeros(lora_A_shape)) - self.lora_B = nn.Parameter(torch.zeros(lora_B_shape)) + self.lora_A = nn.Parameter(torch.zeros(lora_A_shape, dtype=config.dtype)) + self.lora_B = nn.Parameter(torch.zeros(lora_B_shape, dtype=config.dtype)) self.scaling = self.lora_alpha / self.r # For compatibility with (IA)^3, allow all init_weights types here. diff --git a/src/adapters/methods/reft.py b/src/adapters/methods/reft.py index 0914e8d3aa..bcdbeccb5e 100644 --- a/src/adapters/methods/reft.py +++ b/src/adapters/methods/reft.py @@ -18,12 +18,13 @@ def __init__( subtract_projection: bool = True, non_linearity: str = None, dropout: float = 0.0, + dtype: torch.dtype = torch.float32, ): super().__init__() self.orthogonal = orthogonal - self.learned_source = nn.Linear(in_dim, r_dim, bias=True) + self.learned_source = nn.Linear(in_dim, r_dim, bias=True, dtype=dtype) - projection = nn.Linear(in_dim, r_dim, bias=False) + projection = nn.Linear(in_dim, r_dim, bias=False, dtype=dtype) if orthogonal: self.projection = nn.utils.parametrizations.orthogonal(projection) else: @@ -59,6 +60,7 @@ def __init__(self, in_features: int, config: ReftConfig): config.subtract_projection, config.non_linearity, config.dropout, + config.dtype, ) for _ in range(n_units) ] From 2728f3e3c33555ba85c298d056d28875548bb561 Mon Sep 17 00:00:00 2001 From: Alex Yun Date: Tue, 10 Dec 2024 15:22:57 -0600 Subject: [PATCH 2/4] use string for dtype --- src/adapters/configuration/adapter_config.py | 13 ++++++------- src/adapters/methods/lora.py | 5 +++-- src/adapters/methods/reft.py | 5 +++-- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 20aec1cf67..7073c4c5fe 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -3,8 +3,6 @@ from dataclasses import FrozenInstanceError, asdict, dataclass, field, replace from typing import List, Literal, Optional, Union -import torch - from ..utils import resolve_adapter_config @@ -501,7 +499,7 @@ class LoRAConfig(AdapterConfig): composition_mode: str = "add" init_weights: str = "lora" use_gating: bool = False - dtype: torch.dtype = torch.float32 + dtype: Optional[str] = None @dataclass(eq=False) @@ -524,7 +522,7 @@ class IA3Config(LoRAConfig): composition_mode: str = "scale" init_weights: str = "ia3" use_gating: bool = False - dtype: torch.dtype = torch.float32 + dtype: Optional[str] = None @dataclass(eq=False) @@ -555,7 +553,7 @@ class ReftConfig(AdapterConfig): subtract_projection = True dropout: float = 0.05 non_linearity: Optional[str] = None - dtype: torch.dtype = torch.float32 + dtype: Optional[str] = None architecture: str = "reft" @@ -574,6 +572,7 @@ class LoReftConfig(ReftConfig): r: int = 1 orthogonality: bool = True tied_weights: bool = False + dtype: Optional[str] = None @dataclass(eq=False) @@ -588,7 +587,7 @@ class NoReftConfig(ReftConfig): r: int = 1 orthogonality: bool = False tied_weights: bool = False - dtype: torch.dtype = torch.float32 + dtype: Optional[str] = None @dataclass(eq=False) @@ -604,7 +603,7 @@ class DiReftConfig(ReftConfig): orthogonality: bool = False tied_weights: bool = False subtract_projection = False - dtype: torch.dtype = torch.float32 + dtype: Optional[str] = None class ConfigUnion(AdapterConfig): diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index ffd98c3700..d56a11a91d 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -51,9 +51,10 @@ def __init__( else: self.lora_dropout = lambda x: x + dtype = getattr(torch, config.dtype) if config.dtype else None # Actual trainable parameters - self.lora_A = nn.Parameter(torch.zeros(lora_A_shape, dtype=config.dtype)) - self.lora_B = nn.Parameter(torch.zeros(lora_B_shape, dtype=config.dtype)) + self.lora_A = nn.Parameter(torch.zeros(lora_A_shape, dtype=dtype)) + self.lora_B = nn.Parameter(torch.zeros(lora_B_shape, dtype=dtype)) self.scaling = self.lora_alpha / self.r # For compatibility with (IA)^3, allow all init_weights types here. diff --git a/src/adapters/methods/reft.py b/src/adapters/methods/reft.py index bcdbeccb5e..a736eee4ab 100644 --- a/src/adapters/methods/reft.py +++ b/src/adapters/methods/reft.py @@ -18,7 +18,7 @@ def __init__( subtract_projection: bool = True, non_linearity: str = None, dropout: float = 0.0, - dtype: torch.dtype = torch.float32, + dtype: Optional[torch.dtype] = None, ): super().__init__() self.orthogonal = orthogonal @@ -51,6 +51,7 @@ def __init__(self, in_features: int, config: ReftConfig): self.suffix_positions = config.suffix_positions self.tied_weights = config.tied_weights n_units = 1 if config.tied_weights else 2 + dtype = getattr(torch, config.dtype) if config.dtype else None self.units = nn.ModuleList( [ ReftUnit( @@ -60,7 +61,7 @@ def __init__(self, in_features: int, config: ReftConfig): config.subtract_projection, config.non_linearity, config.dropout, - config.dtype, + dtype, ) for _ in range(n_units) ] From 6ef8d804695913a5d7b71583d88f65942bc5d6ba Mon Sep 17 00:00:00 2001 From: Alex Yun Date: Tue, 10 Dec 2024 15:43:38 -0600 Subject: [PATCH 3/4] fix import --- src/adapters/methods/reft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/adapters/methods/reft.py b/src/adapters/methods/reft.py index a736eee4ab..9c6647e399 100644 --- a/src/adapters/methods/reft.py +++ b/src/adapters/methods/reft.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import torch import torch.nn as nn From eccd5f0b37076e2cdd9951a909b496e3a4823823 Mon Sep 17 00:00:00 2001 From: Alex Yun Date: Sun, 22 Dec 2024 11:07:26 -0800 Subject: [PATCH 4/4] add docstring to dtype --- src/adapters/configuration/adapter_config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 7073c4c5fe..b5249cb9f5 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -483,6 +483,7 @@ class LoRAConfig(AdapterConfig): Place a trainable gating module besides the added parameter module to control module activation. This is e.g. used for UniPELT. Defaults to False. Note that modules with use_gating=True cannot be merged using `merge_adapter()`. + dtype (str, optional): torch dtype for reparametrization tensors. Defaults to None. """ architecture: Optional[str] = "lora" @@ -542,6 +543,7 @@ class ReftConfig(AdapterConfig): subtract_projection (bool): If True, subtract the projection of the input. dropout (float): The dropout rate used in the intervention layer. non_linearity (str): The activation function used in the intervention layer. + dtype (str, optional): torch dtype for intervention tensors. Defaults to None. """ layers: Union[Literal["all"], List[int]]