-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update Evaluation Script for Reranking #3198
Conversation
Hello! I'm actually currently updating the CrossEncoder training flow, and one of the proposed changes that I will introduce is to make the However, if you'd to revert the
|
Hi @tomaarsen, Thank you for your quick response! I’ve removed the Regarding my other work, I’ve been preparing a larger PR for the CrossEncoder that includes:
However, I understand you’re currently updating the CrossEncoder training flow, which might conflict with or supersede my changes. I’d love to sync with you to align my updates with your improvements in order to make the process more efficient and avoid redundant work. Let me know how you’d like to proceed—I’m happy to adapt my changes or collaborate further. Thanks again, and looking forward to your thoughts! Best regards, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thank you!
@milistu The timing is indeed a bit unfortunate! I'm quite interested in your loss function, but it will take a second before my refactor is ready. To clarify, I'm going to update the CrossEncoder training to start using a I'm experimenting with some losses, but I haven't been having much luck beyond the common BCE/Cross Entropy losses. Currently, all of the changes are local, but perhaps I can try and make a branch public soon. That way you can align your e.g. loss updates & MS MARCO training script with the new flow? And this is an example of what training might look like: Click to expandimport logging
from datetime import datetime
from datasets import load_dataset
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation.CERerankingEvaluator import CERerankingEvaluator
from sentence_transformers.cross_encoder.losses.BCEWithLogitsLoss import BCEWithLogitsLoss
from sentence_transformers.cross_encoder.trainer import CrossEncoderTrainer
from sentence_transformers.cross_encoder.training_args import CrossEncoderTrainingArguments
def main():
# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
train_dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train")
def mapper(batch):
queries = []
passages = []
labels = []
for query, passages_info in zip(batch["query"], batch["passages"]):
for idx, is_selected in enumerate(passages_info["is_selected"]):
queries.append(query)
passages.append(passages_info["passage_text"][idx])
labels.append(is_selected)
return {"query": queries, "passage": passages, "label": labels}
# breakpoint()
train_dataset = train_dataset.map(mapper, batched=True, remove_columns=train_dataset.column_names)
print(train_dataset)
train_batch_size = 64
num_epochs = 4
output_dir = "output/training_ce-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# Define our CrossEncoder model.
model = CrossEncoder("microsoft/MiniLM-L12-H384-uncased", num_labels=1)
loss = BCEWithLogitsLoss(model)
# Load GooAQ:
dataset = load_dataset("tomaarsen/gooaq-hard-negatives", "triplet-5", split="train").select(range(1_000))
samples = [
{
"query": sample["question"],
"positive": [sample["answer"]],
"negative": [sample["negative_1"], sample["negative_2"], sample["negative_3"], sample["negative_4"], sample["negative_5"]],
}
for sample in dataset
]
evaluator = CERerankingEvaluator(samples, name="GooAQ_1k_5_negs")
evaluator(model)
# 5. Define the training arguments
args = CrossEncoderTrainingArguments(
# Required parameter:
output_dir=output_dir,
# Optional training parameters:
num_train_epochs=num_epochs,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=train_batch_size,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=2,
logging_steps=20,
logging_first_step=True,
run_name=f"ce-msmarco-from-MiniLM-L12", # Will be used in W&B if `wandb` is installed
)
# 6. Create the trainer & start training
trainer = CrossEncoderTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=evaluator,
)
trainer.train()
# 7. Save the final model
final_output_dir = f"{output_dir}/final"
model.save_pretrained(final_output_dir)
if __name__ == "__main__":
main() It's a bit messy, sorry for that. I'm curious what you think, given that you're also clearly working on CrossEncoders now!
|
@tomaarsen On Listwise Loss Experiments:I’ve been working with the ListNet loss (one of the simpler listwise approaches) and observed mixed results on MS-MARCO. The model is learning and the NDCG on evaluation is rising but it is still not enough to surpass the previous pariwise approach. Hyperparameter tuning or combining losses (e.g., pairwise + listwise) might help. Additionally, The current listnet loss is based on this code that I wrote but with some changes to fit better to the old CE training setup. Collaboration Offer:I’d love to align my work with your refactor! If you create a public branch for the new CrossEncoderTrainer setup, I can:
If you are interested we can move this conversation elsewhere (e.g. LinkedIn) and start working on the CE update together. |
Update Evaluation Script for Reranking
Updates
New Feature: Added a
return_metric
argument that specifies which metric will__call__
method return. MRR is still left as default but now we can save the model while training based on NDCG.Documentation Improvement: Updated docstrings for previously missing or incomplete arguments.
Best Regards,