1
1
__author__ = 'qiao'
2
2
3
3
'''
4
- teach LLMs to use NCBI API
4
+ GeneGPT: teach LLMs to use NCBI API
5
5
'''
6
6
7
7
import json
8
8
import openai
9
- openai .api_key = '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'
9
+ import config
10
+ openai .api_key = config .API_KEY
11
+
10
12
import os
11
- import urllib .parse
12
- import urllib .request
13
- import requests
14
13
import re
15
- import time
16
14
import sys
15
+ import time
16
+ import urllib .request
17
17
18
18
def call_api (url ):
19
19
time .sleep (1 )
@@ -107,21 +107,23 @@ def get_prompt_header(mask):
107
107
108
108
109
109
if __name__ == '__main__' :
110
- invalid_tasks = set (['Gene ontology' , 'Gene name extraction' , 'TF regulation' ])
110
+ # rough number of chars for truncating
111
+ # codex accepts 8k tokens ~ 18k chars
111
112
cut_length = 18000
112
113
114
+ # str_mask is a string of six 0/1 marking whether a in-context learning component is used
115
+ # six digits correspond to Dc. 1-2, Dm. 1-4
113
116
str_mask = sys .argv [1 ]
114
117
mask = [bool (int (x )) for x in str_mask ]
115
118
prompt = get_prompt_header (mask )
116
119
120
+ # results are saved in the dir of six digits
117
121
if not os .path .isdir (str_mask ):
118
122
os .mkdir (str_mask )
119
123
120
124
# initialize
121
125
prev_call = time .time ()
122
- #qas = json.load(open('data/newbing_qa.json'))
123
- #qas = json.load(open('data/multihop_qa_into.json'))
124
- qas = json .load (open ('data/blast_qa.json' ))
126
+ qas = json .load (open ('data/geneturing.json' ))
125
127
126
128
for task , info in qas .items ():
127
129
if os .path .exists (os .path .join (str_mask , f'{ task } .json' )):
@@ -160,6 +162,9 @@ def get_prompt_header(mask):
160
162
}
161
163
162
164
delta = time .time () - prev_call
165
+
166
+ # codex has a rate limite of 20 requests / min
167
+ # it's a workaround
163
168
if delta < 3.1 :
164
169
time .sleep (3.1 - delta )
165
170
@@ -176,13 +181,9 @@ def get_prompt_header(mask):
176
181
177
182
prompts .append ([q_prompt , text ])
178
183
179
- #url_regex = r'\[(https?*)\]'
180
184
url_regex = r'\[(https?://[^\[\]]+)\]'
181
185
matches = re .findall (url_regex , text )
182
186
if matches :
183
- #if text[-1:] == ']' and '[' in text and 'http' in text:
184
- #left = text.rindex('[')
185
- #url = text[left + 1: -1]
186
187
url = matches [0 ]
187
188
188
189
# wait till the BLAST is done on NCBI server
@@ -203,7 +204,7 @@ def get_prompt_header(mask):
203
204
output .append ([question , answer , text , prompts ])
204
205
break
205
206
206
- # prevent too many calls
207
+ # prevent dead loops
207
208
if num_calls >= 10 :
208
209
output .append ([question , answer , 'numError' , prompts ])
209
210
break
0 commit comments