Skip to content

Commit

Permalink
added support for llama3.2
Browse files Browse the repository at this point in the history
  • Loading branch information
PromtEngineer committed Sep 26, 2024
1 parent 12e70d6 commit 0b98927
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 4 deletions.
10 changes: 9 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,12 @@ cython_debug/

#MacOS
.DS_Store
SOURCE_DOCUMENTS/.DS_Store
SOURCE_DOCUMENTS/.DS_Store




.byaldi/
sessions/
static/images/
uploaded_documents/
23 changes: 22 additions & 1 deletion models/model_loader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# models/model_loader.py

import os
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from transformers import MllamaForConditionalGeneration

from logger import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -43,6 +45,7 @@ def load_model(model_choice):
_model_cache[model_choice] = (model, processor, device)
logger.info("Qwen model loaded and cached.")
return _model_cache[model_choice]

elif model_choice == 'gemini':
# Load Gemini model
import genai
Expand All @@ -52,13 +55,31 @@ def load_model(model_choice):
_model_cache[model_choice] = (model, processor)
logger.info("Gemini model loaded and cached.")
return _model_cache[model_choice]

elif model_choice == 'gpt4':
# Load OpenAI GPT-4 model
import openai
openai.api_key = os.environ.get('OPENAI_API_KEY')
_model_cache[model_choice] = (None, None)
logger.info("GPT-4 model ready and cached.")
return _model_cache[model_choice]

elif model_choice == 'llama-vision':
# Load Llama-Vision model
device = detect_device()
# model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
model_id = "alpindale/Llama-3.2-11B-Vision-Instruct"
model = MllamaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16 if device != 'cpu' else torch.float32,
device_map="auto"
)
processor = AutoProcessor.from_pretrained(model_id)
model.to(device)
_model_cache[model_choice] = (model, processor, device)
logger.info("Llama-Vision model loaded and cached.")
return _model_cache[model_choice]

else:
logger.error(f"Invalid model choice: {model_choice}")
raise ValueError("Invalid model choice.")
31 changes: 29 additions & 2 deletions models/responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from models.model_loader import load_model
from logger import get_logger
from PIL import Image
import os

logger = get_logger(__name__)
Expand Down Expand Up @@ -65,9 +66,35 @@ def generate_response(images, query, session_id, resized_height=280, resized_wid
response = generate_gpt4_response(images, query, model)
logger.info("Response generated using GPT-4 model.")
return response

elif model_choice == 'llama-vision':
# Load model, processor, and device
model, processor, device = load_model('llama-vision')

# Process images
image_paths = [os.path.join('static', image) for image in images]
# For simplicity, use the first image
image_path = image_paths[0]
image = Image.open(image_path).convert('RGB')

# Prepare messages
messages = [
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": query}
]}
]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(image, input_text, return_tensors="pt").to(device)

# Generate response
output = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(output[0], skip_special_tokens=True)
return response

else:
logger.error(f"Invalid model choice: {model_choice}")
return "Invalid model selected."
logger.error(f"Invalid model choice: {model_choice}")
return "Invalid model selected."
except Exception as e:
logger.error(f"Error generating response: {e}")
return "An error occurred while generating the response."
2 changes: 2 additions & 0 deletions templates/settings.html
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ <h2>Settings</h2>
<option value="qwen" {% if model_choice == 'qwen' %}selected{% endif %}>Qwen2-VL-7B-Instruct</option>
<option value="gemini" {% if model_choice == 'gemini' %}selected{% endif %}>Google Gemini</option>
<option value="gpt4" {% if model_choice == 'gpt4' %}selected{% endif %}>OpenAI GPT-4</option>
<option value="llama-vision" {% if model_choice == 'llama-vision' %}selected{% endif %}>Llama-Vision</option>

</select>
</div>
<div class="form-group">
Expand Down

0 comments on commit 0b98927

Please sign in to comment.