Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EfficientPhys onnx转ncnn模型转换报错 #5452

Closed
408550969 opened this issue May 9, 2024 · 3 comments
Closed

EfficientPhys onnx转ncnn模型转换报错 #5452

408550969 opened this issue May 9, 2024 · 3 comments
Labels

Comments

@408550969
Copy link

408550969 commented May 9, 2024

error log | 日志或报错信息 | ログ

model | 模型 | モデル

  1. original model
"""EfficientPhys: Enabling Simple, Fast and Accurate Camera-Based Vitals Measurement
Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV 2023)
Xin Liu, Brial Hill, Ziheng Jiang, Shwetak Patel, Daniel McDuff
"""

import torch
import torch.nn as nn


class Attention_mask(nn.Module):
    def __init__(self):
        super(Attention_mask, self).__init__()

    def forward(self, x):
        xsum = torch.sum(x, dim=2, keepdim=True)
        xsum = torch.sum(xsum, dim=3, keepdim=True)
        xshape = tuple(x.size())
        return x / xsum * xshape[2] * xshape[3] * 0.5

    def get_config(self):
        """May be generated manually. """
        config = super(Attention_mask, self).get_config()
        return config


class TSM(nn.Module):
    def __init__(self, n_segment=10, fold_div=3):
        super(TSM, self).__init__()
        self.n_segment = n_segment
        self.fold_div = fold_div

    def forward(self, x):
        nt, c, h, w = x.size()
        n_batch = nt // self.n_segment
        x = x.view(n_batch, self.n_segment, c, h, w)
        fold = c // self.fold_div
        out = torch.zeros_like(x)
        out[:, :-1, :fold] = x[:, 1:, :fold]  # shift left
        out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold]  # shift right
        out[:, :, 2 * fold:] = x[:, :, 2 * fold:]  # not shift
        return out.view(nt, c, h, w)


class EfficientPhys(nn.Module):

    def __init__(self, in_channels=3, nb_filters1=32, nb_filters2=64, kernel_size=3, dropout_rate1=0.25,
                 dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128, frame_depth=20, img_size=36, channel='raw'):
        super(EfficientPhys, self).__init__()
        self.in_channels = in_channels
        self.kernel_size = kernel_size
        self.dropout_rate1 = dropout_rate1
        self.dropout_rate2 = dropout_rate2
        self.pool_size = pool_size
        self.nb_filters1 = nb_filters1
        self.nb_filters2 = nb_filters2
        self.nb_dense = nb_dense
        # TSM layers
        self.TSM_1 = TSM(n_segment=frame_depth)
        self.TSM_2 = TSM(n_segment=frame_depth)
        self.TSM_3 = TSM(n_segment=frame_depth)
        self.TSM_4 = TSM(n_segment=frame_depth)
        # Motion branch convs
        self.motion_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1),
                                  bias=True)
        self.motion_conv2 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True)
        self.motion_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1),
                                  bias=True)
        self.motion_conv4 = nn.Conv2d(self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True)
        # Attention layers
        self.apperance_att_conv1 = nn.Conv2d(self.nb_filters1, 1, kernel_size=1, padding=(0, 0), bias=True)
        self.attn_mask_1 = Attention_mask()
        self.apperance_att_conv2 = nn.Conv2d(self.nb_filters2, 1, kernel_size=1, padding=(0, 0), bias=True)
        self.attn_mask_2 = Attention_mask()
        # Avg pooling
        self.avg_pooling_1 = nn.AvgPool2d(self.pool_size)
        self.avg_pooling_2 = nn.AvgPool2d(self.pool_size)
        self.avg_pooling_3 = nn.AvgPool2d(self.pool_size)
        # Dropout layers
        self.dropout_1 = nn.Dropout(self.dropout_rate1)
        self.dropout_2 = nn.Dropout(self.dropout_rate1)
        self.dropout_3 = nn.Dropout(self.dropout_rate1)
        self.dropout_4 = nn.Dropout(self.dropout_rate2)
        # Dense layers
        if img_size == 36:
            self.final_dense_1 = nn.Linear(3136, self.nb_dense, bias=True)
        elif img_size == 72:
            self.final_dense_1 = nn.Linear(16384, self.nb_dense, bias=True)
        elif img_size == 96:
            self.final_dense_1 = nn.Linear(30976, self.nb_dense, bias=True)
        else:
            raise Exception('Unsupported image size')
        self.final_dense_2 = nn.Linear(self.nb_dense, 1, bias=True)
        self.batch_norm = nn.BatchNorm2d(3)
        self.channel = channel

    def forward(self, inputs, params=None):

        #inputs = torch.diff(inputs, dim=0)
        inputs = self.batch_norm(inputs)

        network_input = self.TSM_1(inputs)
        d1 = torch.tanh(self.motion_conv1(network_input))
        d1 = self.TSM_2(d1)
        d2 = torch.tanh(self.motion_conv2(d1))

        g1 = torch.sigmoid(self.apperance_att_conv1(d2))
        g1 = self.attn_mask_1(g1)
        gated1 = d2 * g1

        d3 = self.avg_pooling_1(gated1)
        d4 = self.dropout_1(d3)

        d4 = self.TSM_3(d4)
        d5 = torch.tanh(self.motion_conv3(d4))
        d5 = self.TSM_4(d5)
        d6 = torch.tanh(self.motion_conv4(d5))

        g2 = torch.sigmoid(self.apperance_att_conv2(d6))
        g2 = self.attn_mask_2(g2)
        gated2 = d6 * g2

        d7 = self.avg_pooling_3(gated2)
        d8 = self.dropout_3(d7)
        d9 = d8.view(d8.size(0), -1)
        d10 = torch.tanh(self.final_dense_1(d9))
        d11 = self.dropout_4(d10)
        out = self.final_dense_2(d11)

        return out

how to reproduce | 复现步骤 | 再現方法

转onnx时无异常,在onnx转ncnn时报如下错:

Shape not supported yet!
ConstantOfShape not supported yet!
  # value 4
Shape not supported yet!
Expand not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
Shape not supported yet!
Unknown data type 0
ScatterND not supported yet!
Shape not supported yet!
Expand not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
Shape not supported yet!
Unknown data type 0
ScatterND not supported yet!
Shape not supported yet!
Expand not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
Shape not supported yet!
Unknown data type 0
ScatterND not supported yet!
Shape not supported yet!
ConstantOfShape not supported yet!
  # value 4
Shape not supported yet!
Expand not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
Shape not supported yet!
Unknown data type 0
ScatterND not supported yet!
Shape not supported yet!
Expand not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
Shape not supported yet!
Unknown data type 0
ScatterND not supported yet!
Shape not supported yet!
Expand not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
Shape not supported yet!
Unknown data type 0
ScatterND not supported yet!
Shape not supported yet!
ConstantOfShape not supported yet!
  # value 4
Shape not supported yet!
Expand not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
Shape not supported yet!
Unknown data type 0
ScatterND not supported yet!
Shape not supported yet!
Expand not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
Shape not supported yet!
Unknown data type 0
ScatterND not supported yet!
Shape not supported yet!
Expand not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
Shape not supported yet!
Unknown data type 0
ScatterND not supported yet!
Shape not supported yet!
ConstantOfShape not supported yet!
  # value 4
Shape not supported yet!
Expand not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
Shape not supported yet!
Unknown data type 0
ScatterND not supported yet!
Shape not supported yet!
Expand not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
Shape not supported yet!
Unknown data type 0
ScatterND not supported yet!
Shape not supported yet!
Expand not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Range not supported yet!
Shape not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
ConstantOfShape not supported yet!
  # value 4
Equal not supported yet!
Where not supported yet!
Expand not supported yet!
Shape not supported yet!
Unknown data type 0
ScatterND not supported yet!
Shape not supported yet!
Gather not supported yet!
  # axis=0
Unknown data type 0

直接使用pnnx也会报错:
pnnxparam = ./model.pnnx.param
pnnxbin = ./model.pnnx.bin
pnnxpy = ./model_pnnx.py
pnnxonnx = ./model.pnnx.onnx
ncnnparam = ./model.ncnn.param
ncnnbin = ./model.ncnn.bin
ncnnpy = ./model_ncnn.py
fp16 = 1
optlevel = 2
device = cpu
inputshape = [180,3,72,72]f32
inputshape2 =
customop =
moduleop =
############# pass_level0
inline module = EfficientPhys.Attention_mask
inline module = EfficientPhys.TSM
inline module = EfficientPhys.Attention_mask
inline module = EfficientPhys.TSM


############# pass_level1
############# pass_level2
############# pass_level3
############# pass_level4
############# pass_level5
############# pass_ncnn
fallback batch axis 233 for operand 0
生成的1kb的bin和param

@408550969 408550969 changed the title onnx转ncnn模型转换报错 EfficientPhys onnx转ncnn模型转换报错 May 9, 2024
@408550969 408550969 reopened this May 9, 2024
@nihui nihui added the bug label May 9, 2024
@nihui
Copy link
Member

nihui commented May 9, 2024

可复现

@nihui
Copy link
Member

nihui commented May 10, 2024

#5455

pnnx 转换修好了,但是模型中有5维的操作,这些ncnn不支持

@408550969
Copy link
Author

感谢nihui,今年的绩效就靠大佬了ヽ(=^・ω・^=)丿

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants