Skip to content

Commit

Permalink
HIGHLIGHT: added context support for /v1/chat/completions
Browse files Browse the repository at this point in the history
- Added "enable context?" toggle to Model Settings to turn on/off context (default = True)
- Added n_ctx in config.json set context size if context is enabled
- now load config into session_state and read from session_state instead of config
  • Loading branch information
3x3cut0r committed Nov 7, 2023
1 parent aa32cda commit a0061a4
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 88 deletions.
3 changes: 2 additions & 1 deletion src/config.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"api_url": "http://localhost:8000",
"page_title": "Llama-2-7b-Chat"
"page_title": "Llama-2-7b-Chat",
"n_ctx": 2048
}
148 changes: 99 additions & 49 deletions src/context.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,55 @@
from datetime import datetime
import streamlit as st

# render context in app
def render(container): # container = st.container()
container.empty()

with container.container():

for element in st.session_state['context']:

# if question
if 'question' in element:
q = element['question']
st.markdown(
f"""
<p style="
background-color: #343541;
color: #ececf1;
margin: 0px;
padding: 20px;
">
{q}
</p>
""",
unsafe_allow_html=True,
)

# if response
elif 'choices' in element and element['choices']:
if st.session_state['context'] != []:
for element in st.session_state['context']:

# if /v1/chat/completions endpoint
if 'message' in element['choices'][0]:
if 'content' in element['choices'][0]['message']:
c = element['choices'][0]['message']['content']
# if question
if 'question' in element:
q = element['question']
st.markdown(
f"""
<p style="
background-color: #343541;
color: #ececf1;
margin: 0px;
padding: 20px;
">
{q}
</p>
""",
unsafe_allow_html=True,
)

# if response
elif 'choices' in element and element['choices']:

# if /v1/chat/completions endpoint
if 'message' in element['choices'][0]:
if 'content' in element['choices'][0]['message']:
c = element['choices'][0]['message']['content']
st.markdown(
f"""
<p style="
background-color: #444654;
color: #ced2d8;
margin: 0px;
padding: 20px;
">
{c}
</p>
""",
unsafe_allow_html=True,
)

# if /v1/completions entpoint
elif 'text' in element['choices'][0]:
c = element['choices'][0]['text']
st.markdown(
f"""
<p style="
Expand All @@ -46,31 +63,14 @@ def render(container): # container = st.container()
""",
unsafe_allow_html=True,
)

# if /v1/completions entpoint
elif 'text' in element['choices'][0]:
c = element['choices'][0]['text']
st.markdown(
f"""
<p style="
background-color: #444654;
color: #ced2d8;
margin: 0px;
padding: 20px;
">
{c}
</p>
""",
unsafe_allow_html=True,
)

# append question to context
def append_question(question): # question = string
if st.session_state['context'] == [] or 'question' not in st.session_state['context'][-1] or st.session_state['context'][-1]['question'] != question:
# append user_content to context
def append_question(user_content): # user_content = question = string
if st.session_state['context'] == [] or 'question' not in st.session_state['context'][-1] or st.session_state['context'][-1]['question'] != user_content:
now = int(datetime.now().timestamp())
st.session_state['context'].append({
"id": 0, # todo: add question id here
"question": question,
"question": user_content,
"created": now
})

Expand Down Expand Up @@ -115,4 +115,54 @@ def append(ctx): # ctx = python dict

# raise error if no context was found
else:
raise Exception(f'Error: no context to append or wrong api endpoint\n\nmessage: {ctx}')
raise Exception(f'Error: no context to append or wrong api endpoint\n\nmessage: {ctx}')

# return message from context
def get_message(ctx_element):
# if question
if 'question' in ctx_element:
return "User: " + ctx_element['question'] + "\n"

# if response
elif 'choices' in ctx_element and ctx_element['choices']:

# if /v1/chat/completions endpoint
if 'message' in ctx_element['choices'][0]:
if 'content' in ctx_element['choices'][0]['message']:
return "System: " + ctx_element['choices'][0]['message']['content'] + "\n"

# if /v1/completions entpoint
elif 'text' in ctx_element['choices'][0]:
return "System: " + ctx_element['choices'][0]['text'] + "\n"

# return context history
def get_messages_history(system_content):
history = ""

messages = [{
"role": "system",
"content": system_content
}]

if st.session_state['context'] != []:

# if context is enabled
if st.session_state['enable_context']:
for ctx_element in st.session_state['context']:
history += get_message(ctx_element)

# cut history to n_ctx length of llama.cpp server
# todo: cut complete user and/or system message instead of cutting somewhere in the middle
n_ctx = st.session_state['n_ctx']
history = (history[-n_ctx:]) if len(history) >= n_ctx else history

# if context is disabled
else:
history += get_message(st.session_state['context'][-1])

messages.append({
"role": "user",
"content": history
})

return messages
50 changes: 21 additions & 29 deletions src/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,54 +6,46 @@

urllib3.disable_warnings()

# load config-file
with open("src/config.json", "r", encoding="utf-8") as file:
config = json.load(file)

# send request to API
def send(endpoint, user_content, stream, max_tokens, temperature, top_p, top_k, repeat_penalty, stop, system_content, content_container):
def send(user_content, content_container):

# create static json_data for all requests
json_data = {
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repeat_penalty": repeat_penalty,
"stop": stop,
"stream": stream
"max_tokens": st.session_state['max_tokens'],
"temperature": st.session_state['temperature'],
"top_p": st.session_state['top_p'],
"top_k": st.session_state['top_k'],
"repeat_penalty": st.session_state['repeat_penalty'],
"stop": st.session_state['stop'],
"stream": st.session_state['stream']
}

# add endpoint specific json_data
# endpoint = /v1/chat/completions
if endpoint == "/v1/chat/completions":
json_data['messages'] = [
{
"content": system_content,
"role": "system"
},
{
"content": user_content,
"role": "user"
}
]
if st.session_state['endpoint'] == "/v1/chat/completions":

# add previous context to messages
json_data['messages'] = context.get_messages_history(st.session_state['system_content'])

# other endpoints
else:
system_content = system_content.replace('{prompt}', user_content)
system_content = st.session_state['system_content'].replace('{prompt}', user_content)
json_data['prompt'] = system_content

# send json_data to endpoint
try:
s = requests.Session()
headers = None
with s.post(config["api_url"] + endpoint,
with s.post(st.session_state["api_url"] + st.session_state['endpoint'],
json=json_data,
headers=headers,
stream=stream,
stream=st.session_state['stream'],
timeout=240,
verify=False
) as response:

# if stream is True
if stream:
if st.session_state['stream']:

# store chunks into context
for chunk in response.iter_lines(chunk_size=None, decode_unicode=True):
Expand Down Expand Up @@ -97,7 +89,7 @@ def stop(endpoint, stop):
# send stop request to endpoint
try:
requests.post(
config["api_url"] + endpoint,
st.session_state["api_url"] + endpoint,
json={"messages": stop[0]},
verify=False
)
Expand Down
15 changes: 15 additions & 0 deletions src/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import json
import streamlit as st

# load config-file
def load():
# initialize context in session state if not present
if 'context' not in st.session_state:
st.session_state['context'] = []

# load config in session state
with open("src/config.json", "r", encoding="utf-8") as file:
config = json.load(file)
st.session_state['api_url'] = config['api_url'] if 'api_url' in config else "http://localhost:8000"
st.session_state['title'] = config['title'] if 'title' in config else "Llama-2-7b-Chat"
st.session_state['n_ctx'] = int(config['n_ctx']) if 'n_ctx' in config else 2048
39 changes: 36 additions & 3 deletions src/sidebar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,56 @@ def render():
data = json.load(file)
endpoints = list(data.keys())

# endpoint
endpoint = st.selectbox("endpoint", endpoints)
st.session_state['endpoint'] = endpoint

# sidebar_title
st.title("Model Settings")
st.session_state['sidebar_title'] = "Model Settings"

# user_content
user_content = ""
st.session_state['user_content'] = user_content

# enable_context
enable_context = st.toggle("enable context?", value=True) if endpoint == "/v1/chat/completions" else False
st.session_state['enable_context'] = enable_context

# stream
stream = st.toggle("stream results?", value=True)
st.session_state['stream'] = stream

# max_tokens
max_tokens = st.number_input("max_tokens", value=256, min_value=16, max_value=2048, step=1)
st.session_state['max_tokens'] = max_tokens

# temperature
temperature = st.number_input("temperature", value=0.2, min_value=0.01, max_value=1.99, step=0.05)
st.session_state['temperature'] = temperature

# top_p
top_p = st.number_input("top_p", value=0.95, min_value=0.0, max_value=1.0, step=0.05)
st.session_state['top_p'] = top_p

# top_k
top_k = st.number_input("top_k", value=40, min_value=1, max_value=200, step=1)
st.session_state['top_k'] = top_k

# repeat_penalty
repeat_penalty = st.number_input("repeat_penalty", value=1.1, min_value=1.0, max_value=1.5, step=0.05)
st.session_state['repeat_penalty'] = repeat_penalty

# stop
stop = st.text_input("stop", value=r'\n, ###')
stop = stop.encode().decode('unicode_escape')
stop = stop.replace(" ", "").split(",")
st.session_state['stop'] = stop

if endpoint == "/v1/chat/completions":
system_content = st.text_area("system_content", value="You are a helpful assistant.", height=200)
system_content = st.text_area("system_content", value="A dialog, where User interacts with AI. AI is helpful, kind, obedient, honest, and knows its own limits.", height=200)
else:
system_content = st.text_area("system_content", value=r"\n\n### Instructions:\n{prompt}\n\n### Response:\n", height=200)
system_content = system_content.encode().decode('unicode_escape')
st.markdown("hint: the expression `{prompt}` must exist!", unsafe_allow_html=True)

return endpoint, user_content, stream, max_tokens, temperature, top_p, top_k, repeat_penalty, stop, system_content
st.session_state['system_content'] = system_content
12 changes: 6 additions & 6 deletions streamlit_app.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import streamlit as st
import src.header as header
import src.session as session
import src.sidebar as sidebar
import src.request as request
import src.context as context

# render header
header.render()

# render sidebar
(endpoint, user_content, stream, max_tokens, temperature, top_p, top_k, repeat_penalty, stop, system_content) = sidebar.render()
# load config
session.load()

# initialize context in session state if not present
if 'context' not in st.session_state:
st.session_state['context'] = []
# render sidebar
sidebar.render()

# render content_container
content_container = st.empty()
Expand All @@ -36,4 +36,4 @@
context.render(content_container)

with st.spinner('Generating response...'):
request.send(endpoint, user_content, stream, max_tokens, temperature, top_p, top_k, repeat_penalty, stop, system_content, content_container)
request.send(user_content, content_container)

0 comments on commit a0061a4

Please sign in to comment.