-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat
] Add cache_dir support to CrossEncoder
#2784
Conversation
…CrossEncoder class for clarity
Hello! I'm trying to think whether it would also be sufficient to pass the Then we can avoid loading the model ourselves, and relying more on
|
Hi! Yep, you're right, it’s better to rely on Thank you for the feedback! Please let me know if there are any further adjustments needed. |
Nice! I think this is a good direction. Just like trust_remote_code: bool = False,
revision: Optional[str] = None,
local_files_only: bool = False, these would otherwise also need to be passed to all 3 configs. There's one difficult consideration that remains: the naming. @muellerzr would you be able to help me with that? Here's the setup:
I have concerns regarding keyword arguments:
And cache directory:
I'm considering a big (but soft) deprecation that introduces the same format for all parameters across the project. The old argument names would still work until e.g. v4.0.0, but they'd give a warning to upgrade to the new parameters instead. Perhaps with a decorator like this and this. I'd love to hear your thoughts @muellerzr. cc @osanseviero
|
Yeah, I totally agree with you. I've added the |
Much appreciated! Other than the naming things that I mentioned & will wait for Zach to get back to me on, I think this is almost ready to go :)
|
feat
] Add cache_dir support to CrossEncoder
Shouldn't be (too) breaking as they're far down the list of kwargs
Thanks a bunch! And my apologies for the delay. This will be included in the next release.
|
Support for
cache_dir
Argument to CrossEncoderDescription:
This pull request introduces the
cache_dir
argument to theCrossEncoder
class, enabling users to specify a directory for caching model files. This addition resolves the issue of inconsistent file locations when working withlocal_files_only=True
.Problem:
When initializing the
CrossEncoder
class withlocal_files_only=True
, the following methods download and cache files to different directories:AutoConfig.from_pretrained
AutoModelForSequenceClassification.from_pretrained
AutoTokenizer.from_pretrained
This inconsistency arises because there is no direct way to pass the
cache_dir
argument toAutoConfig.from_pretrained
, leading to config files being stored in theHF_HOME
cache directory and PyTorch model files being stored in the specifiedcache_dir
.Solution:
To address this issue, I utilized the
snapshot_download
function from thehuggingface_hub
library. This function allows for downloading a complete snapshot of a repo's files, including the model, config, and tokenizer files, and stores them in the specifiedcache_dir
. By passing thecache_dir
argument to theCrossEncoder
class, users can now ensure that all files are consistently stored in the same directory.Changes:
cache_dir
argument to theCrossEncoder
class__init__
method.snapshot_download
function to download the model files or get the path to the relevant snapshot from the cache.AutoConfig.from_pretrained
,AutoModelForSequenceClassification.from_pretrained
, andAutoTokenizer.from_pretrained
methods.Testing:
CrossEncoder
class correctly loads model files from the specifiedcache_dir
.cache_dir
argument, confirming that files are still downloaded to the default directories.