-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
Copy pathinstanceNormalizationPlugin.h
152 lines (114 loc) · 5.57 KB
/
instanceNormalizationPlugin.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef TRT_INSTANCE_NORMALIZATION_PLUGIN_H
#define TRT_INSTANCE_NORMALIZATION_PLUGIN_H
#include "NvInfer.h"
#include "NvInferPlugin.h"
#include "common/plugin.h"
#include "common/serialize.hpp"
#include "instanceNormalizationPlugin/instanceNormFwd.h"
#include <cuda_fp16.h>
#include <iostream>
#include <string>
#include <vector>
typedef uint16_t half_type;
namespace nvinfer1
{
namespace plugin
{
class InstanceNormalizationV3Plugin : public IPluginV3,
public IPluginV3OneCore,
public IPluginV3OneBuild,
public IPluginV3OneRuntime
{
public:
InstanceNormalizationV3Plugin(float epsilon, nvinfer1::Weights const& scale, nvinfer1::Weights const& bias,
int32_t relu = 0, float alpha = 0.F);
InstanceNormalizationV3Plugin(float epsilon, std::vector<float> const& scale, std::vector<float> const& bias,
int32_t relu = 0, float alpha = 0.F);
InstanceNormalizationV3Plugin(void const* serialData, size_t serialLength);
InstanceNormalizationV3Plugin() = delete;
InstanceNormalizationV3Plugin(InstanceNormalizationV3Plugin const&) = default;
~InstanceNormalizationV3Plugin() override;
int32_t getNbOutputs() const noexcept override;
IPluginCapability* getCapabilityInterface(PluginCapabilityType type) noexcept override;
InstanceNormalizationV3Plugin* clone() noexcept override;
char const* getPluginName() const noexcept override;
char const* getPluginNamespace() const noexcept override;
size_t getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override;
int32_t enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
// DynamicExt plugin supportsFormat update.
bool supportsFormatCombination(
int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override;
char const* getPluginVersion() const noexcept override;
void setPluginNamespace(char const* pluginNamespace) noexcept;
int32_t getOutputDataTypes(
DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept override;
int32_t getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs,
int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept override;
IPluginV3* attachToContext(IPluginResourceContext* context) noexcept override;
int32_t configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out,
int32_t nbOutputs) noexcept override;
int32_t onShapeChange(
PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
PluginFieldCollection const* getFieldsToSerialize() noexcept override;
int32_t initializeContext();
protected:
void exitContext();
private:
float mEpsilon{};
float mAlpha{};
int32_t mRelu{};
int32_t mNchan{};
std::vector<float> mHostScale;
std::vector<float> mHostBias;
float* mDeviceScale{nullptr};
float* mDeviceBias{nullptr};
nvinfer1::pluginInternal::cudnnHandle_t mCudnnHandle{nullptr};
nvinfer1::pluginInternal::CudnnWrapper& mCudnnWrapper = nvinfer1::pluginInternal::getCudnnWrapper();
nvinfer1::pluginInternal::cudnnTensorDescriptor_t mXDescriptor{nullptr};
nvinfer1::pluginInternal::cudnnTensorDescriptor_t mYDescriptor{nullptr};
nvinfer1::pluginInternal::cudnnTensorDescriptor_t mBDescriptor{nullptr};
std::string mPluginNamespace;
bool mInitialized{false};
int32_t mCudaDriverVersion{-1};
std::vector<nvinfer1::PluginField> mDataToSerialize;
nvinfer1::PluginFieldCollection mFCToSerialize;
// NDHWC implementation
instance_norm_impl::InstanceNormFwdContext mContext;
};
class InstanceNormalizationV3PluginCreator : public nvinfer1::IPluginCreatorV3One
{
public:
InstanceNormalizationV3PluginCreator();
~InstanceNormalizationV3PluginCreator() override = default;
char const* getPluginName() const noexcept override;
char const* getPluginVersion() const noexcept override;
PluginFieldCollection const* getFieldNames() noexcept override;
IPluginV3* createPlugin(char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept override;
void setPluginNamespace(char const* libNamespace) noexcept;
char const* getPluginNamespace() const noexcept override;
private:
static PluginFieldCollection mFC;
static std::vector<PluginField> mPluginAttributes;
std::string mNamespace;
};
} // namespace plugin
} // namespace nvinfer1
#endif // TRT_INSTANCE_NORMALIZATION_PLUGIN_H