Skip to content

Commit 51bd791

Browse files
author
Jin Qiao
committed
updated code
1 parent c93d7d3 commit 51bd791

File tree

3 files changed

+20
-16
lines changed

3 files changed

+20
-16
lines changed

config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
API_KEY = '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'

evaluate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,11 @@ def get_answer(answer, task):
5555

5656

5757
if __name__ == '__main__':
58+
# all geneturing and genehop disease gene location tasks are automatically evaluated
5859
qas = json.load(open('data/geneturing.json'))
5960
qas['Disease gene location'] = json.load(open('data/genehop.json'))['Disease gene location']
60-
61+
62+
# result dir path to evaluate
6163
folder = sys.argv[1]
6264

6365
for task in glob.glob(os.path.join(folder, '*')):

main.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
__author__ = 'qiao'
22

33
'''
4-
teach LLMs to use NCBI API
4+
GeneGPT: teach LLMs to use NCBI API
55
'''
66

77
import json
88
import openai
9-
openai.api_key = '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'
9+
import config
10+
openai.api_key = config.API_KEY
11+
1012
import os
11-
import urllib.parse
12-
import urllib.request
13-
import requests
1413
import re
15-
import time
1614
import sys
15+
import time
16+
import urllib.request
1717

1818
def call_api(url):
1919
time.sleep(1)
@@ -107,21 +107,23 @@ def get_prompt_header(mask):
107107

108108

109109
if __name__ == '__main__':
110-
invalid_tasks = set(['Gene ontology', 'Gene name extraction', 'TF regulation'])
110+
# rough number of chars for truncating
111+
# codex accepts 8k tokens ~ 18k chars
111112
cut_length = 18000
112113

114+
# str_mask is a string of six 0/1 marking whether a in-context learning component is used
115+
# six digits correspond to Dc. 1-2, Dm. 1-4
113116
str_mask = sys.argv[1]
114117
mask = [bool(int(x)) for x in str_mask]
115118
prompt = get_prompt_header(mask)
116119

120+
# results are saved in the dir of six digits
117121
if not os.path.isdir(str_mask):
118122
os.mkdir(str_mask)
119123

120124
# initialize
121125
prev_call = time.time()
122-
#qas = json.load(open('data/newbing_qa.json'))
123-
#qas = json.load(open('data/multihop_qa_into.json'))
124-
qas = json.load(open('data/blast_qa.json'))
126+
qas = json.load(open('data/geneturing.json'))
125127

126128
for task, info in qas.items():
127129
if os.path.exists(os.path.join(str_mask, f'{task}.json')):
@@ -160,6 +162,9 @@ def get_prompt_header(mask):
160162
}
161163

162164
delta = time.time() - prev_call
165+
166+
# codex has a rate limite of 20 requests / min
167+
# it's a workaround
163168
if delta < 3.1:
164169
time.sleep(3.1 - delta)
165170

@@ -176,13 +181,9 @@ def get_prompt_header(mask):
176181

177182
prompts.append([q_prompt, text])
178183

179-
#url_regex = r'\[(https?*)\]'
180184
url_regex = r'\[(https?://[^\[\]]+)\]'
181185
matches = re.findall(url_regex, text)
182186
if matches:
183-
#if text[-1:] == ']' and '[' in text and 'http' in text:
184-
#left = text.rindex('[')
185-
#url = text[left + 1: -1]
186187
url = matches[0]
187188

188189
# wait till the BLAST is done on NCBI server
@@ -203,7 +204,7 @@ def get_prompt_header(mask):
203204
output.append([question, answer, text, prompts])
204205
break
205206

206-
# prevent too many calls
207+
# prevent dead loops
207208
if num_calls >= 10:
208209
output.append([question, answer, 'numError', prompts])
209210
break

0 commit comments

Comments
 (0)