diff --git a/scripts/run_dpo.py b/scripts/run_dpo.py index 06189674..fbd084e5 100644 --- a/scripts/run_dpo.py +++ b/scripts/run_dpo.py @@ -173,10 +173,7 @@ def main(): ############### train_result = dpo_trainer.train() metrics = train_result.metrics - max_train_samples = ( - data_args.max_train_samples if data_args.max_train_samples is not None else len(raw_datasets["train"]) - ) - metrics["train_samples"] = min(max_train_samples, len(raw_datasets["train"])) + metrics["train_samples"] = len(raw_datasets["train"]) dpo_trainer.log_metrics("train", metrics) dpo_trainer.save_metrics("train", metrics) dpo_trainer.save_state() @@ -189,10 +186,7 @@ def main(): if training_args.do_eval: logger.info("*** Evaluate ***") metrics = dpo_trainer.evaluate() - max_eval_samples = ( - data_args.max_eval_samples if data_args.max_eval_samples is not None else len(raw_datasets["test"]) - ) - metrics["eval_samples"] = min(max_eval_samples, len(raw_datasets["test"])) + metrics["eval_samples"] = len(raw_datasets["test"]) dpo_trainer.log_metrics("eval", metrics) dpo_trainer.save_metrics("eval", metrics) diff --git a/scripts/run_sft.py b/scripts/run_sft.py index 97cc0515..e0d892fc 100644 --- a/scripts/run_sft.py +++ b/scripts/run_sft.py @@ -151,8 +151,7 @@ def main(): logger.info("*** Train ***") train_result = trainer.train() metrics = train_result.metrics - max_train_samples = data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) - metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + metrics["train_samples"] = len(train_dataset) trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() @@ -163,8 +162,7 @@ def main(): if training_args.do_eval: logger.info("*** Evaluate ***") metrics = trainer.evaluate() - max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) - metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + metrics["eval_samples"] = len(eval_dataset) trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) diff --git a/src/alignment/configs.py b/src/alignment/configs.py index 9097d94d..9ca7c8e5 100644 --- a/src/alignment/configs.py +++ b/src/alignment/configs.py @@ -197,24 +197,6 @@ class DataArguments: default_factory=lambda: ["train", "test"], metadata={"help": ("List of train test splits to use in the dataset")}, ) - max_train_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - ) - }, - ) - max_eval_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of evaluation examples to this " - "value if set." - ) - }, - ) preprocessing_num_workers: Optional[int] = field( default=None, metadata={"help": "The number of processes to use for the preprocessing."},