Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --validate-after-epochs training flag #5496

Open
mcognetta opened this issue May 3, 2024 · 0 comments
Open

Add --validate-after-epochs training flag #5496

mcognetta opened this issue May 3, 2024 · 0 comments

Comments

@mcognetta
Copy link

🚀 Feature Request

Add a --validate-after-epochs training flag that is a companion flag to --validate-after-updates.

Note: I already have a PR for this ready that I can contribute if this gets approved.

Motivation

When your task is configured to run validation after each epoch, --validate-after-updates can be difficult to use, since you might not know how many updates are in an epoch. This would add a companion flag that allows you to delay validation until N epochs have passed, without having to know in advance how many batches are included in a single epoch.

There is already precedent to have parallel flags for epoch-based and update-based validation (e.g., --validate-interval vs --validate-interval-updates), so it seems like this wouldn't be an unusual addition.

Pitch

Add a --validate-after-epochs flag to configs.py

validate_interval: int = field(
default=1, metadata={"help": "validate every N epochs"}
)
validate_interval_updates: int = field(
default=0, metadata={"help": "validate every N updates"}
)
validate_after_updates: int = field(
default=0, metadata={"help": "dont validate until reaching this many updates"}
)

and to fairseq_cli/train.py

do_validate = (
(
(not end_of_epoch and do_save) # validate during mid-epoch saves
or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0)
or should_stop
or (
cfg.dataset.validate_interval_updates > 0
and num_updates > 0
and num_updates % cfg.dataset.validate_interval_updates == 0
)
)
and not cfg.dataset.disable_validation
and num_updates >= cfg.dataset.validate_after_updates

Alternatives

The work around to this is to just do so estimation on how many batches are in an epoch, or to start a task, let it run for one update so that you can see batches-per-epoch, then start it over with the correct value set.

Additional context

I already have a PR prepared for this (it's like a 5 line change), but my understanding is that things like this need to be approved via issues first.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant