Skip to content

Commit

Permalink
Merge branch 'llama'
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed Nov 12, 2024
2 parents 42dff55 + 11787d2 commit 5d14b16
Show file tree
Hide file tree
Showing 11 changed files with 306 additions and 36 deletions.
1 change: 1 addition & 0 deletions config/method/dare/task_arithmetic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ _target_: fusion_bench.method.DareTaskArithmetic
scaling_factor: 0.3
sparsity_ratio: 0.5
only_on_linear_weights: false
rescale: true
19 changes: 19 additions & 0 deletions config/method/linear/llama_expo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# This algorithm merges a pretrained model with a finetuned model.
#
# $$\theta_{merged} = \theta_{ft} + \alpha (\theta_{ft} - \theta_{pre})$$
#
# where $\theta_{merged}$ is the merged model, $\theta_{ft}$ is the finetuned model (medium-aligned model),
# $\theta_{pre}$ is the pretrained model (base model), and $\alpha$ is the extrapolation factor.
_target_: fusion_bench.method.ExPOAlgorithmForLlama
extrapolation_factor: 0.1
attention_scaling_factor: 1.0

only_on_backbone: true
on_linear_weights: true
on_linear_bias: false
on_embedding: false

fix_last_n_layers: 0
fix_first_n_layers: 0

magnitude_sparsity_ratio: null
19 changes: 19 additions & 0 deletions config/method/linear/llama_expo_with_dare.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
_target_: fusion_bench.method.linear.llama_expo.ExPOWithDareForLLama

extrapolation_factor: 0.1
attention_scaling_factor: 1.0

only_on_backbone: true
on_linear_weights: true
on_linear_bias: false
on_embedding: false

fix_last_n_layers: 0
fix_first_n_layers: 0

magnitude_sparsity_ratio: null

# dare arguments
dare_sparsity_ratio: 0.5
dare_only_on_linear_weights: true
dare_rescale: true
4 changes: 4 additions & 0 deletions config/method/pruning/magnitude_diff_pruning.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: fusion_bench.method.MagnitudeDiffPruningAlgorithm
prune_ratio: 0.5
rescale: false
extract_names: null
2 changes: 2 additions & 0 deletions fusion_bench/method/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# model merging methods
"linear": [
"ExPOAlgorithm",
"ExPOAlgorithmForLlama",
"SimpleAverageForLlama",
"TaskArithmeticForLlama",
"LinearInterpolationAlgorithm",
Expand Down Expand Up @@ -107,6 +108,7 @@
ExPOAlgorithm,
LinearInterpolationAlgorithm,
SimpleAverageForLlama,
ExPOAlgorithmForLlama,
TaskArithmeticForLlama,
)
from .mixture_of_experts import (
Expand Down
31 changes: 7 additions & 24 deletions fusion_bench/method/dare/simple_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fusion_bench import BaseAlgorithm, BaseModelPool
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_mul

from .utils import module_random_drop_, trainable_state_dict
from .task_arithmetic import DareTaskArithmetic

log = logging.getLogger(__name__)

Expand All @@ -23,26 +23,9 @@ def __init__(
super().__init__(**kwargs)

def run(self, modelpool: BaseModelPool):
if not isinstance(modelpool, BaseModelPool):
modelpool = BaseModelPool(modelpool)

if modelpool.has_pretrained:
log.warning("Pretrained model provided but not used")

sum_state_dict = None
for model in modelpool.models():
if sum_state_dict is None:
sum_state_dict = trainable_state_dict(model)
sum_state_dict = module_random_drop_(
sum_state_dict, self.sparsity_ratio, rescale=self.rescale
)
else:
state_dict = trainable_state_dict(model)
state_dict = module_random_drop_(
model, self.sparsity_ratio, rescale=self.rescale
)
sum_state_dict = state_dict_add(sum_state_dict, state_dict)
state_dict = state_dict_mul(sum_state_dict, len(modelpool))

model.load_state_dict(state_dict, strict=False)
return model
return DareTaskArithmetic(
scaling_factor=1 / len(modelpool),
sparsity_ratio=self.sparsity_ratio,
only_on_linear_weights=self.only_on_linear_weight,
rescale=self.rescale,
).run(modelpool)
21 changes: 14 additions & 7 deletions fusion_bench/method/dare/task_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from fusion_bench.utils.state_dict_arithmetic import state_dict_sum

from .utils import (
module_sub_,
module_random_drop_,
module_sub_,
param_random_drop_,
trainable_state_dict,
)


Expand All @@ -23,11 +24,13 @@ def __init__(
scaling_factor: float,
sparsity_ratio: float,
only_on_linear_weights: bool,
rescale: bool = True,
**kwargs,
):
self.scaling_factor = scaling_factor
self.sparsity_ratio = sparsity_ratio
self.only_on_linear_weights = only_on_linear_weights
self.rescale = rescale
super().__init__(**kwargs)

@torch.no_grad()
Expand All @@ -41,24 +44,28 @@ def run(self, modelpool: BaseModelPool):
for model_name in modelpool.model_names
}
task_vectors = {
model_name: module_sub_(finetuned_models, pretrained_model)
model_name: module_sub_(finetuned_models[model_name], pretrained_model)
for model_name in finetuned_models
}
del finetuned_models

# drop and rescale task vectors
for tv in task_vectors.values():
for model_name, tv in task_vectors.items():
if self.only_on_linear_weights:
for module in tv.modules():
for module_name, module in tv.named_modules():
if isinstance(module, nn.Linear):
print(f"pruning model: `{model_name}`, layer: {module_name}.")
param_random_drop_(
module.weight, self.sparsity_ratio, rescale=True
module.weight, self.sparsity_ratio, rescale=self.rescale
)
else:
module_random_drop_(tv, self.sparsity_ratio, rescale=True)
print(f"pruning model: `{model_name}`")
module_random_drop_(tv, self.sparsity_ratio, rescale=self.rescale)

# merge task vectors
task_vector_sum = state_dict_sum(task_vectors.values())
task_vector_sum = state_dict_sum(
[trainable_state_dict(tv) for tv in task_vectors.values()]
)

# scale the task vector and add it to the pretrained model
for name, delta in task_vector_sum.items():
Expand Down
3 changes: 2 additions & 1 deletion fusion_bench/method/linear/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: noqa F401
from .expo import ExPOAlgorithm
from .linear_interpolation import LinearInterpolationAlgorithm
from .llama_expo import ExPOAlgorithmForLlama
from .simple_average_for_llama import SimpleAverageForLlama
from .task_arithmetic_for_llama import TaskArithmeticForLlama
from .expo import ExPOAlgorithm
6 changes: 3 additions & 3 deletions fusion_bench/method/linear/expo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
This module contains the implementation of ExPO merge.
This module contains the implementation of ExPO merge for general nn.Modules.
Reference:
- Zheng et al. Weak-to-Strong Extrapolation Expedites Alignment.
Expand Down Expand Up @@ -75,5 +75,5 @@ def run(self, modelpool: BaseModelPool):
state_dict_mul(delta_parameters, scalar=self.extrapolation_factor),
)

sft_model.load_state_dict(merged_sd)
return sft_model
rlhf_model.load_state_dict(merged_sd)
return rlhf_model
Loading

0 comments on commit 5d14b16

Please sign in to comment.