1
+ import os
2
+ import json
3
+ import uuid
4
+ from typing import List , Dict , Any
5
+
6
+ import faiss
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from ..base import VannaBase
11
+ from ..exceptions import DependencyError
12
+
13
+ class FAISS (VannaBase ):
14
+ def __init__ (self , config = None ):
15
+ if config is None :
16
+ config = {}
17
+
18
+ VannaBase .__init__ (self , config = config )
19
+
20
+ try :
21
+ import faiss
22
+ except ImportError :
23
+ raise DependencyError (
24
+ "FAISS is not installed. Please install it with 'pip install faiss-cpu' or 'pip install faiss-gpu'"
25
+ )
26
+
27
+ try :
28
+ from sentence_transformers import SentenceTransformer
29
+ except ImportError :
30
+ raise DependencyError (
31
+ "SentenceTransformer is not installed. Please install it with 'pip install sentence-transformers'."
32
+ )
33
+
34
+ self .path = config .get ("path" , "." )
35
+ self .embedding_dim = config .get ('embedding_dim' , 384 )
36
+ self .n_results_sql = config .get ('n_results_sql' , config .get ("n_results" , 10 ))
37
+ self .n_results_ddl = config .get ('n_results_ddl' , config .get ("n_results" , 10 ))
38
+ self .n_results_documentation = config .get ('n_results_documentation' , config .get ("n_results" , 10 ))
39
+ self .curr_client = config .get ("client" , "persistent" )
40
+
41
+ if self .curr_client == 'persistent' :
42
+ self .sql_index = self ._load_or_create_index ('sql_index.faiss' )
43
+ self .ddl_index = self ._load_or_create_index ('ddl_index.faiss' )
44
+ self .doc_index = self ._load_or_create_index ('doc_index.faiss' )
45
+ elif self .curr_client == 'in-memory' :
46
+ self .sql_index = faiss .IndexFlatL2 (self .embedding_dim )
47
+ self .ddl_index = faiss .IndexFlatL2 (self .embedding_dim )
48
+ self .doc_index = faiss .IndexFlatL2 (self .embedding_dim )
49
+ elif isinstance (self .curr_client , list ) and len (self .curr_client ) == 3 and all (isinstance (idx , faiss .Index ) for idx in self .curr_client ):
50
+ self .sql_index = self .curr_client [0 ]
51
+ self .ddl_index = self .curr_client [1 ]
52
+ self .doc_index = self .curr_client [2 ]
53
+ else :
54
+ raise ValueError (f"Unsupported storage type was set in config: { self .curr_client } " )
55
+
56
+ self .sql_metadata : List [Dict [str , Any ]] = self ._load_or_create_metadata ('sql_metadata.json' )
57
+ self .ddl_metadata : List [Dict [str , str ]] = self ._load_or_create_metadata ('ddl_metadata.json' )
58
+ self .doc_metadata : List [Dict [str , str ]] = self ._load_or_create_metadata ('doc_metadata.json' )
59
+
60
+ model_name = config .get ('embedding_model' , 'all-MiniLM-L6-v2' )
61
+ self .embedding_model = SentenceTransformer (model_name )
62
+
63
+ def _load_or_create_index (self , filename ):
64
+ filepath = os .path .join (self .path , filename )
65
+ if os .path .exists (filepath ):
66
+ return faiss .read_index (filepath )
67
+ return faiss .IndexFlatL2 (self .embedding_dim )
68
+
69
+ def _load_or_create_metadata (self , filename ):
70
+ filepath = os .path .join (self .path , filename )
71
+ if os .path .exists (filepath ):
72
+ with open (filepath , 'r' ) as f :
73
+ return json .load (f )
74
+ return []
75
+
76
+ def _save_index (self , index , filename ):
77
+ if self .curr_client == 'persistent' :
78
+ filepath = os .path .join (self .path , filename )
79
+ faiss .write_index (index , filepath )
80
+
81
+ def _save_metadata (self , metadata , filename ):
82
+ if self .curr_client == 'persistent' :
83
+ filepath = os .path .join (self .path , filename )
84
+ with open (filepath , 'w' ) as f :
85
+ json .dump (metadata , f )
86
+
87
+ def generate_embedding (self , data : str , ** kwargs ) -> List [float ]:
88
+ embedding = self .embedding_model .encode (data )
89
+ assert embedding .shape [0 ] == self .embedding_dim , \
90
+ f"Embedding dimension mismatch: expected { self .embedding_dim } , got { embedding .shape [0 ]} "
91
+ return embedding .tolist ()
92
+
93
+ def _add_to_index (self , index , metadata_list , text , extra_metadata = None ) -> str :
94
+ embedding = self .generate_embedding (text )
95
+ index .add (np .array ([embedding ], dtype = np .float32 ))
96
+ entry_id = str (uuid .uuid4 ())
97
+ metadata_list .append ({"id" : entry_id , ** (extra_metadata or {})})
98
+ return entry_id
99
+
100
+ def add_question_sql (self , question : str , sql : str , ** kwargs ) -> str :
101
+ entry_id = self ._add_to_index (self .sql_index , self .sql_metadata , question + " " + sql , {"question" : question , "sql" : sql })
102
+ self ._save_index (self .sql_index , 'sql_index.faiss' )
103
+ self ._save_metadata (self .sql_metadata , 'sql_metadata.json' )
104
+ return entry_id
105
+
106
+ def add_ddl (self , ddl : str , ** kwargs ) -> str :
107
+ entry_id = self ._add_to_index (self .ddl_index , self .ddl_metadata , ddl , {"ddl" : ddl })
108
+ self ._save_index (self .ddl_index , 'ddl_index.faiss' )
109
+ self ._save_metadata (self .ddl_metadata , 'ddl_metadata.json' )
110
+ return entry_id
111
+
112
+ def add_documentation (self , documentation : str , ** kwargs ) -> str :
113
+ entry_id = self ._add_to_index (self .doc_index , self .doc_metadata , documentation , {"documentation" : documentation })
114
+ self ._save_index (self .doc_index , 'doc_index.faiss' )
115
+ self ._save_metadata (self .doc_metadata , 'doc_metadata.json' )
116
+ return entry_id
117
+
118
+ def _get_similar (self , index , metadata_list , text , n_results ) -> list :
119
+ embedding = self .generate_embedding (text )
120
+ D , I = index .search (np .array ([embedding ], dtype = np .float32 ), k = n_results )
121
+ return [] if len (I [0 ]) == 0 or I [0 ][0 ] == - 1 else [metadata_list [i ] for i in I [0 ]]
122
+
123
+ def get_similar_question_sql (self , question : str , ** kwargs ) -> list :
124
+ return self ._get_similar (self .sql_index , self .sql_metadata , question , self .n_results_sql )
125
+
126
+ def get_related_ddl (self , question : str , ** kwargs ) -> list :
127
+ return [metadata ["ddl" ] for metadata in self ._get_similar (self .ddl_index , self .ddl_metadata , question , self .n_results_ddl )]
128
+
129
+ def get_related_documentation (self , question : str , ** kwargs ) -> list :
130
+ return [metadata ["documentation" ] for metadata in self ._get_similar (self .doc_index , self .doc_metadata , question , self .n_results_documentation )]
131
+
132
+ def get_training_data (self , ** kwargs ) -> pd .DataFrame :
133
+ sql_data = pd .DataFrame (self .sql_metadata )
134
+ sql_data ['training_data_type' ] = 'sql'
135
+
136
+ ddl_data = pd .DataFrame (self .ddl_metadata )
137
+ ddl_data ['training_data_type' ] = 'ddl'
138
+
139
+ doc_data = pd .DataFrame (self .doc_metadata )
140
+ doc_data ['training_data_type' ] = 'documentation'
141
+
142
+ return pd .concat ([sql_data , ddl_data , doc_data ], ignore_index = True )
143
+
144
+ def remove_training_data (self , id : str , ** kwargs ) -> bool :
145
+ for metadata_list , index , index_name in [
146
+ (self .sql_metadata , self .sql_index , 'sql_index.faiss' ),
147
+ (self .ddl_metadata , self .ddl_index , 'ddl_index.faiss' ),
148
+ (self .doc_metadata , self .doc_index , 'doc_index.faiss' )
149
+ ]:
150
+ for i , item in enumerate (metadata_list ):
151
+ if item ['id' ] == id :
152
+ del metadata_list [i ]
153
+ new_index = faiss .IndexFlatL2 (self .embedding_dim )
154
+ embeddings = [self .generate_embedding (json .dumps (m )) for m in metadata_list ]
155
+ if embeddings :
156
+ new_index .add (np .array (embeddings , dtype = np .float32 ))
157
+ setattr (self , index_name .split ('.' )[0 ], new_index )
158
+
159
+ if self .curr_client == 'persistent' :
160
+ self ._save_index (new_index , index_name )
161
+ self ._save_metadata (metadata_list , f"{ index_name .split ('.' )[0 ]} _metadata.json" )
162
+
163
+ return True
164
+ return False
165
+
166
+ def remove_collection (self , collection_name : str ) -> bool :
167
+ if collection_name in ["sql" , "ddl" , "documentation" ]:
168
+ setattr (self , f"{ collection_name } _index" , faiss .IndexFlatL2 (self .embedding_dim ))
169
+ setattr (self , f"{ collection_name } _metadata" , [])
170
+
171
+ if self .curr_client == 'persistent' :
172
+ self ._save_index (getattr (self , f"{ collection_name } _index" ), f"{ collection_name } _index.faiss" )
173
+ self ._save_metadata ([], f"{ collection_name } _metadata.json" )
174
+
175
+ return True
176
+ return False
0 commit comments