-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bert4rec #197
base: experimental/sasrec
Are you sure you want to change the base?
Bert4rec #197
Conversation
self.shuffle_train = shuffle_train | ||
# TODO: add SequenceDatasetType for fit and recommend | ||
|
||
def process_dataset_train(self, dataset: Dataset) -> Dataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this method is the same for SASRec and BERT4Rec then just move it to the SessionEncoderDataPreparatorBase. Do it for all of the methods that don't have any differences.
Ideally only 2 collate fn methods will have differences
torch.ones((session_max_len, session_max_len), dtype=torch.bool, device=sessions.device) | ||
) | ||
timeline_mask = sessions != 0 | ||
attn_mask = ~timeline_mask.unsqueeze(1).repeat(self.n_heads, timeline_mask.squeeze(-1).shape[1], 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is timeline_mask.squeeze(-1).shape[1]
exactly?
When timeline_mask is of shape [batch_size, session_maxlen]
# #### -------------- Session Encoder -------------- #### # | ||
|
||
|
||
class TransformerBasedSessionEncoder(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need one common class for all models. If there is some custom logic just create the flags and pass them from models during TransformerBasedSessionEncoder
initialization for the model.
Since we haven't decided with the masks yet let's also make it a flag. And pass this flag to model initialization as well. This will simplify these experiments for us.
TransformerBasedSessionEncoder
should be imported from sasrec.py (until we move all things to correct modules)
|
||
def on_train_start(self) -> None: | ||
"""TODO""" | ||
self._truncated_normal_init() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is the only difference then just let's use the one that is used in SASRec without overwriting the class.
Just import SessionEncoderLightningModule
from sasrec.py
until we move everything to correct modules
lr: float = 0.01, | ||
dataloader_num_workers: int = 0, | ||
train_min_user_interaction: int = 2, | ||
mask_prob: float = 0.15, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I can see the only difference in this model from SASRec is receiving mask_prob and passing it to data_preparator_type.
What you need to do:
- Create
TransformerModelBase
class with all of the common methods already implemented. Look here https://github.com/MobileTeleSystems/RecTools/blob/main/rectools/models/base.py for example. Carefully select which of the arguments must be passed to init.mask_prob
shouldn't be there at all.transformer_layers_type
anddata_preparator_type
should not have default values.self.data_preparator
should not be initialized but should be declared likeself.data_preparator: SessionEncoderDataPreparatorBase
. - In SASRecModel init create default values for
transformer_layers_type
anddata_preparator_type
. Initialize self.data_preparator - In BERT4RecModel init create default values for
transformer_layers_type
anddata_preparator_type
. Initialize self.data_preparator
If something fails for linters, let's discuss. But from the first point of view this should work.
return recommend_dataloader | ||
|
||
|
||
class PointWiseFeedForward(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now about the feedforward differences.
There are a few:
- The order of dropout and activation. For RELU it doesn't have any difference in the result. So let's change the order for SASRec variant and it will be the same as BERT4Rec.
- Activation function. Let's just add and argument to PointWiseFeedForward init which activation to use
- Placing second dropout inside PointWiseFeedForward.forward or moving it to TransformerLayers.forward. We can move them to TransformerLayers in SASRecTransformerLayers so that it would be more like the classic picture of transformers architecture.
As far as I can see after all of this we will have one PointWiseFeedForward class for both models. And it's good.
Please make this change as a separate commit and check SASRec metrics.
for i in range(self.n_blocks): | ||
mha_input = self.layer_norm1[i](seqs) | ||
# mha_output, _ = | ||
# self.multi_head_attn[i](mha_input, mha_input, mha_input, attn_mask=attn_mask, need_weights=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not handy at all :)
Here just pass the one you received in forward
always.
And create a flag in the model if you want to create this mask or not. If not - pass it with Null value
ff_output = self.feed_forward[i](ff_input) | ||
seqs = seqs + self.dropout2[i](ff_output) | ||
seqs = seqs * timeline_mask | ||
# seqs = self.dropout3[i](seqs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need it, do we? Is there a reference for this?
Added bert4rec model