-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathtmp.py
71 lines (49 loc) · 3.31 KB
/
tmp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from torch.utils.data import DataLoader
from model.encoder import Encoder
from model.decoder import Decoder
from datautil.waymo_dataset import WaymoDataset, waymo_collate_fn
dataset = WaymoDataset('./data/tfrecords', './data/idxs')
dataloader = DataLoader(dataset, batch_size=1, collate_fn=lambda x: waymo_collate_fn(x))
data0 = next(iter(dataloader))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
states_batch, agents_batch_mask, states_padding_mask_batch, \
(states_hidden_mask_BP, states_hidden_mask_CBP, states_hidden_mask_GDP), \
roadgraph_feat_batch, roadgraph_valid_batch, traffic_light_feat_batch, traffic_light_valid_batch, \
agent_rg_mask, agent_traffic_mask = data0
states_batch, agents_batch_mask, states_padding_mask_batch, \
(states_hidden_mask_BP, states_hidden_mask_CBP, states_hidden_mask_GDP), \
roadgraph_feat_batch, roadgraph_valid_batch, traffic_light_feat_batch, traffic_light_valid_batch, \
agent_rg_mask, agent_traffic_mask = states_batch.to(device), agents_batch_mask.to(device), states_padding_mask_batch.to(device), \
(states_hidden_mask_BP.to(device), states_hidden_mask_CBP.to(device), states_hidden_mask_GDP.to(device)), \
roadgraph_feat_batch.to(device), roadgraph_valid_batch.to(device), traffic_light_feat_batch.to(device), traffic_light_valid_batch.to(device), \
agent_rg_mask.to(device), agent_traffic_mask.to(device)
encoder = Encoder(device, in_feat_dim=states_batch.shape[-1], in_dynamic_rg_dim=traffic_light_feat_batch.shape[-1], in_static_rg_dim=roadgraph_feat_batch.shape[-1])
encoder = encoder.to(device)
decoder = Decoder(device)
decoder = decoder.to(device)
# TODO : randomly select hidden mask
states_hidden_mask_batch = states_hidden_mask_BP
no_nonpad_mask = torch.sum((states_padding_mask_batch*~states_hidden_mask_batch),dim=-1) != 0
states_batch = states_batch[no_nonpad_mask]
agents_batch_mask = agents_batch_mask[no_nonpad_mask][:,no_nonpad_mask]
states_padding_mask_batch = states_padding_mask_batch[no_nonpad_mask]
states_hidden_mask_batch = states_hidden_mask_batch[no_nonpad_mask]
agent_rg_mask = agent_rg_mask[no_nonpad_mask]
agent_traffic_mask = agent_traffic_mask[no_nonpad_mask]
encodings = encoder(states_batch, agents_batch_mask, states_padding_mask_batch, states_hidden_mask_batch,
roadgraph_feat_batch, roadgraph_valid_batch, traffic_light_feat_batch, traffic_light_valid_batch,
agent_rg_mask, agent_traffic_mask)
print(encodings.shape)
decoding = decoder(encodings, agents_batch_mask, states_padding_mask_batch,
states_hidden_mask_batch)
print(decoding.shape)
to_predict_mask = states_padding_mask_batch*states_hidden_mask_batch
gt = states_batch[:,:,:6][to_predict_mask] # 6 channel output : x, y, bbox_yaw, velocity_x, velocity_y, vel_yaw
prediction = decoding.permute(1,2,0,3)[to_predict_mask]
# print(prediction)
def some_loss_function(*args):
return 0
loss = some_loss_function(gt, prediction)
# TODO : training code