Skip to content

Commit eade853

Browse files
author
xcssa
committed
Support RWKV4 Raven 1B5-14B (CPU/GPU)
1 parent 071bc3a commit eade853

File tree

9 files changed

+330
-105
lines changed

9 files changed

+330
-105
lines changed

CRWKV.sln

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,19 @@ EndProject
88
Global
99
GlobalSection(SolutionConfigurationPlatforms) = preSolution
1010
Debug|Any CPU = Debug|Any CPU
11+
Debug|x64 = Debug|x64
1112
Release|Any CPU = Release|Any CPU
13+
Release|x64 = Release|x64
1214
EndGlobalSection
1315
GlobalSection(ProjectConfigurationPlatforms) = postSolution
1416
{D2C9B375-1A91-47C9-9953-5D9B7289AEBB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
1517
{D2C9B375-1A91-47C9-9953-5D9B7289AEBB}.Debug|Any CPU.Build.0 = Debug|Any CPU
18+
{D2C9B375-1A91-47C9-9953-5D9B7289AEBB}.Debug|x64.ActiveCfg = Debug|x64
19+
{D2C9B375-1A91-47C9-9953-5D9B7289AEBB}.Debug|x64.Build.0 = Debug|x64
1620
{D2C9B375-1A91-47C9-9953-5D9B7289AEBB}.Release|Any CPU.ActiveCfg = Release|Any CPU
1721
{D2C9B375-1A91-47C9-9953-5D9B7289AEBB}.Release|Any CPU.Build.0 = Release|Any CPU
22+
{D2C9B375-1A91-47C9-9953-5D9B7289AEBB}.Release|x64.ActiveCfg = Release|x64
23+
{D2C9B375-1A91-47C9-9953-5D9B7289AEBB}.Release|x64.Build.0 = Release|x64
1824
EndGlobalSection
1925
GlobalSection(SolutionProperties) = preSolution
2026
HideSolutionNode = FALSE

CRWKV/CRWKV.csproj

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
<Project Sdk="Microsoft.NET.Sdk">
1+
<Project Sdk="Microsoft.NET.Sdk">
22

3-
<PropertyGroup>
4-
<OutputType>Exe</OutputType>
5-
<TargetFramework>net6.0</TargetFramework>
6-
<ImplicitUsings>enable</ImplicitUsings>
7-
<Nullable>enable</Nullable>
8-
<IncludeNativeLibrariesForSelfExtract>true</IncludeNativeLibrariesForSelfExtract>
9-
</PropertyGroup>
3+
<PropertyGroup>
4+
<OutputType>Exe</OutputType>
5+
<TargetFramework>net6.0</TargetFramework>
6+
<ImplicitUsings>enable</ImplicitUsings>
7+
<Nullable>enable</Nullable>
8+
<IncludeNativeLibrariesForSelfExtract>true</IncludeNativeLibrariesForSelfExtract>
9+
<Platforms>AnyCPU;x64</Platforms>
10+
</PropertyGroup>
1011

11-
<ItemGroup>
12-
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.14.1" />
13-
<PackageReference Include="Seq2SeqSharp" Version="2.5.0" />
14-
</ItemGroup>
12+
<ItemGroup>
13+
<PackageReference Include="MathNet.Numerics" Version="5.0.0" />
14+
<PackageReference Include="Microsoft.ML.OnnxRuntime.Gpu" Version="1.14.1" />
15+
<PackageReference Include="Seq2SeqSharp" Version="2.5.0" />
16+
</ItemGroup>
1517

1618
</Project>

CRWKV/Program.cs

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,20 @@
11
using RWKV;
22

3-
Console.Write("Input Model Name(rwkv-4-pile-169m-uint8.onnx): ");
3+
Console.Write("Input Model Name(RWKV_32_2560_16.onnx): ");
44
var modelName = Console.ReadLine();
55
if (string.IsNullOrEmpty(modelName))
6-
modelName = "rwkv-4-pile-169m-uint8.onnx";
6+
modelName = "RWKV_32_2560_16.onnx";
77

8-
Console.Write("ctx_len(1024): ");
9-
var ctx_len = 1024;
10-
var ctx_len_str = Console.ReadLine();
11-
if (!string.IsNullOrEmpty(ctx_len_str))
12-
ctx_len = int.Parse(ctx_len_str);
8+
var modelNames = modelName.Split("_");
9+
var n_layer = int.Parse(modelNames[1]);
10+
var n_embd = int.Parse(modelNames[2]);
1311

14-
Console.Write("n_layer(12): ");
15-
var n_layer = 12;
16-
var n_layer_str = Console.ReadLine();
17-
if (!string.IsNullOrEmpty(n_layer_str))
18-
n_layer = int.Parse(n_layer_str);
12+
Console.WriteLine($"Loading...");
1913

20-
Console.Write("n_embd(768): ");
21-
var n_embd = 768;
22-
var n_embd_str = Console.ReadLine();
23-
if (!string.IsNullOrEmpty(n_embd_str))
24-
n_embd = int.Parse(n_embd_str);
14+
var rf = new RunnerFactory(modelName, n_layer, n_embd);
15+
rf.Init();
16+
var r = rf.NewRunner();
2517

26-
Console.WriteLine($"Loading({modelName})[{ctx_len},{n_layer},{n_embd}]...");
27-
var r = new Runner(modelName, ctx_len, n_layer, n_embd);
28-
r.Init();
2918
while (true)
3019
{
3120
Console.Write(">");

CRWKV/RWKV/OnnxModel.cs

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
using Microsoft.ML.OnnxRuntime;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
4+
namespace RWKV
5+
{
6+
public enum OnnxModelType
7+
{
8+
FP16,
9+
FP32
10+
}
11+
12+
public class OnnxModel
13+
{
14+
private InferenceSession _inferenceSession;
15+
private Type _type;
16+
private int _embed;
17+
private int _layers;
18+
private List<string> _input_names;
19+
private List<string> _output_names;
20+
private List<NamedOnnxValue> _inputs;
21+
private OnnxModelType _modelType;
22+
23+
public OnnxModelType ModelType => _modelType;
24+
25+
public OnnxModel(string model, int embed, int layers)
26+
{
27+
var options = new SessionOptions();
28+
options.AppendExecutionProvider_CPU();
29+
options.AppendExecutionProvider_CUDA();
30+
_inferenceSession = new InferenceSession(model, options);
31+
_type = _inferenceSession.InputMetadata["instate0"].ElementType;
32+
_embed = embed;
33+
_layers = layers;
34+
_input_names = _inferenceSession.InputMetadata.Select(x => x.Key).ToList();
35+
_output_names = _inferenceSession.OutputMetadata.Select(x => x.Key).ToList();
36+
_inputs = new List<NamedOnnxValue>();
37+
38+
if (_type == typeof(Float16))
39+
{
40+
_modelType = OnnxModelType.FP16;
41+
}
42+
else if (_type == typeof(float))
43+
{
44+
_modelType = OnnxModelType.FP32;
45+
}
46+
else
47+
{
48+
throw new NotSupportedException();
49+
}
50+
}
51+
52+
public object GetEmptyStates()
53+
{
54+
switch (_modelType)
55+
{
56+
case OnnxModelType.FP16:
57+
{
58+
var state = new List<Tensor<Float16>>();
59+
for (int i = 0; i < _layers; i++)
60+
{
61+
state.Add(GDenseTensor<Float16>(0));
62+
state.Add(GDenseTensor<Float16>(0));
63+
state.Add(GDenseTensor<Float16>(0));
64+
state.Add(GDenseTensor<Float16>(0));
65+
state.Add(GDenseTensor<Float16>(64512));
66+
}
67+
return state;
68+
}
69+
case OnnxModelType.FP32:
70+
{
71+
var state = new List<Tensor<float>>();
72+
for (int i = 0; i < _layers; i++)
73+
{
74+
state.Add(GDenseTensor<float>(0));
75+
state.Add(GDenseTensor<float>(0));
76+
state.Add(GDenseTensor<float>(0));
77+
state.Add(GDenseTensor<float>(0));
78+
state.Add(GDenseTensor<float>(float.NegativeInfinity));
79+
}
80+
return state;
81+
};
82+
default:
83+
throw new NotSupportedException();
84+
}
85+
}
86+
87+
public (IEnumerable<float> logits, object state) Forward(int xi, object state)
88+
{
89+
switch (_modelType)
90+
{
91+
case OnnxModelType.FP16:
92+
{
93+
var ret = Forward_FP16(xi, (List<Tensor<Float16>>)state);
94+
return (ret.logits.Select(x => HalfToSinglePrecision(x)).AsEnumerable(), ret.state);
95+
}
96+
case OnnxModelType.FP32:
97+
{
98+
var ret = Forward_FP32(xi, (List<Tensor<float>>)state);
99+
return (ret.logits.AsEnumerable(), ret.state);
100+
}
101+
default:
102+
throw new NotSupportedException();
103+
}
104+
}
105+
106+
private (Tensor<Float16> logits, IList<Tensor<Float16>> state) Forward_FP16(int xi, List<Tensor<Float16>> state)
107+
{
108+
_inputs.Clear();
109+
var input = new DenseTensor<int>(new[] { xi }, new[] { 1 });
110+
_inputs.Add(NamedOnnxValue.CreateFromTensor(_input_names.First(), input));
111+
for (int i = 1; i < _input_names.Count; i++)
112+
{
113+
_inputs.Add(NamedOnnxValue.CreateFromTensor(_input_names[i], state[i - 1]));
114+
}
115+
var data = _inferenceSession.Run(_inputs);
116+
return (data.First().AsTensor<Float16>(), data.Skip(1).Select(x => x.AsTensor<Float16>()).ToList());
117+
}
118+
119+
private (Tensor<float> logits, IList<Tensor<float>> state) Forward_FP32(int xi, IList<Tensor<float>> state)
120+
{
121+
_inputs.Clear();
122+
var input = new DenseTensor<int>(new[] { xi }, new[] { 1 });
123+
_inputs.Add(NamedOnnxValue.CreateFromTensor(_input_names.First(), input));
124+
for (int i = 1; i < _input_names.Count; i++)
125+
{
126+
_inputs.Add(NamedOnnxValue.CreateFromTensor(_input_names[i], state[i - 1]));
127+
}
128+
var data = _inferenceSession.Run(_inputs);
129+
return (data.First().AsTensor<float>(), data.Skip(1).Select(x => x.AsTensor<float>()).ToList());
130+
}
131+
132+
private float HalfToSinglePrecision(ushort half)
133+
{
134+
uint sign = (uint)(half >> 15);
135+
uint exponent = (uint)((half & 0x7C00) >> 10);
136+
uint mantissa = (uint)(half & 0x03FF);
137+
138+
uint singleSign = sign << 31;
139+
uint singleExponent = (exponent + 127 - 15) << 23;
140+
uint singleMantissa = mantissa << (23 - 10);
141+
142+
uint singleFloatBits = singleSign | singleExponent | singleMantissa;
143+
float result = BitConverter.ToSingle(BitConverter.GetBytes(singleFloatBits), 0);
144+
145+
return result;
146+
}
147+
148+
private DenseTensor<T> GDenseTensor<T>(T value)
149+
{
150+
var tvalue = new DenseTensor<T>(_embed);
151+
for (int i2 = 0; i2 < _embed; i2++)
152+
{
153+
tvalue[i2] = value;
154+
}
155+
return tvalue;
156+
}
157+
}
158+
}

0 commit comments

Comments
 (0)