diff --git a/global.json b/global.json index eafb435..586eb6e 100644 --- a/global.json +++ b/global.json @@ -1,5 +1,6 @@ { "sdk": { + "version": "9.0.100-rc.2.24474.11", "rollForward": "latestMajor", "allowPrerelease": false } diff --git a/src/libs/Tiktoken.Core/CoreBPE.cs b/src/libs/Tiktoken.Core/CoreBPE.cs index 88de98c..c15146f 100644 --- a/src/libs/Tiktoken.Core/CoreBPE.cs +++ b/src/libs/Tiktoken.Core/CoreBPE.cs @@ -1,4 +1,5 @@ using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using System.Text; using System.Text.RegularExpressions; @@ -20,7 +21,11 @@ public class CoreBpe internal bool EnableCache { get; set; } = true; private ConcurrentDictionary> FastCache { get; set; } = new(); - private ConcurrentDictionary FastCacheCounts { get; set; } = new(); + private ConcurrentDictionary FastCacheCounts { get; set; } = new( +#if NET9_0_OR_GREATER + new AlternateStringComparer() +#endif + ); private Regex SpecialRegex { get; set; } private Regex Regex { get; set; } @@ -59,7 +64,11 @@ public CoreBpe( #else static x => new string(x.Key.Select(static y => (char) y).ToArray()), #endif - static x => x.Value); + static x => x.Value +#if NET9_0_OR_GREATER + , new AlternateStringComparer() +#endif + ); SpecialTokensEncoder = specialTokensEncoder; Regex = new Regex(pattern, RegexOptions.Compiled); @@ -89,23 +98,40 @@ public int CountTokensNative(string text) var textSpan = text.AsSpan(); Span pieceBytes = stackalloc byte[128]; #endif +#if NET9_0_OR_GREATER + var fastEncoderLookup = FastEncoder.GetAlternateLookup>(); + var fastCacheCountLookup = FastCacheCounts.GetAlternateLookup>(); +#endif #if NET7_0_OR_GREATER foreach (var match in Regex.EnumerateMatches(textSpan)) { +#if NET9_0_OR_GREATER + var fastKey = textSpan.Slice(match.Index, match.Length); +#else var fastKey = new string(textSpan.Slice(match.Index, match.Length)); +#endif #else foreach (Match match in Regex.Matches(text)) { var matchValue = match.Value; var fastKey = matchValue; #endif + +#if NET9_0_OR_GREATER + if (fastEncoderLookup.ContainsKey(fastKey)) +#else if (FastEncoder.ContainsKey(fastKey)) +#endif { tokens++; continue; } +#if NET9_0_OR_GREATER + if (EnableCache && fastCacheCountLookup.TryGetValue(fastKey, out var fastNumberOfTokens)) +#else if (EnableCache && FastCacheCounts.TryGetValue(fastKey, out var fastNumberOfTokens)) +#endif { tokens += fastNumberOfTokens; continue; @@ -127,7 +153,11 @@ public int CountTokensNative(string text) if (EnableCache) { +#if NET9_0_OR_GREATER + fastCacheCountLookup[fastKey] = numberOfTokens; +#else FastCacheCounts[fastKey] = numberOfTokens; +#endif } } @@ -569,4 +599,39 @@ private static byte[] GetUtf8Bytes(ReadOnlySpan text, Span scratch) } } #endif + +#if NET9_0_OR_GREATER + private sealed class AlternateStringComparer : IEqualityComparer, + IAlternateEqualityComparer, string> + { + public string Create(ReadOnlySpan alternate) + { + return new(alternate); + } + + public bool Equals(string? x, string? y) + { + return string.Equals(x, y, StringComparison.Ordinal); + } + + public bool Equals(ReadOnlySpan alternate, string other) + { + return other?.AsSpan().SequenceEqual(alternate) ?? false; + } + + public int GetHashCode([DisallowNull] string str) + { + return str is null ? 0 : GetHashCode(str.AsSpan()); + } + + public int GetHashCode(ReadOnlySpan alternate) + { + // use the djb2 hash function for simplicity: http://www.cse.yorku.ca/~oz/hash.html + uint hash = 5381; + foreach (var ch in alternate) + hash = hash * 33u + ch; + return (int)hash; + } + } +#endif } \ No newline at end of file diff --git a/src/libs/Tiktoken.Core/Tiktoken.Core.csproj b/src/libs/Tiktoken.Core/Tiktoken.Core.csproj index bc62e6c..aa41460 100644 --- a/src/libs/Tiktoken.Core/Tiktoken.Core.csproj +++ b/src/libs/Tiktoken.Core/Tiktoken.Core.csproj @@ -1,7 +1,7 @@ - net4.6.2;netstandard2.0;netstandard2.1;net6.0;net8.0 + net4.6.2;netstandard2.0;netstandard2.1;net6.0;net8.0;net9.0 true $(NoWarn);CA1724 Tiktoken