Skip to content

Commit fb91f67

Browse files
authored
Train via the Sentence Transformers Trainer from ST v3 (#554)
* Train via the Sentence Transformers Trainer from ST v3 * Simplify some init code; docstring * Prevent breaking changes by updating TrainerCallback * Replace ST Training Args with SetFit Training Args * Remove unused properties * Require 'accelerate' when training SetFit models * Remove log in docs as it is no longer used * Fix docs issue * Require installing sentence-transformers[train] * Keep not having to override metric_for_best_model by default It'll just keep using the loss of whatever trainer you're using. * Ensure logs directory is made in Callbacks example * Fix outdated docstring
1 parent 72f4d1e commit fb91f67

16 files changed

+280
-441
lines changed

docs/source/en/how_to/callbacks.mdx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,15 @@ trainer.train()
5959
SetFit supports custom callbacks in the same way that `transformers` does: by subclassing [`TrainerCallback`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.TrainerCallback). This class implements a lot of `on_...` methods that can be overridden. For example, the following script shows a custom callback that saves plots of the tSNE of the training and evaluation embeddings during training.
6060

6161
```py
62+
import os
6263
import matplotlib.pyplot as plt
6364
from sklearn.manifold import TSNE
6465

6566
class EmbeddingPlotCallback(TrainerCallback):
6667
"""Simple embedding plotting callback that plots the tSNE of the training and evaluation datasets throughout training."""
68+
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
69+
os.makedirs("logs", exist_ok=True)
70+
6771
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: SetFitModel, **kwargs):
6872
train_embeddings = model.encode(train_dataset["text"])
6973
eval_embeddings = model.encode(eval_dataset["text"])

docs/source/en/installation.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Before you start, you'll need to setup your environment and install the appropri
55

66
## pip
77

8-
The most straightforward way to install 🤗 Datasets is with pip:
8+
The most straightforward way to install 🤗 SetFit is with pip:
99

1010
```bash
1111
pip install setfit

docs/source/en/reference/trainer.mdx

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
- apply_hyperparameters
1717
- evaluate
1818
- hyperparameter_search
19-
- log
2019
- pop_callback
2120
- push_to_hub
2221
- remove_callback
@@ -31,7 +30,6 @@
3130
- apply_hyperparameters
3231
- evaluate
3332
- hyperparameter_search
34-
- log
3533
- pop_callback
3634
- push_to_hub
3735
- remove_callback

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
INTEGRATIONS_REQUIRE = ["optuna"]
1313
REQUIRED_PKGS = [
1414
"datasets>=2.15.0",
15-
"sentence-transformers>=2.2.1",
15+
"sentence-transformers[train]>=3",
1616
"transformers>=4.41.0",
1717
"evaluate>=0.3.0",
1818
"huggingface_hub>=0.23.0",

src/setfit/model_card.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,6 @@ def __init__(self, trainer: "Trainer") -> None:
3838
super().__init__()
3939
self.trainer = trainer
4040

41-
callbacks = [
42-
callback
43-
for callback in self.trainer.callback_handler.callbacks
44-
if isinstance(callback, CodeCarbonCallback)
45-
]
46-
if callbacks:
47-
trainer.model.model_card_data.code_carbon_callback = callbacks[0]
48-
4941
def on_init_end(
5042
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: "SetFitModel", **kwargs
5143
):
@@ -109,19 +101,22 @@ def on_evaluate(
109101
metrics: Dict[str, float],
110102
**kwargs,
111103
) -> None:
104+
keys = {"eval_embedding_loss", "eval_polarity_embedding_loss", "eval_aspect_embedding_loss"} & set(metrics)
105+
if not keys:
106+
return
112107
if (
113108
model.model_card_data.eval_lines_list
114109
and model.model_card_data.eval_lines_list[-1]["Step"] == state.global_step
115110
):
116-
model.model_card_data.eval_lines_list[-1]["Validation Loss"] = metrics["eval_embedding_loss"]
111+
model.model_card_data.eval_lines_list[-1]["Validation Loss"] = metrics[keys.pop()]
117112
else:
118113
model.model_card_data.eval_lines_list.append(
119114
{
120115
# "Training Loss": self.state.log_history[-1]["loss"] if "loss" in self.state.log_history[-1] else "-",
121116
"Epoch": state.epoch,
122117
"Step": state.global_step,
123118
"Training Loss": "-",
124-
"Validation Loss": metrics["eval_embedding_loss"],
119+
"Validation Loss": metrics[keys.pop()],
125120
}
126121
)
127122

src/setfit/sampler.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from itertools import zip_longest
2-
from typing import Generator, Iterable, List, Optional
2+
from typing import Dict, Generator, Iterable, List, Optional, Union
33

44
import numpy as np
55
import torch
6-
from sentence_transformers import InputExample
76
from torch.utils.data import IterableDataset
87

98
from . import logging
@@ -35,7 +34,8 @@ def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Genera
3534
class ContrastiveDataset(IterableDataset):
3635
def __init__(
3736
self,
38-
examples: List[InputExample],
37+
sentences: List[str],
38+
labels: List[Union[int, float]],
3939
multilabel: bool,
4040
num_iterations: Optional[None] = None,
4141
sampling_strategy: str = "oversampling",
@@ -44,7 +44,8 @@ def __init__(
4444
"""Generates positive and negative text pairs for contrastive learning.
4545
4646
Args:
47-
examples (InputExample): text and labels in a text transformer dataclass
47+
sentences (List[str]): text sentences to generate pairs from
48+
labels (List[Union[int, float]]): labels for each sentence
4849
multilabel: set to process "multilabel" labels array
4950
sampling_strategy: "unique", "oversampling", or "undersampling"
5051
num_iterations: if provided explicitly sets the number of pairs to be generated
@@ -57,8 +58,8 @@ def __init__(
5758
self.neg_index = 0
5859
self.pos_pairs = []
5960
self.neg_pairs = []
60-
self.sentences = np.array([s.texts[0] for s in examples])
61-
self.labels = np.array([s.label for s in examples])
61+
self.sentences = sentences
62+
self.labels = labels
6263
self.sentence_labels = list(zip(self.sentences, self.labels))
6364
self.max_pairs = max_pairs
6465

@@ -89,23 +90,23 @@ def __init__(
8990
def generate_pairs(self) -> None:
9091
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
9192
if _label == label:
92-
self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0))
93+
self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0})
9394
else:
94-
self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0))
95+
self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0})
9596
if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs:
9697
break
9798

9899
def generate_multilabel_pairs(self) -> None:
99100
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
100101
if any(np.logical_and(_label, label)):
101102
# logical_and checks if labels are both set for each class
102-
self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0))
103+
self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0})
103104
else:
104-
self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0))
105+
self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0})
105106
if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs:
106107
break
107108

108-
def get_positive_pairs(self) -> List[InputExample]:
109+
def get_positive_pairs(self) -> List[Dict[str, Union[str, float]]]:
109110
pairs = []
110111
for _ in range(self.len_pos_pairs):
111112
if self.pos_index >= len(self.pos_pairs):
@@ -114,7 +115,7 @@ def get_positive_pairs(self) -> List[InputExample]:
114115
self.pos_index += 1
115116
return pairs
116117

117-
def get_negative_pairs(self) -> List[InputExample]:
118+
def get_negative_pairs(self) -> List[Dict[str, Union[str, float]]]:
118119
pairs = []
119120
for _ in range(self.len_neg_pairs):
120121
if self.neg_index >= len(self.neg_pairs):
@@ -137,15 +138,16 @@ def __len__(self) -> int:
137138
class ContrastiveDistillationDataset(ContrastiveDataset):
138139
def __init__(
139140
self,
140-
examples: List[InputExample],
141+
sentences: List[str],
141142
cos_sim_matrix: torch.Tensor,
142143
num_iterations: Optional[None] = None,
143144
sampling_strategy: str = "oversampling",
144145
max_pairs: int = -1,
145146
) -> None:
146147
self.cos_sim_matrix = cos_sim_matrix
147148
super().__init__(
148-
examples,
149+
sentences,
150+
[0] * len(sentences),
149151
multilabel=False,
150152
num_iterations=num_iterations,
151153
sampling_strategy=sampling_strategy,
@@ -163,6 +165,8 @@ def __init__(
163165

164166
def generate_pairs(self) -> None:
165167
for (text_one, id_one), (text_two, id_two) in shuffle_combinations(self.sentence_labels):
166-
self.pos_pairs.append(InputExample(texts=[text_one, text_two], label=self.cos_sim_matrix[id_one][id_two]))
168+
self.pos_pairs.append(
169+
{"sentence_1": text_one, "sentence_2": text_two, "label": self.cos_sim_matrix[id_one][id_two]}
170+
)
167171
if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs:
168172
break

src/setfit/span/trainer.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,36 +78,35 @@ def __init__(
7878
model.aspect_model, model.polarity_model, eval_dataset
7979
)
8080

81+
# Set a better default value for the metric for best model for the aspect and polarity models
82+
aspect_args = args if args is not None else TrainingArguments()
83+
polarity_args = (polarity_args or args or TrainingArguments()).copy()
84+
if aspect_args.metric_for_best_model == "embedding_loss":
85+
aspect_args.metric_for_best_model = "aspect_embedding_loss"
86+
if polarity_args.metric_for_best_model == "embedding_loss":
87+
polarity_args.metric_for_best_model = "polarity_embedding_loss"
88+
8189
self.aspect_trainer = Trainer(
8290
model.aspect_model,
83-
args=args,
91+
args=aspect_args,
8492
train_dataset=aspect_train_dataset,
8593
eval_dataset=aspect_eval_dataset,
8694
metric=metric,
8795
metric_kwargs=metric_kwargs,
8896
callbacks=callbacks,
8997
)
90-
self.aspect_trainer._set_logs_mapper(
91-
{
92-
"eval_embedding_loss": "eval_aspect_embedding_loss",
93-
"embedding_loss": "aspect_embedding_loss",
94-
}
95-
)
98+
self.aspect_trainer._set_logs_prefix("aspect_embedding")
99+
96100
self.polarity_trainer = Trainer(
97101
model.polarity_model,
98-
args=polarity_args or args,
102+
args=polarity_args,
99103
train_dataset=polarity_train_dataset,
100104
eval_dataset=polarity_eval_dataset,
101105
metric=metric,
102106
metric_kwargs=metric_kwargs,
103107
callbacks=callbacks,
104108
)
105-
self.polarity_trainer._set_logs_mapper(
106-
{
107-
"eval_embedding_loss": "eval_polarity_embedding_loss",
108-
"embedding_loss": "polarity_embedding_loss",
109-
}
110-
)
109+
self.polarity_trainer._set_logs_prefix("polarity_embedding")
111110

112111
def preprocess_dataset(
113112
self, aspect_model: AspectModel, polarity_model: PolarityModel, dataset: Dataset

0 commit comments

Comments
 (0)