Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 160eeaf
Author: tanganke <[email protected]>
Date:   Fri Nov 1 20:29:05 2024 +0800

    	modified:   fusion_bench/taskpool/llama/test_generation.py

commit 574447f
Author: tanganke <[email protected]>
Date:   Fri Nov 1 19:41:52 2024 +0800

    update expo.py

commit 4e613a4
Author: tanganke <[email protected]>
Date:   Fri Nov 1 18:32:17 2024 +0800

    update task vector cos similarity visualization

commit 201c81b
Author: Anke Tang <[email protected]>
Date:   Fri Nov 1 15:42:34 2024 +0800

    update docstring

commit 50b5974
Author: Anke Tang <[email protected]>
Date:   Fri Nov 1 15:41:09 2024 +0800

    fix bugs

commit 92548f6
Author: Anke Tang <[email protected]>
Date:   Fri Nov 1 15:34:51 2024 +0800

    update docs and test generation for llama taskpool

commit 7a4ec5d
Author: Anke Tang <[email protected]>
Date:   Fri Nov 1 15:27:04 2024 +0800

    update LlamaTestGenerationTaskPool

commit 04048f5
Author: Anke Tang <[email protected]>
Date:   Fri Nov 1 15:25:53 2024 +0800

    add LlamaTestGenerationTaskPool

commit f22ae6d
Author: tanganke <[email protected]>
Date:   Fri Nov 1 12:14:02 2024 +0800

    update functions related to llama models
    	new file:   fusion_bench/dataset/llama/openai.py
    	modified:   fusion_bench/modelpool/base_pool.py
    	modified:   fusion_bench/modelpool/causal_lm/causal_lm.py
    	new file:   fusion_bench/taskpool/llama/__init__.py
    	new file:   fusion_bench/taskpool/llama/test_generation.py

commit 378efd2
Author: tanganke <[email protected]>
Date:   Fri Nov 1 09:15:58 2024 +0800

    update
    	modified:   config/method/adamerging/llama_sft.yaml
    	modified:   fusion_bench/dataset/llama/alpaca.py

commit ea0cfa5
Author: tanganke <[email protected]>
Date:   Thu Oct 31 21:07:35 2024 +0800

    update

commit 5d3e2f8
Author: tanganke <[email protected]>
Date:   Thu Oct 31 20:57:44 2024 +0800

    remove unused files

commit aacdc97
Author: tanganke <[email protected]>
Date:   Thu Oct 31 20:56:00 2024 +0800

    update sharegpt.py

commit d0acaaa
Author: tanganke <[email protected]>
Date:   Thu Oct 31 20:26:00 2024 +0800

    add sharegpt.py

commit 37e1ab3
Author: tanganke <[email protected]>
Date:   Thu Oct 31 18:27:07 2024 +0800

    	new file:   fusion_bench/dataset/llama/alpaca.py

commit 67688c3
Author: tanganke <[email protected]>
Date:   Thu Oct 31 17:38:33 2024 +0800

    rename squad.py and wikitext.py

commit c53288c
Author: tanganke <[email protected]>
Date:   Thu Oct 31 17:35:28 2024 +0800

    add AdamW mixin

commit 2ec72e4
Author: tanganke <[email protected]>
Date:   Thu Oct 31 17:17:01 2024 +0800

    Add squad.py, update docs

commit aa3b318
Author: tanganke <[email protected]>
Date:   Thu Oct 31 16:03:40 2024 +0800

    optimize task_vector_violin_plot

commit 8a5806f
Author: tanganke <[email protected]>
Date:   Thu Oct 31 15:49:24 2024 +0800

    add TaskVectorViolinPlot
    	new file:   config/method/analysis/task_vector_violin_plot.yaml
    	new file:   fusion_bench/method/analysis/task_vector_violin_plot.py

commit 418a802
Author: tanganke <[email protected]>
Date:   Thu Oct 31 14:30:10 2024 +0800

    fix bugs

commit 855460a
Merge: 18717f0 b37a63a
Author: tanganke <[email protected]>
Date:   Thu Oct 31 14:12:07 2024 +0800

    Merge branch 'llama' of https://github.com/tanganke/fusion_bench into llama

commit b37a63a
Author: Anke Tang <[email protected]>
Date:   Thu Oct 31 11:34:58 2024 +0800

    Update llama_sft.yaml, llama_adamerging.py, layer_wise_fusion.py.

commit e306e38
Author: Anke Tang <[email protected]>
Date:   Thu Oct 31 11:10:23 2024 +0800

    refactor: rename BaseModelFusionAlgorithm to BaseAlgorithm

    **Key changes**:
    - Renamed BaseModelFusionAlgorithm to BaseAlgorithm for better clarity and consistency
    - Updated all references to BaseModelFusionAlgorithm across the codebase
    - Added BaseModelAlgorithm as an alias for BaseAlgorithm for backward compatibility
    - Added/improved docstrings for models/wrappers/ensemble.py
    - No functional changes - purely refactoring of class names

commit 8268d7a
Author: tanganke <[email protected]>
Date:   Wed Oct 30 20:50:39 2024 +0800

    update code

commit 18717f0
Author: tanganke <[email protected]>
Date:   Wed Oct 30 20:50:39 2024 +0800

    update code

commit be0442e
Author: tanganke <[email protected]>
Date:   Wed Oct 30 20:20:56 2024 +0800

    update config

commit 46ae4b6
Author: tanganke <[email protected]>
Date:   Wed Oct 30 20:20:56 2024 +0800

    update config

commit 3538534
Author: tanganke <[email protected]>
Date:   Wed Oct 30 20:16:55 2024 +0800

    update llama_sft adamerging

commit a4a6c89
Author: tanganke <[email protected]>
Date:   Wed Oct 30 20:16:55 2024 +0800

    update llama_sft adamerging

commit c68b7d2
Author: tanganke <[email protected]>
Date:   Wed Oct 30 14:52:22 2024 +0800

    update llama adamerging

commit a58cb45
Author: tanganke <[email protected]>
Date:   Wed Oct 30 10:49:37 2024 +0800

    update data for llama models

commit ccfc3c2
Author: tanganke <[email protected]>
Date:   Wed Oct 30 14:52:22 2024 +0800

    update llama adamerging

commit bcefd49
Author: tanganke <[email protected]>
Date:   Wed Oct 30 10:49:37 2024 +0800

    update data for llama models

commit a66f396
Author: tanganke <[email protected]>
Date:   Wed Oct 30 10:07:00 2024 +0800

    update data utils for llama models

commit d0b0884
Author: tanganke <[email protected]>
Date:   Wed Oct 30 09:54:05 2024 +0800

    add collator and template

commit d722c1a
Author: tanganke <[email protected]>
Date:   Wed Oct 30 10:07:00 2024 +0800

    update data utils for llama models

commit 9bad0eb
Author: tanganke <[email protected]>
Date:   Wed Oct 30 09:54:05 2024 +0800

    add collator and template

commit 71a9291
Author: Anke Tang <[email protected]>
Date:   Wed Oct 30 00:39:26 2024 +0800

    Update utils modified: fusion_bench/utils/devices.py modified: fusion_bench/utils/parameters.py

commit dcdeb1b
Author: Anke Tang <[email protected]>
Date:   Wed Oct 30 00:05:59 2024 +0800

    Enhance LLaMA model integration

commit d09ca14
Author: Anke Tang <[email protected]>
Date:   Wed Oct 30 00:39:26 2024 +0800

    Update utils
    	modified:   fusion_bench/utils/devices.py
    	modified:   fusion_bench/utils/parameters.py

commit a4cc970
Author: Anke Tang <[email protected]>
Date:   Wed Oct 30 00:05:59 2024 +0800

    Enhance LLaMA model integration

commit 68ee6be
Author: tanganke <[email protected]>
Date:   Tue Oct 29 21:38:08 2024 +0800

    partial update

commit 72c21bb
Author: tanganke <[email protected]>
Date:   Tue Oct 29 21:38:08 2024 +0800

    partial update

commit 7026685
Author: tanganke <[email protected]>
Date:   Tue Oct 29 20:13:32 2024 +0800

    add Dare simple_average
  • Loading branch information
tanganke committed Nov 1, 2024
1 parent 16dda8e commit e35649a
Show file tree
Hide file tree
Showing 14 changed files with 546 additions and 50 deletions.
2 changes: 1 addition & 1 deletion config/method/adamerging/llama_sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ average_attntion: true
start_layer_idx: 0.3
# learning rate
optimizer: adam
lr: 1e-3
lr: 1e-4
init_values: 0.5
# if `clamp_weights` is true, the weights will be clamped to [0, 1]
clamp_weights: false
Expand Down
7 changes: 5 additions & 2 deletions config/method/analysis/task_vector_cos_similarity.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
_target_: fusion_bench.method.TaskVectorCosSimilarity
csv_save_path: null
plot_heatmap: false
plot_heatmap: true
trainable_only: true
max_points_per_model: null
output_path: null

7 changes: 7 additions & 0 deletions docs/taskpool/LlamaTestGenerationTaskPool.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# LlamaTestGenerationTaskPool

The `LlamaTestGenerationTaskPool` class is used to evaluate a language model on a set of prompts. It can also be used in an interactive mode for debugging purposes.

## References

::: fusion_bench.taskpool.llama.test_generation
37 changes: 34 additions & 3 deletions fusion_bench/dataset/llama/alpaca.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import List, Dict
from datasets import Dataset
from transformers import PreTrainedTokenizer
import logging
import os
from typing import Any, Dict, List, Optional

from datasets import Dataset, load_dataset, load_from_disk
from transformers import PreTrainedTokenizer

import fusion_bench

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -109,3 +113,30 @@ def prepare_samples(samples: Dict[str, List[str]]) -> Dict[str, List[List[int]]]
)

return tokenized_dataset


def load_tokenized_alpaca_dataset_from_json(
data_files: str,
tokenizer: PreTrainedTokenizer,
max_length: int,
split: Optional[str] = "train",
cache_path: Optional[str] = None,
):
if cache_path is not None and fusion_bench.utils.path.path_is_dir_and_not_empty(
cache_path
):
datasets = load_from_disk(cache_path)
if split is None:
return datasets
else:
return datasets[split]
else:
assert (
tokenizer is not None
), "Cached dataset not found. Need tokenizer to process the raw data."

dataset = load_dataset("json", data_files=data_files)
if split is not None:
dataset = dataset[split]
dataset = tokenize_alpaca_dataset(dataset, tokenizer, max_length=max_length)
return dataset
160 changes: 160 additions & 0 deletions fusion_bench/dataset/llama/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import logging
from typing import Dict, List

from datasets import Dataset
from transformers import PreTrainedTokenizer

log = logging.getLogger(__name__)


def tokenize_messages_dataset(
dataset: Dataset,
tokenizer: PreTrainedTokenizer,
max_length: int = 2048,
padding: bool = True,
system_template: str = "### System: {message}\n",
user_template: str = "## User: {message}\n",
assistant_template: str = "## Assistant: {message}\n",
) -> Dataset:
R"""
Tokenize dataset with messages format supporting loss calculation flags.
write a script to tokenizer datasets with the following format:
Examples:
```json
{
"messages": [
{
"role": "system",
"content": "XXX",
"calculate_loss": 0
},
{
"role": "system",
"content": "XXX",
"calculate_loss": 0
},
{
"role": "user",
"content": "XXX",
"calculate_loss": 0
},
{
"role": "assistant",
"content": "XXX",
"calculate_loss": 1
}
],
"create_info": [
{
"date": "20240830",
"owner": "l00470783",
"within_source_id": 0,
"describe": "...",
"source": [
"..."
],
"language": "zh"
}
],
"feature_info": {
"domain": "...",
"tags": [
"..."
]
},
"source_file": "..."
}
```
Args:
dataset: Input dataset with messages format
tokenizer: The tokenizer to use
max_length: Maximum sequence length
system_template: Template for system messages
user_template: Template for user messages
assistant_template: Template for assistant messages
Returns:
Tokenized dataset
"""

def build_prompt(messages: List[Dict[str, str]]) -> tuple[str, str]:
"""
Build prompt and get response that needs loss calculation.
Returns conversation history and the response to calculate loss on.
"""
history = ""
response = ""

for message in messages:
role = message["role"]
content = message["content"].strip()
calculate_loss = message.get("calculate_loss", 0)

# Build conversation history
if role == "system":
history += system_template.format(message=content)
elif role == "user":
history += user_template.format(message=content)
elif role == "assistant":
if calculate_loss:
# If this assistant message needs loss calculation,
# save it as response and don't add to history
response = content
else:
# Otherwise add to conversation history
history += assistant_template.format(message=content)

return history, response

def prepare_sample(sample: dict) -> dict:
# Get conversation history and response
history, response = build_prompt(sample["messages"])

# Tokenize prompt and response
prompt_tokens = tokenizer.encode(history, add_special_tokens=False)
response_tokens = tokenizer.encode(response, add_special_tokens=False)

# Create input_ids with EOS token
input_ids = prompt_tokens + response_tokens + [tokenizer.eos_token_id]

# Create attention mask
attention_mask = [1] * len(input_ids)

# Create labels: -100 for prompt, actual tokens for response
labels = (
[-100] * len(prompt_tokens) + response_tokens + [tokenizer.eos_token_id]
)

# Truncate if exceeds max length
if len(input_ids) > max_length:
input_ids = input_ids[:max_length]
attention_mask = attention_mask[:max_length]
labels = labels[:max_length]

# Pad if necessary
if padding:
padding_length = max_length - len(input_ids)
if padding_length > 0:
input_ids.extend([tokenizer.pad_token_id] * padding_length)
attention_mask.extend([0] * padding_length)
labels.extend([-100] * padding_length)

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}

if tokenizer.pad_token is None:
log.warning("Tokenizer does not have a `pad_token`. Set it the `eos_token`.")
tokenizer.pad_token = tokenizer.eos_token
# Process the dataset
tokenized_dataset = dataset.map(
prepare_sample, remove_columns=dataset.column_names, desc="Tokenizing dataset"
)

return tokenized_dataset
118 changes: 96 additions & 22 deletions fusion_bench/method/analysis/task_vector_cos_similarity.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,75 @@
import logging
import os
from typing import Dict, List, Optional, cast

import numpy as np
import pandas as pd
import torch
import torch.utils
from numpy.typing import NDArray
from torch import nn
from tqdm.auto import tqdm

from fusion_bench.method import BaseAlgorithm
from fusion_bench.mixins import LightningFabricMixin
from fusion_bench.modelpool import BaseModelPool
from fusion_bench.utils.parameters import (
StateDictType,
state_dict_to_vector,
trainable_state_dict,
)
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub

log = logging.getLogger(__name__)


class TaskVectorCosSimilarity(BaseAlgorithm):
class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
"""
This class is similar to the Dummy algorithm,
but it also print (or save) the cosine similarity matrix between the task vectors of the models in the model pool.
"""

_config_mapping = BaseAlgorithm._config_mapping | {
"csv_save_path": "csv_save_path",
"plot_heatmap": "plot_heatmap",
"_output_path": "output_path",
}

def __init__(self, csv_save_path: str, plot_heatmap: bool, **kwargs):
self.csv_save_path = csv_save_path
def __init__(
self,
plot_heatmap: bool,
trainable_only: bool = True,
max_points_per_model: Optional[int] = None,
output_path: Optional[str] = None,
**kwargs,
):
self.plot_heatmap = plot_heatmap
self.trainable_only = trainable_only
self.max_points_per_model = max_points_per_model
self._output_path = output_path
super().__init__(**kwargs)

@property
def output_path(self):
if self._output_path is None:
return self.fabric.logger.log_dir
else:
return self._output_path

@torch.no_grad()
def run(self, modelpool: BaseModelPool):
pretrained_model = modelpool.load_model("_pretrained_")
pretrained_sd = torch.nn.utils.parameters_to_vector(
pretrained_model.parameters()
)

task_vectors = torch.empty(len(modelpool), pretrained_sd.size(0))
for model_idx, model_name in enumerate(modelpool.model_names):
model = modelpool.load_model(model_name)
model_sd = torch.nn.utils.parameters_to_vector(model.parameters())
task_vectors[model_idx] = model_sd - pretrained_sd
# convert the task vectors to float64
task_vectors = task_vectors.to(dtype=torch.float64)
pretrained_model = modelpool.load_pretrained_model()

task_vectors = []
for name, finetuned_model in tqdm(
modelpool.named_models(), total=len(modelpool)
):
print(f"computing task vectors for {name}")
task_vectors.append(
self.get_task_vector(pretrained_model, finetuned_model).to(
torch.float64
)
)
task_vectors = torch.stack(task_vectors, dim=0)

cos_sim_matrix = torch.zeros(
len(modelpool), len(modelpool), dtype=torch.float64
Expand All @@ -56,15 +90,18 @@ def run(self, modelpool: BaseModelPool):
)

print(cos_sim_df)
if self.csv_save_path is not None:
cos_sim_df.to_csv(self.csv_save_path)
if self.output_path is not None:
os.makedirs(self.output_path, exist_ok=True)
cos_sim_df.to_csv(
os.path.join(self.output_path, "task_vector_cos_similarity.csv")
)

if self.plot_heatmap:
self.plot_and_show_heatmap(self, cos_sim_df)
self._plot_heatmap(cos_sim_df)

return pretrained_model

def plot_and_show_heatmap(self, data: pd.DataFrame, figsize=(4, 3)):
def _plot_heatmap(self, data: pd.DataFrame):
"""
This function plots a heatmap of the provided data using seaborn.
Expand All @@ -79,7 +116,7 @@ def plot_and_show_heatmap(self, data: pd.DataFrame, figsize=(4, 3)):
import seaborn as sns

# Create a heatmap using seaborn
plt.figure(figsize=figsize)
plt.figure()
sns.heatmap(
data,
annot=True,
Expand All @@ -95,4 +132,41 @@ def plot_and_show_heatmap(self, data: pd.DataFrame, figsize=(4, 3)):
plt.yticks(rotation=45)

# Show plot
plt.show()
plt.savefig(
os.path.join(self.output_path, "task_vector_cos_similarity.pdf"),
bbox_inches="tight",
)
plt.close()

def get_task_vector(
self, pretrained_model: nn.Module, finetuned_model: nn.Module
) -> torch.Tensor:
task_vector = state_dict_sub(
self.get_state_dict(finetuned_model),
self.get_state_dict(pretrained_model),
)
task_vector = state_dict_to_vector(task_vector)

task_vector = task_vector.cpu().float().numpy()
# downsample if necessary
if (
self.max_points_per_model is not None
and self.max_points_per_model > 0
and task_vector.shape[0] > self.max_points_per_model
):
log.info(
f"Downsampling task vectors to {self.max_points_per_model} points."
)
indices = np.random.choice(
task_vector.shape[0], self.max_points_per_model, replace=False
)
task_vector = task_vector[indices].copy()

task_vector = torch.from_numpy(task_vector)
return task_vector

def get_state_dict(self, model: nn.Module):
if self.trainable_only:
return trainable_state_dict(model)
else:
return model.state_dict()
Loading

0 comments on commit e35649a

Please sign in to comment.