diff --git a/front.py b/front.py new file mode 100644 index 0000000..a8f8560 --- /dev/null +++ b/front.py @@ -0,0 +1,165 @@ +import streamlit as st +from llm_code import * + +# Frontend code +st.set_page_config(page_title="🦙💬 Enchantty Chatbot") +st.title("Enchantty") +st.write("Upload your Hugging Face API token and chat with your PDF's and Websites", + unsafe_allow_html=True, + format="markdown", + style={'font-size': '20px'}) + +with st.sidebar: + st.markdown("*Getting a Hugging Face token*") + st.markdown("Steps:") + st.markdown("1. Navigate to [Hugging Face](https://huggingface.co/settings/tokens)") + st.markdown("2. Create a write token and copy it to your clipboard") + st.markdown("3. Paste the token in the input field below") + st.markdown("
", unsafe_allow_html=True) + st.markdown("*Hugging Face API Token*") + token_placeholder = st.empty() + HF_token = token_placeholder.text_input("Enter your Hugging Face API token", type="password") + + # Check if HF_token is provided + if HF_token: + # Replace the token input field with the success message + token_placeholder.empty() + st.success('API key provided!', icon='✅') + else: + st.warning('Please enter your Hugging Face API token!', icon='⚠️') + + # Changing from here + if HF_token: + + st.markdown("

Choose Model

", + unsafe_allow_html=True) + + # Add a dropdown menu to select LLM model + selected_model = st.selectbox("Select LLM Model", list(llm_models.keys())) + + # Display warning for models with gated access + if selected_model == "Llama-2-7B (Gated Access)" or selected_model == "Gemma-7B (Gated Access)" or selected_model == "Gemma-7B-it (Gated Access)" : + st.warning("Access to this model requires authorization from Hugging Face.") + + file_or_url_placeholder = st.empty() + file_or_url = st.radio("Choose Input Type", ("PDF File", "Website")) + + if file_or_url == "PDF File": + uploaded_file = st.file_uploader('Upload your .pdf file', type="pdf") + if uploaded_file is not None: + # Replace the PDF upload input field with the success message + file_or_url_placeholder.empty() + st.success('PDF file uploaded successfully!', icon='✅') + # Save the uploaded file to a temporary location and process it + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + tmp_file.write(uploaded_file.getvalue()) + content = PDFPlumberLoader(tmp_file.name).load() + + elif file_or_url == "Website": + url_placeholder = st.empty() + url = st.text_input("Enter the URL") + if url.strip(): + # Replace the URL input field with the success message + url_placeholder.empty() + st.success('URL entered successfully!', icon='✅') + # Process the URL + content = WebBaseLoader(url).load() + + st.markdown("

Advanced Features

", + unsafe_allow_html=True) + max_length = st.slider("Token Max Length", min_value=256, max_value=1024, value=256, step=128) + temp = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.1, step=0.1) + + +if 'content' in locals(): + text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50) + chunking = text_splitter.split_documents(content) + embeddings = HuggingFaceInferenceAPIEmbeddings( + api_key=HF_token, + model_name="sentence-transformers/all-MiniLM-L6-v2" + ) + vectorstore = FAISS.from_documents(chunking, embeddings) + prompt = hub.pull("rlm/rag-prompt", api_url="https://api.hub.langchain.com") + + # Get the selected LLM model ID + selected_model_id = llm_models[selected_model] + + def model(user_query, max_length, temp): + record_timing() # Record time before generating response + llm = HuggingFaceHub( + repo_id=selected_model_id, + huggingfacehub_api_token=HF_token, + model_kwargs={"max_length": max_length, "temperature": temp} + ) + retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 2}) + qa = RetrievalQA.from_chain_type(llm=llm, + chain_type="stuff", + retriever=retriever, + return_source_documents=True, + verbose=True, + chain_type_kwargs={"prompt": prompt}) + response = qa(user_query)["result"] + + answer_start = response.find("Answer:") + if answer_start != -1: + answer = response[answer_start + len("Answer:"):].strip() + return answer + else: + return "Sorry, I couldn't find the answer." + + # Reset chat when the model selection changes + if "selected_model" in st.session_state: + if st.session_state.selected_model != selected_model: + st.session_state.messages = [] + st.session_state.selected_model = selected_model + + # CSS styling for the text input + styl = f""" + + """ + st.markdown(styl, unsafe_allow_html=True) + + if "widget" not in st.session_state: + st.session_state.widget = '' + if "messages" not in st.session_state.keys(): + st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}] + + # Display or clear chat messages + for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.write(message["content"]) + + def clear_chat_history(): + st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}] + st.sidebar.button('Clear Chat History', on_click=clear_chat_history) + + def submit(): + record_timing() # Record time before submitting message + st.session_state.something = st.session_state.widget + st.session_state.widget = '' + + if "messages" not in st.session_state: + st.session_state.messages = [{"role": "assistant", "content": "How may I help you today?"}] + + if user_prompt := st.chat_input("enter your query"): + st.session_state.messages.append({"role": "user", "content": user_prompt}) + with st.chat_message("user"): + st.write(user_prompt) + + if st.session_state.messages[-1]["role"] != "assistant": + with st.chat_message("assistant"): + with st.spinner("Thinking..."): + response = model(user_prompt, max_length, temp) + placeholder = st.empty() + full_response = '' + for item in response: + full_response += item + placeholder.markdown(full_response) + placeholder.markdown(full_response) + message = {"role": "assistant", "content": full_response} + st.session_state.messages.append(message) diff --git a/llm_code.py b/llm_code.py new file mode 100644 index 0000000..2e6bf3f --- /dev/null +++ b/llm_code.py @@ -0,0 +1,33 @@ +import os +import time +import tempfile +import pdfplumber +from langchain.chains import RetrievalQA +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.vectorstores import FAISS +from langchain_community.llms import HuggingFaceHub +from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings +from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline +from langchain_community.document_loaders import PDFPlumberLoader, WebBaseLoader +from langchain import hub +import re + +# Function to record time +time_start = 0 +def record_timing(): + global time_start + if time_start != 0: + duration = time.time() - time_start + print(f"Time taken for query-response pair: {duration:.2f} seconds") + time_start = time.time() + +# List of available LLM models +llm_models = { + "Mistral-7B-Instruct-v0.2": "mistralai/Mistral-7B-Instruct-v0.2", + "Mixtral-8x7B-Instruct-v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "Gemma-7B (Gated Access)": "google/gemma-7b", + "Gemma-7B-it (Gated Access)":"google/gemma-7b-it", + "Zephyr":"HuggingFaceH4/zephyr-7b-beta", + "Google-2b (Gated Access)":"google/gemma-2b-it" +} +