Skip to content

Commit c1bc61b

Browse files
committed
OpenAI AssistantGPT
Interface with Streamlit, Investment recommendation with assistant API
1 parent 3753e13 commit c1bc61b

File tree

3 files changed

+289
-0
lines changed

3 files changed

+289
-0
lines changed

main.py renamed to ChefGPT.py

File renamed without changes.

__pycache__/main.cpython-311.pyc

0 Bytes
Binary file not shown.

pages/07_AssistantGPT.py

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
import yfinance as yf
2+
from duckduckgo_search import DDGS
3+
from itertools import islice
4+
from openai import OpenAI
5+
import streamlit as st
6+
import re
7+
import json
8+
9+
# Modified DuckDuckGoSearchAPIWrapper class compatible
10+
11+
12+
class DuckDuckGoSearchAPIWrapper:
13+
def __init__(self, region='wt-wt', safesearch='Moderate', timelimit='y', backend='api'):
14+
self.region = region
15+
self.safesearch = safesearch
16+
self.timelimit = timelimit
17+
self.backend = backend
18+
19+
def run(self, query, max_results=5):
20+
try:
21+
ddgs = DDGS()
22+
results = ddgs.text(query, region=self.region, safesearch=self.safesearch,
23+
timelimit=self.timelimit, backend=self.backend)
24+
limited_results = list(islice(results, max_results))
25+
26+
if not limited_results:
27+
return "No good DuckDuckGo Search Result was found"
28+
29+
# Extract the ticker symbol using regex
30+
for result in limited_results:
31+
match = re.search(r'\b[A-Z]{1,5}\b', result['title'])
32+
if match:
33+
ticker_symbol = match.group(0)
34+
if ticker_symbol not in ["NYSE", "NASDAQ"]:
35+
return {"ticker": ticker_symbol}
36+
except Exception as e:
37+
return f"Error retrieving data: {str(e)}"
38+
39+
40+
# Describe function tools
41+
functions = [
42+
{
43+
"type": "function",
44+
"function": {
45+
"name": "get_ticker",
46+
"description": "Given the name of a company returns its ticker symbol",
47+
"parameters": {
48+
"type": "object",
49+
"properties": {
50+
"company_name": {
51+
"type": "string",
52+
"description": "The name of the company",
53+
}
54+
},
55+
"required": ["company_name"],
56+
},
57+
},
58+
},
59+
{
60+
"type": "function",
61+
"function": {
62+
"name": "get_income_statement",
63+
"description": "Given a ticker symbol (i.e AAPL) returns the company's income statement.",
64+
"parameters": {
65+
"type": "object",
66+
"properties": {
67+
"ticker": {
68+
"type": "string",
69+
"description": "Ticker symbol of the company",
70+
},
71+
},
72+
"required": ["ticker"],
73+
},
74+
},
75+
},
76+
{
77+
"type": "function",
78+
"function": {
79+
"name": "get_balance_sheet",
80+
"description": "Given a ticker symbol (i.e AAPL) returns the company's balance sheet.",
81+
"parameters": {
82+
"type": "object",
83+
"properties": {
84+
"ticker": {
85+
"type": "string",
86+
"description": "Ticker symbol of the company",
87+
},
88+
},
89+
"required": ["ticker"],
90+
},
91+
},
92+
},
93+
{
94+
"type": "function",
95+
"function": {
96+
"name": "get_daily_stock_performance",
97+
"description": "Given a ticker symbol (i.e AAPL) returns the performance of the stock for the last 100 days.",
98+
"parameters": {
99+
"type": "object",
100+
"properties": {
101+
"ticker": {
102+
"type": "string",
103+
"description": "Ticker symbol of the company",
104+
},
105+
},
106+
"required": ["ticker"],
107+
},
108+
},
109+
},
110+
]
111+
112+
# Define tool functions
113+
114+
115+
def get_ticker(inputs):
116+
# inputs will look like this: {"company_name":"Apple"}
117+
ddg = DuckDuckGoSearchAPIWrapper()
118+
company_name = inputs["company_name"]
119+
return ddg.run(f"Ticker symbol of {company_name}")
120+
121+
122+
def get_income_statement(inputs):
123+
ticker = inputs['ticker']
124+
stock = yf.Ticker(ticker)
125+
income_statement = stock.income_stmt
126+
return json.dumps(income_statement.to_json())
127+
128+
129+
def get_balance_sheet(inputs):
130+
ticker = inputs['ticker']
131+
stock = yf.Ticker(ticker)
132+
balance_sheet = stock.balance_sheet
133+
return json.dumps(balance_sheet.to_json())
134+
135+
136+
def get_daily_stock_performance(inputs):
137+
ticker = inputs['ticker']
138+
stock = yf.Ticker(ticker)
139+
history = stock.history(period="3mo")
140+
return json.dumps(history.to_json())
141+
142+
143+
# Mapping the functions
144+
functions_map = {
145+
"get_ticker": get_ticker,
146+
"get_income_statement": get_income_statement,
147+
"get_balance_sheet": get_balance_sheet,
148+
"get_daily_stock_performance": get_daily_stock_performance,
149+
}
150+
151+
# Create an assistant and get assistant ID
152+
client = OpenAI()
153+
154+
assistant_id = "asst_vC1F6Bt2TKZ0EVUf9tA6B9p8"
155+
156+
# Define functions for message handling
157+
158+
159+
def get_run(run_id, thread_id):
160+
return client.beta.threads.runs.retrieve(
161+
run_id=run_id,
162+
thread_id=thread_id,
163+
)
164+
165+
166+
def send_message(thread_id, content):
167+
return client.beta.threads.messages.create(
168+
thread_id=thread_id,
169+
role="user",
170+
content=content,
171+
)
172+
173+
174+
def get_messages(thread_id):
175+
messages = client.beta.threads.messages.list(
176+
thread_id=thread_id
177+
)
178+
messages = list(messages)
179+
return messages
180+
181+
182+
def get_tool_outputs(run_id, thread_id):
183+
run = get_run(run_id, thread_id)
184+
outputs = []
185+
for action in run.required_action.submit_tool_outputs.tool_calls:
186+
action_id = action.id
187+
function = action.function
188+
# because function.arguments just brings str, so convert it to json so that the function can actually use it.
189+
function_args = json.loads(function.arguments)
190+
print(
191+
f'Calling function:{function.name} with arg {function.arguments}')
192+
output = functions_map[function.name](function_args)
193+
output_str = json.dumps(output)
194+
outputs.append(
195+
{
196+
"tool_call_id": action_id,
197+
"output": output_str,
198+
}
199+
)
200+
return outputs
201+
202+
203+
def submit_tool_outputs(run_id, thread_id):
204+
outputs = get_tool_outputs(run_id, thread_id)
205+
return client.beta.threads.runs.submit_tool_outputs(
206+
run_id=run_id,
207+
thread_id=thread_id,
208+
tool_outputs=outputs
209+
)
210+
211+
212+
def save_message(message, role):
213+
st.session_state["messages"].append(
214+
{"message": message, "role": role}
215+
)
216+
217+
218+
def write_message(message, role, save=True):
219+
# shows messages in the beginning, and save them
220+
with st.chat_message(role):
221+
st.markdown(message)
222+
if save:
223+
# Note that the messages are stored in a dictionary form
224+
save_message(message, role)
225+
226+
# Displaying messages without saving them: display saved messages
227+
228+
229+
def paint_history():
230+
for message in st.session_state["messages"]:
231+
write_message(message["message"], message["role"], save=False)
232+
233+
234+
# ========================================================================================
235+
st.set_page_config(
236+
page_title="AssistantGPT",
237+
page_icon="💻",
238+
)
239+
st.markdown(
240+
"""
241+
# AssistantGPT
242+
243+
Welcome to AssistantGPT.
244+
245+
AssistantGPT will provide financial insights for the companies of your intrest for stock investment.
246+
247+
Provide the name of the company to begin with.
248+
"""
249+
)
250+
251+
query = st.text_input("Write the name of the company you are interested in.")
252+
253+
if query:
254+
paint_history()
255+
write_message(query, "human")
256+
if not st.session_state.get("thread"):
257+
thread = client.beta.threads.create(
258+
messages=[
259+
{
260+
"role": "user",
261+
"content": query
262+
}
263+
]
264+
)
265+
st.session_state["thread"] = [thread]
266+
else:
267+
thread = st.session_state["thread"][0]
268+
send_message(thread.id, query)
269+
run = client.beta.threads.runs.create(
270+
thread_id=thread.id,
271+
assistant_id=assistant_id,
272+
)
273+
with st.chat_message("ai"):
274+
with st.spinner("Creating an answer..."):
275+
while get_run(run.id, thread.id).status in [
276+
"queued",
277+
"in_progress",
278+
"requires_action",
279+
]:
280+
if get_run(run.id, thread.id).status == "requires_action":
281+
submit_tool_outputs(run.id, thread.id)
282+
message = get_messages(thread.id)[
283+
0].content[0].text.value.replace("$", "\$")
284+
save_message(message, "ai")
285+
st.markdown(message)
286+
287+
else:
288+
st.session_state["messages"] = []
289+
st.session_state["thread"] = []

0 commit comments

Comments
 (0)