Skip to content

Commit

Permalink
Merge branch 'main' into cai
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn authored Feb 1, 2024
2 parents 45f90a3 + 5ad6db0 commit b364259
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/alignment/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from accelerate import Accelerator
from huggingface_hub import list_repo_files
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from peft import LoraConfig, PeftConfig

Expand Down Expand Up @@ -106,7 +107,7 @@ def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool:
try:
# Try first if model on a Hub repo
repo_files = list_repo_files(model_name_or_path, revision=revision)
except HFValidationError:
except (HFValidationError, RepositoryNotFoundError):
# If not, check local repo
repo_files = os.listdir(model_name_or_path)
return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files
Expand Down
16 changes: 15 additions & 1 deletion tests/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
import torch
from transformers import AutoTokenizer

from alignment import DataArguments, ModelArguments, get_peft_config, get_quantization_config, get_tokenizer
from alignment import (
DataArguments,
ModelArguments,
get_peft_config,
get_quantization_config,
get_tokenizer,
is_adapter_model,
)
from alignment.data import DEFAULT_CHAT_TEMPLATE


Expand Down Expand Up @@ -88,3 +95,10 @@ def test_no_peft_config(self):
model_args = ModelArguments(use_peft=False)
peft_config = get_peft_config(model_args)
self.assertIsNone(peft_config)


class IsAdapterModelTest(unittest.TestCase):
def test_is_adapter_model_calls_listdir(self):
# Assert that for an invalid repo name it gets to the point where it calls os.listdir,
# which is expected to raise a FileNotFoundError
self.assertRaises(FileNotFoundError, is_adapter_model, "nonexistent/model")

0 comments on commit b364259

Please sign in to comment.