-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathreward_model.py
91 lines (76 loc) · 3.03 KB
/
reward_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModelForSequenceClassification
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXConfig, GPTNeoXModel, GPTNeoXPreTrainedModel
from transformers.utils import ModelOutput
from dataclasses import dataclass
from typing import Literal, Optional
# Thank OpenAssistant for their helpful code:
# https://github.com/LAION-AI/Open-Assistant/blob/main/model/model_training/models/reward_model.py
class GPTNeoXRewardModelConfig(GPTNeoXConfig):
model_type = "gpt_neox_reward_model"
pooling: Literal["mean", "last"]
def __init__(
self,
pooling: Literal["mean", "last"] = "last",
**kwargs,
):
super().__init__(**kwargs)
self.pooling = pooling or "last"
@dataclass
class GPTNeoXRewardModelOutput(ModelOutput):
"""
Reward model output.
Args:
logits (`torch.FloatTensor` of shape `(batch_size, 1)`):
Reward score
"""
logits: torch.FloatTensor = None
class GPTNeoXRewardModel(GPTNeoXPreTrainedModel):
config_class = GPTNeoXRewardModelConfig
def __init__(self, config):
if type(config) == GPTNeoXConfig:
config = GPTNeoXRewardModelConfig.from_dict(config.to_dict())
super().__init__(config)
self.gpt_neox = GPTNeoXModel(config)
self.out_proj = nn.Linear(config.hidden_size, 1)
self.pooling = config.pooling
def forward(
self,
input_ids,
attention_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
return_dict: Optional[bool] = True,
) -> GPTNeoXRewardModelOutput:
outputs = self.gpt_neox(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.pooling == "mean":
if attention_mask is None:
pooled = hidden_states.mean(dim=1)
else:
pooled = (hidden_states * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
elif self.pooling == "last":
if attention_mask is None:
pooled = hidden_states[:, -1]
else:
last_idx = attention_mask.cumsum(dim=1).argmax(dim=1)
pooled = hidden_states.gather(1, last_idx.view(-1, 1, 1).expand(-1, 1, hidden_states.size(-1))).squeeze(
1
)
else:
raise ValueError(f"Unknown pooling method: {self.pooling}")
logits = self.out_proj(pooled)
if not return_dict:
return (logits,) + outputs[1:]
return GPTNeoXRewardModelOutput(logits=logits)
AutoConfig.register("gpt_neox_reward_model", GPTNeoXRewardModelConfig)
AutoModelForSequenceClassification.register(GPTNeoXRewardModelConfig, GPTNeoXRewardModel)