Skip to content

Commit

Permalink
Merge pull request #28 from tanganke/add-docstrings
Browse files Browse the repository at this point in the history
Add docstrings to some classes and methods
  • Loading branch information
tanganke authored Nov 13, 2024
2 parents c7ca936 + 64c1a21 commit b894b2b
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 0 deletions.
3 changes: 3 additions & 0 deletions fusion_bench/compat/method/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def run(self, modelpool):
Args:
modelpool: The pool of models to fuse.
Returns:
The fused model.
Examples:
>>> algorithm = SimpleAverageAlgorithm()
>>> modelpool = ModelPool()
Expand Down
100 changes: 100 additions & 0 deletions fusion_bench/dataset/gpt2_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,42 @@ def qqp_tokenize_function(examples, tokenizer):


class TokenizedGLUE:
"""
A class to load and cache GLUE datasets for GPT-2 models.
This class provides methods to load various GLUE datasets and tokenize them
using a provided tokenizer. The datasets are cached to disk to avoid
reloading and tokenizing them multiple times.
Attributes:
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenizing the datasets.
"""

def __init__(self, tokenizer: PreTrainedTokenizer):
"""
Initialize the TokenizedGLUE class with a tokenizer.
Args:
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenizing the datasets.
"""
super().__init__()
self.tokenizer = tokenizer

def load_dataset(
self, name: Literal["mrpc", "mnli", "cola", "sst2", "qnli", "qqp", "rte"]
):
"""
Load and tokenize a GLUE dataset.
This method loads a specified GLUE dataset, tokenizes it using the provided
tokenizer, and caches the tokenized dataset to disk.
Args:
name (Literal["mrpc", "mnli", "cola", "sst2", "qnli", "qqp", "rte"]): The name of the GLUE dataset to load.
Returns:
Dataset: The tokenized GLUE dataset.
"""
glue_dataset_loaders = {
"mrpc": self.load_mrpc_dataset,
"mnli": self.load_mnli_dataset,
Expand All @@ -121,6 +149,15 @@ def load_dataset(

@cache_dataset
def load_mrpc_dataset(self):
"""
Load and tokenize the MRPC dataset.
This method loads the MRPC dataset, tokenizes it using the provided
tokenizer, and caches the tokenized dataset to disk.
Returns:
Dataset: The tokenized MRPC dataset.
"""
dataset = load_dataset("glue", "mrpc")
dataset = dataset.map(
partial(mrpc_tokenize_function, tokenizer=self.tokenizer),
Expand All @@ -131,6 +168,15 @@ def load_mrpc_dataset(self):

@cache_dataset
def load_rte_dataset(self):
"""
Load and tokenize the RTE dataset.
This method loads the RTE dataset, tokenizes it using the provided
tokenizer, and caches the tokenized dataset to disk.
Returns:
Dataset: The tokenized RTE dataset.
"""
dataset = load_dataset("glue", "rte")
dataset = dataset.map(
# RTE has the same format as MRPC
Expand All @@ -142,6 +188,15 @@ def load_rte_dataset(self):

@cache_dataset
def load_wnli_dataset(self):
"""
Load and tokenize the WNLI dataset.
This method loads the WNLI dataset, tokenizes it using the provided
tokenizer, and caches the tokenized dataset to disk.
Returns:
Dataset: The tokenized WNLI dataset.
"""
dataset = load_dataset("glue", "wnli")
dataset = dataset.map(
partial(mrpc_tokenize_function, tokenizer=self.tokenizer),
Expand All @@ -152,6 +207,15 @@ def load_wnli_dataset(self):

@cache_dataset
def load_qqp_dataset(self):
"""
Load and tokenize the QQP dataset.
This method loads the QQP dataset, tokenizes it using the provided
tokenizer, and caches the tokenized dataset to disk.
Returns:
Dataset: The tokenized QQP dataset.
"""
dataset = load_dataset("glue", "qqp")
dataset = dataset.map(
partial(qqp_tokenize_function, tokenizer=self.tokenizer),
Expand All @@ -162,6 +226,15 @@ def load_qqp_dataset(self):

@cache_dataset
def load_mnli_dataset(self):
"""
Load and tokenize the MNLI dataset.
This method loads the MNLI dataset, tokenizes it using the provided
tokenizer, and caches the tokenized dataset to disk.
Returns:
Dataset: The tokenized MNLI dataset.
"""
dataset = load_dataset("glue", "mnli")
dataset = dataset.map(
partial(mnli_tokenize_function, tokenizer=self.tokenizer),
Expand All @@ -172,6 +245,15 @@ def load_mnli_dataset(self):

@cache_dataset
def load_cola_dataset(self):
"""
Load and tokenize the CoLA dataset.
This method loads the CoLA dataset, tokenizes it using the provided
tokenizer, and caches the tokenized dataset to disk.
Returns:
Dataset: The tokenized CoLA dataset.
"""
dataset = load_dataset("glue", "cola")
dataset = dataset.map(
partial(cola_tokenize_function, tokenizer=self.tokenizer),
Expand All @@ -182,6 +264,15 @@ def load_cola_dataset(self):

@cache_dataset
def load_sst2_dataset(self):
"""
Load and tokenize the SST-2 dataset.
This method loads the SST-2 dataset, tokenizes it using the provided
tokenizer, and caches the tokenized dataset to disk.
Returns:
Dataset: The tokenized SST-2 dataset.
"""
dataset = load_dataset("glue", "sst2")
dataset = dataset.map(
partial(cola_tokenize_function, tokenizer=self.tokenizer),
Expand All @@ -192,6 +283,15 @@ def load_sst2_dataset(self):

@cache_dataset
def load_qnli_dataset(self):
"""
Load and tokenize the QNLI dataset.
This method loads the QNLI dataset, tokenizes it using the provided
tokenizer, and caches the tokenized dataset to disk.
Returns:
Dataset: The tokenized QNLI dataset.
"""
dataset = load_dataset("glue", "qnli")
dataset = dataset.map(
partial(qnli_tokenize_function, tokenizer=self.tokenizer),
Expand Down

0 comments on commit b894b2b

Please sign in to comment.