-
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b26251a
commit 922ecf9
Showing
6 changed files
with
250 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,6 @@ | ||
# Ignore vscode | ||
/.vscode | ||
/DB | ||
/models | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# models/converters.py | ||
|
||
import os | ||
from docx2pdf import convert | ||
from logger import get_logger | ||
|
||
logger = get_logger(__name__) | ||
|
||
def convert_docs_to_pdfs(folder_path): | ||
""" | ||
Converts .doc and .docx files in the folder to PDFs. | ||
Args: | ||
folder_path (str): The path to the folder containing documents. | ||
""" | ||
try: | ||
for filename in os.listdir(folder_path): | ||
if filename.lower().endswith(('.doc', '.docx')): | ||
doc_path = os.path.join(folder_path, filename) | ||
pdf_path = os.path.splitext(doc_path)[0] + '.pdf' | ||
convert(doc_path, pdf_path) | ||
logger.info(f"Converted '{filename}' to PDF.") | ||
except Exception as e: | ||
logger.error(f"Error converting documents to PDFs: {e}") | ||
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# models/indexer.py | ||
|
||
import os | ||
from byaldi import RAGMultiModalModel | ||
from models.converters import convert_docs_to_pdfs | ||
from logger import get_logger | ||
|
||
logger = get_logger(__name__) | ||
|
||
def index_documents(folder_path, index_name='document_index', index_path=None): | ||
""" | ||
Indexes documents in the specified folder using Byaldi. | ||
Args: | ||
folder_path (str): The path to the folder containing documents to index. | ||
index_name (str): The name of the index to create or update. | ||
index_path (str): The path where the index should be saved. | ||
Returns: | ||
RAGMultiModalModel: The RAG model with the indexed documents. | ||
""" | ||
try: | ||
logger.info(f"Starting document indexing in folder: {folder_path}") | ||
# Convert non-PDF documents to PDFs | ||
convert_docs_to_pdfs(folder_path) | ||
logger.info("Conversion of non-PDF documents to PDFs completed.") | ||
|
||
# Initialize RAG model | ||
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali") | ||
logger.info("RAG model loaded.") | ||
|
||
# Index the documents in the folder | ||
RAG.index( | ||
input_path=folder_path, | ||
index_name=index_name, | ||
# index_path=index_path, | ||
store_collection_with_index=True, | ||
overwrite=True | ||
) | ||
logger.info(f"Indexing completed. Index saved at '{index_path}'.") | ||
return RAG | ||
except Exception as e: | ||
logger.error(f"Error during indexing: {e}") | ||
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# models/model_loader.py | ||
|
||
import os | ||
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor | ||
import torch | ||
from logger import get_logger | ||
|
||
logger = get_logger(__name__) | ||
|
||
# Cache for loaded models | ||
_model_cache = {} | ||
|
||
def detect_device(): | ||
""" | ||
Detects the best available device (CUDA, MPS, or CPU). | ||
""" | ||
if torch.cuda.is_available(): | ||
return 'cuda' | ||
elif torch.backends.mps.is_available(): | ||
return 'mps' | ||
else: | ||
return 'cpu' | ||
|
||
def load_model(model_choice): | ||
""" | ||
Loads and caches the specified model. | ||
""" | ||
global _model_cache | ||
|
||
if model_choice in _model_cache: | ||
logger.info(f"Model '{model_choice}' loaded from cache.") | ||
return _model_cache[model_choice] | ||
|
||
if model_choice == 'qwen': | ||
device = detect_device() | ||
model = Qwen2VLForConditionalGeneration.from_pretrained( | ||
"Qwen/Qwen2-VL-7B-Instruct", | ||
torch_dtype=torch.float16 if device != 'cpu' else torch.float32, | ||
device_map="auto" | ||
) | ||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") | ||
model.to(device) | ||
_model_cache[model_choice] = (model, processor, device) | ||
logger.info("Qwen model loaded and cached.") | ||
return _model_cache[model_choice] | ||
elif model_choice == 'gemini': | ||
# Load Gemini model | ||
import genai | ||
genai.api_key = os.environ.get('GENAI_API_KEY') | ||
model = genai.GenerativeModel(model_name="gemini-1.5-pro") | ||
processor = None | ||
_model_cache[model_choice] = (model, processor) | ||
logger.info("Gemini model loaded and cached.") | ||
return _model_cache[model_choice] | ||
elif model_choice == 'gpt4': | ||
# Load OpenAI GPT-4 model | ||
import openai | ||
openai.api_key = os.environ.get('OPENAI_API_KEY') | ||
_model_cache[model_choice] = (None, None) | ||
logger.info("GPT-4 model ready and cached.") | ||
return _model_cache[model_choice] | ||
else: | ||
logger.error(f"Invalid model choice: {model_choice}") | ||
raise ValueError("Invalid model choice.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# models/responder.py | ||
|
||
from models.model_loader import load_model | ||
from logger import get_logger | ||
import os | ||
|
||
logger = get_logger(__name__) | ||
|
||
def generate_response(images, query, session_id, resized_height=280, resized_width=280, model_choice='qwen'): | ||
""" | ||
Generates a response using the selected model based on the query and images. | ||
""" | ||
try: | ||
logger.info(f"Generating response using model '{model_choice}'.") | ||
if model_choice == 'qwen': | ||
from qwen_vl_utils import process_vision_info | ||
# Load cached model | ||
model, processor, device = load_model('qwen') | ||
# Ensure dimensions are multiples of 28 | ||
resized_height = (resized_height // 28) * 28 | ||
resized_width = (resized_width // 28) * 28 | ||
|
||
image_contents = [] | ||
for image in images: | ||
image_contents.append({ | ||
"type": "image", | ||
"image": os.path.join('static', image), | ||
"resized_height": resized_height, | ||
"resized_width": resized_width | ||
}) | ||
messages = [ | ||
{ | ||
"role": "user", | ||
"content": image_contents + [{"type": "text", "text": query}], | ||
} | ||
] | ||
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | ||
image_inputs, video_inputs = process_vision_info(messages) | ||
inputs = processor( | ||
text=[text], | ||
images=image_inputs, | ||
videos=video_inputs, | ||
padding=True, | ||
return_tensors="pt", | ||
) | ||
inputs = inputs.to(device) | ||
generated_ids = model.generate(**inputs, max_new_tokens=128) | ||
generated_ids_trimmed = [ | ||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | ||
] | ||
output_text = processor.batch_decode( | ||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
) | ||
logger.info("Response generated using Qwen model.") | ||
return output_text[0] | ||
elif model_choice == 'gemini': | ||
from models.gemini_responder import generate_gemini_response | ||
model, processor = load_model('gemini') | ||
response = generate_gemini_response(images, query, model, processor) | ||
logger.info("Response generated using Gemini model.") | ||
return response | ||
elif model_choice == 'gpt4': | ||
from models.gpt4_responder import generate_gpt4_response | ||
model, _ = load_model('gpt4') | ||
response = generate_gpt4_response(images, query, model) | ||
logger.info("Response generated using GPT-4 model.") | ||
return response | ||
else: | ||
logger.error(f"Invalid model choice: {model_choice}") | ||
return "Invalid model selected." | ||
except Exception as e: | ||
logger.error(f"Error generating response: {e}") | ||
return "An error occurred while generating the response." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# models/retriever.py | ||
|
||
import base64 | ||
import os | ||
from logger import get_logger | ||
|
||
logger = get_logger(__name__) | ||
|
||
def retrieve_documents(RAG, query, session_id, k=3): | ||
""" | ||
Retrieves relevant documents based on the user query using Byaldi. | ||
Args: | ||
RAG (RAGMultiModalModel): The RAG model with the indexed documents. | ||
query (str): The user's query. | ||
session_id (str): The session ID to store images in per-session folder. | ||
k (int): The number of documents to retrieve. | ||
Returns: | ||
list: A list of image filenames corresponding to the retrieved documents. | ||
""" | ||
try: | ||
logger.info(f"Retrieving documents for query: {query}") | ||
results = RAG.search(query, k=k) | ||
images = [] | ||
session_images_folder = os.path.join('static', 'images', session_id) | ||
os.makedirs(session_images_folder, exist_ok=True) | ||
for result in results: | ||
if result.base64: | ||
image_data = base64.b64decode(result.base64) | ||
image_filename = f"retrieved_{result.doc_id}_{result.page_num}.png" | ||
image_path = os.path.join(session_images_folder, image_filename) | ||
with open(image_path, 'wb') as f: | ||
f.write(image_data) | ||
images.append(os.path.join('images', session_id, image_filename)) | ||
logger.debug(f"Retrieved and saved image: {image_filename}") | ||
else: | ||
# Handle cases where base64 data is not available | ||
logger.warning(f"No base64 data for document {result.doc_id}, page {result.page_num}") | ||
logger.info(f"Total {len(images)} documents retrieved.") | ||
return images | ||
except Exception as e: | ||
logger.error(f"Error retrieving documents: {e}") | ||
return [] |