Skip to content

Commit

Permalink
Add argument documentation to functions
Browse files Browse the repository at this point in the history
Add missing argument explanations to various functions across multiple files.

* **`fusion_bench/compat/taskpool/clip_image_classification.py`**:
  - Add explanation for `clip_model` argument in `evaluate` function.
  - Add explanation for `model` argument in `evaluate` function.

* **`fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py`**:
  - Add explanation for `model` argument in `evaluate` function.

* **`fusion_bench/dataset/gsm8k.py`**:
  - Add explanation for `dataset_name` argument in `load_gsm8k_question_label_data` function.

* **`fusion_bench/dataset/nyuv2.py`**:
  - Add explanation for `index` argument in `__getitem__` function.

* **`fusion_bench/method/adamerging/task_wise_adamerging.py`**:
  - Add explanation for `module`, `batch`, and `task` arguments in `compute_logits` function.

* **`fusion_bench/method/dawe/dawe_for_clip.py`**:
  - Add explanation for `pretrained_model_name_or_path` argument in `load_resnet_processor` function.

* **`fusion_bench/method/dummy.py`**:
  - Add explanation for `modelpool` argument in `run` function.

* **`fusion_bench/method/ensemble.py`**:
  - Add explanation for `modelpool` argument in `run` function.

* **`fusion_bench/method/linear/expo.py`**:
  - Add explanation for `modelpool` argument in `run` function.

* **`fusion_bench/method/linear/linear_interpolation.py`**:
  - Add explanation for `modelpool` argument in `run` function.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/tanganke/fusion_bench?shareId=XXXX-XXXX-XXXX-XXXX).
  • Loading branch information
tanganke committed Oct 31, 2024
1 parent e08c570 commit 82bdd3a
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 2 deletions.
12 changes: 12 additions & 0 deletions fusion_bench/compat/taskpool/clip_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def test_loader(self):
def evaluate(self, clip_model: CLIPModel):
"""
Evaluate the model on the image classification task.
Args:
clip_model (CLIPModel): The CLIP model to evaluate.
Returns:
dict: A dictionary containing the evaluation results.
"""
classifier = HFCLIPClassifier(
clip_model=clip_model, processor=self._clip_processor
Expand Down Expand Up @@ -151,6 +157,12 @@ def load_task(self, task_name_or_config: str | DictConfig):
def evaluate(self, model: CLIPVisionModel):
"""
Evaluate the model on the image classification task.
Args:
model (CLIPVisionModel): The vision model to evaluate.
Returns:
dict: A dictionary containing the evaluation results for each task.
"""
# if the fabric is not set, and we have a GPU, create a fabric instance
if self._fabric is None and torch.cuda.is_available():
Expand Down
9 changes: 9 additions & 0 deletions fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,15 @@ def load_task(self, task_name_or_config: str | DictConfig):
raise ValueError(f"Unknown task {task_config.name}")

def evaluate(self, model: T5ForConditionalGeneration):
"""
Evaluate the model on the FlanT5 GLUE text generation tasks.
Args:
model (T5ForConditionalGeneration): The model to evaluate.
Returns:
dict: A dictionary containing the evaluation results for each task.
"""
if not isinstance(model, T5ForConditionalGeneration):
log.warning(
f"Model is not an instance of T5ForConditionalGeneration, but {type(model)}"
Expand Down
3 changes: 3 additions & 0 deletions fusion_bench/dataset/gsm8k.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def load_gsm8k_question_label_data(
{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'}
Args:
dataset_name (Literal["train", "test", "train_socratic", "test_socratic"]): The name of the dataset to load.
Returns:
questions (List[str]): List of questions.
labels (List[float]): List of labels. For example, the label for the above example is `72.0`.
Expand Down
9 changes: 9 additions & 0 deletions fusion_bench/dataset/nyuv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ def __init__(
self.noise = torch.rand(self.data_len, 1, 288, 384)

def __getitem__(self, index):
"""
Retrieve an item from the dataset.
Args:
index (int): The index of the item to retrieve.
Returns:
tuple: A tuple containing the image and a dictionary of task-specific outputs.
"""
# load data from the pre-processed npy files
image = torch.from_numpy(
np.moveaxis(
Expand Down
13 changes: 12 additions & 1 deletion fusion_bench/method/adamerging/task_wise_adamerging.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,18 @@ def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
pass

@abstractmethod
def compute_logits(self, module, batch, task) -> Tensor:
def compute_logits(self, module: nn.Module, batch, task: str) -> Tensor:
"""
Compute the logits for the given batch and task.
Args:
module (nn.Module): The model module.
batch (tuple): A batch of input data.
task (str): The name of the task.
Returns:
Tensor: The classification logits for the batch.
"""
pass

def test_time_adaptation(self, module: TaskWiseMergedModel):
Expand Down
9 changes: 9 additions & 0 deletions fusion_bench/method/dawe/dawe_for_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def convert_to_rgb(image: Image | list[Image]) -> Image | list[Image]:


def load_resnet_processor(pretrained_model_name_or_path: str):
"""
Load a ResNet processor for image preprocessing.
Args:
pretrained_model_name_or_path (str): The path or name of the pretrained ResNet model.
Returns:
function: A function that processes images using the ResNet processor.
"""
processor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)
return lambda img: processor(
images=convert_to_rgb(img), return_tensors="pt", do_rescale=False
Expand Down
3 changes: 3 additions & 0 deletions fusion_bench/method/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def run(self, modelpool: BaseModelPool):
This method returns the pretrained model from the model pool.
If the pretrained model is not available, it returns the first model from the model pool.
Args:
modelpool (BaseModelPool): The pool of models to fuse.
Raises:
AssertionError: If the model is not found in the model pool.
"""
Expand Down
27 changes: 27 additions & 0 deletions fusion_bench/method/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
class SimpleEnsembleAlgorithm(BaseModelFusionAlgorithm):
@torch.no_grad()
def run(self, modelpool: BaseModelPool | List[nn.Module]):
"""
Run the simple ensemble algorithm on the given model pool.
Args:
modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.
Returns:
EnsembleModule: The ensembled model.
"""
log.info(f"Running ensemble algorithm with {len(modelpool)} models")

models = [modelpool.load_model(m) for m in modelpool.model_names]
Expand All @@ -40,6 +49,15 @@ def __init__(self, normalize: bool, weights: List[float], **kwargs):

@torch.no_grad()
def run(self, modelpool: BaseModelPool | List[nn.Module]):
"""
Run the weighted ensemble algorithm on the given model pool.
Args:
modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.
Returns:
WeightedEnsembleModule: The weighted ensembled model.
"""
if not isinstance(modelpool, BaseModelPool):
modelpool = BaseModelPool(models=modelpool)

Expand All @@ -61,6 +79,15 @@ def run(self, modelpool: BaseModelPool | List[nn.Module]):
class MaxModelPredictorAlgorithm(BaseModelFusionAlgorithm):
@torch.no_grad()
def run(self, modelpool: BaseModelPool | List[nn.Module]):
"""
Run the max model predictor algorithm on the given model pool.
Args:
modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.
Returns:
MaxModelPredictor: The max model predictor ensembled model.
"""
if not isinstance(modelpool, BaseModelPool):
modelpool = BaseModelPool(models=modelpool)

Expand Down
9 changes: 9 additions & 0 deletions fusion_bench/method/linear/expo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def __init__(self, extrapolation_factor: float, **kwargs):
super().__init__(**kwargs)

def run(self, modelpool: BaseModelPool):
"""
Run the ExPO merge algorithm.
Args:
modelpool (BaseModelPool): The pool of models to merge.
Returns:
nn.Module: The merged model.
"""
if not isinstance(modelpool, BaseModelPool):
modelpool = BaseModelPool(modelpool)

Expand Down
2 changes: 1 addition & 1 deletion fusion_bench/method/linear/linear_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def run(self, modelpool: BaseModelPool):
and returns a model with the interpolated state dict.
Args:
modelpool (BaseModelPool): The pool of models to interpolate.
modelpool (BaseModelPool): The pool of models to interpolate. Must contain exactly two models.
Returns:
nn.Module: The model with the interpolated state dict.
Expand Down

0 comments on commit 82bdd3a

Please sign in to comment.