Skip to content

Commit

Permalink
DeformConv (test), GroupNormalization (Dynamo)
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Apr 20, 2024
1 parent ce12bd6 commit c9e1e7b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ saved_model/
node_modules/

.act.secrets
.act.env
.act.env

*.sarif
37 changes: 37 additions & 0 deletions make_test_op/make_DeformConv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#! /usr/bin/env python

import torch
import torch.nn as nn
import numpy as np
import onnx
from onnxsim import simplify
import numpy as np
np.random.seed(0)
from torchvision.ops import deform_conv2d

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x):
kh, kw = 3, 3
weight = torch.rand(5, 3, kh, kw)
offset = torch.rand(4, 2 * kh * kw, 8, 8)
mask = torch.rand(4, kh * kw, 8, 8)
return deform_conv2d(input, offset, weight, mask=mask)

if __name__ == "__main__":
OPSET=19
MODEL = f'DeformConv'
model = Model()
onnx_file = f"{MODEL}_{OPSET}.onnx"
x = torch.randn(4, 3, 10, 10)
onnx_program = torch.onnx.dynamo_export(model, x)
onnx_program.save(onnx_file)

# model_onnx1 = onnx.load(onnx_file)
# model_onnx1 = onnx.shape_inference.infer_shapes(model_onnx1)
# onnx.save(model_onnx1, onnx_file)
# model_onnx2 = onnx.load(onnx_file)
# model_simp, check = simplify(model_onnx2)
# onnx.save(model_simp, onnx_file)
6 changes: 5 additions & 1 deletion make_test_op/make_GroupNormalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,8 @@ def forward(self, x):
onnx.save(model_onnx1, onnx_file)
model_onnx2 = onnx.load(onnx_file)
model_simp, check = simplify(model_onnx2)
onnx.save(model_simp, onnx_file)
onnx.save(model_simp, onnx_file)

onnx_file = f"{MODEL}_{OPSET}_dynamo.onnx"
onnx_program = torch.onnx.dynamo_export(model, x)
onnx_program.save(onnx_file)

0 comments on commit c9e1e7b

Please sign in to comment.