|
| 1 | +using Microsoft.ML.OnnxRuntime; |
| 2 | +using Microsoft.ML.OnnxRuntime.Tensors; |
| 3 | + |
| 4 | +namespace RWKV |
| 5 | +{ |
| 6 | + public enum OnnxModelType |
| 7 | + { |
| 8 | + FP16, |
| 9 | + FP32 |
| 10 | + } |
| 11 | + |
| 12 | + public class OnnxModel |
| 13 | + { |
| 14 | + private InferenceSession _inferenceSession; |
| 15 | + private Type _type; |
| 16 | + private int _embed; |
| 17 | + private int _layers; |
| 18 | + private List<string> _input_names; |
| 19 | + private List<string> _output_names; |
| 20 | + private List<NamedOnnxValue> _inputs; |
| 21 | + private OnnxModelType _modelType; |
| 22 | + |
| 23 | + public OnnxModelType ModelType => _modelType; |
| 24 | + |
| 25 | + public OnnxModel(string model, int embed, int layers) |
| 26 | + { |
| 27 | + var options = new SessionOptions(); |
| 28 | + options.AppendExecutionProvider_CPU(); |
| 29 | + options.AppendExecutionProvider_CUDA(); |
| 30 | + _inferenceSession = new InferenceSession(model, options); |
| 31 | + _type = _inferenceSession.InputMetadata["instate0"].ElementType; |
| 32 | + _embed = embed; |
| 33 | + _layers = layers; |
| 34 | + _input_names = _inferenceSession.InputMetadata.Select(x => x.Key).ToList(); |
| 35 | + _output_names = _inferenceSession.OutputMetadata.Select(x => x.Key).ToList(); |
| 36 | + _inputs = new List<NamedOnnxValue>(); |
| 37 | + |
| 38 | + if (_type == typeof(Float16)) |
| 39 | + { |
| 40 | + _modelType = OnnxModelType.FP16; |
| 41 | + } |
| 42 | + else if (_type == typeof(float)) |
| 43 | + { |
| 44 | + _modelType = OnnxModelType.FP32; |
| 45 | + } |
| 46 | + else |
| 47 | + { |
| 48 | + throw new NotSupportedException(); |
| 49 | + } |
| 50 | + } |
| 51 | + |
| 52 | + public object GetEmptyStates() |
| 53 | + { |
| 54 | + switch (_modelType) |
| 55 | + { |
| 56 | + case OnnxModelType.FP16: |
| 57 | + { |
| 58 | + var state = new List<Tensor<Float16>>(); |
| 59 | + for (int i = 0; i < _layers; i++) |
| 60 | + { |
| 61 | + state.Add(GDenseTensor<Float16>(0)); |
| 62 | + state.Add(GDenseTensor<Float16>(0)); |
| 63 | + state.Add(GDenseTensor<Float16>(0)); |
| 64 | + state.Add(GDenseTensor<Float16>(0)); |
| 65 | + state.Add(GDenseTensor<Float16>(64512)); |
| 66 | + } |
| 67 | + return state; |
| 68 | + } |
| 69 | + case OnnxModelType.FP32: |
| 70 | + { |
| 71 | + var state = new List<Tensor<float>>(); |
| 72 | + for (int i = 0; i < _layers; i++) |
| 73 | + { |
| 74 | + state.Add(GDenseTensor<float>(0)); |
| 75 | + state.Add(GDenseTensor<float>(0)); |
| 76 | + state.Add(GDenseTensor<float>(0)); |
| 77 | + state.Add(GDenseTensor<float>(0)); |
| 78 | + state.Add(GDenseTensor<float>(float.NegativeInfinity)); |
| 79 | + } |
| 80 | + return state; |
| 81 | + }; |
| 82 | + default: |
| 83 | + throw new NotSupportedException(); |
| 84 | + } |
| 85 | + } |
| 86 | + |
| 87 | + public (IEnumerable<float> logits, object state) Forward(int xi, object state) |
| 88 | + { |
| 89 | + switch (_modelType) |
| 90 | + { |
| 91 | + case OnnxModelType.FP16: |
| 92 | + { |
| 93 | + var ret = Forward_FP16(xi, (List<Tensor<Float16>>)state); |
| 94 | + return (ret.logits.Select(x => HalfToSinglePrecision(x)).AsEnumerable(), ret.state); |
| 95 | + } |
| 96 | + case OnnxModelType.FP32: |
| 97 | + { |
| 98 | + var ret = Forward_FP32(xi, (List<Tensor<float>>)state); |
| 99 | + return (ret.logits.AsEnumerable(), ret.state); |
| 100 | + } |
| 101 | + default: |
| 102 | + throw new NotSupportedException(); |
| 103 | + } |
| 104 | + } |
| 105 | + |
| 106 | + private (Tensor<Float16> logits, IList<Tensor<Float16>> state) Forward_FP16(int xi, List<Tensor<Float16>> state) |
| 107 | + { |
| 108 | + _inputs.Clear(); |
| 109 | + var input = new DenseTensor<int>(new[] { xi }, new[] { 1 }); |
| 110 | + _inputs.Add(NamedOnnxValue.CreateFromTensor(_input_names.First(), input)); |
| 111 | + for (int i = 1; i < _input_names.Count; i++) |
| 112 | + { |
| 113 | + _inputs.Add(NamedOnnxValue.CreateFromTensor(_input_names[i], state[i - 1])); |
| 114 | + } |
| 115 | + var data = _inferenceSession.Run(_inputs); |
| 116 | + return (data.First().AsTensor<Float16>(), data.Skip(1).Select(x => x.AsTensor<Float16>()).ToList()); |
| 117 | + } |
| 118 | + |
| 119 | + private (Tensor<float> logits, IList<Tensor<float>> state) Forward_FP32(int xi, IList<Tensor<float>> state) |
| 120 | + { |
| 121 | + _inputs.Clear(); |
| 122 | + var input = new DenseTensor<int>(new[] { xi }, new[] { 1 }); |
| 123 | + _inputs.Add(NamedOnnxValue.CreateFromTensor(_input_names.First(), input)); |
| 124 | + for (int i = 1; i < _input_names.Count; i++) |
| 125 | + { |
| 126 | + _inputs.Add(NamedOnnxValue.CreateFromTensor(_input_names[i], state[i - 1])); |
| 127 | + } |
| 128 | + var data = _inferenceSession.Run(_inputs); |
| 129 | + return (data.First().AsTensor<float>(), data.Skip(1).Select(x => x.AsTensor<float>()).ToList()); |
| 130 | + } |
| 131 | + |
| 132 | + private float HalfToSinglePrecision(ushort half) |
| 133 | + { |
| 134 | + uint sign = (uint)(half >> 15); |
| 135 | + uint exponent = (uint)((half & 0x7C00) >> 10); |
| 136 | + uint mantissa = (uint)(half & 0x03FF); |
| 137 | + |
| 138 | + uint singleSign = sign << 31; |
| 139 | + uint singleExponent = (exponent + 127 - 15) << 23; |
| 140 | + uint singleMantissa = mantissa << (23 - 10); |
| 141 | + |
| 142 | + uint singleFloatBits = singleSign | singleExponent | singleMantissa; |
| 143 | + float result = BitConverter.ToSingle(BitConverter.GetBytes(singleFloatBits), 0); |
| 144 | + |
| 145 | + return result; |
| 146 | + } |
| 147 | + |
| 148 | + private DenseTensor<T> GDenseTensor<T>(T value) |
| 149 | + { |
| 150 | + var tvalue = new DenseTensor<T>(_embed); |
| 151 | + for (int i2 = 0; i2 < _embed; i2++) |
| 152 | + { |
| 153 | + tvalue[i2] = value; |
| 154 | + } |
| 155 | + return tvalue; |
| 156 | + } |
| 157 | + } |
| 158 | +} |
0 commit comments