Skip to content

Commit

Permalink
format code using isort and black
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed Nov 14, 2024
1 parent e222168 commit 071eef5
Show file tree
Hide file tree
Showing 34 changed files with 129 additions and 63 deletions.
1 change: 1 addition & 0 deletions fusion_bench/compat/taskpool/base_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class TaskPool:
config (DictConfig): The configuration for the task pool.
_all_task_names (List[str]): A list of all task names in the task pool.
"""

_program = None

def __init__(self, taskpool_config: DictConfig):
Expand Down
2 changes: 1 addition & 1 deletion fusion_bench/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# flake8: noqa F401
from datasets import load_dataset
from omegaconf import DictConfig, open_dict

from datasets import load_dataset
from fusion_bench.utils import instantiate

from .clip_dataset import CLIPDataset
Expand Down
3 changes: 1 addition & 2 deletions fusion_bench/dataset/gpt2_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
from pathlib import Path
from typing import Literal

from transformers import PreTrainedTokenizer

from datasets import load_dataset, load_from_disk
from transformers import PreTrainedTokenizer


def cache_dataset(
Expand Down
6 changes: 3 additions & 3 deletions fusion_bench/dataset/imdb.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
import os
from typing import Any, Dict, List, Optional

from datasets import load_dataset, load_from_disk
from transformers import PreTrainedTokenizer
from trl import SFTConfig, SFTTrainer

import fusion_bench
import os
import logging
from trl import SFTConfig, SFTTrainer

log = logging.getLogger(__name__)
6 changes: 3 additions & 3 deletions fusion_bench/dataset/llama/squad.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, Dict, List, Optional, Literal
import logging
import os
from typing import Any, Dict, List, Literal, Optional

from datasets import load_dataset, load_from_disk
from transformers import PreTrainedTokenizer

import fusion_bench
import os
import logging

log = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions fusion_bench/dataset/llama/wikitext.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
import os
from typing import Any, Dict, List, Optional

from datasets import load_dataset, load_from_disk
from transformers import PreTrainedTokenizer

import fusion_bench
import os
import logging

log = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion fusion_bench/method/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@
from .fisher_merging import FisherMergingForCLIPVisionModel
from .linear import (
ExPOAlgorithm,
ExPOAlgorithmForLlama,
LinearInterpolationAlgorithm,
SimpleAverageForLlama,
ExPOAlgorithmForLlama,
TaskArithmeticForLlama,
)
from .mixture_of_experts import (
Expand Down
4 changes: 2 additions & 2 deletions fusion_bench/method/adamerging/entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def entropy_loss(logits: Tensor, eps: float = 1e-8) -> Tensor:
assert (
logits.dim() == 2
), f"Expected logits to have 2 dimensions, found {logits.dim()}, {logits.size()=}"

# Compute the softmax probabilities
probs = torch.softmax(logits, dim=-1)

# Compute the entropy loss
return -torch.sum(probs * torch.log(probs + eps), dim=-1).mean()
2 changes: 1 addition & 1 deletion fusion_bench/method/dare/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# flake8: noqa F401
from .task_arithmetic import DareTaskArithmetic
from .simple_average import DareSimpleAverage
from .task_arithmetic import DareTaskArithmetic
1 change: 1 addition & 0 deletions fusion_bench/method/dawe/warppers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
Classes:
DataAdaptiveWeightEnsemblingModel: A class for data-adaptive weight ensembling.
"""

# flake8: noqa F401
from .dawe_model import DataAdaptiveWeightEnsemblingModel
4 changes: 3 additions & 1 deletion fusion_bench/method/fisher_merging/fisher_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def get_param_names_to_merge(
return param_names_to_merge


def get_param_squared_gradients(model: nn.Module, param_names_to_merge: List[str]) -> Dict[str, Tensor]:
def get_param_squared_gradients(
model: nn.Module, param_names_to_merge: List[str]
) -> Dict[str, Tensor]:
"""
Get the squared gradients of parameters.
Expand Down
1 change: 1 addition & 0 deletions fusion_bench/method/fisher_merging/gpt2_fisher_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class FisherMergingAlgorithmForGPT2(
batch_size (int): Batch size for data loading.
num_workers (int): Number of workers for data loading.
"""

classifiers = {}
modelpool: HuggingFaceGPT2ClassificationPool = None
_config_mapping = FisherMergingAlgorithm._config_mapping | {
Expand Down
15 changes: 15 additions & 0 deletions fusion_bench/method/lm_finetune/causal_lm_instruct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from fusion_bench import BaseAlgorithm
from fusion_bench.modelpool import CausalLMPool


class CausalLMInstructionFineTune(BaseAlgorithm):

def run(self, modelpool: CausalLMPool):
tokenizer = modelpool.load_tokenizer()
model = modelpool.load_model()
optimizer = modelpool.load_optimizer(model)
scheduler = modelpool.load_scheduler(optimizer)
dataloader = modelpool.load_dataloader(tokenizer)
model = modelpool.train_model(model, dataloader, optimizer, scheduler)
modelpool.save_model(model)
return modelpool.evaluate_model(model)
1 change: 1 addition & 0 deletions fusion_bench/method/pruning/llama_random_prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class RandomPruningForLlama(BaseAlgorithm, SimpleProfilerMixin):
n (int): The number of weights to be pruned in each group (for semistructured pruning).
m (int): The total number of weights in each group (for semistructured pruning).
"""

_config_mapping = BaseAlgorithm._config_mapping | {
"prune_type": "prune_type",
"sparsity_ratio": "sparsity_ratio",
Expand Down
3 changes: 1 addition & 2 deletions fusion_bench/method/pruning/wanda_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import random
from typing import List, Optional, Tuple, cast # noqa: F401

from datasets import load_dataset
from torch import Tensor
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizer

from datasets import load_dataset


# Wrapper for tokenized input IDs
class TokenizerWrapper:
Expand Down
2 changes: 1 addition & 1 deletion fusion_bench/method/pruning/wanda_utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def eval_ppl_wikitext_train(model, trainloader, bs=1, device=None):


# Function to evaluate perplexity (ppl) specifically on the wikitext dataset
def eval_ppl_wikitext(model, testenc, bs : int =1, device=None):
def eval_ppl_wikitext(model, testenc, bs: int = 1, device=None):
"""
Evaluate perplexity (ppl) specifically on the wikitext dataset.
Expand Down
2 changes: 1 addition & 1 deletion fusion_bench/method/trust_region/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from torch import nn


# Model conversion utils


def state_dict_to_vector(state_dict, remove_keys=[]):
"""
Convert a state dictionary to a vector.
Expand Down
1 change: 1 addition & 0 deletions fusion_bench/method/we_moe/clip_we_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
Attributes:
modelpool (CLIPVisionModelPool): The model pool containing the CLIP models.
"""

modelpool: CLIPVisionModelPool = None

def load_checkpoint(self, model, checkpoint):
Expand Down
1 change: 1 addition & 0 deletions fusion_bench/method/we_moe/we_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
modelpool (ModelPool): The pool of models to be fused.
profiler (SimpleProfiler): The profiler for measuring performance.
"""

_fabric: L.Fabric = None
modelpool: ModelPool = None

Expand Down
1 change: 1 addition & 0 deletions fusion_bench/metrics/text_to_image_generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
In this module, we implement some metrics for text-to-image generation tasks.
Including reward functions for alignment and Reinforcement Learning with Human Feedback training (RLHF).
"""

# flake8: noqa F401
from .aesthetic_scorer import aesthetic_scorer
from .compressibility import jpeg_compressibility_scorer, jpeg_incompressibility_scorer
Expand Down
4 changes: 2 additions & 2 deletions fusion_bench/mixins/optim/adamw_with_warmup.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import math
import torch

from abc import abstractmethod

import torch


def warmup_cosine_schedule(warmup_steps: int, total_steps: int, min_lr: float = 0):
def lr_lambda(current_step):
Expand Down
2 changes: 1 addition & 1 deletion fusion_bench/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .tokenizer_loader import load_config, load_tokenizer
from .model_utils.misc import find_all_linear_modules
from .tokenizer_loader import load_config, load_tokenizer
33 changes: 25 additions & 8 deletions fusion_bench/models/llama/model_utils/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,30 @@
logger = logging.getLogger(__name__)


def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
def _noisy_mean_initialization(
embed_weight: "torch.Tensor", num_new_tokens: int
) -> None:
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight


def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
def resize_embedding_layer(
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"
) -> None:
r"""
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore

params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
if (
model.get_output_embeddings() is not None
and not model.config.tie_word_embeddings
):
params.append(model.get_output_embeddings().weight)

context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
Expand All @@ -58,13 +65,23 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
raise ValueError("Cannot resize embedding layers of a quantized model.")

if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
raise ValueError("Current model does not support resizing embedding layers.")
raise ValueError(
"Current model does not support resizing embedding layers."
)

model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)

logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
_noisy_mean_initialization(
model.get_input_embeddings().weight.data, num_new_tokens
)
_noisy_mean_initialization(
model.get_output_embeddings().weight.data, num_new_tokens
)

logger.info(
"Resized token embeddings from {} to {}.".format(
current_embedding_size, new_embedding_size
)
)
2 changes: 1 addition & 1 deletion fusion_bench/models/llama/model_utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, List
import logging
from typing import TYPE_CHECKING, List

if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
Expand Down
3 changes: 1 addition & 2 deletions fusion_bench/models/llama/model_utils/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union

import torch
import transformers.models
from transformers.activations import ACT2FN
from transformers.utils import logging

import logging

if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel

Expand Down
4 changes: 1 addition & 3 deletions fusion_bench/models/llama/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@

import torch
from peft import PeftModel
from transformers import (
PreTrainedTokenizerBase,
)
from transformers import PreTrainedTokenizerBase

from .model_utils.visual import (
get_image_seqlen,
Expand Down
6 changes: 1 addition & 5 deletions fusion_bench/models/llama/tokenizer_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@
from .patcher import patch_processor_, patch_tokenizer_

if TYPE_CHECKING:
from transformers import (
PretrainedConfig,
PreTrainedTokenizer,
ProcessorMixin,
)
from transformers import PretrainedConfig, PreTrainedTokenizer, ProcessorMixin


logger = logging.getLogger(__name__)
Expand Down
Loading

0 comments on commit 071eef5

Please sign in to comment.