Skip to content

Commit cc0967d

Browse files
committed
Update data split to 40%
1 parent a0364db commit cc0967d

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

train_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ def main(key: PRNGKeyArray):
5050

5151
if args.tune_hyperparams:
5252
args.group = 'Sweeps' if args.baseline else 'Sweeps_5i'
53-
54-
trainloader = train_dataset.create_dataloader('20%')
55-
valloader = val_dataset.create_dataloader('20%')
53+
54+
trainloader = train_dataset.create_dataloader("40%")
55+
valloader = val_dataset.create_dataloader("40%")
5656

5757
trainloader = train_dataset.create_dataloader("40%")
5858
valloader = val_dataset.create_dataloader("40%")
@@ -136,7 +136,7 @@ def main(key: PRNGKeyArray):
136136
def kickoff_optuna(trial, **trainer_kwargs):
137137
args = trainer_kwargs['args']
138138

139-
args.epochs = 2
139+
args.epochs = 1
140140

141141
args.lr = trial.suggest_float('lr', 1e-4, 1e-2, step=1e-4)
142142
args.drop_rate = trial.suggest_float('drop_rate', 0.0, 0.1, step=0.01)

0 commit comments

Comments
 (0)