Skip to content

Commit a0061a4

Browse files
committed
HIGHLIGHT: added context support for /v1/chat/completions
- Added "enable context?" toggle to Model Settings to turn on/off context (default = True) - Added n_ctx in config.json set context size if context is enabled - now load config into session_state and read from session_state instead of config
1 parent aa32cda commit a0061a4

File tree

6 files changed

+179
-88
lines changed

6 files changed

+179
-88
lines changed

src/config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{
22
"api_url": "http://localhost:8000",
3-
"page_title": "Llama-2-7b-Chat"
3+
"page_title": "Llama-2-7b-Chat",
4+
"n_ctx": 2048
45
}

src/context.py

Lines changed: 99 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,55 @@
11
from datetime import datetime
22
import streamlit as st
3-
3+
44
# render context in app
55
def render(container): # container = st.container()
66
container.empty()
77

88
with container.container():
9-
10-
for element in st.session_state['context']:
11-
12-
# if question
13-
if 'question' in element:
14-
q = element['question']
15-
st.markdown(
16-
f"""
17-
<p style="
18-
background-color: #343541;
19-
color: #ececf1;
20-
margin: 0px;
21-
padding: 20px;
22-
">
23-
{q}
24-
</p>
25-
""",
26-
unsafe_allow_html=True,
27-
)
28-
29-
# if response
30-
elif 'choices' in element and element['choices']:
9+
if st.session_state['context'] != []:
10+
for element in st.session_state['context']:
3111

32-
# if /v1/chat/completions endpoint
33-
if 'message' in element['choices'][0]:
34-
if 'content' in element['choices'][0]['message']:
35-
c = element['choices'][0]['message']['content']
12+
# if question
13+
if 'question' in element:
14+
q = element['question']
15+
st.markdown(
16+
f"""
17+
<p style="
18+
background-color: #343541;
19+
color: #ececf1;
20+
margin: 0px;
21+
padding: 20px;
22+
">
23+
{q}
24+
</p>
25+
""",
26+
unsafe_allow_html=True,
27+
)
28+
29+
# if response
30+
elif 'choices' in element and element['choices']:
31+
32+
# if /v1/chat/completions endpoint
33+
if 'message' in element['choices'][0]:
34+
if 'content' in element['choices'][0]['message']:
35+
c = element['choices'][0]['message']['content']
36+
st.markdown(
37+
f"""
38+
<p style="
39+
background-color: #444654;
40+
color: #ced2d8;
41+
margin: 0px;
42+
padding: 20px;
43+
">
44+
{c}
45+
</p>
46+
""",
47+
unsafe_allow_html=True,
48+
)
49+
50+
# if /v1/completions entpoint
51+
elif 'text' in element['choices'][0]:
52+
c = element['choices'][0]['text']
3653
st.markdown(
3754
f"""
3855
<p style="
@@ -46,31 +63,14 @@ def render(container): # container = st.container()
4663
""",
4764
unsafe_allow_html=True,
4865
)
49-
50-
# if /v1/completions entpoint
51-
elif 'text' in element['choices'][0]:
52-
c = element['choices'][0]['text']
53-
st.markdown(
54-
f"""
55-
<p style="
56-
background-color: #444654;
57-
color: #ced2d8;
58-
margin: 0px;
59-
padding: 20px;
60-
">
61-
{c}
62-
</p>
63-
""",
64-
unsafe_allow_html=True,
65-
)
6666

67-
# append question to context
68-
def append_question(question): # question = string
69-
if st.session_state['context'] == [] or 'question' not in st.session_state['context'][-1] or st.session_state['context'][-1]['question'] != question:
67+
# append user_content to context
68+
def append_question(user_content): # user_content = question = string
69+
if st.session_state['context'] == [] or 'question' not in st.session_state['context'][-1] or st.session_state['context'][-1]['question'] != user_content:
7070
now = int(datetime.now().timestamp())
7171
st.session_state['context'].append({
7272
"id": 0, # todo: add question id here
73-
"question": question,
73+
"question": user_content,
7474
"created": now
7575
})
7676

@@ -115,4 +115,54 @@ def append(ctx): # ctx = python dict
115115

116116
# raise error if no context was found
117117
else:
118-
raise Exception(f'Error: no context to append or wrong api endpoint\n\nmessage: {ctx}')
118+
raise Exception(f'Error: no context to append or wrong api endpoint\n\nmessage: {ctx}')
119+
120+
# return message from context
121+
def get_message(ctx_element):
122+
# if question
123+
if 'question' in ctx_element:
124+
return "User: " + ctx_element['question'] + "\n"
125+
126+
# if response
127+
elif 'choices' in ctx_element and ctx_element['choices']:
128+
129+
# if /v1/chat/completions endpoint
130+
if 'message' in ctx_element['choices'][0]:
131+
if 'content' in ctx_element['choices'][0]['message']:
132+
return "System: " + ctx_element['choices'][0]['message']['content'] + "\n"
133+
134+
# if /v1/completions entpoint
135+
elif 'text' in ctx_element['choices'][0]:
136+
return "System: " + ctx_element['choices'][0]['text'] + "\n"
137+
138+
# return context history
139+
def get_messages_history(system_content):
140+
history = ""
141+
142+
messages = [{
143+
"role": "system",
144+
"content": system_content
145+
}]
146+
147+
if st.session_state['context'] != []:
148+
149+
# if context is enabled
150+
if st.session_state['enable_context']:
151+
for ctx_element in st.session_state['context']:
152+
history += get_message(ctx_element)
153+
154+
# cut history to n_ctx length of llama.cpp server
155+
# todo: cut complete user and/or system message instead of cutting somewhere in the middle
156+
n_ctx = st.session_state['n_ctx']
157+
history = (history[-n_ctx:]) if len(history) >= n_ctx else history
158+
159+
# if context is disabled
160+
else:
161+
history += get_message(st.session_state['context'][-1])
162+
163+
messages.append({
164+
"role": "user",
165+
"content": history
166+
})
167+
168+
return messages

src/request.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,54 +6,46 @@
66

77
urllib3.disable_warnings()
88

9-
# load config-file
10-
with open("src/config.json", "r", encoding="utf-8") as file:
11-
config = json.load(file)
12-
139
# send request to API
14-
def send(endpoint, user_content, stream, max_tokens, temperature, top_p, top_k, repeat_penalty, stop, system_content, content_container):
10+
def send(user_content, content_container):
11+
12+
# create static json_data for all requests
1513
json_data = {
16-
"max_tokens": max_tokens,
17-
"temperature": temperature,
18-
"top_p": top_p,
19-
"top_k": top_k,
20-
"repeat_penalty": repeat_penalty,
21-
"stop": stop,
22-
"stream": stream
14+
"max_tokens": st.session_state['max_tokens'],
15+
"temperature": st.session_state['temperature'],
16+
"top_p": st.session_state['top_p'],
17+
"top_k": st.session_state['top_k'],
18+
"repeat_penalty": st.session_state['repeat_penalty'],
19+
"stop": st.session_state['stop'],
20+
"stream": st.session_state['stream']
2321
}
2422

23+
# add endpoint specific json_data
2524
# endpoint = /v1/chat/completions
26-
if endpoint == "/v1/chat/completions":
27-
json_data['messages'] = [
28-
{
29-
"content": system_content,
30-
"role": "system"
31-
},
32-
{
33-
"content": user_content,
34-
"role": "user"
35-
}
36-
]
25+
if st.session_state['endpoint'] == "/v1/chat/completions":
26+
27+
# add previous context to messages
28+
json_data['messages'] = context.get_messages_history(st.session_state['system_content'])
3729

3830
# other endpoints
3931
else:
40-
system_content = system_content.replace('{prompt}', user_content)
32+
system_content = st.session_state['system_content'].replace('{prompt}', user_content)
4133
json_data['prompt'] = system_content
42-
34+
4335
# send json_data to endpoint
4436
try:
4537
s = requests.Session()
4638
headers = None
47-
with s.post(config["api_url"] + endpoint,
39+
with s.post(st.session_state["api_url"] + st.session_state['endpoint'],
4840
json=json_data,
4941
headers=headers,
50-
stream=stream,
42+
stream=st.session_state['stream'],
5143
timeout=240,
5244
verify=False
5345
) as response:
5446

5547
# if stream is True
56-
if stream:
48+
if st.session_state['stream']:
5749

5850
# store chunks into context
5951
for chunk in response.iter_lines(chunk_size=None, decode_unicode=True):
@@ -97,7 +89,7 @@ def stop(endpoint, stop):
9789
# send stop request to endpoint
9890
try:
9991
requests.post(
100-
config["api_url"] + endpoint,
92+
st.session_state["api_url"] + endpoint,
10193
json={"messages": stop[0]},
10294
verify=False
10395
)

src/session.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import json
2+
import streamlit as st
3+
4+
# load config-file
5+
def load():
6+
# initialize context in session state if not present
7+
if 'context' not in st.session_state:
8+
st.session_state['context'] = []
9+
10+
# load config in session state
11+
with open("src/config.json", "r", encoding="utf-8") as file:
12+
config = json.load(file)
13+
st.session_state['api_url'] = config['api_url'] if 'api_url' in config else "http://localhost:8000"
14+
st.session_state['title'] = config['title'] if 'title' in config else "Llama-2-7b-Chat"
15+
st.session_state['n_ctx'] = int(config['n_ctx']) if 'n_ctx' in config else 2048

src/sidebar.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,56 @@ def render():
1010
data = json.load(file)
1111
endpoints = list(data.keys())
1212

13+
# endpoint
1314
endpoint = st.selectbox("endpoint", endpoints)
15+
st.session_state['endpoint'] = endpoint
16+
17+
# sidebar_title
1418
st.title("Model Settings")
19+
st.session_state['sidebar_title'] = "Model Settings"
20+
21+
# user_content
1522
user_content = ""
23+
st.session_state['user_content'] = user_content
24+
25+
# enable_context
26+
enable_context = st.toggle("enable context?", value=True) if endpoint == "/v1/chat/completions" else False
27+
st.session_state['enable_context'] = enable_context
28+
29+
# stream
1630
stream = st.toggle("stream results?", value=True)
31+
st.session_state['stream'] = stream
32+
33+
# max_tokens
1734
max_tokens = st.number_input("max_tokens", value=256, min_value=16, max_value=2048, step=1)
35+
st.session_state['max_tokens'] = max_tokens
36+
37+
# temperature
1838
temperature = st.number_input("temperature", value=0.2, min_value=0.01, max_value=1.99, step=0.05)
39+
st.session_state['temperature'] = temperature
40+
41+
# top_p
1942
top_p = st.number_input("top_p", value=0.95, min_value=0.0, max_value=1.0, step=0.05)
43+
st.session_state['top_p'] = top_p
44+
45+
# top_k
2046
top_k = st.number_input("top_k", value=40, min_value=1, max_value=200, step=1)
47+
st.session_state['top_k'] = top_k
48+
49+
# repeat_penalty
2150
repeat_penalty = st.number_input("repeat_penalty", value=1.1, min_value=1.0, max_value=1.5, step=0.05)
51+
st.session_state['repeat_penalty'] = repeat_penalty
52+
53+
# stop
2254
stop = st.text_input("stop", value=r'\n, ###')
2355
stop = stop.encode().decode('unicode_escape')
2456
stop = stop.replace(" ", "").split(",")
57+
st.session_state['stop'] = stop
58+
2559
if endpoint == "/v1/chat/completions":
26-
system_content = st.text_area("system_content", value="You are a helpful assistant.", height=200)
60+
system_content = st.text_area("system_content", value="A dialog, where User interacts with AI. AI is helpful, kind, obedient, honest, and knows its own limits.", height=200)
2761
else:
2862
system_content = st.text_area("system_content", value=r"\n\n### Instructions:\n{prompt}\n\n### Response:\n", height=200)
2963
system_content = system_content.encode().decode('unicode_escape')
3064
st.markdown("hint: the expression `{prompt}` must exist!", unsafe_allow_html=True)
31-
32-
return endpoint, user_content, stream, max_tokens, temperature, top_p, top_k, repeat_penalty, stop, system_content
65+
st.session_state['system_content'] = system_content

streamlit_app.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
import streamlit as st
22
import src.header as header
3+
import src.session as session
34
import src.sidebar as sidebar
45
import src.request as request
56
import src.context as context
67

78
# render header
89
header.render()
910

10-
# render sidebar
11-
(endpoint, user_content, stream, max_tokens, temperature, top_p, top_k, repeat_penalty, stop, system_content) = sidebar.render()
11+
# load config
12+
session.load()
1213

13-
# initialize context in session state if not present
14-
if 'context' not in st.session_state:
15-
st.session_state['context'] = []
14+
# render sidebar
15+
sidebar.render()
1616

1717
# render content_container
1818
content_container = st.empty()
@@ -36,4 +36,4 @@
3636
context.render(content_container)
3737

3838
with st.spinner('Generating response...'):
39-
request.send(endpoint, user_content, stream, max_tokens, temperature, top_p, top_k, repeat_penalty, stop, system_content, content_container)
39+
request.send(user_content, content_container)

0 commit comments

Comments
 (0)