Skip to content

Commit a9fb933

Browse files
authored
Add WinRT wrapper (microsoft#360)
* add winrt projects * add gitignore for VS files * retarget to vc142 * make neighborCount uint32 * api takes byte[] for metadata * enable CFG and disable incremental linking to make BinSkim pass * format * remove edit and continue /ZI since it's incompatible for CFG * remove arm/arm64 platforms
1 parent a5bd48e commit a9fb933

15 files changed

+1331
-268
lines changed

.gitignore

Lines changed: 397 additions & 1 deletion
Large diffs are not rendered by default.

SPTAG.sln

Lines changed: 290 additions & 267 deletions
Large diffs are not rendered by default.

Test/WinRTTest/WinRTTest.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#include <iostream>
2+
#include <winrt/SPTAG.h>
3+
#include <winrt/Windows.Foundation.Collections.h>
4+
#include <winrt/Windows.Data.Json.h>
5+
#include <winrt/Windows.Storage.h>
6+
#include <winrt/Windows.Security.Cryptography.h>
7+
#include <winrt/Windows.Storage.Streams.h>
8+
9+
extern "C" __declspec(dllexport) winrt::SPTAG::LogLevel SPTAG_GetLoggerLevel() { return winrt::SPTAG::LogLevel::Empty; }
10+
11+
using namespace winrt;
12+
using namespace Windows::Security::Cryptography;
13+
using namespace Windows::Data::Json;
14+
15+
JsonObject CreateJsonObject(bool v) { auto js = JsonObject{}; js.Insert(L"value", JsonValue::CreateBooleanValue(v)); return js; }
16+
JsonObject CreateJsonObject(std::wstring_view v) { auto js = JsonObject{}; js.Insert(L"value", JsonValue::CreateStringValue(v)); return js; }
17+
template<size_t N> JsonObject CreateJsonObject(const wchar_t(&v)[N]) { return CreateJsonObject(winrt::hstring(v)); }
18+
JsonObject CreateJsonObject(double v) { auto js = JsonObject{}; js.Insert(L"value", JsonValue::CreateNumberValue(v)); return js; }
19+
JsonObject CreateJsonObject(nullptr_t) { auto js = JsonObject{}; js.Insert(L"value", JsonValue::CreateNullValue()); return js; }
20+
21+
template<typename T>
22+
winrt::array_view<uint8_t> Serialize(const T& value) {
23+
JsonObject json = CreateJsonObject(value);
24+
auto str = json.Stringify();
25+
auto ibuffer = CryptographicBuffer::ConvertStringToBinary(str, BinaryStringEncoding::Utf16LE);
26+
auto start = ibuffer.data();
27+
winrt::com_array<uint8_t> _array(start, start + ibuffer.Length());
28+
return _array;
29+
}
30+
31+
winrt::hstring Deserialize(winrt::array_view<uint8_t> v) {
32+
std::wstring_view sv { reinterpret_cast<wchar_t*>(v.begin()), v.size() / sizeof(wchar_t) };
33+
std::wstring wstr {sv};
34+
JsonObject js;
35+
if (JsonObject::TryParse(wstr, js)) {
36+
auto value = js.GetNamedValue(L"value");
37+
return value.Stringify();
38+
} else {
39+
return winrt::hstring{ wstr };
40+
}
41+
}
42+
43+
int main()
44+
{
45+
SPTAG::AnnIndex idx;
46+
using embedding_t = std::array<float, 1024>;
47+
auto b = CryptographicBuffer::ConvertStringToBinary(L"first one", BinaryStringEncoding::Utf16LE);
48+
idx.AddWithMetadata(embedding_t{ 1, 0 }, winrt::com_array<uint8_t>(b.data(), b.data() + b.Length()));
49+
idx.AddWithMetadata(embedding_t{ 0, 1, 0 }, Serialize(L"second one"));
50+
idx.AddWithMetadata(embedding_t{ 0, 0.5f, 0.7f, 0 }, Serialize(3.14));
51+
idx.AddWithMetadata(embedding_t{ 0, 0.7f, 0.5f, 0 }, Serialize(true));
52+
idx.AddWithMetadata(embedding_t{ 0, 0.7f, 0.8f, 0 }, Serialize(L"fifth"));
53+
54+
auto res = idx.Search(embedding_t{ 0.f, 0.99f, 0.01f }, 12);
55+
for (const winrt::SPTAG::SearchResult& r : res) {
56+
std::wcout << Deserialize(r.Metadata());
57+
std::wcout << L" -- " << r.Distance() << L"\n";
58+
}
59+
60+
auto folder = winrt::Windows::Storage::KnownFolders::DocumentsLibrary();
61+
auto file = folder.CreateFileAsync(L"vector_index", winrt::Windows::Storage::CreationCollisionOption::ReplaceExisting).get();
62+
idx.Save(file);
63+
64+
SPTAG::AnnIndex idx2;
65+
idx2.Load(file);
66+
67+
}
68+

Test/WinRTTest/WinRTTest.vcxproj

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
<?xml version="1.0" encoding="utf-8"?>
2+
<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
3+
<Import Project="..\..\packages\Microsoft.Windows.CppWinRT.2.0.230207.1\build\native\Microsoft.Windows.CppWinRT.props" Condition="Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.230207.1\build\native\Microsoft.Windows.CppWinRT.props')" />
4+
<ItemGroup Label="ProjectConfigurations">
5+
<ProjectConfiguration Include="Debug|Win32">
6+
<Configuration>Debug</Configuration>
7+
<Platform>Win32</Platform>
8+
</ProjectConfiguration>
9+
<ProjectConfiguration Include="Release|Win32">
10+
<Configuration>Release</Configuration>
11+
<Platform>Win32</Platform>
12+
</ProjectConfiguration>
13+
<ProjectConfiguration Include="Debug|x64">
14+
<Configuration>Debug</Configuration>
15+
<Platform>x64</Platform>
16+
</ProjectConfiguration>
17+
<ProjectConfiguration Include="Release|x64">
18+
<Configuration>Release</Configuration>
19+
<Platform>x64</Platform>
20+
</ProjectConfiguration>
21+
</ItemGroup>
22+
<PropertyGroup Label="Globals">
23+
<VCProjectVersion>16.0</VCProjectVersion>
24+
<Keyword>Win32Proj</Keyword>
25+
<ProjectGuid>{c9df3099-a142-4aa7-b936-1541816a1f21}</ProjectGuid>
26+
<RootNamespace>WinRTTest</RootNamespace>
27+
<WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion>
28+
</PropertyGroup>
29+
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
30+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">
31+
<ConfigurationType>Application</ConfigurationType>
32+
<UseDebugLibraries>true</UseDebugLibraries>
33+
<PlatformToolset>v142</PlatformToolset>
34+
<CharacterSet>Unicode</CharacterSet>
35+
</PropertyGroup>
36+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
37+
<ConfigurationType>Application</ConfigurationType>
38+
<UseDebugLibraries>false</UseDebugLibraries>
39+
<PlatformToolset>v142</PlatformToolset>
40+
<WholeProgramOptimization>true</WholeProgramOptimization>
41+
<CharacterSet>Unicode</CharacterSet>
42+
</PropertyGroup>
43+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
44+
<ConfigurationType>Application</ConfigurationType>
45+
<UseDebugLibraries>true</UseDebugLibraries>
46+
<PlatformToolset>v142</PlatformToolset>
47+
<CharacterSet>Unicode</CharacterSet>
48+
</PropertyGroup>
49+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
50+
<ConfigurationType>Application</ConfigurationType>
51+
<UseDebugLibraries>false</UseDebugLibraries>
52+
<PlatformToolset>v142</PlatformToolset>
53+
<WholeProgramOptimization>true</WholeProgramOptimization>
54+
<CharacterSet>Unicode</CharacterSet>
55+
</PropertyGroup>
56+
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
57+
<ImportGroup Label="ExtensionSettings">
58+
</ImportGroup>
59+
<ImportGroup Label="Shared">
60+
</ImportGroup>
61+
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
62+
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
63+
</ImportGroup>
64+
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
65+
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
66+
</ImportGroup>
67+
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
68+
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
69+
</ImportGroup>
70+
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
71+
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
72+
</ImportGroup>
73+
<PropertyGroup Label="UserMacros" />
74+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
75+
<LinkIncremental>false</LinkIncremental>
76+
</PropertyGroup>
77+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
78+
<LinkIncremental>false</LinkIncremental>
79+
</PropertyGroup>
80+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
81+
<LinkIncremental>false</LinkIncremental>
82+
</PropertyGroup>
83+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
84+
<LinkIncremental>false</LinkIncremental>
85+
</PropertyGroup>
86+
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
87+
<ClCompile>
88+
<WarningLevel>Level3</WarningLevel>
89+
<SDLCheck>true</SDLCheck>
90+
<PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
91+
<ConformanceMode>true</ConformanceMode>
92+
<ControlFlowGuard>Guard</ControlFlowGuard>
93+
<DebugInformationFormat>ProgramDatabase</DebugInformationFormat>
94+
</ClCompile>
95+
<Link>
96+
<SubSystem>Console</SubSystem>
97+
</Link>
98+
</ItemDefinitionGroup>
99+
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
100+
<ClCompile>
101+
<WarningLevel>Level3</WarningLevel>
102+
<FunctionLevelLinking>true</FunctionLevelLinking>
103+
<IntrinsicFunctions>true</IntrinsicFunctions>
104+
<SDLCheck>true</SDLCheck>
105+
<PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
106+
<ConformanceMode>true</ConformanceMode>
107+
<ControlFlowGuard>Guard</ControlFlowGuard>
108+
</ClCompile>
109+
<Link>
110+
<SubSystem>Console</SubSystem>
111+
<EnableCOMDATFolding>true</EnableCOMDATFolding>
112+
<OptimizeReferences>true</OptimizeReferences>
113+
</Link>
114+
</ItemDefinitionGroup>
115+
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
116+
<ClCompile>
117+
<WarningLevel>Level3</WarningLevel>
118+
<SDLCheck>true</SDLCheck>
119+
<PreprocessorDefinitions>_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
120+
<ConformanceMode>true</ConformanceMode>
121+
<ControlFlowGuard>Guard</ControlFlowGuard>
122+
<DebugInformationFormat>ProgramDatabase</DebugInformationFormat>
123+
</ClCompile>
124+
<Link>
125+
<SubSystem>Console</SubSystem>
126+
</Link>
127+
</ItemDefinitionGroup>
128+
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
129+
<ClCompile>
130+
<WarningLevel>Level3</WarningLevel>
131+
<FunctionLevelLinking>true</FunctionLevelLinking>
132+
<IntrinsicFunctions>true</IntrinsicFunctions>
133+
<SDLCheck>true</SDLCheck>
134+
<PreprocessorDefinitions>NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
135+
<ConformanceMode>true</ConformanceMode>
136+
<ControlFlowGuard>Guard</ControlFlowGuard>
137+
</ClCompile>
138+
<Link>
139+
<SubSystem>Console</SubSystem>
140+
<EnableCOMDATFolding>true</EnableCOMDATFolding>
141+
<OptimizeReferences>true</OptimizeReferences>
142+
</Link>
143+
</ItemDefinitionGroup>
144+
<ItemGroup>
145+
<ClCompile Include="WinRTTest.cpp" />
146+
</ItemGroup>
147+
<ItemGroup>
148+
<None Include="packages.config" />
149+
</ItemGroup>
150+
<ItemGroup>
151+
<ProjectReference Include="..\..\Wrappers\WinRT\SPTAG.WinRT.vcxproj">
152+
<Project>{8dc74c33-6e15-43ed-9300-2a140589e3da}</Project>
153+
</ProjectReference>
154+
</ItemGroup>
155+
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
156+
<ImportGroup Label="ExtensionTargets">
157+
<Import Project="..\..\packages\Microsoft.Windows.CppWinRT.2.0.230207.1\build\native\Microsoft.Windows.CppWinRT.targets" Condition="Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.230207.1\build\native\Microsoft.Windows.CppWinRT.targets')" />
158+
</ImportGroup>
159+
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
160+
<PropertyGroup>
161+
<ErrorText>This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}.</ErrorText>
162+
</PropertyGroup>
163+
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.230207.1\build\native\Microsoft.Windows.CppWinRT.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.2.0.230207.1\build\native\Microsoft.Windows.CppWinRT.props'))" />
164+
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.230207.1\build\native\Microsoft.Windows.CppWinRT.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.2.0.230207.1\build\native\Microsoft.Windows.CppWinRT.targets'))" />
165+
</Target>
166+
</Project>
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
<?xml version="1.0" encoding="utf-8"?>
2+
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
3+
<ItemGroup>
4+
<Filter Include="Source Files">
5+
<UniqueIdentifier>{4FC737F1-C7A5-4376-A066-2A32D752A2FF}</UniqueIdentifier>
6+
<Extensions>cpp;c;cc;cxx;c++;cppm;ixx;def;odl;idl;hpj;bat;asm;asmx</Extensions>
7+
</Filter>
8+
<Filter Include="Header Files">
9+
<UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier>
10+
<Extensions>h;hh;hpp;hxx;h++;hm;inl;inc;ipp;xsd</Extensions>
11+
</Filter>
12+
<Filter Include="Resource Files">
13+
<UniqueIdentifier>{67DA6AB6-F800-4c08-8B7A-83BB121AAD01}</UniqueIdentifier>
14+
<Extensions>rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms</Extensions>
15+
</Filter>
16+
</ItemGroup>
17+
<ItemGroup>
18+
<ClCompile Include="WinRTTest.cpp">
19+
<Filter>Source Files</Filter>
20+
</ClCompile>
21+
</ItemGroup>
22+
<ItemGroup>
23+
<None Include="packages.config" />
24+
</ItemGroup>
25+
</Project>

Test/WinRTTest/packages.config

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<?xml version="1.0" encoding="utf-8"?>
2+
<packages>
3+
<package id="Microsoft.Windows.CppWinRT" version="2.0.230207.1" targetFramework="native" />
4+
</packages>

Wrappers/WinRT/AnnIndex.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#include "pch.h"
2+
#include "AnnIndex.h"
3+
#if __has_include("AnnIndex.g.cpp")
4+
#include "AnnIndex.g.cpp"
5+
#endif
6+
#include "SearchResult.g.cpp"
7+
8+
namespace winrt::SPTAG::implementation
9+
{
10+
template<typename T, typename = std::enable_if_t<std::is_pod_v<T>>>
11+
sptag::ByteArray GetByteArray(const winrt::array_view<const T>& data) {
12+
auto copy = new T[data.size()];
13+
for (auto i = 0u; i < data.size(); i++) copy[i] = data[i];
14+
const auto byteSize = data.size() * sizeof(data.at(0)) / sizeof(byte);
15+
auto byteArray = sptag::ByteArray(reinterpret_cast<byte*>(copy), byteSize, true);
16+
return byteArray;
17+
}
18+
19+
SPTAG::SearchResult AnnIndex::GetResultFromMetadata(const sptag::BasicResult& r) const {
20+
return winrt::make<SearchResult>(r.Meta.Data(), r.Meta.Length(), r.Dist);
21+
}
22+
23+
24+
winrt::Windows::Foundation::Collections::IVector<SPTAG::SearchResult> AnnIndex::Search(EmbeddingVector p_data, uint32_t p_resultNum) const {
25+
auto vec = std::vector<SPTAG::SearchResult>{};
26+
p_resultNum = (std::min)(static_cast<sptag::SizeType>(p_resultNum), m_index->GetNumSamples());
27+
auto results = std::make_shared<sptag::QueryResult>(p_data.data(), p_resultNum, true);
28+
29+
if (nullptr != m_index) {
30+
m_index->SearchIndex(*results);
31+
}
32+
for (const auto& r : *results) {
33+
auto sr = GetResultFromMetadata(r);
34+
vec.push_back(sr);
35+
}
36+
return winrt::single_threaded_vector<SPTAG::SearchResult>(std::move(vec));
37+
}
38+
39+
void AnnIndex::Load(winrt::Windows::Storage::StorageFile file) {
40+
auto path = file.Path();
41+
if (sptag::ErrorCode::Success != sptag::VectorIndex::LoadIndexFromFile(winrt::to_string(path), m_index)) {
42+
throw winrt::hresult_error{};
43+
}
44+
}
45+
46+
void AnnIndex::Save(winrt::Windows::Storage::StorageFile file) {
47+
auto path = file.Path();
48+
if (sptag::ErrorCode::Success != m_index->SaveIndexToFile(winrt::to_string(path))) {
49+
throw winrt::hresult_error{};
50+
}
51+
}
52+
53+
54+
template<typename T>
55+
void AnnIndex::_AddWithMetadataImpl(EmbeddingVector p_data, T metadata) {
56+
if (m_dimension == 0) {
57+
m_dimension = p_data.size();
58+
} else if (m_dimension != static_cast<decltype(m_dimension)>(p_data.size())) {
59+
throw winrt::hresult_invalid_argument{};
60+
}
61+
int p_num{ 1 };
62+
auto byteArray = GetByteArray(p_data);
63+
std::shared_ptr<sptag::VectorSet> vectors(new sptag::BasicVectorSet(byteArray,
64+
m_inputValueType,
65+
static_cast<sptag::DimensionType>(m_dimension),
66+
static_cast<sptag::SizeType>(p_num)));
67+
68+
69+
sptag::ByteArray p_meta = GetByteArray(metadata);
70+
bool p_withMetaIndex{ true };
71+
bool p_normalized{ true };
72+
73+
std::uint64_t* offsets = new std::uint64_t[p_num + 1]{ 0 };
74+
if (!sptag::MetadataSet::GetMetadataOffsets(p_meta.Data(), p_meta.Length(), offsets, p_num + 1, '\n')) throw winrt::hresult_invalid_argument{};
75+
std::shared_ptr<sptag::MetadataSet> meta(new sptag::MemMetadataSet(p_meta, sptag::ByteArray((std::uint8_t*)offsets, (p_num + 1) * sizeof(std::uint64_t), true), (sptag::SizeType)p_num));
76+
if (sptag::ErrorCode::Success != m_index->AddIndex(vectors, meta, p_withMetaIndex, p_normalized)) {
77+
throw winrt::hresult_error(E_UNEXPECTED);
78+
}
79+
}
80+
81+
void AnnIndex::AddWithMetadata(array_view<float const> data, array_view<uint8_t const> metadata) {
82+
_AddWithMetadataImpl(data, metadata);
83+
}
84+
}

0 commit comments

Comments
 (0)