-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to randomly select a column for preconditioning for each epoch? #40
Comments
Unfortunately, our framework does not support it yet, but we will incorporate random preconditioning for each epoch in our future release. If someone wants to contribute, we will be happy to merge a PR. |
Hi @unnir, I'm working together with @BelenGarciaPascual. We are also planning to make a proper PR. ```python
batch_size = 32
steps = len(data)//batch_size
epochs = [0,1,2,3,4,5,6,7]
columns = data.columns
for epoch in epochs:
for idx, column in enumerate(columns):
print(f'{epoch=} -> {column=}')
great = GReaT(base, # Name of the large language model used (see HuggingFace for more options)
batch_size=batch_size,
epochs=epoch*len(data.columns) + idx + 1, # Number of epochs to train (only one epoch for demonstration)
save_steps=steps, # Save model weights every x steps
logging_steps=steps, # Log the loss and learning rate every x steps
experiment_dir=f"aleks_{llm}_trainer", # Name of the directory where all intermediate steps are saved
)
if epoch == 0 and idx == 0:
trainer = great.fit(data, conditional_col=column)
else:
trainer = great.fit(data, conditional_col=column, resume_from_checkpoint=True)
rmtree(Path(f"aleks_{llm}_trainer")/f"checkpoint-{epoch*len(data.columns)*steps + idx*steps}")
great.save(f"aleks_california_{llm}")
for path in Path(f"aleks_{llm}_trainer").iterdir():
if path.is_dir():
print(f'{path=}')
|
Cool! Thank you for the update, I would recommend to train the model longer, at least 10+ epochs to get even better results. Also, to speed the training you can pass the fp16=True to the GReaT method. It should be at least 2 times faster. |
@unnir Just curious, how long it take for you to fine tune model on your hardware (I assume you are using some GPU) |
Hi @BelenGarciaPascual and @unnir, I am interested in this discussion. However, as described in the paper, after converting each row to a textual encoding, Great permutes the sequence to ignore the order. So, I don't understand what did you mean by saying the last column was used in the training/fine-tuning the model. In my understanding, the last column is only used in the sampling phase if you don't specify the precondition. |
Hi @nphdang, You are correct: Great permutes the sequences in its own during fine-tuning. Our initial suggestion there was wrong. |
@kontramind thanks for the clarification. Yes, doing permutation in the sampling phase is simpler, we just need to iterate each column and set it as the precondition. I tried this step and it could slightly improve in the classification downstream task. |
As explained in the code documentation, when training/fine-tuning by the function .fit(), if no column in the tabular data is specified, the last column is taken to precondition.
We have used several metrics from the SDMetrics library to compare the real California housing dataset to several synthetically generated tabular datasets from GReaT. From there, we notice that this default preconditioning is not ideal, as that last column is almost an exact match to its synthetic column, while the rest of the columns present much more variability.
We would really like to mitigate this effect by selecting a column randomly, where every column has equal probability to be selected, and this column selection is repeated every time the data is revisited, i.e, every number of epochs. Like this, all columns would be selected for preconditioning during the fitting process.
Is it possible, in any argument in GReaT() or in .fit(), to specify this random selection of columns switching every epoch? Or one has to somehow re-write the code and not use the be_great package directly?
Many thanks beforehand!
The text was updated successfully, but these errors were encountered: