Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow specifying adapter dtype in AdapterConfig #767

Merged
merged 4 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/adapters/configuration/adapter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ class LoRAConfig(AdapterConfig):
Place a trainable gating module besides the added parameter module to control module activation. This is
e.g. used for UniPELT. Defaults to False. Note that modules with use_gating=True cannot be merged using
`merge_adapter()`.
dtype (str, optional): torch dtype for reparametrization tensors. Defaults to None.
"""

architecture: Optional[str] = "lora"
Expand All @@ -499,6 +500,7 @@ class LoRAConfig(AdapterConfig):
composition_mode: str = "add"
init_weights: str = "lora"
use_gating: bool = False
dtype: Optional[str] = None
calpt marked this conversation as resolved.
Show resolved Hide resolved


@dataclass(eq=False)
Expand All @@ -521,6 +523,7 @@ class IA3Config(LoRAConfig):
composition_mode: str = "scale"
init_weights: str = "ia3"
use_gating: bool = False
dtype: Optional[str] = None


@dataclass(eq=False)
Expand All @@ -540,6 +543,7 @@ class ReftConfig(AdapterConfig):
subtract_projection (bool): If True, subtract the projection of the input.
dropout (float): The dropout rate used in the intervention layer.
non_linearity (str): The activation function used in the intervention layer.
dtype (str, optional): torch dtype for intervention tensors. Defaults to None.
"""

layers: Union[Literal["all"], List[int]]
Expand All @@ -551,6 +555,7 @@ class ReftConfig(AdapterConfig):
subtract_projection = True
dropout: float = 0.05
non_linearity: Optional[str] = None
dtype: Optional[str] = None

architecture: str = "reft"

Expand All @@ -569,6 +574,7 @@ class LoReftConfig(ReftConfig):
r: int = 1
orthogonality: bool = True
tied_weights: bool = False
dtype: Optional[str] = None


@dataclass(eq=False)
Expand All @@ -583,6 +589,7 @@ class NoReftConfig(ReftConfig):
r: int = 1
orthogonality: bool = False
tied_weights: bool = False
dtype: Optional[str] = None


@dataclass(eq=False)
Expand All @@ -598,6 +605,7 @@ class DiReftConfig(ReftConfig):
orthogonality: bool = False
tied_weights: bool = False
subtract_projection = False
dtype: Optional[str] = None


class ConfigUnion(AdapterConfig):
Expand Down
5 changes: 3 additions & 2 deletions src/adapters/methods/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ def __init__(
else:
self.lora_dropout = lambda x: x

dtype = getattr(torch, config.dtype) if config.dtype else None
# Actual trainable parameters
self.lora_A = nn.Parameter(torch.zeros(lora_A_shape))
self.lora_B = nn.Parameter(torch.zeros(lora_B_shape))
self.lora_A = nn.Parameter(torch.zeros(lora_A_shape, dtype=dtype))
self.lora_B = nn.Parameter(torch.zeros(lora_B_shape, dtype=dtype))
self.scaling = self.lora_alpha / self.r

# For compatibility with (IA)^3, allow all init_weights types here.
Expand Down
9 changes: 6 additions & 3 deletions src/adapters/methods/reft.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional

import torch
import torch.nn as nn
Expand All @@ -18,12 +18,13 @@ def __init__(
subtract_projection: bool = True,
non_linearity: str = None,
dropout: float = 0.0,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.orthogonal = orthogonal
self.learned_source = nn.Linear(in_dim, r_dim, bias=True)
self.learned_source = nn.Linear(in_dim, r_dim, bias=True, dtype=dtype)

projection = nn.Linear(in_dim, r_dim, bias=False)
projection = nn.Linear(in_dim, r_dim, bias=False, dtype=dtype)
if orthogonal:
self.projection = nn.utils.parametrizations.orthogonal(projection)
else:
Expand All @@ -50,6 +51,7 @@ def __init__(self, in_features: int, config: ReftConfig):
self.suffix_positions = config.suffix_positions
self.tied_weights = config.tied_weights
n_units = 1 if config.tied_weights else 2
dtype = getattr(torch, config.dtype) if config.dtype else None
self.units = nn.ModuleList(
[
ReftUnit(
Expand All @@ -59,6 +61,7 @@ def __init__(self, in_features: int, config: ReftConfig):
config.subtract_projection,
config.non_linearity,
config.dropout,
dtype,
)
for _ in range(n_units)
]
Expand Down
Loading