Skip to content

Commit 0b1c926

Browse files
committed
remove z-loss mess
1 parent feba465 commit 0b1c926

File tree

1 file changed

+6
-77
lines changed

1 file changed

+6
-77
lines changed

open_flamingo/train/losses.py

Lines changed: 6 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
from open_flamingo.src.vlm import VLM
22
import torch
3-
from torch import Tensor
4-
from torch.nn import CrossEntropyLoss
53

6-
SUPPORTED_LOSSES = ["next_token_prediction", "next_token_prediction_with_z_loss"]
4+
SUPPORTED_LOSSES = ["next_token_prediction"]
75

86

97
def get_loss_fn(loss_name):
108
if loss_name == "next_token_prediction":
119
return NextTokenPrediction()
12-
elif loss_name == "next_token_prediction_with_z_loss":
13-
return NextTokenPredictionWithZLoss()
1410
else:
1511
raise ValueError(
1612
f"Loss {loss_name} not supported. Supported losses: {SUPPORTED_LOSSES}"
@@ -47,10 +43,10 @@ def __call__(
4743
raise NotImplementedError
4844

4945

50-
class NextTokenPredictionWithZLoss(Loss):
46+
class NextTokenPrediction(Loss):
5147
@property
5248
def name(self):
53-
return "next_token_prediction_with_z_loss"
49+
return "next_token_prediction"
5450

5551
def __call__(
5652
self,
@@ -60,7 +56,6 @@ def __call__(
6056
input_ids: torch.Tensor,
6157
attention_mask: torch.Tensor,
6258
autocast: callable,
63-
z_loss_eps: float = 1e-4,
6459
):
6560
# set up labels; language model is expected to handle shifting
6661
labels = input_ids.clone()
@@ -74,55 +69,15 @@ def __call__(
7469

7570
# call forward
7671
with autocast():
77-
logits = model(
72+
loss = model(
7873
vision_x=images,
7974
lang_x=input_ids,
8075
attention_mask=attention_mask,
8176
labels=labels,
82-
)[1]
83-
84-
logits = logits.float()
85-
86-
# Shift so that tokens < n predict n
87-
shift_logits = logits[..., :-1, :].contiguous()
88-
shift_labels = labels[..., 1:].contiguous()
89-
# Flatten the tokens
90-
loss_fct = CrossEntropyLossWithZLoss(eps=z_loss_eps)
91-
shift_logits = shift_logits.view(-1, unwrap_model(model).lang_model.config.vocab_size)
92-
shift_labels = shift_labels.view(-1)
93-
# Enable model parallelism
94-
shift_labels = shift_labels.to(shift_logits.device)
95-
loss = loss_fct(shift_logits, shift_labels)
96-
77+
)[0]
9778
return loss
9879

9980

100-
class NextTokenPrediction(NextTokenPredictionWithZLoss):
101-
# same as NextTokenPredictionWithZLoss, but with z_loss_eps = 0
102-
@property
103-
def name(self):
104-
return "next_token_prediction"
105-
106-
def __call__(
107-
self,
108-
model: VLM,
109-
tokenizer,
110-
images: torch.Tensor,
111-
input_ids: torch.Tensor,
112-
attention_mask: torch.Tensor,
113-
autocast: callable,
114-
):
115-
return super().__call__(
116-
model=model,
117-
tokenizer=tokenizer,
118-
images=images,
119-
input_ids=input_ids,
120-
attention_mask=attention_mask,
121-
autocast=autocast,
122-
z_loss_eps=0,
123-
)
124-
125-
12681
def unwrap_model(model):
12782
"""
12883
Unwrap a model from a DataParallel or DistributedDataParallel wrapper.
@@ -132,30 +87,4 @@ def unwrap_model(model):
13287
):
13388
return model.module
13489
else:
135-
return model
136-
137-
138-
# From OpenLM (https://github.com/mlfoundations/open_lm/blob/main/open_lm/losses.py)
139-
class CrossEntropyLossWithZLoss(CrossEntropyLoss):
140-
def __init__(
141-
self,
142-
eps: float = 1e-4,
143-
weight: Tensor = None,
144-
size_average=None,
145-
ignore_index: int = -100,
146-
reduce=None,
147-
reduction: str = "mean",
148-
label_smoothing: float = 0,
149-
) -> None:
150-
super().__init__(
151-
weight, size_average, ignore_index, reduce, reduction, label_smoothing
152-
)
153-
self.eps = eps
154-
155-
def forward(self, input: Tensor, target: Tensor) -> Tensor:
156-
if self.eps == 0:
157-
return super().forward(input, target)
158-
159-
return super().forward(input, target) + self.eps * torch.square(
160-
torch.logsumexp(input, dim=-1).mean()
161-
)
90+
return model

0 commit comments

Comments
 (0)