Skip to content

Commit

Permalink
update docs and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed Nov 14, 2024
1 parent 69ea2da commit e222168
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/algorithms/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,5 @@ In summary, the Fusion Algorithm module is vital for the model merging operation

### References

::: fusion_bench.method.BaseAlgorithm
::: fusion_bench.method.BaseModelFusionAlgorithm
3 changes: 3 additions & 0 deletions fusion_bench/method/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,6 @@ def run(self, modelpool: BaseModelPool):


BaseModelFusionAlgorithm = BaseAlgorithm
"""
Alias for `BaseAlgorithm`.
"""
11 changes: 10 additions & 1 deletion fusion_bench/method/trust_region/clip_task_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,16 @@
log = logging.getLogger(__name__)


def trainable_state_dict(module: nn.Module):
def trainable_state_dict(module: nn.Module) -> StateDictType:
"""
Returns the state dictionary of the module containing only the trainable parameters.
Args:
module (nn.Module): The neural network module.
Returns:
Dict[str, Tensor]: A dictionary containing the names and values of the trainable parameters.
"""
return {
name: param for name, param in module.named_parameters() if param.requires_grad
}
Expand Down
42 changes: 39 additions & 3 deletions fusion_bench/utils/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,28 @@
# Model conversion utils


def trainable_state_dict(module: nn.Module):
return {
name: param for name, param in module.named_parameters() if param.requires_grad
def trainable_state_dict(
module: nn.Module,
prefix: str = "",
keep_vars: bool = False,
) -> StateDictType:
"""
Returns the state dictionary of the module containing only the trainable parameters.
Args:
module (nn.Module): The neural network module.
prefix (str, optional): The prefix to add to the parameter names. Defaults to "".
keep_vars (bool, optional): If True, the parameters are not detached. Defaults to False.
Returns:
Dict[str, Tensor]: A dictionary containing the names and values of the trainable parameters.
"""
state_dict = {
prefix + name: param if keep_vars else param.detach()
for name, param in module.named_parameters()
if param.requires_grad
}
return state_dict


def state_dict_to_vector(state_dict, remove_keys=[]):
Expand Down Expand Up @@ -78,6 +96,24 @@ def vector_to_state_dict(vector, state_dict, remove_keys=[]):


def human_readable(num: int) -> str:
"""
Converts a number into a human-readable string with appropriate magnitude suffix.
Examples:
```python
print(human_readable(1500))
# Output: '1.50K'
print(human_readable(1500000))
# Output: '1.50M'
```
Args:
num (int): The number to convert.
Returns:
str: The human-readable string representation of the number.
"""
if num < 1000 and isinstance(num, int):
return str(num)
magnitude = 0
Expand Down

0 comments on commit e222168

Please sign in to comment.