Skip to content

Commit ede702b

Browse files
authored
Merge pull request advimman#63 from Sanster/add_torchscript_convert_script
add torchscript convert script
2 parents c77bcae + c390bd3 commit ede702b

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

bin/to_jit.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import os
2+
from pathlib import Path
3+
4+
import hydra
5+
import torch
6+
import yaml
7+
from omegaconf import OmegaConf
8+
from torch import nn
9+
10+
from saicinpainting.training.trainers import load_checkpoint
11+
from saicinpainting.utils import register_debug_signal_handlers
12+
13+
14+
class JITWrapper(nn.Module):
15+
def __init__(self, model):
16+
super().__init__()
17+
self.model = model
18+
19+
def forward(self, image, mask):
20+
batch = {
21+
"image": image,
22+
"mask": mask
23+
}
24+
out = self.model(batch)
25+
return out["inpainted"]
26+
27+
28+
@hydra.main(config_path="../configs/prediction", config_name="default.yaml")
29+
def main(predict_config: OmegaConf):
30+
register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
31+
32+
train_config_path = os.path.join(predict_config.model.path, "config.yaml")
33+
with open(train_config_path, "r") as f:
34+
train_config = OmegaConf.create(yaml.safe_load(f))
35+
36+
train_config.training_model.predict_only = True
37+
train_config.visualizer.kind = "noop"
38+
39+
checkpoint_path = os.path.join(
40+
predict_config.model.path, "models", predict_config.model.checkpoint
41+
)
42+
model = load_checkpoint(
43+
train_config, checkpoint_path, strict=False, map_location="cpu"
44+
)
45+
model.eval()
46+
jit_model_wrapper = JITWrapper(model)
47+
48+
image = torch.rand(1, 3, 120, 120)
49+
mask = torch.rand(1, 1, 120, 120)
50+
output = jit_model_wrapper(image, mask)
51+
52+
if torch.cuda.is_available():
53+
device = torch.device("cuda")
54+
else:
55+
device = torch.device("cpu")
56+
57+
image = image.to(device)
58+
mask = mask.to(device)
59+
traced_model = torch.jit.trace(jit_model_wrapper, (image, mask), strict=False).to(device)
60+
61+
save_path = Path(predict_config.save_path)
62+
save_path.parent.mkdir(parents=True, exist_ok=True)
63+
64+
print(f"Saving big-lama.pt model to {save_path}")
65+
traced_model.save(save_path)
66+
67+
print(f"Checking jit model output...")
68+
jit_model = torch.jit.load(str(save_path))
69+
jit_output = jit_model(image, mask)
70+
diff = (output - jit_output).abs().sum()
71+
print(f"diff: {diff}")
72+
73+
74+
if __name__ == "__main__":
75+
main()

0 commit comments

Comments
 (0)