Skip to content

Commit 254ff6a

Browse files
authored
Merge pull request #496 from zc277584121/milvus
Add Milvus vectorstore support
2 parents e46b2f1 + d23bd5e commit 254ff6a

File tree

5 files changed

+338
-1
lines changed

5 files changed

+338
-1
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"]
3333
snowflake = ["snowflake-connector-python"]
3434
duckdb = ["duckdb"]
3535
google = ["google-generativeai", "google-cloud-aiplatform"]
36-
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client"]
36+
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]"]
3737
test = ["tox"]
3838
chromadb = ["chromadb"]
3939
openai = ["openai"]
@@ -48,3 +48,4 @@ vllm = ["vllm"]
4848
pinecone = ["pinecone-client", "fastembed"]
4949
opensearch = ["opensearch-py", "opensearch-dsl"]
5050
hf = ["transformers"]
51+
milvus = ["pymilvus[model]"]

src/vanna/milvus/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .milvus_vector import Milvus_VectorStore

src/vanna/milvus/milvus_vector.py

Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
import uuid
2+
from typing import List
3+
4+
import pandas as pd
5+
from pymilvus import DataType, MilvusClient, model
6+
7+
from ..base import VannaBase
8+
9+
# Setting the URI as a local file, e.g.`./milvus.db`,
10+
# is the most convenient method, as it automatically utilizes Milvus Lite
11+
# to store all data in this file.
12+
#
13+
# If you have large scale of data such as more than a million docs, we
14+
# recommend setting up a more performant Milvus server on docker or kubernetes.
15+
# When using this setup, please use the server URI,
16+
# e.g.`http://localhost:19530`, as your URI.
17+
18+
DEFAULT_MILVUS_URI = "./milvus.db"
19+
# DEFAULT_MILVUS_URI = "http://localhost:19530"
20+
21+
MAX_LIMIT_SIZE = 10_000
22+
23+
24+
class Milvus_VectorStore(VannaBase):
25+
"""
26+
Vectorstore implementation using Milvus - https://milvus.io/docs/quickstart.md
27+
28+
Args:
29+
- config (dict, optional): Dictionary of `Milvus_VectorStore config` options. Defaults to `None`.
30+
- milvus_client: A `pymilvus.MilvusClient` instance.
31+
- embedding_function:
32+
A `milvus_model.base.BaseEmbeddingFunction` instance. Defaults to `DefaultEmbeddingFunction()`.
33+
For more models, please refer to:
34+
https://milvus.io/docs/embeddings.md
35+
"""
36+
def __init__(self, config=None):
37+
VannaBase.__init__(self, config=config)
38+
39+
if "milvus_client" in config:
40+
self.milvus_client = config["milvus_client"]
41+
else:
42+
self.milvus_client = MilvusClient(uri=DEFAULT_MILVUS_URI)
43+
44+
if "embedding_function" in config:
45+
self.embedding_function = config.get("embedding_function")
46+
else:
47+
self.embedding_function = model.DefaultEmbeddingFunction()
48+
self._embedding_dim = self.embedding_function.encode_documents(["foo"])[0].shape[0]
49+
self._create_collections()
50+
self.n_results = config.get("n_results", 10)
51+
52+
def _create_collections(self):
53+
self._create_sql_collection("vannasql")
54+
self._create_ddl_collection("vannaddl")
55+
self._create_doc_collection("vannadoc")
56+
57+
58+
def generate_embedding(self, data: str, **kwargs) -> List[float]:
59+
return self.embedding_function.encode_documents(data).tolist()
60+
61+
62+
def _create_sql_collection(self, name: str):
63+
if not self.milvus_client.has_collection(collection_name=name):
64+
vannasql_schema = MilvusClient.create_schema(
65+
auto_id=False,
66+
enable_dynamic_field=False,
67+
)
68+
vannasql_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
69+
vannasql_schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
70+
vannasql_schema.add_field(field_name="sql", datatype=DataType.VARCHAR, max_length=65535)
71+
vannasql_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)
72+
73+
vannasql_index_params = self.milvus_client.prepare_index_params()
74+
vannasql_index_params.add_index(
75+
field_name="vector",
76+
index_name="vector",
77+
index_type="AUTOINDEX",
78+
metric_type="L2",
79+
)
80+
self.milvus_client.create_collection(
81+
collection_name=name,
82+
schema=vannasql_schema,
83+
index_params=vannasql_index_params,
84+
consistency_level="Strong"
85+
)
86+
87+
def _create_ddl_collection(self, name: str):
88+
if not self.milvus_client.has_collection(collection_name=name):
89+
vannaddl_schema = MilvusClient.create_schema(
90+
auto_id=False,
91+
enable_dynamic_field=False,
92+
)
93+
vannaddl_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
94+
vannaddl_schema.add_field(field_name="ddl", datatype=DataType.VARCHAR, max_length=65535)
95+
vannaddl_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)
96+
97+
vannaddl_index_params = self.milvus_client.prepare_index_params()
98+
vannaddl_index_params.add_index(
99+
field_name="vector",
100+
index_name="vector",
101+
index_type="AUTOINDEX",
102+
metric_type="L2",
103+
)
104+
self.milvus_client.create_collection(
105+
collection_name=name,
106+
schema=vannaddl_schema,
107+
index_params=vannaddl_index_params,
108+
consistency_level="Strong"
109+
)
110+
111+
def _create_doc_collection(self, name: str):
112+
if not self.milvus_client.has_collection(collection_name=name):
113+
vannadoc_schema = MilvusClient.create_schema(
114+
auto_id=False,
115+
enable_dynamic_field=False,
116+
)
117+
vannadoc_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
118+
vannadoc_schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
119+
vannadoc_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)
120+
121+
vannadoc_index_params = self.milvus_client.prepare_index_params()
122+
vannadoc_index_params.add_index(
123+
field_name="vector",
124+
index_name="vector",
125+
index_type="AUTOINDEX",
126+
metric_type="L2",
127+
)
128+
self.milvus_client.create_collection(
129+
collection_name=name,
130+
schema=vannadoc_schema,
131+
index_params=vannadoc_index_params,
132+
consistency_level="Strong"
133+
)
134+
135+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
136+
if len(question) == 0 or len(sql) == 0:
137+
raise Exception("pair of question and sql can not be null")
138+
_id = str(uuid.uuid4()) + "-sql"
139+
embedding = self.embedding_function.encode_documents([question])[0]
140+
self.milvus_client.insert(
141+
collection_name="vannasql",
142+
data={
143+
"id": _id,
144+
"text": question,
145+
"sql": sql,
146+
"vector": embedding
147+
}
148+
)
149+
return _id
150+
151+
def add_ddl(self, ddl: str, **kwargs) -> str:
152+
if len(ddl) == 0:
153+
raise Exception("ddl can not be null")
154+
_id = str(uuid.uuid4()) + "-ddl"
155+
embedding = self.embedding_function.encode_documents([ddl])[0]
156+
self.milvus_client.insert(
157+
collection_name="vannaddl",
158+
data={
159+
"id": _id,
160+
"ddl": ddl,
161+
"vector": embedding
162+
}
163+
)
164+
return _id
165+
166+
def add_documentation(self, documentation: str, **kwargs) -> str:
167+
if len(documentation) == 0:
168+
raise Exception("documentation can not be null")
169+
_id = str(uuid.uuid4()) + "-doc"
170+
embedding = self.embedding_function.encode_documents([documentation])[0]
171+
self.milvus_client.insert(
172+
collection_name="vannadoc",
173+
data={
174+
"id": _id,
175+
"doc": documentation,
176+
"vector": embedding
177+
}
178+
)
179+
return _id
180+
181+
def get_training_data(self, **kwargs) -> pd.DataFrame:
182+
sql_data = self.milvus_client.query(
183+
collection_name="vannasql",
184+
output_fields=["*"],
185+
limit=MAX_LIMIT_SIZE,
186+
)
187+
df = pd.DataFrame()
188+
df_sql = pd.DataFrame(
189+
{
190+
"id": [doc["id"] for doc in sql_data],
191+
"question": [doc["text"] for doc in sql_data],
192+
"content": [doc["sql"] for doc in sql_data],
193+
}
194+
)
195+
df = pd.concat([df, df_sql])
196+
197+
ddl_data = self.milvus_client.query(
198+
collection_name="vannaddl",
199+
output_fields=["*"],
200+
limit=MAX_LIMIT_SIZE,
201+
)
202+
203+
df_ddl = pd.DataFrame(
204+
{
205+
"id": [doc["id"] for doc in ddl_data],
206+
"question": [None for doc in ddl_data],
207+
"content": [doc["ddl"] for doc in ddl_data],
208+
}
209+
)
210+
df = pd.concat([df, df_ddl])
211+
212+
doc_data = self.milvus_client.query(
213+
collection_name="vannadoc",
214+
output_fields=["*"],
215+
limit=MAX_LIMIT_SIZE,
216+
)
217+
218+
df_doc = pd.DataFrame(
219+
{
220+
"id": [doc["id"] for doc in doc_data],
221+
"question": [None for doc in doc_data],
222+
"content": [doc["doc"] for doc in doc_data],
223+
}
224+
)
225+
df = pd.concat([df, df_doc])
226+
return df
227+
228+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
229+
search_params = {
230+
"metric_type": "L2",
231+
"params": {"nprobe": 128},
232+
}
233+
embeddings = self.embedding_function.encode_queries([question])
234+
res = self.milvus_client.search(
235+
collection_name="vannasql",
236+
anns_field="vector",
237+
data=embeddings,
238+
limit=self.n_results,
239+
output_fields=["text", "sql"],
240+
search_params=search_params
241+
)
242+
res = res[0]
243+
244+
list_sql = []
245+
for doc in res:
246+
dict = {}
247+
dict["question"] = doc["entity"]["text"]
248+
dict["sql"] = doc["entity"]["sql"]
249+
list_sql.append(dict)
250+
return list_sql
251+
252+
def get_related_ddl(self, question: str, **kwargs) -> list:
253+
search_params = {
254+
"metric_type": "L2",
255+
"params": {"nprobe": 128},
256+
}
257+
embeddings = self.embedding_function.encode_queries([question])
258+
res = self.milvus_client.search(
259+
collection_name="vannaddl",
260+
anns_field="vector",
261+
data=embeddings,
262+
limit=self.n_results,
263+
output_fields=["ddl"],
264+
search_params=search_params
265+
)
266+
res = res[0]
267+
268+
list_ddl = []
269+
for doc in res:
270+
list_ddl.append(doc["entity"]["ddl"])
271+
return list_ddl
272+
273+
def get_related_documentation(self, question: str, **kwargs) -> list:
274+
search_params = {
275+
"metric_type": "L2",
276+
"params": {"nprobe": 128},
277+
}
278+
embeddings = self.embedding_function.encode_queries([question])
279+
res = self.milvus_client.search(
280+
collection_name="vannadoc",
281+
anns_field="vector",
282+
data=embeddings,
283+
limit=self.n_results,
284+
output_fields=["doc"],
285+
search_params=search_params
286+
)
287+
res = res[0]
288+
289+
list_doc = []
290+
for doc in res:
291+
list_doc.append(doc["entity"]["doc"])
292+
return list_doc
293+
294+
def remove_training_data(self, id: str, **kwargs) -> bool:
295+
if id.endswith("-sql"):
296+
self.milvus_client.delete(collection_name="vannasql", ids=[id])
297+
return True
298+
elif id.endswith("-ddl"):
299+
self.milvus_client.delete(collection_name="vannaddl", ids=[id])
300+
return True
301+
elif id.endswith("-doc"):
302+
self.milvus_client.delete(collection_name="vannadoc", ids=[id])
303+
return True
304+
else:
305+
return False

tests/test_imports.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ def test_regular_imports():
77
from vanna.hf.hf import Hf
88
from vanna.local import LocalContext_OpenAI
99
from vanna.marqo.marqo import Marqo_VectorStore
10+
from vanna.milvus.milvus_vector import Milvus_VectorStore
1011
from vanna.mistral.mistral import Mistral
1112
from vanna.ollama.ollama import Ollama
1213
from vanna.openai.openai_chat import OpenAI_Chat
@@ -24,6 +25,7 @@ def test_shortcut_imports():
2425
from vanna.chromadb import ChromaDB_VectorStore
2526
from vanna.hf import Hf
2627
from vanna.marqo import Marqo_VectorStore
28+
from vanna.milvus import Milvus_VectorStore
2729
from vanna.mistral import Mistral
2830
from vanna.ollama import Ollama
2931
from vanna.openai import OpenAI_Chat, OpenAI_Embeddings

tests/test_vanna.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,34 @@ def test_vn_chroma():
111111
df = vn_chroma.run_sql(sql)
112112
assert len(df) == 7
113113

114+
115+
from vanna.milvus import Milvus_VectorStore
116+
117+
118+
class VannaMilvus(Milvus_VectorStore, OpenAI_Chat):
119+
def __init__(self, config=None):
120+
Milvus_VectorStore.__init__(self, config=config)
121+
OpenAI_Chat.__init__(self, config=config)
122+
123+
vn_milvus = VannaMilvus(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'})
124+
vn_milvus.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
125+
126+
def test_vn_milvus():
127+
existing_training_data = vn_milvus.get_training_data()
128+
if len(existing_training_data) > 0:
129+
for _, training_data in existing_training_data.iterrows():
130+
vn_milvus.remove_training_data(training_data['id'])
131+
132+
df_ddl = vn_milvus.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
133+
134+
for ddl in df_ddl['sql'].to_list():
135+
vn_milvus.train(ddl=ddl)
136+
137+
sql = vn_milvus.generate_sql("What are the top 7 customers by sales?")
138+
df = vn_milvus.run_sql(sql)
139+
assert len(df) == 7
140+
141+
114142
class VannaNumResults(ChromaDB_VectorStore, OpenAI_Chat):
115143
def __init__(self, config=None):
116144
ChromaDB_VectorStore.__init__(self, config=config)

0 commit comments

Comments
 (0)