Skip to content

Commit 3dda818

Browse files
committed
added audio transcription feature
2 parents 9cdd323 + 409bdd0 commit 3dda818

File tree

7 files changed

+175
-50
lines changed

7 files changed

+175
-50
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ __pycache__/
44
.vscode/
55
docs
66
vectorstore
7+
sessions
8+
docs
79

810
# environment
911
venv

app.py

Lines changed: 98 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,36 @@
1-
import ollama
2-
import streamlit as st
31
import os
4-
from utils import save_chat_history,load_chat_history_json,get_timestamp
5-
from datetime import datetime
6-
from streamlit_mic_recorder import mic_recorder
7-
from transformers import pipeline
82
import yaml
3+
from datetime import datetime
4+
5+
import ollama
96
import torch
7+
import streamlit as st
8+
from transformers import pipeline
9+
from streamlit_mic_recorder import mic_recorder
10+
from utils import save_chat_history, load_chat_history_json, get_timestamp
11+
12+
from prompt_templates import SYSTEM_PROMPT
1013
from audio_transcribe import transcribe_audio
14+
from text_processor import get_document_chunks
15+
from chromadb_operations import ChromadbOperations
16+
17+
header = st.container()
18+
header.title("Knowly")
19+
header.write("""<div class='fixed-header'/>""", unsafe_allow_html=True)
20+
21+
with header:
22+
col1, col2 = st.columns(2)
23+
with col1:
24+
if "model" not in st.session_state:
25+
st.session_state["model"] = ""
26+
models = [model["name"] for model in ollama.list()["models"]]
27+
st.session_state["model"] = st.selectbox("Choose your model", models)
28+
with col2:
29+
st.write('Record Audio:')
30+
voice_recording = mic_recorder(start_prompt="Start recording", stop_prompt="Stop recording", just_once=True)
31+
transcribed_audio_prompt = ''
32+
if voice_recording:
33+
transcribed_audio_prompt = transcribe_audio(voice_recording["bytes"])
1134

1235
with open('config.yaml', 'r') as f:
1336
config = yaml.safe_load(f)
@@ -32,68 +55,112 @@ def set_session_name(session):
3255
del st.session_state["messages"]
3356
st.session_state["messages"] = load_chat_history_json(session)
3457

35-
def model_res_generator():
58+
def model_res_generator(rag:bool=False):
59+
prompt = st.session_state["messages"][-1]["content"] # extracting last user prompt
60+
if rag:
61+
context = st.session_state["vector_db"].query(query_text=prompt, k=1) # fetching similar contexts from vector database
62+
63+
# creating paragraph of contexts
64+
paragraph = ""
65+
for i, item in enumerate(context):
66+
paragraph += item
67+
if i != len(context)-1:
68+
paragraph += "\n"
69+
70+
# replacing user prompt with augmented prompt
71+
st.session_state["messages"][-1]["content"] = formatted_prompt(query=prompt, context=paragraph)
72+
3673
stream = ollama.chat(
3774
model=st.session_state["model"],
3875
messages=st.session_state["messages"],
3976
stream=True,
4077
)
78+
79+
# replacing augmented prompt with actual user prompt
80+
if rag:
81+
st.session_state["messages"][-1]["content"] = prompt
4182
for chunk in stream:
4283
yield chunk["message"]["content"]
4384

44-
st.title("Knowly")
45-
st.sidebar.title("Chat sessions")
85+
def formatted_prompt(query:str, context:str):
86+
return SYSTEM_PROMPT + f"Question: {query}" + f"\n\nContext: {context}"
87+
88+
def save_session(session_key):
89+
if "messages" in st.session_state:
90+
if st.session_state.session_key == "new_session":
91+
st.session_state.session_key = get_timestamp() + '.json'
92+
save_chat_history(st.session_state['messages'], st.session_state.session_key)
93+
else:
94+
save_chat_history(st.session_state['messages'], st.session_state.session_key)
4695

4796
if "messages" not in st.session_state:
4897
st.session_state["messages"] = []
4998

50-
if "model" not in st.session_state:
51-
st.session_state["model"] = ""
52-
5399
if "session_key" not in st.session_state:
54100
if len(os.listdir('sessions/')) != 0:
55101
st.session_state["session_key"] = os.listdir('sessions/')[-1]
56102
st.session_state["messages"] = load_chat_history_json(st.session_state.session_key)
57103
else:
58104
st.session_state["session_key"] = "new_session"
59105

60-
def save_session(session_key):
61-
if "messages" in st.session_state:
62-
if st.session_state.session_key == "new_session":
63-
st.session_state.session_key = get_timestamp() + '.json'
64-
save_chat_history(st.session_state['messages'],st.session_state.session_key)
65-
else:
66-
save_chat_history(st.session_state['messages'],st.session_state.session_key)
67-
68-
models = [model["name"] for model in ollama.list()["models"]]
69-
st.session_state["model"] = st.selectbox("choose you model", models)
70-
71106
load_chat()
72107

73-
voice_recording = mic_recorder(start_prompt="Start recording", stop_prompt="Stop recording", just_once=True)
74-
transcribed_audio_prompt = ''
75-
if voice_recording:
76-
transcribed_audio_prompt = transcribe_audio(voice_recording["bytes"])
108+
with st.sidebar:
109+
st.sidebar.write('**Pdf Upload:**')
110+
with st.form("my-form", clear_on_submit=True):
111+
uploaded_docs = st.file_uploader(label="Upload pdf or text files",
112+
accept_multiple_files=True,
113+
key="document_uploader",
114+
type=["pdf"])
115+
submitted = st.form_submit_button("UPLOAD")
116+
117+
if submitted:
118+
print("uploaded docs section is running...")
119+
os.makedirs("docs", exist_ok=True)
120+
with st.spinner("Processing documents..."):
121+
# saving the uploaded files in directory
122+
for file_item in uploaded_docs:
123+
with open(f"docs/{file_item.name}", "wb") as f:
124+
f.write(file_item.getbuffer())
125+
f.close()
77126

127+
st.session_state["vector_db"] = ChromadbOperations()
128+
text_chunks = get_document_chunks(path="docs")
129+
st.session_state["vector_db"].insert_data(text_chunks)
130+
del st.session_state["document_uploader"]
131+
132+
# pdf chat
133+
pdf_chat_mode = st.sidebar.toggle(label="PDF Chat",
134+
key="pdf_chat",
135+
value=False,
136+
disabled=True if "vectorstore" not in os.listdir(str(os.getcwd())) else False)
137+
138+
# load the current vector database if exists
139+
if pdf_chat_mode:
140+
if "vector_db" not in st.session_state.keys() and "vectorstore" in os.listdir(str(os.getcwd())):
141+
st.session_state["vector_db"] = ChromadbOperations()
142+
78143
user_prompt = st.chat_input("Enter your question:")
79144
if user_prompt is not None or transcribed_audio_prompt != '':
80145
if user_prompt:
81146
prompt = user_prompt
82147
else:
83148
prompt = transcribed_audio_prompt
84-
149+
85150
st.session_state["messages"].append({"role" : "user", "content": prompt})
86151

87152
with st.chat_message("user"):
88153
st.markdown(prompt)
89154

90155
with st.chat_message("assistant"):
91-
message = st.write_stream(model_res_generator())
92-
st.session_state["messages"].append({"role":"assistant", "content" : message})
156+
message = st.write_stream(model_res_generator(rag=pdf_chat_mode))
157+
st.session_state["messages"].append({"role": "assistant", "content": message})
93158

94159
save_session(st.session_state.session_key)
95160

96-
st.sidebar.button(label="new chat", on_click=create_new_chat)
161+
st.sidebar.write('**Chat History:**')
162+
163+
st.sidebar.button(label="New chat", on_click=create_new_chat)
97164

98165
session_list = os.listdir("sessions/")
99166
for session in session_list:

audio_transcribe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
import torch
2-
from transformers import pipeline
3-
import librosa
41
import io
52
import yaml
63

4+
import torch
5+
import librosa
6+
from transformers import pipeline
77

88
with open('config.yaml', 'r') as f:
99
config = yaml.safe_load(f)

chromadb_operations.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import chromadb
2+
from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2
3+
import os
4+
import shutil
5+
6+
7+
class ChromadbOperations:
8+
def __init__(self):
9+
self.client = chromadb.PersistentClient(path=str('vectorstore'))
10+
self.collection = self.client.get_or_create_collection(name="text_collection",
11+
embedding_function=ONNXMiniLM_L6_V2())
12+
13+
def insert_data(self, texts_chunks):
14+
embedding_count = len(self.collection.get()['ids'])
15+
if embedding_count == 0:
16+
ids = [str(i) for i in range(1, len(texts_chunks)+1)]
17+
else:
18+
ids = [str(i) for i in range(embedding_count+1, embedding_count+1+len(texts_chunks))]
19+
self.collection.add(documents=texts_chunks, ids=ids)
20+
21+
def count(self):
22+
return self.collection.count()
23+
24+
def query(self, query_text, k):
25+
response = self.collection.query(query_texts=[query_text], n_results=k)
26+
return response['documents'][0]
27+
28+
def delete_vector_storage(self):
29+
if len(self.client.list_collections()) != 0:
30+
database_contents = os.listdir(f"{os.getcwd()}/vectorstore")
31+
self.client.delete_collection(name="text_collection")
32+
for name in database_contents:
33+
if os.path.isdir(f"vectorstore/{name}"):
34+
shutil.rmtree(f"vectorstore/{name}")
35+
else:
36+
os.remove(f"vectorstore/{name}")

prompt_templates.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
SYSTEM_PROMPT = """Answer the following question only using the context provided, being as concise as possible.
2+
If you're unsure, just say that you don't know.
3+
4+
"""

style.css

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
1-
/* div.element-container.st-emotion-cache-1xanlfj.e1f1d6gn4 {
2-
position: fixed;
3-
bottom: 8rem;
1+
[data-testid="column"] {
2+
box-shadow: rgb(0 0 0 / 20%) 0px 2px 1px -1px, rgb(0 0 0 / 14%) 0px 1px 1px 0px, rgb(0 0 0 / 12%) 0px 1px 3px 0px;
3+
border-radius: 15px;
4+
padding: 1% 1% 1% 1%;
45
}
56

6-
button.st-emotion-cache-7ym5gk.ef3psqc12 {
7-
position: fixed;
8-
bottom: 7rem;
9-
right: 300px;
10-
11-
} */
12-
13-
button.myButton {
14-
position: fixed;
15-
bottom: 7rem;
16-
right: 300px;
17-
7+
div[data-testid="stVerticalBlock"] div:has(div.fixed-header) {
8+
position: sticky;
9+
border-radius: 15px;
10+
background: rgb(101, 105, 109);
11+
top: 2.875rem;
12+
z-index: 999;
13+
text-align: center;
1814
}
19-

text_processor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from langchain_community.document_loaders import DirectoryLoader
2+
from langchain.text_splitter import RecursiveCharacterTextSplitter
3+
4+
5+
def get_document_chunks(path):
6+
text_splitter = RecursiveCharacterTextSplitter(
7+
separators=["\n\n","\n"],
8+
chunk_size=2000,
9+
chunk_overlap=100,
10+
length_function=len
11+
)
12+
13+
all_chunks = []
14+
15+
# loading all pdf documents at once
16+
pdf_loader = DirectoryLoader(path=str(path), glob="**/*.pdf")
17+
pdf_documents = pdf_loader.load()
18+
for single_chunk in text_splitter.split_documents(documents=pdf_documents):
19+
all_chunks.append(single_chunk.page_content)
20+
21+
return all_chunks

0 commit comments

Comments
 (0)