1
1
from open_flamingo .src .vlm import VLM
2
2
import torch
3
- from torch import Tensor
4
- from torch .nn import CrossEntropyLoss
5
3
6
- SUPPORTED_LOSSES = ["next_token_prediction" , "next_token_prediction_with_z_loss" ]
4
+ SUPPORTED_LOSSES = ["next_token_prediction" ]
7
5
8
6
9
7
def get_loss_fn (loss_name ):
10
8
if loss_name == "next_token_prediction" :
11
9
return NextTokenPrediction ()
12
- elif loss_name == "next_token_prediction_with_z_loss" :
13
- return NextTokenPredictionWithZLoss ()
14
10
else :
15
11
raise ValueError (
16
12
f"Loss { loss_name } not supported. Supported losses: { SUPPORTED_LOSSES } "
@@ -47,10 +43,10 @@ def __call__(
47
43
raise NotImplementedError
48
44
49
45
50
- class NextTokenPredictionWithZLoss (Loss ):
46
+ class NextTokenPrediction (Loss ):
51
47
@property
52
48
def name (self ):
53
- return "next_token_prediction_with_z_loss "
49
+ return "next_token_prediction "
54
50
55
51
def __call__ (
56
52
self ,
@@ -60,7 +56,6 @@ def __call__(
60
56
input_ids : torch .Tensor ,
61
57
attention_mask : torch .Tensor ,
62
58
autocast : callable ,
63
- z_loss_eps : float = 1e-4 ,
64
59
):
65
60
# set up labels; language model is expected to handle shifting
66
61
labels = input_ids .clone ()
@@ -74,55 +69,15 @@ def __call__(
74
69
75
70
# call forward
76
71
with autocast ():
77
- logits = model (
72
+ loss = model (
78
73
vision_x = images ,
79
74
lang_x = input_ids ,
80
75
attention_mask = attention_mask ,
81
76
labels = labels ,
82
- )[1 ]
83
-
84
- logits = logits .float ()
85
-
86
- # Shift so that tokens < n predict n
87
- shift_logits = logits [..., :- 1 , :].contiguous ()
88
- shift_labels = labels [..., 1 :].contiguous ()
89
- # Flatten the tokens
90
- loss_fct = CrossEntropyLossWithZLoss (eps = z_loss_eps )
91
- shift_logits = shift_logits .view (- 1 , unwrap_model (model ).lang_model .config .vocab_size )
92
- shift_labels = shift_labels .view (- 1 )
93
- # Enable model parallelism
94
- shift_labels = shift_labels .to (shift_logits .device )
95
- loss = loss_fct (shift_logits , shift_labels )
96
-
77
+ )[0 ]
97
78
return loss
98
79
99
80
100
- class NextTokenPrediction (NextTokenPredictionWithZLoss ):
101
- # same as NextTokenPredictionWithZLoss, but with z_loss_eps = 0
102
- @property
103
- def name (self ):
104
- return "next_token_prediction"
105
-
106
- def __call__ (
107
- self ,
108
- model : VLM ,
109
- tokenizer ,
110
- images : torch .Tensor ,
111
- input_ids : torch .Tensor ,
112
- attention_mask : torch .Tensor ,
113
- autocast : callable ,
114
- ):
115
- return super ().__call__ (
116
- model = model ,
117
- tokenizer = tokenizer ,
118
- images = images ,
119
- input_ids = input_ids ,
120
- attention_mask = attention_mask ,
121
- autocast = autocast ,
122
- z_loss_eps = 0 ,
123
- )
124
-
125
-
126
81
def unwrap_model (model ):
127
82
"""
128
83
Unwrap a model from a DataParallel or DistributedDataParallel wrapper.
@@ -132,30 +87,4 @@ def unwrap_model(model):
132
87
):
133
88
return model .module
134
89
else :
135
- return model
136
-
137
-
138
- # From OpenLM (https://github.com/mlfoundations/open_lm/blob/main/open_lm/losses.py)
139
- class CrossEntropyLossWithZLoss (CrossEntropyLoss ):
140
- def __init__ (
141
- self ,
142
- eps : float = 1e-4 ,
143
- weight : Tensor = None ,
144
- size_average = None ,
145
- ignore_index : int = - 100 ,
146
- reduce = None ,
147
- reduction : str = "mean" ,
148
- label_smoothing : float = 0 ,
149
- ) -> None :
150
- super ().__init__ (
151
- weight , size_average , ignore_index , reduce , reduction , label_smoothing
152
- )
153
- self .eps = eps
154
-
155
- def forward (self , input : Tensor , target : Tensor ) -> Tensor :
156
- if self .eps == 0 :
157
- return super ().forward (input , target )
158
-
159
- return super ().forward (input , target ) + self .eps * torch .square (
160
- torch .logsumexp (input , dim = - 1 ).mean ()
161
- )
90
+ return model
0 commit comments