Skip to content

Commit e4638b1

Browse files
authored
Merge pull request #631 from m7mdhka/main
Integrating FAISS in VannaAI
2 parents f571c09 + fd8a928 commit e4638b1

File tree

2 files changed

+177
-0
lines changed

2 files changed

+177
-0
lines changed

src/vanna/faiss/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .faiss import FAISS

src/vanna/faiss/faiss.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import os
2+
import json
3+
import uuid
4+
from typing import List, Dict, Any
5+
6+
import faiss
7+
import numpy as np
8+
import pandas as pd
9+
10+
from ..base import VannaBase
11+
from ..exceptions import DependencyError
12+
13+
class FAISS(VannaBase):
14+
def __init__(self, config=None):
15+
if config is None:
16+
config = {}
17+
18+
VannaBase.__init__(self, config=config)
19+
20+
try:
21+
import faiss
22+
except ImportError:
23+
raise DependencyError(
24+
"FAISS is not installed. Please install it with 'pip install faiss-cpu' or 'pip install faiss-gpu'"
25+
)
26+
27+
try:
28+
from sentence_transformers import SentenceTransformer
29+
except ImportError:
30+
raise DependencyError(
31+
"SentenceTransformer is not installed. Please install it with 'pip install sentence-transformers'."
32+
)
33+
34+
self.path = config.get("path", ".")
35+
self.embedding_dim = config.get('embedding_dim', 384)
36+
self.n_results_sql = config.get('n_results_sql', config.get("n_results", 10))
37+
self.n_results_ddl = config.get('n_results_ddl', config.get("n_results", 10))
38+
self.n_results_documentation = config.get('n_results_documentation', config.get("n_results", 10))
39+
self.curr_client = config.get("client", "persistent")
40+
41+
if self.curr_client == 'persistent':
42+
self.sql_index = self._load_or_create_index('sql_index.faiss')
43+
self.ddl_index = self._load_or_create_index('ddl_index.faiss')
44+
self.doc_index = self._load_or_create_index('doc_index.faiss')
45+
elif self.curr_client == 'in-memory':
46+
self.sql_index = faiss.IndexFlatL2(self.embedding_dim)
47+
self.ddl_index = faiss.IndexFlatL2(self.embedding_dim)
48+
self.doc_index = faiss.IndexFlatL2(self.embedding_dim)
49+
elif isinstance(self.curr_client, list) and len(self.curr_client) == 3 and all(isinstance(idx, faiss.Index) for idx in self.curr_client):
50+
self.sql_index = self.curr_client[0]
51+
self.ddl_index = self.curr_client[1]
52+
self.doc_index = self.curr_client[2]
53+
else:
54+
raise ValueError(f"Unsupported storage type was set in config: {self.curr_client}")
55+
56+
self.sql_metadata: List[Dict[str, Any]] = self._load_or_create_metadata('sql_metadata.json')
57+
self.ddl_metadata: List[Dict[str, str]] = self._load_or_create_metadata('ddl_metadata.json')
58+
self.doc_metadata: List[Dict[str, str]] = self._load_or_create_metadata('doc_metadata.json')
59+
60+
model_name = config.get('embedding_model', 'all-MiniLM-L6-v2')
61+
self.embedding_model = SentenceTransformer(model_name)
62+
63+
def _load_or_create_index(self, filename):
64+
filepath = os.path.join(self.path, filename)
65+
if os.path.exists(filepath):
66+
return faiss.read_index(filepath)
67+
return faiss.IndexFlatL2(self.embedding_dim)
68+
69+
def _load_or_create_metadata(self, filename):
70+
filepath = os.path.join(self.path, filename)
71+
if os.path.exists(filepath):
72+
with open(filepath, 'r') as f:
73+
return json.load(f)
74+
return []
75+
76+
def _save_index(self, index, filename):
77+
if self.curr_client == 'persistent':
78+
filepath = os.path.join(self.path, filename)
79+
faiss.write_index(index, filepath)
80+
81+
def _save_metadata(self, metadata, filename):
82+
if self.curr_client == 'persistent':
83+
filepath = os.path.join(self.path, filename)
84+
with open(filepath, 'w') as f:
85+
json.dump(metadata, f)
86+
87+
def generate_embedding(self, data: str, **kwargs) -> List[float]:
88+
embedding = self.embedding_model.encode(data)
89+
assert embedding.shape[0] == self.embedding_dim, \
90+
f"Embedding dimension mismatch: expected {self.embedding_dim}, got {embedding.shape[0]}"
91+
return embedding.tolist()
92+
93+
def _add_to_index(self, index, metadata_list, text, extra_metadata=None) -> str:
94+
embedding = self.generate_embedding(text)
95+
index.add(np.array([embedding], dtype=np.float32))
96+
entry_id = str(uuid.uuid4())
97+
metadata_list.append({"id": entry_id, **(extra_metadata or {})})
98+
return entry_id
99+
100+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
101+
entry_id = self._add_to_index(self.sql_index, self.sql_metadata, question + " " + sql, {"question": question, "sql": sql})
102+
self._save_index(self.sql_index, 'sql_index.faiss')
103+
self._save_metadata(self.sql_metadata, 'sql_metadata.json')
104+
return entry_id
105+
106+
def add_ddl(self, ddl: str, **kwargs) -> str:
107+
entry_id = self._add_to_index(self.ddl_index, self.ddl_metadata, ddl, {"ddl": ddl})
108+
self._save_index(self.ddl_index, 'ddl_index.faiss')
109+
self._save_metadata(self.ddl_metadata, 'ddl_metadata.json')
110+
return entry_id
111+
112+
def add_documentation(self, documentation: str, **kwargs) -> str:
113+
entry_id = self._add_to_index(self.doc_index, self.doc_metadata, documentation, {"documentation": documentation})
114+
self._save_index(self.doc_index, 'doc_index.faiss')
115+
self._save_metadata(self.doc_metadata, 'doc_metadata.json')
116+
return entry_id
117+
118+
def _get_similar(self, index, metadata_list, text, n_results) -> list:
119+
embedding = self.generate_embedding(text)
120+
D, I = index.search(np.array([embedding], dtype=np.float32), k=n_results)
121+
return [] if len(I[0]) == 0 or I[0][0] == -1 else [metadata_list[i] for i in I[0]]
122+
123+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
124+
return self._get_similar(self.sql_index, self.sql_metadata, question, self.n_results_sql)
125+
126+
def get_related_ddl(self, question: str, **kwargs) -> list:
127+
return [metadata["ddl"] for metadata in self._get_similar(self.ddl_index, self.ddl_metadata, question, self.n_results_ddl)]
128+
129+
def get_related_documentation(self, question: str, **kwargs) -> list:
130+
return [metadata["documentation"] for metadata in self._get_similar(self.doc_index, self.doc_metadata, question, self.n_results_documentation)]
131+
132+
def get_training_data(self, **kwargs) -> pd.DataFrame:
133+
sql_data = pd.DataFrame(self.sql_metadata)
134+
sql_data['training_data_type'] = 'sql'
135+
136+
ddl_data = pd.DataFrame(self.ddl_metadata)
137+
ddl_data['training_data_type'] = 'ddl'
138+
139+
doc_data = pd.DataFrame(self.doc_metadata)
140+
doc_data['training_data_type'] = 'documentation'
141+
142+
return pd.concat([sql_data, ddl_data, doc_data], ignore_index=True)
143+
144+
def remove_training_data(self, id: str, **kwargs) -> bool:
145+
for metadata_list, index, index_name in [
146+
(self.sql_metadata, self.sql_index, 'sql_index.faiss'),
147+
(self.ddl_metadata, self.ddl_index, 'ddl_index.faiss'),
148+
(self.doc_metadata, self.doc_index, 'doc_index.faiss')
149+
]:
150+
for i, item in enumerate(metadata_list):
151+
if item['id'] == id:
152+
del metadata_list[i]
153+
new_index = faiss.IndexFlatL2(self.embedding_dim)
154+
embeddings = [self.generate_embedding(json.dumps(m)) for m in metadata_list]
155+
if embeddings:
156+
new_index.add(np.array(embeddings, dtype=np.float32))
157+
setattr(self, index_name.split('.')[0], new_index)
158+
159+
if self.curr_client == 'persistent':
160+
self._save_index(new_index, index_name)
161+
self._save_metadata(metadata_list, f"{index_name.split('.')[0]}_metadata.json")
162+
163+
return True
164+
return False
165+
166+
def remove_collection(self, collection_name: str) -> bool:
167+
if collection_name in ["sql", "ddl", "documentation"]:
168+
setattr(self, f"{collection_name}_index", faiss.IndexFlatL2(self.embedding_dim))
169+
setattr(self, f"{collection_name}_metadata", [])
170+
171+
if self.curr_client == 'persistent':
172+
self._save_index(getattr(self, f"{collection_name}_index"), f"{collection_name}_index.faiss")
173+
self._save_metadata([], f"{collection_name}_metadata.json")
174+
175+
return True
176+
return False

0 commit comments

Comments
 (0)