-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_model.py
133 lines (107 loc) · 4.91 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
import time
import torch
from model_copy import SingleViewto3D
from r2n2_custom import R2N2
from pytorch3d.datasets.r2n2.utils import collate_batched_R2N2
import dataset_location
from pytorch3d.ops import sample_points_from_meshes
import losses
def get_args_parser():
parser = argparse.ArgumentParser('Singleto3D', add_help=False)
# Model parameters
parser.add_argument('--arch', default='resnet18', type=str)
parser.add_argument('--lr', default=4e-4, type=str)
parser.add_argument('--max_iter', default=10000, type=str)
parser.add_argument('--log_freq', default=1000, type=str)
parser.add_argument('--batch_size', default=2, type=str)
parser.add_argument('--num_workers', default=0, type=str)
parser.add_argument('--type', default='vox', choices=['vox', 'point', 'mesh'], type=str)
parser.add_argument('--n_points', default=15000, type=int)
parser.add_argument('--w_chamfer', default=1.0, type=float)
parser.add_argument('--w_smooth', default=0.1, type=float)
parser.add_argument('--save_freq', default=10000, type=int)
parser.add_argument('--device', default='cuda', type=str)
parser.add_argument('--load_feat', action='store_true')
parser.add_argument('--load_checkpoint', action='store_true')
return parser
def preprocess(feed_dict,args):
images = feed_dict['images'].squeeze(1)
if args.type == "vox":
voxels = feed_dict['voxels'].float()
ground_truth_3d = voxels
elif args.type == "point":
mesh = feed_dict['mesh']
pointclouds_tgt = sample_points_from_meshes(mesh, args.n_points)
ground_truth_3d = pointclouds_tgt
elif args.type == "mesh":
ground_truth_3d = feed_dict["mesh"]
if args.load_feat:
feats = torch.stack(feed_dict['feats'])
return feats.to(args.device), ground_truth_3d.to(args.device)
else:
return images.to(args.device), ground_truth_3d.to(args.device)
def calculate_loss(predictions, ground_truth, args):
if args.type == 'vox':
loss = losses.voxel_loss(predictions,ground_truth)
elif args.type == 'point':
loss = losses.chamfer_loss(predictions, ground_truth)
elif args.type == 'mesh':
sample_trg = sample_points_from_meshes(ground_truth, args.n_points)
sample_pred = sample_points_from_meshes(predictions, args.n_points)
loss_reg = losses.chamfer_loss(sample_pred, sample_trg)
loss_smooth = losses.smoothness_loss(predictions)
loss = args.w_chamfer * loss_reg + args.w_smooth * loss_smooth
return loss
def train_model(args):
r2n2_dataset = R2N2("train", dataset_location.SHAPENET_PATH, dataset_location.R2N2_PATH, dataset_location.SPLITS_PATH, return_voxels=True, return_feats=args.load_feat)
loader = torch.utils.data.DataLoader(
r2n2_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
collate_fn=collate_batched_R2N2,
pin_memory=True,
drop_last=True)
train_loader = iter(loader)
model = SingleViewto3D(args)
model.to(args.device)
model.train()
# ============ preparing optimizer ... ============
optimizer = torch.optim.Adam(model.parameters(), lr = args.lr) # to use with ViTs
start_iter = 0
start_time = time.time()
if args.load_checkpoint:
checkpoint = torch.load(f'checkpoint_{args.type}_Aditi.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_iter = checkpoint['step']
print(f"Succesfully loaded iter {start_iter}")
print("Starting training !")
for step in range(start_iter, args.max_iter):
iter_start_time = time.time()
if step % len(train_loader) == 0: #restart after one epoch
train_loader = iter(loader)
read_start_time = time.time()
feed_dict = next(train_loader)
images_gt, ground_truth_3d = preprocess(feed_dict,args)
read_time = time.time() - read_start_time
prediction_3d = model(images_gt, args)
loss = calculate_loss(prediction_3d, ground_truth_3d, args)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_time = time.time() - start_time
iter_time = time.time() - iter_start_time
loss_vis = loss.cpu().item()
if (step % args.save_freq) == 0:
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}, f'checkpoint_{args.type}_Aditi.pth')
print("[%4d/%4d]; ttime: %.0f (%.2f, %.2f); loss: %.3f" % (step, args.max_iter, total_time, read_time, iter_time, loss_vis))
print('Done!')
if __name__ == '__main__':
parser = argparse.ArgumentParser('Singleto3D', parents=[get_args_parser()])
args = parser.parse_args()
train_model(args)