Skip to content

Commit

Permalink
Clean deprecated max_samples arguments (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
kirill-fedyanin authored Jan 4, 2024
1 parent e316174 commit 98fe28f
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 30 deletions.
10 changes: 2 additions & 8 deletions scripts/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,7 @@ def main():
###############
train_result = dpo_trainer.train()
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(raw_datasets["train"])
)
metrics["train_samples"] = min(max_train_samples, len(raw_datasets["train"]))
metrics["train_samples"] = len(raw_datasets["train"])
dpo_trainer.log_metrics("train", metrics)
dpo_trainer.save_metrics("train", metrics)
dpo_trainer.save_state()
Expand All @@ -189,10 +186,7 @@ def main():
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = dpo_trainer.evaluate()
max_eval_samples = (
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(raw_datasets["test"])
)
metrics["eval_samples"] = min(max_eval_samples, len(raw_datasets["test"]))
metrics["eval_samples"] = len(raw_datasets["test"])
dpo_trainer.log_metrics("eval", metrics)
dpo_trainer.save_metrics("eval", metrics)

Expand Down
6 changes: 2 additions & 4 deletions scripts/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ def main():
logger.info("*** Train ***")
train_result = trainer.train()
metrics = train_result.metrics
max_train_samples = data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
metrics["train_samples"] = len(train_dataset)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
Expand All @@ -163,8 +162,7 @@ def main():
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
metrics["eval_samples"] = len(eval_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

Expand Down
18 changes: 0 additions & 18 deletions src/alignment/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,24 +197,6 @@ class DataArguments:
default_factory=lambda: ["train", "test"],
metadata={"help": ("List of train test splits to use in the dataset")},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
Expand Down

0 comments on commit 98fe28f

Please sign in to comment.