1
- import ollama
2
- import streamlit as st
3
1
import os
4
- from utils import save_chat_history ,load_chat_history_json ,get_timestamp
5
- from datetime import datetime
6
- from streamlit_mic_recorder import mic_recorder
7
- from transformers import pipeline
8
2
import yaml
3
+ from datetime import datetime
4
+
5
+ import ollama
9
6
import torch
7
+ import streamlit as st
8
+ from transformers import pipeline
9
+ from streamlit_mic_recorder import mic_recorder
10
+ from utils import save_chat_history , load_chat_history_json , get_timestamp
11
+
12
+ from prompt_templates import SYSTEM_PROMPT
10
13
from audio_transcribe import transcribe_audio
14
+ from text_processor import get_document_chunks
15
+ from chromadb_operations import ChromadbOperations
16
+
17
+ header = st .container ()
18
+ header .title ("Knowly" )
19
+ header .write ("""<div class='fixed-header'/>""" , unsafe_allow_html = True )
20
+
21
+ with header :
22
+ col1 , col2 = st .columns (2 )
23
+ with col1 :
24
+ if "model" not in st .session_state :
25
+ st .session_state ["model" ] = ""
26
+ models = [model ["name" ] for model in ollama .list ()["models" ]]
27
+ st .session_state ["model" ] = st .selectbox ("Choose your model" , models )
28
+ with col2 :
29
+ st .write ('Record Audio:' )
30
+ voice_recording = mic_recorder (start_prompt = "Start recording" , stop_prompt = "Stop recording" , just_once = True )
31
+ transcribed_audio_prompt = ''
32
+ if voice_recording :
33
+ transcribed_audio_prompt = transcribe_audio (voice_recording ["bytes" ])
11
34
12
35
with open ('config.yaml' , 'r' ) as f :
13
36
config = yaml .safe_load (f )
@@ -32,68 +55,112 @@ def set_session_name(session):
32
55
del st .session_state ["messages" ]
33
56
st .session_state ["messages" ] = load_chat_history_json (session )
34
57
35
- def model_res_generator ():
58
+ def model_res_generator (rag :bool = False ):
59
+ prompt = st .session_state ["messages" ][- 1 ]["content" ] # extracting last user prompt
60
+ if rag :
61
+ context = st .session_state ["vector_db" ].query (query_text = prompt , k = 1 ) # fetching similar contexts from vector database
62
+
63
+ # creating paragraph of contexts
64
+ paragraph = ""
65
+ for i , item in enumerate (context ):
66
+ paragraph += item
67
+ if i != len (context )- 1 :
68
+ paragraph += "\n "
69
+
70
+ # replacing user prompt with augmented prompt
71
+ st .session_state ["messages" ][- 1 ]["content" ] = formatted_prompt (query = prompt , context = paragraph )
72
+
36
73
stream = ollama .chat (
37
74
model = st .session_state ["model" ],
38
75
messages = st .session_state ["messages" ],
39
76
stream = True ,
40
77
)
78
+
79
+ # replacing augmented prompt with actual user prompt
80
+ if rag :
81
+ st .session_state ["messages" ][- 1 ]["content" ] = prompt
41
82
for chunk in stream :
42
83
yield chunk ["message" ]["content" ]
43
84
44
- st .title ("Knowly" )
45
- st .sidebar .title ("Chat sessions" )
85
+ def formatted_prompt (query :str , context :str ):
86
+ return SYSTEM_PROMPT + f"Question: { query } " + f"\n \n Context: { context } "
87
+
88
+ def save_session (session_key ):
89
+ if "messages" in st .session_state :
90
+ if st .session_state .session_key == "new_session" :
91
+ st .session_state .session_key = get_timestamp () + '.json'
92
+ save_chat_history (st .session_state ['messages' ], st .session_state .session_key )
93
+ else :
94
+ save_chat_history (st .session_state ['messages' ], st .session_state .session_key )
46
95
47
96
if "messages" not in st .session_state :
48
97
st .session_state ["messages" ] = []
49
98
50
- if "model" not in st .session_state :
51
- st .session_state ["model" ] = ""
52
-
53
99
if "session_key" not in st .session_state :
54
100
if len (os .listdir ('sessions/' )) != 0 :
55
101
st .session_state ["session_key" ] = os .listdir ('sessions/' )[- 1 ]
56
102
st .session_state ["messages" ] = load_chat_history_json (st .session_state .session_key )
57
103
else :
58
104
st .session_state ["session_key" ] = "new_session"
59
105
60
- def save_session (session_key ):
61
- if "messages" in st .session_state :
62
- if st .session_state .session_key == "new_session" :
63
- st .session_state .session_key = get_timestamp () + '.json'
64
- save_chat_history (st .session_state ['messages' ],st .session_state .session_key )
65
- else :
66
- save_chat_history (st .session_state ['messages' ],st .session_state .session_key )
67
-
68
- models = [model ["name" ] for model in ollama .list ()["models" ]]
69
- st .session_state ["model" ] = st .selectbox ("choose you model" , models )
70
-
71
106
load_chat ()
72
107
73
- voice_recording = mic_recorder (start_prompt = "Start recording" , stop_prompt = "Stop recording" , just_once = True )
74
- transcribed_audio_prompt = ''
75
- if voice_recording :
76
- transcribed_audio_prompt = transcribe_audio (voice_recording ["bytes" ])
108
+ with st .sidebar :
109
+ st .sidebar .write ('**Pdf Upload:**' )
110
+ with st .form ("my-form" , clear_on_submit = True ):
111
+ uploaded_docs = st .file_uploader (label = "Upload pdf or text files" ,
112
+ accept_multiple_files = True ,
113
+ key = "document_uploader" ,
114
+ type = ["pdf" ])
115
+ submitted = st .form_submit_button ("UPLOAD" )
116
+
117
+ if submitted :
118
+ print ("uploaded docs section is running..." )
119
+ os .makedirs ("docs" , exist_ok = True )
120
+ with st .spinner ("Processing documents..." ):
121
+ # saving the uploaded files in directory
122
+ for file_item in uploaded_docs :
123
+ with open (f"docs/{ file_item .name } " , "wb" ) as f :
124
+ f .write (file_item .getbuffer ())
125
+ f .close ()
77
126
127
+ st .session_state ["vector_db" ] = ChromadbOperations ()
128
+ text_chunks = get_document_chunks (path = "docs" )
129
+ st .session_state ["vector_db" ].insert_data (text_chunks )
130
+ del st .session_state ["document_uploader" ]
131
+
132
+ # pdf chat
133
+ pdf_chat_mode = st .sidebar .toggle (label = "PDF Chat" ,
134
+ key = "pdf_chat" ,
135
+ value = False ,
136
+ disabled = True if "vectorstore" not in os .listdir (str (os .getcwd ())) else False )
137
+
138
+ # load the current vector database if exists
139
+ if pdf_chat_mode :
140
+ if "vector_db" not in st .session_state .keys () and "vectorstore" in os .listdir (str (os .getcwd ())):
141
+ st .session_state ["vector_db" ] = ChromadbOperations ()
142
+
78
143
user_prompt = st .chat_input ("Enter your question:" )
79
144
if user_prompt is not None or transcribed_audio_prompt != '' :
80
145
if user_prompt :
81
146
prompt = user_prompt
82
147
else :
83
148
prompt = transcribed_audio_prompt
84
-
149
+
85
150
st .session_state ["messages" ].append ({"role" : "user" , "content" : prompt })
86
151
87
152
with st .chat_message ("user" ):
88
153
st .markdown (prompt )
89
154
90
155
with st .chat_message ("assistant" ):
91
- message = st .write_stream (model_res_generator ())
92
- st .session_state ["messages" ].append ({"role" :"assistant" , "content" : message })
156
+ message = st .write_stream (model_res_generator (rag = pdf_chat_mode ))
157
+ st .session_state ["messages" ].append ({"role" : "assistant" , "content" : message })
93
158
94
159
save_session (st .session_state .session_key )
95
160
96
- st .sidebar .button (label = "new chat" , on_click = create_new_chat )
161
+ st .sidebar .write ('**Chat History:**' )
162
+
163
+ st .sidebar .button (label = "New chat" , on_click = create_new_chat )
97
164
98
165
session_list = os .listdir ("sessions/" )
99
166
for session in session_list :
0 commit comments