Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bert4rec #197

Draft
wants to merge 1 commit into
base: experimental/sasrec
Choose a base branch
from

Conversation

spirinamayya
Copy link
Contributor

Added bert4rec model

self.shuffle_train = shuffle_train
# TODO: add SequenceDatasetType for fit and recommend

def process_dataset_train(self, dataset: Dataset) -> Dataset:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this method is the same for SASRec and BERT4Rec then just move it to the SessionEncoderDataPreparatorBase. Do it for all of the methods that don't have any differences.
Ideally only 2 collate fn methods will have differences

torch.ones((session_max_len, session_max_len), dtype=torch.bool, device=sessions.device)
)
timeline_mask = sessions != 0
attn_mask = ~timeline_mask.unsqueeze(1).repeat(self.n_heads, timeline_mask.squeeze(-1).shape[1], 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is timeline_mask.squeeze(-1).shape[1] exactly?
When timeline_mask is of shape [batch_size, session_maxlen]

# #### -------------- Session Encoder -------------- #### #


class TransformerBasedSessionEncoder(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need one common class for all models. If there is some custom logic just create the flags and pass them from models during TransformerBasedSessionEncoder initialization for the model.
Since we haven't decided with the masks yet let's also make it a flag. And pass this flag to model initialization as well. This will simplify these experiments for us.
TransformerBasedSessionEncoder should be imported from sasrec.py (until we move all things to correct modules)


def on_train_start(self) -> None:
"""TODO"""
self._truncated_normal_init()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is the only difference then just let's use the one that is used in SASRec without overwriting the class.
Just import SessionEncoderLightningModule from sasrec.py until we move everything to correct modules

lr: float = 0.01,
dataloader_num_workers: int = 0,
train_min_user_interaction: int = 2,
mask_prob: float = 0.15,
Copy link
Collaborator

@blondered blondered Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can see the only difference in this model from SASRec is receiving mask_prob and passing it to data_preparator_type.
What you need to do:

  1. Create TransformerModelBase class with all of the common methods already implemented. Look here https://github.com/MobileTeleSystems/RecTools/blob/main/rectools/models/base.py for example. Carefully select which of the arguments must be passed to init. mask_prob shouldn't be there at all. transformer_layers_type and data_preparator_type should not have default values. self.data_preparator should not be initialized but should be declared like self.data_preparator: SessionEncoderDataPreparatorBase.
  2. In SASRecModel init create default values for transformer_layers_type and data_preparator_type. Initialize self.data_preparator
  3. In BERT4RecModel init create default values for transformer_layers_type and data_preparator_type. Initialize self.data_preparator

If something fails for linters, let's discuss. But from the first point of view this should work.

return recommend_dataloader


class PointWiseFeedForward(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now about the feedforward differences.
There are a few:

  1. The order of dropout and activation. For RELU it doesn't have any difference in the result. So let's change the order for SASRec variant and it will be the same as BERT4Rec.
  2. Activation function. Let's just add and argument to PointWiseFeedForward init which activation to use
  3. Placing second dropout inside PointWiseFeedForward.forward or moving it to TransformerLayers.forward. We can move them to TransformerLayers in SASRecTransformerLayers so that it would be more like the classic picture of transformers architecture.

As far as I can see after all of this we will have one PointWiseFeedForward class for both models. And it's good.
Please make this change as a separate commit and check SASRec metrics.

for i in range(self.n_blocks):
mha_input = self.layer_norm1[i](seqs)
# mha_output, _ =
# self.multi_head_attn[i](mha_input, mha_input, mha_input, attn_mask=attn_mask, need_weights=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not handy at all :)
Here just pass the one you received in forward always.
And create a flag in the model if you want to create this mask or not. If not - pass it with Null value

ff_output = self.feed_forward[i](ff_input)
seqs = seqs + self.dropout2[i](ff_output)
seqs = seqs * timeline_mask
# seqs = self.dropout3[i](seqs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need it, do we? Is there a reference for this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants