-
Notifications
You must be signed in to change notification settings - Fork 29.2k
Convert _VALID_DICT_FIELDS to class attribute for shared dict parsing in subclasses #36736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the |
Hi @SunMarc, any suggestions for this? |
Since this is Trainer/TrainingArguments, leaving it to @SunMarc and @muellerzr! |
c3e825a
to
f6beb59
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, if it alleviates a painpoint for TRL/other downstream libs, I don't see an issue here!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
f6beb59
to
bb2d303
Compare
@SunMarc, kindly request for review🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM !
00e159b
to
66f36a8
Compare
Hi all, kindly request for merging |
Hi @Tavish9, we're rebasing/merging your branch onto main to get tests to pass, because of CI issues on the previous commit that your branch was based off. As soon as we can get the CI to pass, we'll merge your branch into |
Hi @Tavish9 after rebasing, I think some of the failing tests are actually caused by this PR! Can you check the tests in |
ok, i'm glad to help the test |
Hi, @Rocketknight1, I found that the original logic in tests_hub is wrong. I will fix them all.🤗 |
Small testcase: import unittest
from dataclasses import dataclass, field
from typing import Optional, Union, get_origin
from transformers import HfArgumentParser
@dataclass
class Test:
accelerator_config: Optional[Union[dict, str]] = field(
default=None,
metadata={
"help": (
"Config to be used with the internal Accelerator object initialization. The value is either a "
"accelerator json config file (e.g., `accelerator_config.json`) or an already loaded json file as `dict`."
)
},
)
class HfArgumentParserTest(unittest.TestCase):
def test_a(self):
# First find any annotations that contain `dict`
fields = Test.__dataclass_fields__
print(fields["accelerator_config"].type)
parser = HfArgumentParser(Test)
def test_b(self):
# First find any annotations that contain `dict`
fields = Test.__dataclass_fields__
print(fields["accelerator_config"].type)
print(get_origin(fields["accelerator_config"].type)) typing.Union[dict, str, NoneType]
.<class 'str'>
None After We either, move |
744b304
to
013f74a
Compare
013f74a
to
3d3581e
Compare
3d3581e
to
863c289
Compare
cc @SunMarc @muellerzr can you take another look now that the tests have been updated as well? You can merge if you're happy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks !
please check the tests carefully, as we are not aligned with the testing pipeline. merge if you think there is no problem:) |
Any updates? @SunMarc, @Rocketknight1 |
Good for me regarding the test. For A tiny question, would the change will break current usecase? I see you mentioned
Would those codes have to access the module level's |
Potentially but this is a private attribute so it shouldn't be too big of an issue |
Good question! However, from a naming perspective(semi-private), it doesn't make sense for a subclass to access the parent class's _VALID_DICT_FIELDS. Although, the current logic can be implemented with the following code: from transformers.training_args import _VALID_DICT_FIELDS
_VALID_DICT_FIELDS.append("my_dict_field") It works, but I think no one does that. you can also refer to this discussion for reference. |
…ng in subclasses (huggingface#36736) * make _VALID_DICT_FIELDS as a class attribute * fix test case about TrainingArguments
…ng in subclasses (huggingface#36736) * make _VALID_DICT_FIELDS as a class attribute * fix test case about TrainingArguments
…ng in subclasses (huggingface#36736) * make _VALID_DICT_FIELDS as a class attribute * fix test case about TrainingArguments
What does this PR do?
This PR refactors the
_VALID_DICT_FIELDS
attribute into a class member variable within the base class, enabling subclasses of the training argument class to inherit and reuse a unified dictionary parsing logic in their__post_init__
methods.Subclasses no longer need to reimplement repetitive
__post_init__
logic for dictionary fields.Fixes huggingface/trl#3082
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@muellerzr and @SunMarc