|
| 1 | +/*! |
| 2 | + * Copyright (c) 2018 by Contributors |
| 3 | + * \file amsoftmax-inl.h |
| 4 | + * \brief AmSoftmax from <TODO> |
| 5 | + * \author Jia Guo |
| 6 | + */ |
| 7 | +#ifndef MXNET_OPERATOR_AMSOFTMAX_INL_H_ |
| 8 | +#define MXNET_OPERATOR_AMSOFTMAX_INL_H_ |
| 9 | + |
| 10 | +#include <dmlc/logging.h> |
| 11 | +#include <dmlc/parameter.h> |
| 12 | +#include <mxnet/operator.h> |
| 13 | +#include <cmath> |
| 14 | +#include <map> |
| 15 | +#include <vector> |
| 16 | +#include <string> |
| 17 | +#include "./operator_common.h" |
| 18 | + |
| 19 | +namespace mxnet { |
| 20 | +namespace op { |
| 21 | + |
| 22 | +namespace amsoftmax_enum { |
| 23 | +enum AmSoftmaxOpInputs {kData, kWeight, kLabel}; |
| 24 | +enum AmSoftmaxOpOutputs {kOut, kOOut}; |
| 25 | +enum AmSoftmaxResource {kTempSpace}; |
| 26 | +} |
| 27 | + |
| 28 | +struct AmSoftmaxParam : public dmlc::Parameter<AmSoftmaxParam> { |
| 29 | + float margin; |
| 30 | + float s; |
| 31 | + int num_hidden; |
| 32 | + int verbose; |
| 33 | + float eps; |
| 34 | + DMLC_DECLARE_PARAMETER(AmSoftmaxParam) { |
| 35 | + DMLC_DECLARE_FIELD(margin).set_default(0.5).set_lower_bound(0.0) |
| 36 | + .describe("AmSoftmax margin"); |
| 37 | + DMLC_DECLARE_FIELD(s).set_default(64.0).set_lower_bound(1.0) |
| 38 | + .describe("s to X"); |
| 39 | + DMLC_DECLARE_FIELD(num_hidden).set_lower_bound(1) |
| 40 | + .describe("Number of hidden nodes of the output"); |
| 41 | + DMLC_DECLARE_FIELD(verbose).set_default(0) |
| 42 | + .describe("Log for beta change"); |
| 43 | + DMLC_DECLARE_FIELD(eps).set_default(1e-10f) |
| 44 | + .describe("l2 eps"); |
| 45 | + } |
| 46 | +}; |
| 47 | + |
| 48 | +template<typename xpu, typename DType> |
| 49 | +class AmSoftmaxOp : public Operator { |
| 50 | + public: |
| 51 | + explicit AmSoftmaxOp(AmSoftmaxParam param) { |
| 52 | + this->param_ = param; |
| 53 | + count_ = 0; |
| 54 | + } |
| 55 | + |
| 56 | + virtual void Forward(const OpContext &ctx, |
| 57 | + const std::vector<TBlob> &in_data, |
| 58 | + const std::vector<OpReqType> &req, |
| 59 | + const std::vector<TBlob> &out_data, |
| 60 | + const std::vector<TBlob> &aux_args) { |
| 61 | + using namespace mshadow; |
| 62 | + using namespace mshadow::expr; |
| 63 | + CHECK_EQ(in_data.size(), 3); |
| 64 | + CHECK_EQ(out_data.size(), 2); |
| 65 | + CHECK_EQ(req.size(), 2); |
| 66 | + CHECK_EQ(req[amsoftmax_enum::kOut], kWriteTo); |
| 67 | + Stream<xpu> *stream = ctx.get_stream<xpu>(); |
| 68 | + const int n = in_data[amsoftmax_enum::kData].size(0); //batch size |
| 69 | + const int m = in_data[amsoftmax_enum::kWeight].size(0);//num classes |
| 70 | + Tensor<xpu, 2, DType> x = in_data[amsoftmax_enum::kData].FlatTo2D<xpu, DType>(stream); |
| 71 | + Tensor<xpu, 2, DType> w = in_data[amsoftmax_enum::kWeight].FlatTo2D<xpu, DType>(stream); |
| 72 | + Tensor<xpu, 1, DType> label = in_data[amsoftmax_enum::kLabel].get_with_shape<xpu, 1, DType>(Shape1(n), stream); |
| 73 | + Tensor<xpu, 2, DType> out = out_data[amsoftmax_enum::kOut].FlatTo2D<xpu, DType>(stream); |
| 74 | + Tensor<xpu, 2, DType> oout = out_data[amsoftmax_enum::kOOut].get_with_shape<xpu, 2, DType>(Shape2(n,1), stream); |
| 75 | + //Tensor<xpu, 2, DType> workspace = ctx.requested[amsoftmax_enum::kTempSpace].get_space_typed<xpu, 2, DType>(Shape2(n, 1), stream); |
| 76 | +#if defined(__CUDACC__) |
| 77 | + CHECK_EQ(stream->blas_handle_ownership_, Stream<xpu>::OwnHandle) |
| 78 | + << "Must init CuBLAS handle in stream"; |
| 79 | +#endif |
| 80 | + // original fully connected |
| 81 | + out = dot(x, w.T()); |
| 82 | + if (ctx.is_train) { |
| 83 | + const DType margin = static_cast<DType>(param_.margin); |
| 84 | + const DType s = static_cast<DType>(param_.s); |
| 85 | + AmSoftmaxForward(x, w, label, out, oout, margin, s); |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + //virtual void GradNorm(mshadow::Tensor<xpu, 2, DType> grad, mshadow::Stream<xpu>* s) { |
| 90 | + // using namespace mshadow; |
| 91 | + // using namespace mshadow::expr; |
| 92 | + // Tensor<cpu, 2, DType> grad_cpu(grad.shape_); |
| 93 | + // AllocSpace(&grad_cpu); |
| 94 | + // Copy(grad_cpu, grad, s); |
| 95 | + // DType grad_norm = param_.eps; |
| 96 | + // for(uint32_t i=0;i<grad_cpu.shape_[0];i++) { |
| 97 | + // for(uint32_t j=0;j<grad_cpu.shape_[1];j++) { |
| 98 | + // grad_norm += grad_cpu[i][j]*grad_cpu[i][j]; |
| 99 | + // } |
| 100 | + // } |
| 101 | + // grad_norm = sqrt(grad_norm); |
| 102 | + // grad_cpu /= grad_norm; |
| 103 | + // Copy(grad, grad_cpu, s); |
| 104 | + // FreeSpace(&grad_cpu); |
| 105 | + //} |
| 106 | + |
| 107 | + virtual DType GradNorm(mshadow::Tensor<xpu, 2, DType> grad, mshadow::Stream<xpu>* s) { |
| 108 | + using namespace mshadow; |
| 109 | + using namespace mshadow::expr; |
| 110 | + Tensor<cpu, 2, DType> grad_cpu(grad.shape_); |
| 111 | + AllocSpace(&grad_cpu); |
| 112 | + Copy(grad_cpu, grad, s); |
| 113 | + DType grad_norm = param_.eps; |
| 114 | + for(uint32_t i=0;i<grad_cpu.shape_[0];i++) { |
| 115 | + for(uint32_t j=0;j<grad_cpu.shape_[1];j++) { |
| 116 | + grad_norm += grad_cpu[i][j]*grad_cpu[i][j]; |
| 117 | + } |
| 118 | + } |
| 119 | + grad_norm = sqrt(grad_norm); |
| 120 | + //grad_cpu /= grad_norm; |
| 121 | + //Copy(grad, grad_cpu, s); |
| 122 | + FreeSpace(&grad_cpu); |
| 123 | + return grad_norm; |
| 124 | + } |
| 125 | + virtual void Print(mshadow::Tensor<xpu, 2, DType> tensor, mshadow::Stream<xpu>* s) { |
| 126 | + using namespace mshadow; |
| 127 | + using namespace mshadow::expr; |
| 128 | + Tensor<cpu, 2, DType> tensor_cpu(tensor.shape_); |
| 129 | + AllocSpace(&tensor_cpu); |
| 130 | + Copy(tensor_cpu, tensor, s); |
| 131 | + for(uint32_t i=0;i<tensor_cpu.shape_[0];i++) { |
| 132 | + for(uint32_t j=0;j<tensor_cpu.shape_[1];j++) { |
| 133 | + std::cout<<tensor_cpu[i][j]<<","; |
| 134 | + } |
| 135 | + std::cout<<std::endl; |
| 136 | + } |
| 137 | + FreeSpace(&tensor_cpu); |
| 138 | + } |
| 139 | + |
| 140 | + virtual void Backward(const OpContext &ctx, |
| 141 | + const std::vector<TBlob> &out_grad, |
| 142 | + const std::vector<TBlob> &in_data, |
| 143 | + const std::vector<TBlob> &out_data, |
| 144 | + const std::vector<OpReqType> &req, |
| 145 | + const std::vector<TBlob> &in_grad, |
| 146 | + const std::vector<TBlob> &aux_args) { |
| 147 | + using namespace mshadow; |
| 148 | + using namespace mshadow::expr; |
| 149 | + CHECK_EQ(out_grad.size(), 1); |
| 150 | + CHECK_EQ(in_data.size(), 3); |
| 151 | + CHECK_EQ(out_data.size(), 2); |
| 152 | + CHECK_GE(in_grad.size(), 2); |
| 153 | + CHECK_GE(req.size(), 2); |
| 154 | + CHECK_EQ(req[amsoftmax_enum::kData], kWriteTo); |
| 155 | + CHECK_EQ(req[amsoftmax_enum::kWeight], kWriteTo); |
| 156 | + Stream<xpu> *stream = ctx.get_stream<xpu>(); |
| 157 | + const int n = in_data[amsoftmax_enum::kData].size(0); |
| 158 | + const int m = in_data[amsoftmax_enum::kWeight].size(0); |
| 159 | + Tensor<xpu, 2, DType> x = in_data[amsoftmax_enum::kData].FlatTo2D<xpu, DType>(stream); |
| 160 | + Tensor<xpu, 2, DType> w = in_data[amsoftmax_enum::kWeight].FlatTo2D<xpu, DType>(stream); |
| 161 | + Tensor<xpu, 1, DType> label = in_data[amsoftmax_enum::kLabel].get_with_shape<xpu, 1, DType>(Shape1(n), stream); |
| 162 | + Tensor<xpu, 2, DType> out = out_data[amsoftmax_enum::kOut].FlatTo2D<xpu, DType>(stream); |
| 163 | + Tensor<xpu, 2, DType> oout = out_data[amsoftmax_enum::kOOut].get_with_shape<xpu, 2, DType>(Shape2(n,1), stream); |
| 164 | + Tensor<xpu, 2, DType> o_grad = out_grad[amsoftmax_enum::kOut].FlatTo2D<xpu, DType>(stream); |
| 165 | + Tensor<xpu, 2, DType> x_grad = in_grad[amsoftmax_enum::kData].FlatTo2D<xpu, DType>(stream); |
| 166 | + Tensor<xpu, 2, DType> w_grad = in_grad[amsoftmax_enum::kWeight].FlatTo2D<xpu, DType>(stream); |
| 167 | + Tensor<xpu, 2, DType> workspace = ctx.requested[amsoftmax_enum::kTempSpace].get_space_typed<xpu, 2, DType>(Shape2(n, 1), stream); |
| 168 | +#if defined(__CUDACC__) |
| 169 | + CHECK_EQ(stream->blas_handle_ownership_, Stream<xpu>::OwnHandle) |
| 170 | + << "Must init CuBLAS handle in stream"; |
| 171 | +#endif |
| 172 | + // original fully connected |
| 173 | + x_grad = dot(o_grad, w); |
| 174 | + w_grad = dot(o_grad.T(), x); |
| 175 | + // large margin fully connected |
| 176 | + const DType margin = static_cast<DType>(param_.margin); |
| 177 | + const DType s = static_cast<DType>(param_.s); |
| 178 | + AmSoftmaxBackward(x, w, label, out, oout, o_grad, x_grad, w_grad, workspace, margin, s); |
| 179 | + count_+=1; |
| 180 | + if (param_.verbose) { |
| 181 | + if(count_%param_.verbose==0) { |
| 182 | + DType n = GradNorm(x_grad, stream); |
| 183 | + LOG(INFO)<<"x_grad norm:"<<n; |
| 184 | + n = GradNorm(w_grad, stream); |
| 185 | + LOG(INFO)<<"w_grad norm:"<<n; |
| 186 | + //Print(oout, stream); |
| 187 | + } |
| 188 | + } |
| 189 | + } |
| 190 | + |
| 191 | + |
| 192 | + private: |
| 193 | + AmSoftmaxParam param_; |
| 194 | + uint32_t count_; |
| 195 | +}; |
| 196 | + |
| 197 | +template<typename xpu> |
| 198 | +Operator *CreateOp(AmSoftmaxParam param, int dtype); |
| 199 | + |
| 200 | +#if DMLC_USE_CXX11 |
| 201 | +class AmSoftmaxProp : public OperatorProperty { |
| 202 | + public: |
| 203 | + void Init(const std::vector<std::pair<std::string, std::string> > &kwargs) override { |
| 204 | + param_.Init(kwargs); |
| 205 | + } |
| 206 | + |
| 207 | + std::map<std::string, std::string> GetParams() const override { |
| 208 | + return param_.__DICT__(); |
| 209 | + } |
| 210 | + |
| 211 | + std::vector<std::string> ListArguments() const override { |
| 212 | + return {"data", "weight", "label"}; |
| 213 | + } |
| 214 | + |
| 215 | + std::vector<std::string> ListOutputs() const override { |
| 216 | + return {"output", "ooutput"}; |
| 217 | + } |
| 218 | + |
| 219 | + int NumOutputs() const override { |
| 220 | + return 2; |
| 221 | + } |
| 222 | + |
| 223 | + int NumVisibleOutputs() const override { |
| 224 | + return 1; |
| 225 | + } |
| 226 | + |
| 227 | + bool InferShape(std::vector<TShape> *in_shape, |
| 228 | + std::vector<TShape> *out_shape, |
| 229 | + std::vector<TShape> *aux_shape) const override { |
| 230 | + using namespace mshadow; |
| 231 | + CHECK_EQ(in_shape->size(), 3) << "Input:[data, label, weight]"; |
| 232 | + const TShape &dshape = in_shape->at(amsoftmax_enum::kData); |
| 233 | + const TShape &lshape = in_shape->at(amsoftmax_enum::kLabel); |
| 234 | + CHECK_EQ(dshape.ndim(), 2) << "data shape should be (batch_size, feature_dim)"; |
| 235 | + CHECK_EQ(lshape.ndim(), 1) << "label shape should be (batch_size,)"; |
| 236 | + const int n = dshape[0]; |
| 237 | + const int feature_dim = dshape[1]; |
| 238 | + const int m = param_.num_hidden; |
| 239 | + SHAPE_ASSIGN_CHECK(*in_shape, amsoftmax_enum::kWeight, Shape2(m, feature_dim)); |
| 240 | + out_shape->clear(); |
| 241 | + out_shape->push_back(Shape2(n, m)); // output |
| 242 | + out_shape->push_back(Shape2(n, 1)); // output |
| 243 | + aux_shape->clear(); |
| 244 | + return true; |
| 245 | + } |
| 246 | + |
| 247 | + std::vector<ResourceRequest> BackwardResource( |
| 248 | + const std::vector<TShape> &in_shape) const override { |
| 249 | + return {ResourceRequest::kTempSpace}; |
| 250 | + } |
| 251 | + |
| 252 | + std::vector<int> DeclareBackwardDependency( |
| 253 | + const std::vector<int> &out_grad, |
| 254 | + const std::vector<int> &in_data, |
| 255 | + const std::vector<int> &out_data) const override { |
| 256 | + return {out_grad[amsoftmax_enum::kOut], |
| 257 | + in_data[amsoftmax_enum::kData], |
| 258 | + in_data[amsoftmax_enum::kWeight], in_data[amsoftmax_enum::kLabel]}; |
| 259 | + } |
| 260 | + |
| 261 | + std::string TypeString() const override { |
| 262 | + return "AmSoftmax"; |
| 263 | + } |
| 264 | + |
| 265 | + OperatorProperty *Copy() const override { |
| 266 | + auto ptr = new AmSoftmaxProp(); |
| 267 | + ptr->param_ = param_; |
| 268 | + return ptr; |
| 269 | + } |
| 270 | + |
| 271 | + Operator *CreateOperator(Context ctx) const override { |
| 272 | + LOG(FATAL) << "Not Implemented."; |
| 273 | + return NULL; |
| 274 | + } |
| 275 | + |
| 276 | + Operator *CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape, |
| 277 | + std::vector<int> *in_type) const override; |
| 278 | + |
| 279 | + private: |
| 280 | + AmSoftmaxParam param_; |
| 281 | +}; |
| 282 | +#endif // DMLC_USE_CXX11 |
| 283 | + |
| 284 | +} // namespace op |
| 285 | +} // namespace mxnet |
| 286 | + |
| 287 | +#endif |
0 commit comments