Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to llama-3.3-70b-versatile and Codebase Enhancements #17

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions cohere_reranking.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,54 @@
import os
import logging
from typing import List

import cohere

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# use ENV variables
# Environment Variables
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
MODEL = "rerank-multilingual-v3.0"
MODEL = os.getenv("COHERE_RERANK_MODEL", "rerank-multilingual-v3.0")

co = cohere.Client(api_key=COHERE_API_KEY)
if not COHERE_API_KEY:
logger.error("COHERE_API_KEY is not set in environment variables.")
raise ValueError("COHERE_API_KEY is required but not set.")

# Initialize Cohere Client
try:
co = cohere.Client(api_key=COHERE_API_KEY)
except Exception as e:
logger.exception(f"Failed to initialize Cohere client: {e}")
raise

def get_reranking_cohere(docs, query, top_res):

def get_reranking_cohere(docs: List[str], query: str, top_res: int) -> List[str]:
"""
Re-ranks a list of documents based on a query using Cohere's reranking API.

Args:
docs (list of str): List of documents to be re-ranked.
query (str): Query string to rank the documents against.
top_res (int): Number of top results to return.
docs (List[str]): List of documents to be re-ranked.
query (str): Query string to rank the documents against.
top_res (int): Number of top results to return.

Returns:
list of str: Top re-ranked documents based on the query.
List[str]: Top re-ranked documents based on the query.
"""
if not docs:
logger.warning("No documents provided for reranking.")
return []

if not query:
logger.warning("Empty query provided for reranking.")
return []

if top_res <= 0:
logger.warning("Invalid top_res value provided. Must be greater than 0.")
return []

try:
# Call the Cohere rerank API
response = co.rerank(
model=MODEL,
query=query,
Expand All @@ -32,10 +57,14 @@ def get_reranking_cohere(docs, query, top_res):
return_documents=True
)

# Extract and return the texts of the top documents
return [item.document.text for item in response.results]
reranked_docs = [item.document.text for item in response.results]
if not reranked_docs:
logger.warning("Cohere rerank returned no results.")
return reranked_docs

except cohere.CohereError as e:
logger.error(f"Cohere API error during reranking: {e}")
except Exception as e:
# Log the error and handle it as needed
print(f"An error occurred: {e}")
return []
logger.exception(f"An unexpected error occurred during reranking: {e}")

return []
46 changes: 32 additions & 14 deletions extract_content_from_website.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,50 @@
import logging
from typing import Optional

from langchain_community.document_loaders import WebBaseLoader

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Constants
MAX_CHARACTERS = 4000
MIN_LENGTH = 200


def extract_website_content(url):
def extract_website_content(url: str) -> str:
"""
Extracts and cleans the main content from a given website URL.

Args:
url (str): The URL of the website from which to extract content.
url (str): The URL of the website from which to extract content.

Returns:
str: The first 4000 characters of the cleaned main content if it is sufficiently long, otherwise an empty string.
str: The first 4000 characters of the cleaned main content if it is sufficiently long; otherwise, an empty string.
"""
if not url or not isinstance(url, str):
logger.error("Invalid URL provided for content extraction.")
return ""

try:
clean_text = []
loader = WebBaseLoader(url)
data = loader.load()

# Aggregate content using a list to avoid inefficient string concatenation in the loop
clean_text = []
for doc in data:
if doc.page_content: # Check if page_content is not None or empty
clean_text.append(doc.page_content.replace("\n", ""))

# Join all parts into a single string after processing
clean_text = "".join(clean_text)

# Return up to the first 4000 characters if the content is sufficiently long
return clean_text[:4000] if len(clean_text) > 200 else ""
content = doc.page_content
if content:
cleaned = content.replace("\n", " ").strip()
if cleaned:
clean_text.append(cleaned)

combined_text = " ".join(clean_text)
if len(combined_text) > MIN_LENGTH:
return combined_text[:MAX_CHARACTERS]
else:
logger.warning(f"Extracted content is too short ({len(combined_text)} characters).")
return ""

except Exception as error:
print('Error extracting main content:', error)
logger.exception(f"Error extracting main content from {url}: {error}")
return ""
83 changes: 58 additions & 25 deletions groq_api.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,40 @@
import json
import os
import logging
from typing import Generator, Dict, Any
from groq import Groq
from langchain_core.prompts import PromptTemplate
from prompts import search_prompt_system, relevant_prompt_system

# use ENV variables
MODEL = "llama3-70b-8192"
api_key_groq = os.getenv("GROQ_API_KEY")
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Environment Variables
MODEL = "llama-3.3-70b-versatile"
GROQ_API_KEY = os.getenv("GROQ_API_KEY")

client = Groq()
if not GROQ_API_KEY:
logger.error("GROQ_API_KEY is not set in environment variables.")
raise ValueError("GROQ_API_KEY is required but not set.")

client = Groq(api_key=GROQ_API_KEY)

def get_answer(query, contexts, date_context):
system_prompt_search = PromptTemplate(input_variables=["date_today"], template=search_prompt_system)

def get_answer(query: str, contexts: str, date_context: str) -> Generator[str, None, None]:
"""
Generate an answer based on the query and contexts using Groq API.

:param query: User's search query.
:param contexts: Contextual information related to the query.
:param date_context: Current date and time context.
:return: Generator yielding chunks of the answer.
"""
system_prompt = PromptTemplate(input_variables=["date_today"], template=search_prompt_system)

messages = [
{"role": "system", "content": system_prompt_search.format(date_today=date_context)},
{"role": "user", "content": "User Question : " + query + "\n\n CONTEXTS :\n\n" + contexts}
{"role": "system", "content": system_prompt.format(date_today=date_context)},
{"role": "user", "content": f"User Question: {query}\n\nCONTEXTS:\n\n{contexts}"}
]

try:
Expand All @@ -29,30 +46,46 @@ def get_answer(query, contexts, date_context):
)

for chunk in stream:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
content = chunk.choices[0].delta.content
if content:
yield content

except Exception as e:
print(f"Error during get_answer_groq call: {e}")
yield "data:" + json.dumps(
{'type': 'error', 'data': "We are currently experiencing some issues. Please try again later."}) + "\n\n"
logger.exception(f"Error during get_answer_groq call: {e}")
error_response = json.dumps({
'type': 'error',
'data': "We are currently experiencing some issues. Please try again later."
})
yield f"data:{error_response}\n\n"


def get_relevant_questions(contexts: str, query: str) -> Dict[str, Any]:
"""
Generate relevant follow-up questions based on the query and contexts using Groq API.

:param contexts: Contextual information related to the query.
:param query: User's search query.
:return: Dictionary containing follow-up questions.
"""
messages = [
{"role": "system", "content": relevant_prompt_system},
{"role": "user", "content": f"User Query: {query}\n\nContexts:\n{contexts}\n"}
]

def get_relevant_questions(contexts, query):
try:
response = client.chat.completions.create(
model=MODEL,
messages=[
{"role": "system",
"content": relevant_prompt_system
},
{"role": "user",
"content": "User Query: " + query + "\n\n" + "Contexts: " + "\n" + contexts + "\n"}
],
response_format={"type": "json_object"},
messages=messages,
response_format="json_object",
)

return response.choices[0].message.content
content = response.choices[0].message.content
follow_up = json.loads(content)
return follow_up

except json.JSONDecodeError as e:
logger.error(f"JSON decode error in get_relevant_questions: {e}")
except Exception as e:
print(f"Error during RELEVANT GROQ ***************: {e}")
return {}
logger.exception(f"Error during get_relevant_questions: {e}")

return {}
32 changes: 22 additions & 10 deletions jina_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,46 @@
API_URL = "https://api.jina.ai/v1/rerank"
API_KEY = os.getenv("JINA_API_KEY")
MODEL = "jina-reranker-v2-base-multilingual"

if not API_KEY:
logger.error("JINA_API_KEY is not set in environment variables.")
raise ValueError("JINA_API_KEY is required but not set.")

HEADERS = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}

session = requests.Session()
session.headers.update(HEADERS)

def get_reranking_jina(docs: List[str], query: str, top_res: int) -> List[str]:

def get_reranking_jina(docs: List[str], query: str, top_res: int, timeout: int = 10) -> List[str]:
"""
Get reranked documents using Jina AI API.

:param docs: List of documents to rerank
:param query: Query string
:param top_res: Number of top results to return
:param timeout: Request timeout in seconds
:return: List of reranked documents
"""
data = {
"model": MODEL,
"query": query,
"documents": docs,
"top_n": top_res
}

try:
data = {
"model": MODEL,
"query": query,
"documents": docs,
"top_n": top_res
}

response = requests.post(API_URL, headers=HEADERS, json=data, timeout=10)
response = session.post(API_URL, json=data, timeout=timeout)
response.raise_for_status()
response_data = response.json()

return [item['document']['text'] for item in response_data.get('results', [])]
reranked_docs = [item['document']['text'] for item in response_data.get('results', [])]
if not reranked_docs:
logger.warning("No reranked results returned.")
return reranked_docs

except RequestException as e:
logger.error(f"HTTP error occurred while reranking: {e}")
Expand Down
Loading