-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
1120 lines (912 loc) · 39.1 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import asyncio
import datetime
import json
import os
import re
import argparse
import logging
import numpy as np
import requests
from asknews_sdk import AskNewsSDK
import typeguard
from litellm import acompletion
from litellm.files.main import ModelResponse
from litellm.types.utils import Choices
import litellm
from pydantic import BaseModel
import forecasting_tools
from forecasting_tools.ai_models.resource_managers.refreshing_bucket_rate_limiter import RefreshingBucketRateLimiter
# Add this after imports, before CONSTANTS section
logging.getLogger("LiteLLM").setLevel(logging.ERROR)
######################### CONSTANTS #########################
# Constants
SUBMIT_PREDICTION = True # set to True to publish your predictions to Metaculus
USE_EXAMPLE_QUESTIONS = False # set to True to forecast example questions rather than the tournament questions
NUM_RUNS_PER_QUESTION = 5 # The median forecast is taken between NUM_RUNS_PER_QUESTION runs
SKIP_PREVIOUSLY_FORECASTED_QUESTIONS = True
GET_NEWS = True # set to True to enable AskNews after entering ASKNEWS secrets
LLM_MODEL_NAME: str | None = None
CALL_VERY_SLOWLY = False
# Environment variables
METACULUS_TOKEN = os.getenv("METACULUS_TOKEN") or None
PERPLEXITY_API_KEY = os.getenv("PERPLEXITY_API_KEY") or None
ASKNEWS_CLIENT_ID = os.getenv("ASKNEWS_CLIENT_ID") or None
ASKNEWS_SECRET = os.getenv("ASKNEWS_SECRET") or None
EXA_API_KEY = os.getenv("EXA_API_KEY") or None
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") or None # You'll need the OpenAI API Key if you want to use the Exa Smart Searcher
# The tournament IDs below can be used for testing your bot.
Q4_2024_AI_BENCHMARKING_ID = 32506
Q1_2025_AI_BENCHMARKING_ID = 32627
Q4_2024_QUARTERLY_CUP_ID = 3672
Q1_2025_QUARTERLY_CUP_ID = 32630
AXC_2025_TOURNAMENT_ID = 32564
GIVEWELL_ID = 3600
RESPIRATORY_OUTLOOK_ID = 3411
TOURNAMENT_ID = Q1_2025_AI_BENCHMARKING_ID
# The example questions can be used for testing your bot. (note that question and post id are not always the same)
EXAMPLE_QUESTIONS = [ # (question_id, post_id)
(578, 578), # Human Extinction - Binary - https://www.metaculus.com/questions/578/human-extinction-by-2100/
(14333, 14333), # Age of Oldest Human - Numeric - https://www.metaculus.com/questions/14333/age-of-oldest-human-as-of-2100/
(22427, 22427), # Number of New Leading AI Labs - Multiple Choice - https://www.metaculus.com/questions/22427/number-of-new-leading-ai-labs/
]
######################### HELPER FUNCTIONS #########################
# @title Helper functions
AUTH_HEADERS = {"headers": {"Authorization": f"Token {METACULUS_TOKEN}"}}
API_BASE_URL = "https://www.metaculus.com/api"
def post_question_comment(post_id: int, comment_text: str) -> None:
"""
Post a comment on the question page as the bot user.
"""
response = requests.post(
f"{API_BASE_URL}/comments/create/",
json={
"text": comment_text,
"parent": None,
"included_forecast": True,
"is_private": True,
"on_post": post_id,
},
**AUTH_HEADERS, # type: ignore
)
if not response.ok:
raise RuntimeError(response.text)
def post_question_prediction(question_id: int, forecast_payload: dict) -> None:
"""
Post a forecast on a question.
"""
url = f"{API_BASE_URL}/questions/forecast/"
response = requests.post(
url,
json=[
{
"question": question_id,
**forecast_payload,
},
],
**AUTH_HEADERS, # type: ignore
)
print(f"Response: {response.status_code}")
if not response.ok:
raise RuntimeError(response.text)
def create_forecast_payload(
forecast: float | dict[str, float] | list[float],
question_type: str,
) -> dict:
"""
Accepts a forecast and generates the api payload in the correct format.
If the question is binary, forecast must be a float.
If the question is multiple choice, forecast must be a dictionary that
maps question.options labels to floats.
If the question is numeric, forecast must be a dictionary that maps
quartiles or percentiles to datetimes, or a 201 value cdf.
"""
if question_type == "binary":
return {
"probability_yes": forecast,
"probability_yes_per_category": None,
"continuous_cdf": None,
}
if question_type == "multiple_choice":
return {
"probability_yes": None,
"probability_yes_per_category": forecast,
"continuous_cdf": None,
}
# numeric or date
return {
"probability_yes": None,
"probability_yes_per_category": None,
"continuous_cdf": forecast,
}
def list_posts_from_tournament(
tournament_id: int, offset: int = 0, count: int = 50
) -> list[dict]:
"""
List (all details) {count} posts from the {tournament_id}
"""
url_qparams = {
"limit": count,
"offset": offset,
"order_by": "-hotness",
"forecast_type": ",".join(
[
"binary",
"multiple_choice",
"numeric",
]
),
"tournaments": [tournament_id],
"statuses": "open",
"include_description": "true",
}
url = f"{API_BASE_URL}/posts/"
response = requests.get(url, **AUTH_HEADERS, params=url_qparams) # type: ignore
if not response.ok:
raise Exception(response.text)
data = json.loads(response.content)
return data
def get_open_question_ids_from_tournament(tournament_id: int) -> list[tuple[int, int]]:
posts = list_posts_from_tournament(tournament_id)
post_dict = dict()
for post in posts["results"]: # type: ignore
if question := post.get("question"):
# single question post
post_dict[post["id"]] = [question]
open_question_id_post_id = [] # [(question_id, post_id)]
for post_id, questions in post_dict.items():
for question in questions:
if question.get("status") == "open":
print(
f"ID: {question['id']}\nQ: {question['title']}\nCloses: "
f"{question['scheduled_close_time']}"
)
open_question_id_post_id.append((question["id"], post_id))
return open_question_id_post_id
def get_post_details(post_id: int) -> dict:
"""
Get all details about a post from the Metaculus API.
"""
url = f"{API_BASE_URL}/posts/{post_id}/"
print(f"Getting details for {url}")
response = requests.get(
url,
**AUTH_HEADERS, # type: ignore
)
if not response.ok:
raise Exception(response.text)
return json.loads(response.content)
llm_concurrency_semaphore: asyncio.Semaphore | None = None
rate_limiter = RefreshingBucketRateLimiter(
capacity=1,
refresh_rate=0.05,
)
async def call_llm(prompt: str, temperature: float = 0.3) -> str:
assert LLM_MODEL_NAME is not None
litellm.drop_params = True
assert llm_concurrency_semaphore is not None
if LLM_MODEL_NAME == "gemini/gemini-exp-1206":
context_window = 100000
prompt = prompt[:context_window]
if CALL_VERY_SLOWLY:
await rate_limiter.wait_till_able_to_acquire_resources(1)
async with llm_concurrency_semaphore:
response = await acompletion(
model=LLM_MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
temperature=temperature,
stream=False,
)
assert isinstance(response, ModelResponse)
choices = response.choices
choices = typeguard.check_type(choices, list[Choices])
reasoning = choices[0].message.content
assert isinstance(reasoning, str)
return reasoning
def run_research(question: str) -> str:
research = ""
if GET_NEWS == True:
if ASKNEWS_CLIENT_ID and ASKNEWS_SECRET:
research = call_asknews(question)
elif EXA_API_KEY:
research = call_exa_smart_searcher(question)
elif PERPLEXITY_API_KEY:
research = call_perplexity(question)
else:
raise ValueError("No API key provided")
else:
research = "No research done"
return research
def call_perplexity(question: str) -> str:
url = "https://api.perplexity.ai/chat/completions"
api_key = PERPLEXITY_API_KEY
headers = {
"accept": "application/json",
"authorization": f"Bearer {api_key}",
"content-type": "application/json",
}
payload = {
"model": "sonar-pro",
"messages": [
{
"role": "system", # this is a system prompt designed to guide the perplexity assistant
"content": """
You are an assistant to a superforecaster.
The superforecaster will give you a question they intend to forecast on.
To be a great assistant, you generate a concise but detailed rundown of the most relevant news, including if the question would resolve Yes or No based on current information.
You do not produce forecasts yourself.
""",
},
{
"role": "user", # this is the actual prompt we ask the perplexity assistant to answer
"content": question,
},
],
}
response = requests.post(url=url, json=payload, headers=headers)
if not response.ok:
raise Exception(response.text)
content = response.json()["choices"][0]["message"]["content"]
return content
def call_exa_smart_searcher(question: str) -> str:
if OPENAI_API_KEY is None:
searcher = forecasting_tools.ExaSearcher(
include_highlights=True,
num_results=10,
)
highlights = asyncio.run(searcher.invoke_for_highlights_in_relevance_order(question))
prioritized_highlights = highlights[:10]
combined_highlights = ""
for i, highlight in enumerate(prioritized_highlights):
combined_highlights += f'[Highlight {i+1}]:\nTitle: {highlight.source.title}\nURL: {highlight.source.url}\nText: "{highlight.highlight_text}"\n\n'
response = combined_highlights
else:
searcher = forecasting_tools.SmartSearcher(
temperature=0,
num_searches_to_run=2,
num_sites_per_search=10,
)
prompt = (
"You are an assistant to a superforecaster. The superforecaster will give"
"you a question they intend to forecast on. To be a great assistant, you generate"
"a concise but detailed rundown of the most relevant news, including if the question"
"would resolve Yes or No based on current information. You do not produce forecasts yourself."
f"\n\nThe question is: {question}"
)
response = asyncio.run(searcher.invoke(prompt))
return response
def call_asknews(question: str) -> str:
"""
Use the AskNews `news` endpoint to get news context for your query.
The full API reference can be found here: https://docs.asknews.app/en/reference#get-/v1/news/search
"""
ask = AskNewsSDK(
client_id=ASKNEWS_CLIENT_ID, client_secret=ASKNEWS_SECRET, scopes=set(["news"])
)
# get the latest news related to the query (within the past 48 hours)
hot_response = ask.news.search_news(
query=question, # your natural language query
n_articles=6, # control the number of articles to include in the context, originally 5
return_type="both",
strategy="latest news", # enforces looking at the latest news only
)
# get context from the "historical" database that contains a news archive going back to 2023
historical_response = ask.news.search_news(
query=question,
n_articles=10,
return_type="both",
strategy="news knowledge", # looks for relevant news within the past 60 days
)
hot_articles = hot_response.as_dicts
historical_articles = historical_response.as_dicts
formatted_articles = "Here are the relevant news articles:\n\n"
if hot_articles:
hot_articles = [article.__dict__ for article in hot_articles]
hot_articles = sorted(hot_articles, key=lambda x: x["pub_date"], reverse=True)
for article in hot_articles:
pub_date = article["pub_date"].strftime("%B %d, %Y %I:%M %p")
formatted_articles += f"**{article['eng_title']}**\n{article['summary']}\nOriginal language: {article['language']}\nPublish date: {pub_date}\nSource:[{article['source_id']}]({article['article_url']})\n\n"
if historical_articles:
historical_articles = [article.__dict__ for article in historical_articles]
historical_articles = sorted(
historical_articles, key=lambda x: x["pub_date"], reverse=True
)
for article in historical_articles:
pub_date = article["pub_date"].strftime("%B %d, %Y %I:%M %p")
formatted_articles += f"**{article['eng_title']}**\n{article['summary']}\nOriginal language: {article['language']}\nPublish date: {pub_date}\nSource:[{article['source_id']}]({article['article_url']})\n\n"
if not hot_articles and not historical_articles:
formatted_articles += "No articles were found.\n\n"
return formatted_articles
return formatted_articles
class ReasonedPrediction(BaseModel):
forecast: float | dict[str, float] | list[float]
rationale: str
class AggregatePrediction(BaseModel):
forecast: float | dict[str, float] | list[float]
sub_predictions: list[ReasonedPrediction]
news: str
############### BINARY ###############
# @title Binary prompt & functions
# This section includes functionality for binary questions.
BINARY_PROMPT_TEMPLATE = """
You are a professional forecaster interviewing for a job.
Your interview question is:
{title}
Question background:
{background}
This question's outcome will be determined by the specific criteria below. These criteria have not yet been satisfied:
{resolution_criteria}
{fine_print}
Your research assistant says:
{summary_report}
Today is {today}.
Before answering you write:
(a) The time left until the outcome to the question is known.
(b) The status quo outcome if nothing changed.
(c) A brief description of a scenario that results in a No outcome.
(d) A brief description of a scenario that results in a Yes outcome.
You write your rationale remembering that good forecasters put extra weight on the status quo outcome since the world changes slowly most of the time.
The last thing you write is your final answer as: "Probability: ZZ%", 0-100
"""
def extract_probability_from_response_as_percentage_not_decimal(
forecast_text: str,
) -> float:
matches = re.findall(r"(\d+)%", forecast_text)
if matches:
# Return the last number found before a '%'
number = int(matches[-1])
number = min(99, max(1, number)) # clamp the number between 1 and 99
return number
else:
raise ValueError(f"Could not extract prediction from response: {forecast_text}")
async def get_binary_gpt_prediction(
question_details: dict, num_runs: int
) -> AggregatePrediction:
today = datetime.datetime.now().strftime("%Y-%m-%d")
title = question_details["title"]
resolution_criteria = question_details["resolution_criteria"]
background = question_details["description"]
fine_print = question_details["fine_print"]
summary_report = run_research(title)
content = BINARY_PROMPT_TEMPLATE.format(
title=title,
today=today,
background=background,
resolution_criteria=resolution_criteria,
fine_print=fine_print,
summary_report=summary_report,
)
async def get_rationale_and_probability(content: str) -> ReasonedPrediction:
rationale = await call_llm(content)
probability = extract_probability_from_response_as_percentage_not_decimal(rationale)
return ReasonedPrediction(
forecast=probability/100,
rationale=rationale
)
sub_predictions = await asyncio.gather(
*[get_rationale_and_probability(content) for _ in range(num_runs)]
)
forecasts = [p.forecast for p in sub_predictions]
forecasts = typeguard.check_type(forecasts, list[float])
median_probability = float(np.median(forecasts))
print(f"Generated {len(sub_predictions)} sub-predictions")
return AggregatePrediction(
forecast=median_probability,
sub_predictions=sub_predictions,
news=summary_report
)
####################### NUMERIC ###############
# @title Numeric prompt & functions
NUMERIC_PROMPT_TEMPLATE = """
You are a professional forecaster interviewing for a job.
Your interview question is:
{title}
Background:
{background}
{resolution_criteria}
{fine_print}
Your research assistant says:
{summary_report}
Today is {today}.
{lower_bound_message}
{upper_bound_message}
Formatting Instructions:
- Please notice the units requested (e.g. whether you represent a number as 1,000,000 or 1m).
- Never use scientific notation.
- Always start with a smaller number (more negative if negative) and then increase from there
Before answering you write:
(a) The time left until the outcome to the question is known.
(b) The outcome if nothing changed.
(c) The outcome if the current trend continued.
(d) The expectations of experts and markets.
(e) A brief description of an unexpected scenario that results in a low outcome.
(f) A brief description of an unexpected scenario that results in a high outcome.
You remind yourself that good forecasters are humble and set wide 90/10 confidence intervals to account for unknown unkowns.
The last thing you write is your final answer as:
"
Percentile 10: XX
Percentile 20: XX
Percentile 40: XX
Percentile 60: XX
Percentile 80: XX
Percentile 90: XX
"
"""
def extract_percentiles_from_response(forecast_text: str) -> dict:
# Helper function that returns a list of tuples with numbers for all lines with Percentile
def extract_percentile_numbers(text) -> dict:
pattern = r"^.*(?:P|p)ercentile.*$"
number_pattern = r"-\s*(?:[^\d\-]*\s*)?(\d+(?:,\d{3})*(?:\.\d+)?)|(\d+(?:,\d{3})*(?:\.\d+)?)"
results = []
for line in text.split("\n"):
if re.match(pattern, line):
numbers = re.findall(number_pattern, line)
numbers_no_commas = [
next(num for num in match if num).replace(",", "")
for match in numbers
]
numbers = [
float(num) if "." in num else int(num)
for num in numbers_no_commas
]
if len(numbers) > 1:
first_number = numbers[0]
last_number = numbers[-1]
# Check if the original line had a negative sign before the last number
if "-" in line.split(":")[-1]:
last_number = -abs(last_number)
results.append((first_number, last_number))
# Convert results to dictionary
percentile_values = {}
for first_num, second_num in results:
key = first_num
percentile_values[key] = second_num
return percentile_values
percentile_values = extract_percentile_numbers(forecast_text)
if len(percentile_values) > 0:
return percentile_values
else:
raise ValueError(f"Could not extract prediction from response: {forecast_text}")
def generate_continuous_cdf(
percentile_values: dict,
question_type: str,
open_upper_bound: bool,
open_lower_bound: bool,
upper_bound: float,
lower_bound: float,
zero_point: float | None,
) -> list[float]:
"""
Returns: list[float]: A list of 201 float values representing the CDF.
"""
percentile_max = max(float(key) for key in percentile_values.keys())
percentile_min = min(float(key) for key in percentile_values.keys())
range_min = lower_bound
range_max = upper_bound
range_size = range_max - range_min
buffer = 1 if range_size > 100 else 0.01 * range_size
# Adjust any values that are exactly at the bounds
for percentile, value in list(percentile_values.items()):
if not open_lower_bound and value <= range_min + buffer:
percentile_values[percentile] = range_min + buffer
if not open_upper_bound and value >= range_max - buffer:
percentile_values[percentile] = range_max - buffer
# Set cdf values outside range
if open_upper_bound:
if range_max > percentile_values[percentile_max]:
percentile_values[int(100 - (0.5 * (100 - percentile_max)))] = range_max
else:
percentile_values[100] = range_max
# Set cdf values outside range
if open_lower_bound:
if range_min < percentile_values[percentile_min]:
percentile_values[int(0.5 * percentile_min)] = range_min
else:
percentile_values[0] = range_min
sorted_percentile_values = dict(sorted(percentile_values.items()))
# Normalize percentile keys
normalized_percentile_values = {}
for key, value in sorted_percentile_values.items():
percentile = float(key) / 100
normalized_percentile_values[percentile] = value
value_percentiles = {
value: key for key, value in normalized_percentile_values.items()
}
# function for log scaled questions
def generate_cdf_locations(range_min, range_max, zero_point):
if zero_point is None:
scale = lambda x: range_min + (range_max - range_min) * x
else:
deriv_ratio = (range_max - zero_point) / (range_min - zero_point)
scale = lambda x: range_min + (range_max - range_min) * (
deriv_ratio**x - 1
) / (deriv_ratio - 1)
xaxis = [scale(x) for x in np.linspace(0, 1, 201)]
return xaxis
cdf_xaxis = generate_cdf_locations(range_min, range_max, zero_point)
def linear_interpolation(x_values, xy_pairs):
# Sort the xy_pairs by x-values
sorted_pairs = sorted(xy_pairs.items())
# Extract sorted x and y values
known_x = [pair[0] for pair in sorted_pairs]
known_y = [pair[1] for pair in sorted_pairs]
# Initialize the result list
y_values = []
for x in x_values:
# Check if x is exactly in the known x values
if x in known_x:
y_values.append(known_y[known_x.index(x)])
else:
# Find the indices of the two nearest known x-values
i = 0
while i < len(known_x) and known_x[i] < x:
i += 1
# If x is outside the range of known x-values, use the nearest endpoint
if i == 0:
y_values.append(known_y[0])
elif i == len(known_x):
y_values.append(known_y[-1])
else:
# Perform linear interpolation
x0, x1 = known_x[i - 1], known_x[i]
y0, y1 = known_y[i - 1], known_y[i]
# Linear interpolation formula
y = y0 + (x - x0) * (y1 - y0) / (x1 - x0)
y_values.append(y)
return y_values
continuous_cdf = linear_interpolation(cdf_xaxis, value_percentiles)
return continuous_cdf
async def get_numeric_gpt_prediction(
question_details: dict, num_runs: int
) -> AggregatePrediction:
today = datetime.datetime.now().strftime("%Y-%m-%d")
title = question_details["title"]
resolution_criteria = question_details["resolution_criteria"]
background = question_details["description"]
fine_print = question_details["fine_print"]
question_type = question_details["type"]
scaling = question_details["scaling"]
open_upper_bound = question_details["open_upper_bound"]
open_lower_bound = question_details["open_lower_bound"]
upper_bound = scaling["range_max"]
lower_bound = scaling["range_min"]
zero_point = scaling["zero_point"]
# Create messages about the bounds that are passed in the LLM prompt
if open_upper_bound:
upper_bound_message = ""
else:
upper_bound_message = f"The outcome can not be higher than {upper_bound}."
if open_lower_bound:
lower_bound_message = ""
else:
lower_bound_message = f"The outcome can not be lower than {lower_bound}."
summary_report = run_research(title)
content = NUMERIC_PROMPT_TEMPLATE.format(
title=title,
today=today,
background=background,
resolution_criteria=resolution_criteria,
fine_print=fine_print,
summary_report=summary_report,
lower_bound_message=lower_bound_message,
upper_bound_message=upper_bound_message,
)
async def ask_llm_to_get_cdf(content: str) -> ReasonedPrediction:
rationale = await call_llm(content)
percentile_values = extract_percentiles_from_response(rationale)
cdf = generate_continuous_cdf(
percentile_values,
question_type,
open_upper_bound,
open_lower_bound,
upper_bound,
lower_bound,
zero_point,
)
return ReasonedPrediction(
forecast=cdf,
rationale=rationale
)
sub_predictions = await asyncio.gather(
*[ask_llm_to_get_cdf(content) for _ in range(num_runs)]
)
all_cdfs = np.array([p.forecast for p in sub_predictions])
median_cdf = np.median(all_cdfs, axis=0).tolist()
return AggregatePrediction(
forecast=median_cdf,
sub_predictions=sub_predictions,
news=summary_report
)
########################## MULTIPLE CHOICE ###############
# @title Multiple Choice prompt & functions
MULTIPLE_CHOICE_PROMPT_TEMPLATE = """
You are a professional forecaster interviewing for a job.
Your interview question is:
{title}
The options are: {options}
Background:
{background}
{resolution_criteria}
{fine_print}
Your research assistant says:
{summary_report}
Today is {today}.
Before answering you write:
(a) The time left until the outcome to the question is known.
(b) The status quo outcome if nothing changed.
(c) A description of an scenario that results in an unexpected outcome.
You write your rationale remembering that (1) good forecasters put extra weight on the status quo outcome since the world changes slowly most of the time, and (2) good forecasters leave some moderate probability on most options to account for unexpected outcomes.
The last thing you write is your final probabilities for the N options in this order {options} as:
Option_A: Probability_A
Option_B: Probability_B
...
Option_N: Probability_N
"""
def extract_option_probabilities_from_response(forecast_text: str, options: list[str]) -> list[float]:
# Helper function that returns a list of tuples with numbers for all lines with Percentile
def extract_option_probabilities(text: str) -> list[float]:
# Number extraction pattern
number_pattern = r"-?\d+(?:,\d{3})*(?:\.\d+)?"
results = []
# Iterate through each line in the text
for line in text.split("\n"):
# Extract all numbers from the line
numbers = re.findall(number_pattern, line)
numbers_no_commas = [num.replace(",", "") for num in numbers]
# Convert strings to float or int
numbers = [
float(num) if "." in num else int(num) for num in numbers_no_commas
]
# Add the tuple of numbers to results
if len(numbers) >= 1:
last_number = numbers[-1]
results.append(last_number)
return results
option_probabilities = extract_option_probabilities(forecast_text)
NUM_OPTIONS = len(options)
if len(option_probabilities) > 0:
# return the last NUM_OPTIONS items
return option_probabilities[-NUM_OPTIONS:] # type: ignore
else:
raise ValueError(f"Could not extract prediction from response: {forecast_text}")
def generate_multiple_choice_forecast(options, option_probabilities) -> dict:
"""
Returns: dict corresponding to the probabilities of each option.
"""
# confirm that there is a probability for each option
if len(options) != len(option_probabilities):
raise ValueError(
f"Number of options ({len(options)}) does not match number of probabilities ({len(option_probabilities)})"
)
# Ensure we are using decimals
total_sum = sum(option_probabilities)
decimal_list = [x / total_sum for x in option_probabilities]
def normalize_list(float_list):
# Step 1: Clamp values
clamped_list = [max(min(x, 0.99), 0.01) for x in float_list]
# Step 2: Calculate the sum of all elements
total_sum = sum(clamped_list)
# Step 3: Normalize the list so that all elements add up to 1
normalized_list = [x / total_sum for x in clamped_list]
# Step 4: Adjust for any small floating-point errors
adjustment = 1.0 - sum(normalized_list)
normalized_list[-1] += adjustment
return normalized_list
normalized_option_probabilities = normalize_list(decimal_list)
probability_yes_per_category = {}
for i in range(len(options)):
probability_yes_per_category[options[i]] = normalized_option_probabilities[i]
return probability_yes_per_category
async def get_multiple_choice_gpt_prediction(
question_details: dict,
num_runs: int,
) -> AggregatePrediction:
today = datetime.datetime.now().strftime("%Y-%m-%d")
title = question_details["title"]
resolution_criteria = question_details["resolution_criteria"]
background = question_details["description"]
fine_print = question_details["fine_print"]
options = question_details["options"]
summary_report = run_research(title)
content = MULTIPLE_CHOICE_PROMPT_TEMPLATE.format(
title=title,
today=today,
background=background,
resolution_criteria=resolution_criteria,
fine_print=fine_print,
summary_report=summary_report,
options=options,
)
async def ask_llm_for_multiple_choice_probabilities(
content: str,
) -> ReasonedPrediction:
rationale = await call_llm(content)
option_probabilities = extract_option_probabilities_from_response(rationale, options)
probability_yes_per_category = generate_multiple_choice_forecast(
options, option_probabilities
)
return ReasonedPrediction(
forecast=probability_yes_per_category,
rationale=rationale
)
sub_predictions = await asyncio.gather(
*[ask_llm_for_multiple_choice_probabilities(content) for _ in range(num_runs)]
)
average_probability_yes_per_category: dict[str, float] = {}
option_forecasts = [prediction.forecast for prediction in sub_predictions]
option_forecasts = typeguard.check_type(option_forecasts, list[dict[str, float]])
for option in options:
probabilities_for_current_option = [
forecast[option] for forecast in option_forecasts
]
average_probability_yes_per_category[option] = sum(
probabilities_for_current_option
) / len(probabilities_for_current_option)
return AggregatePrediction(
forecast=average_probability_yes_per_category,
sub_predictions=sub_predictions,
news=summary_report
)
################### FORECASTING ###################
def forecast_is_already_made(post_details: dict) -> bool:
"""
Check if a forecast has already been made by looking at my_forecasts in the question data.
question.my_forecasts.latest.forecast_values has the following values for each question type:
Binary: [probability for no, probability for yes]
Numeric: [cdf value 1, cdf value 2, ..., cdf value 201]
Multiple Choice: [probability for option 1, probability for option 2, ...]
"""
try:
forecast_values = post_details["question"]["my_forecasts"]["latest"][
"forecast_values"
]
return forecast_values is not None
except Exception:
return False
async def run_prediction_function(question_details: dict, num_runs_per_question: int) -> AggregatePrediction:
question_type = question_details["type"]
if question_type == "binary":
prediction = await get_binary_gpt_prediction(
question_details, num_runs_per_question
)
elif question_type == "numeric":
prediction = await get_numeric_gpt_prediction(
question_details, num_runs_per_question
)
elif question_type == "multiple_choice":
prediction = await get_multiple_choice_gpt_prediction(
question_details, num_runs_per_question
)
else:
raise ValueError(f"Unknown question type: {question_type}")
print(f"----------------------------------------------\n")
print(f"Question: {question_details['title']}")
print(f"Forecast: {prediction.forecast}")
for sub_prediction in prediction.sub_predictions:
print(f"Sub-Prediction: {sub_prediction.forecast}")
print(f"Rationale: {sub_prediction.rationale}")
print(f"News: {prediction.news}")
return prediction
async def forecast_individual_question(
question_id: int,
post_id: int,
submit_prediction: bool,
num_runs_per_question: int,
skip_previously_forecasted_questions: bool,
) -> str:
post_details = get_post_details(post_id)
question_details = post_details["question"]
title = question_details["title"]
question_type = question_details["type"]
summary_of_forecast = ""
summary_of_forecast += f"----------\nQuestion: {title}\n"
summary_of_forecast += f"URL: https://www.metaculus.com/questions/{post_id}/\n"
if question_type == "multiple_choice":
options = question_details["options"]
summary_of_forecast += f"options: {options}\n"
if (
forecast_is_already_made(post_details)
and skip_previously_forecasted_questions == True
):