Skip to content

Commit 344b94c

Browse files
committed
organize
1 parent 7271301 commit 344b94c

19 files changed

+86
-11
lines changed

curricula/__init__.py

Whitespace-only changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

curricula/mentornet.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
class MentorNet(nn.Module):
5+
def __init__(self, num_labels, num_epochs,
6+
decay = 0.9, percentile = 0.7,
7+
label_embed_dim = 2, epoch_embed_dim = 5,
8+
fc_dim = 20, lstm_dim = 1):
9+
super().__init__()
10+
11+
self.register_buffer('avg', None)
12+
self.decay = decay
13+
self.percentile = percentile
14+
15+
self.lstm = nn.LSTM(2, lstm_dim, batch_first = True, bidirectional = True)
16+
17+
self.label_embed = nn.Embedding(num_labels, label_embed_dim)
18+
self.epoch_embed = nn.Embedding(num_epochs, epoch_embed_dim)
19+
self.epoch_embed.weight.requires_grad = False
20+
21+
self.fc = nn.Sequential(
22+
nn.Linear(2*lstm_dim + label_embed_dim + epoch_embed_dim, fc_dim),
23+
nn.Tanh(),
24+
nn.Linear(fc_dim, 1),
25+
nn.Sigmoid()
26+
)
27+
28+
def forward(self, loss, labels, epoch, *args):
29+
with torch.no_grad():
30+
if self.avg is None:
31+
self.avg = torch.quantile(loss, self.percentile)
32+
else:
33+
self.avg = self.decay * self.avg + (1 - self.decay) * torch.quantile(loss, self.percentile)
34+
35+
lossdiff = loss - self.avg
36+
37+
lstm_input = torch.stack([loss, lossdiff], 1).unsqueeze(1)
38+
39+
_, (h, _) = self.lstm(lstm_input)
40+
h = torch.cat([d for d in h], -1)
41+
42+
epochs = torch.ones_like(loss).long() * epoch
43+
epoch_embed = self.epoch_embed(epochs)
44+
label_embed = self.label_embed(labels)
45+
46+
feats = torch.cat([h, epoch_embed, label_embed], 1)
47+
48+
confs = self.fc(feats)
49+
50+
return confs
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)