Skip to content

Commit

Permalink
added the models folder
Browse files Browse the repository at this point in the history
  • Loading branch information
PromtEngineer committed Sep 19, 2024
1 parent b26251a commit 922ecf9
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 1 deletion.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Ignore vscode
/.vscode
/DB
/models

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
25 changes: 25 additions & 0 deletions models/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# models/converters.py

import os
from docx2pdf import convert
from logger import get_logger

logger = get_logger(__name__)

def convert_docs_to_pdfs(folder_path):
"""
Converts .doc and .docx files in the folder to PDFs.
Args:
folder_path (str): The path to the folder containing documents.
"""
try:
for filename in os.listdir(folder_path):
if filename.lower().endswith(('.doc', '.docx')):
doc_path = os.path.join(folder_path, filename)
pdf_path = os.path.splitext(doc_path)[0] + '.pdf'
convert(doc_path, pdf_path)
logger.info(f"Converted '{filename}' to PDF.")
except Exception as e:
logger.error(f"Error converting documents to PDFs: {e}")
raise
44 changes: 44 additions & 0 deletions models/indexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# models/indexer.py

import os
from byaldi import RAGMultiModalModel
from models.converters import convert_docs_to_pdfs
from logger import get_logger

logger = get_logger(__name__)

def index_documents(folder_path, index_name='document_index', index_path=None):
"""
Indexes documents in the specified folder using Byaldi.
Args:
folder_path (str): The path to the folder containing documents to index.
index_name (str): The name of the index to create or update.
index_path (str): The path where the index should be saved.
Returns:
RAGMultiModalModel: The RAG model with the indexed documents.
"""
try:
logger.info(f"Starting document indexing in folder: {folder_path}")
# Convert non-PDF documents to PDFs
convert_docs_to_pdfs(folder_path)
logger.info("Conversion of non-PDF documents to PDFs completed.")

# Initialize RAG model
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
logger.info("RAG model loaded.")

# Index the documents in the folder
RAG.index(
input_path=folder_path,
index_name=index_name,
# index_path=index_path,
store_collection_with_index=True,
overwrite=True
)
logger.info(f"Indexing completed. Index saved at '{index_path}'.")
return RAG
except Exception as e:
logger.error(f"Error during indexing: {e}")
raise
64 changes: 64 additions & 0 deletions models/model_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# models/model_loader.py

import os
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
import torch
from logger import get_logger

logger = get_logger(__name__)

# Cache for loaded models
_model_cache = {}

def detect_device():
"""
Detects the best available device (CUDA, MPS, or CPU).
"""
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'

def load_model(model_choice):
"""
Loads and caches the specified model.
"""
global _model_cache

if model_choice in _model_cache:
logger.info(f"Model '{model_choice}' loaded from cache.")
return _model_cache[model_choice]

if model_choice == 'qwen':
device = detect_device()
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype=torch.float16 if device != 'cpu' else torch.float32,
device_map="auto"
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
model.to(device)
_model_cache[model_choice] = (model, processor, device)
logger.info("Qwen model loaded and cached.")
return _model_cache[model_choice]
elif model_choice == 'gemini':
# Load Gemini model
import genai
genai.api_key = os.environ.get('GENAI_API_KEY')
model = genai.GenerativeModel(model_name="gemini-1.5-pro")
processor = None
_model_cache[model_choice] = (model, processor)
logger.info("Gemini model loaded and cached.")
return _model_cache[model_choice]
elif model_choice == 'gpt4':
# Load OpenAI GPT-4 model
import openai
openai.api_key = os.environ.get('OPENAI_API_KEY')
_model_cache[model_choice] = (None, None)
logger.info("GPT-4 model ready and cached.")
return _model_cache[model_choice]
else:
logger.error(f"Invalid model choice: {model_choice}")
raise ValueError("Invalid model choice.")
73 changes: 73 additions & 0 deletions models/responder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# models/responder.py

from models.model_loader import load_model
from logger import get_logger
import os

logger = get_logger(__name__)

def generate_response(images, query, session_id, resized_height=280, resized_width=280, model_choice='qwen'):
"""
Generates a response using the selected model based on the query and images.
"""
try:
logger.info(f"Generating response using model '{model_choice}'.")
if model_choice == 'qwen':
from qwen_vl_utils import process_vision_info
# Load cached model
model, processor, device = load_model('qwen')
# Ensure dimensions are multiples of 28
resized_height = (resized_height // 28) * 28
resized_width = (resized_width // 28) * 28

image_contents = []
for image in images:
image_contents.append({
"type": "image",
"image": os.path.join('static', image),
"resized_height": resized_height,
"resized_width": resized_width
})
messages = [
{
"role": "user",
"content": image_contents + [{"type": "text", "text": query}],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device)
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
logger.info("Response generated using Qwen model.")
return output_text[0]
elif model_choice == 'gemini':
from models.gemini_responder import generate_gemini_response
model, processor = load_model('gemini')
response = generate_gemini_response(images, query, model, processor)
logger.info("Response generated using Gemini model.")
return response
elif model_choice == 'gpt4':
from models.gpt4_responder import generate_gpt4_response
model, _ = load_model('gpt4')
response = generate_gpt4_response(images, query, model)
logger.info("Response generated using GPT-4 model.")
return response
else:
logger.error(f"Invalid model choice: {model_choice}")
return "Invalid model selected."
except Exception as e:
logger.error(f"Error generating response: {e}")
return "An error occurred while generating the response."
44 changes: 44 additions & 0 deletions models/retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# models/retriever.py

import base64
import os
from logger import get_logger

logger = get_logger(__name__)

def retrieve_documents(RAG, query, session_id, k=3):
"""
Retrieves relevant documents based on the user query using Byaldi.
Args:
RAG (RAGMultiModalModel): The RAG model with the indexed documents.
query (str): The user's query.
session_id (str): The session ID to store images in per-session folder.
k (int): The number of documents to retrieve.
Returns:
list: A list of image filenames corresponding to the retrieved documents.
"""
try:
logger.info(f"Retrieving documents for query: {query}")
results = RAG.search(query, k=k)
images = []
session_images_folder = os.path.join('static', 'images', session_id)
os.makedirs(session_images_folder, exist_ok=True)
for result in results:
if result.base64:
image_data = base64.b64decode(result.base64)
image_filename = f"retrieved_{result.doc_id}_{result.page_num}.png"
image_path = os.path.join(session_images_folder, image_filename)
with open(image_path, 'wb') as f:
f.write(image_data)
images.append(os.path.join('images', session_id, image_filename))
logger.debug(f"Retrieved and saved image: {image_filename}")
else:
# Handle cases where base64 data is not available
logger.warning(f"No base64 data for document {result.doc_id}, page {result.page_num}")
logger.info(f"Total {len(images)} documents retrieved.")
return images
except Exception as e:
logger.error(f"Error retrieving documents: {e}")
return []

0 comments on commit 922ecf9

Please sign in to comment.