Skip to content

Commit 353ce3b

Browse files
authored
Add files via upload
1 parent 442ce36 commit 353ce3b

File tree

7 files changed

+1494
-2
lines changed

7 files changed

+1494
-2
lines changed

README.md

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,46 @@
1-
# GrIPS
2-
This repository will contain the code for our paper: **GrIPS: Gradient-free, Edit-based Instruction Search for Prompting Large Language Models**. Complete code will be out shortly.
1+
# GrIPS: Gradient-free, Edit-based Instruction Search for Prompting Large Language Models
2+
* Authors: [Archiki Prasad](https://archiki.github.io), [Peter Hase](https://peterbhase.github.io/), [Xiang Zhou](https://owenzx.github.io/), and [Mohit Bansal](https://www.cs.unc.edu/~mbansal/) (UNC Chapel Hill)
3+
* [Paper](https://archiki.github.io/GrIPS.html)
4+
* **Note:** This is preliminary version of our code. The complete code to run all experiments in the paper will be added shortly.
5+
6+
<img src="./assets/Main Pipeline.png" alt="teaser image" width="7500"/>
7+
8+
## Dependencies
9+
This code is written using PyTorch and [HuggingFace's Transformer repo](https://github.com/huggingface/pytorch-transformers). Running GrIPS with GPT-2 models requires access to GPUs. The search is quite light-weight (no model training involved) and therefore one GPU should suffice. On the other hand, running GrIPS with InstructGPT or GPT-3 models requires an OpenAI API key. Please add your key to the `openai_key.txt` file.
10+
11+
## Installation
12+
The simplest way to run our code is to start with a fresh environment.
13+
```
14+
conda create -n GrIPS python=3.9
15+
source activate GrIPS
16+
pip install -r requirements.txt
17+
```
18+
19+
## Running Search
20+
* `run_search.py` contains the implementation of GrIPS.
21+
* By default, we use the InstructGPT Babbage model. To use a different GPT-3 model from the API change `model_name` in `nat_inst_gpt3.py`.
22+
* To switch to GPT-2 models, import `nat_inst_gpt2.py` and use an apporpriate model.
23+
* `expanded_encodeinstructions.py` is a data loader file that interfaces with the task datasets provided in Natural Instructions.
24+
* Here is an example code to run GrIPS (with default InstructGPT babbage)
25+
```
26+
python run_search.py --mode "Instruction Only" --task-idx 0 --train-seed 0 \
27+
--num-compose 1 --num-candidates 5 --num-iters 10 --patience 2 --write-preds \
28+
--meta-dir "logs/" --meta-name "babbage_all_edits_l_1_m_5_n_10@seed_0.txt"
29+
```
30+
31+
## Acknowledgments
32+
We thank the authors and contributors of [Callibrate Before Use](https://github.com/tonyzhaozh/few-shot-learning), and [Natural-Instructions](https://github.com/allenai/natural-instructions) for their public code release.
33+
34+
## Reference
35+
Please cite our paper if you use our dataset in your works:
36+
```bibtex
37+
38+
@article{Prasad2022GrIPS,
39+
title = {GrIPS: Gradient-free, Edit-based Instruction Search for Prompting Large Language Models},
40+
author = {Archiki Prasad and Peter Hase and Xiang Zhou and Mohit Bansal},
41+
year = {2022},
42+
archivePrefix = {arXiv},
43+
primaryClass = {cs.CL},
44+
eprint = {2202.xxxx}
45+
}
46+
```

expanded_encodeinstruction.py

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
import json
2+
import os
3+
import random
4+
import math
5+
import pdb
6+
from transformers import GPT2Tokenizer
7+
def lowercase_list(lst):
8+
return [l.lower() for l in lst]
9+
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
10+
def one_token(label):
11+
return tokenizer.decode(tokenizer.encode(label, return_tensors='pt')[0][0])
12+
def encodeinstruction (task, instruction_structure =['Definition','Prompt','Things to Avoid','Emphasis & Caution', 'Negative Examples Full Explanations', 'Positive Examples Full Explanations'], number_of_examples=0, number_of_instances= 100, null_word=None, seed=0, modified={}):
13+
random.seed(0)
14+
with open('data/ExpandedNaturalInstructions/'+task) as json_file:
15+
data = json.load(json_file)
16+
labels = list(set([data["Instances"][i]["output"][0] for i in range(len(data["Instances"])) ]))
17+
labels.sort()
18+
19+
assert len(labels) < 25, "Check {} is a classification task.".format(task)
20+
instances_per_label = number_of_instances // len(labels)
21+
remainder = number_of_instances % len(labels)
22+
instance_pools = {label:{'indices':[]} for label in labels}
23+
for i, inst in enumerate(data["Instances"]):
24+
label = inst['output'][0]
25+
instance_pools[label]['indices'].append(i)
26+
remaining = 0
27+
test_pools = {}
28+
29+
for l, label in enumerate(labels):
30+
31+
if len(instance_pools[label]['indices']) >= 4 + instances_per_label: #leave out some examples for Definition + Examples (hard-coded)
32+
num = instances_per_label
33+
if l < remainder: num += 1
34+
35+
test_pools[label] = random.sample(instance_pools[label]['indices'], num)
36+
instance_pools[label]['indices'] = [i for i in instance_pools[label]['indices'] if i not in test_pools[label]]
37+
38+
else:
39+
40+
num = len(instance_pools[label]['indices']) - 4
41+
remaining += instances_per_label - num
42+
43+
test_pools[label] = random.sample(instance_pools[label]['indices'], num)
44+
instance_pools[label]['indices'] = [i for i in instance_pools[label]['indices'] if i not in test_pools[label]]
45+
46+
47+
all_remaining_indices = []
48+
remaining = number_of_instances - sum([len(t) for t in test_pools.values()])
49+
for label in labels: all_remaining_indices.extend(instance_pools[label]['indices'])
50+
remaining_test = random.sample(all_remaining_indices, remaining)
51+
52+
for t in remaining_test:
53+
label = data['Instances'][t]['output'][0]
54+
test_pools[label].append(t)
55+
instance_pools[label]['indices'].remove(t)
56+
57+
indexlist = []
58+
for label in labels: indexlist.extend(test_pools[label])
59+
assert len(indexlist) == number_of_instances, pdb.set_trace()
60+
61+
random.seed(seed)
62+
if number_of_examples == -1: total_num_examples = 1
63+
else: total_num_examples = number_of_examples * len(labels)
64+
pos_examples = {label:[] for label in labels}
65+
for eg in data["Positive Examples"]:
66+
label = eg['output']
67+
try: pos_examples[label].append(eg)
68+
except: pdb.set_trace()
69+
for label in labels:
70+
for id in instance_pools[label]['indices']:
71+
inst = data["Instances"][id]
72+
inst['output'] = inst['output'][0]
73+
pos_examples[label].append(inst)
74+
75+
chosen_examples = []
76+
if number_of_examples > 0 :
77+
for label in labels: chosen_examples.extend(random.sample(pos_examples[label], number_of_examples))
78+
elif number_of_examples == -1:
79+
label = random.sample(labels, 1)
80+
chosen_examples.extend(random.sample(pos_examples[label], number_of_examples))
81+
assert len(chosen_examples) == total_num_examples
82+
random.shuffle(chosen_examples)
83+
84+
generic_instruction=''
85+
for i in instruction_structure:
86+
if i!='Positive Examples Full Only' and i!='Positive Examples Full Explanations' and i!='Negative Examples Full Explanations':
87+
if data[i]!='-':
88+
if i in modified.keys():
89+
data[i] = modified[i]
90+
data[i] = data[i].replace('\n' + 'Things to avoid: -', '')
91+
data[i] = data[i].replace('\n' + 'Emphasis & Caution: -', '')
92+
if generic_instruction=='':
93+
generic_instruction=generic_instruction+i+': '+data[i].strip()
94+
else:
95+
generic_instruction=generic_instruction+"\n"+i+': '+data[i].strip()
96+
elif i=='Positive Examples Full Only' :
97+
for j in range(total_num_examples):
98+
if 'examples' in modified.keys():
99+
if generic_instruction!='':
100+
generic_instruction=generic_instruction+"\n"+'input: '+modified['examples'][j]['input'] + "\n"+ 'output: '+ one_token(modified['examples'][j]['output'])
101+
else:
102+
generic_instruction=generic_instruction+'input: '+modified['examples']['input'] + "\n"+ 'output: '+ one_token(modified['examples'][j]['output'])
103+
104+
else:
105+
106+
if generic_instruction!='':
107+
generic_instruction=generic_instruction+"\n"+'input: '+chosen_examples[j]['input'] + "\n"+ 'output: '+ one_token(chosen_examples[j]['output'])
108+
else:
109+
generic_instruction=generic_instruction+'input: '+chosen_examples[j]['input'] + "\n"+ 'output: '+ one_token(chosen_examples[j]['output'])
110+
111+
112+
elif i=='Positive Examples Full Explanations' : #This mode of Natural Instructions not supported
113+
assert False
114+
115+
elif i=='Negative Examples Full Explanations' : #This mode of Natural Instructions not supported
116+
assert False
117+
118+
119+
promptlist=[]
120+
answerlist=[]
121+
122+
for i in range(number_of_instances):
123+
if null_word is None:
124+
if 'input' in modified.keys():
125+
if generic_instruction!= '': prompt=generic_instruction+"\n"+'input: '+data['Instances'][indexlist[i]]['input']+" " + modified['input'] + "\n"+"output:"
126+
else: prompt='input: '+data['Instances'][indexlist[i]]['input']+"\n"+"output:"
127+
else:
128+
if generic_instruction!= '': prompt=generic_instruction+"\n"+'input: '+data['Instances'][indexlist[i]]['input']+"\n"+"output:"
129+
else: prompt='input: '+data['Instances'][indexlist[i]]['input']+"\n"+"output:"
130+
else:
131+
if generic_instruction!='': prompt=generic_instruction+"\n"+'input: '+null_word+"\n"+"output:"
132+
else: prompt='input: '+null_word+"\n"+"output:"
133+
if 'Completion' in labels[0]:
134+
prompt = prompt + ' Completion'
135+
promptlist.append(prompt)
136+
answer = data['Instances'][indexlist[i]]['output'][0].strip(".").replace('Completion ', '')
137+
answer = one_token(answer)
138+
answerlist.append(answer)
139+
140+
return promptlist, answerlist, indexlist
141+
142+
143+
def training_encodeinstruction (task, instruction_structure =['Definition','Prompt','Things to Avoid','Emphasis & Caution', 'Negative Examples Full Explanations', 'Positive Examples Full Explanations'], number_of_examples=0, number_of_instances= 100, null_word=None, seed=0, modified={}):
144+
145+
random.seed(0)
146+
with open('data/ExpandedNaturalInstructions/'+task) as json_file:
147+
data = json.load(json_file)
148+
labels = list(set([data["Instances"][i]["output"][0] for i in range(len(data["Instances"])) ]))
149+
labels.sort()
150+
assert len(labels) < 25, "Check {} is a classification task.".format(task)
151+
instances_per_label = number_of_instances // len(labels)
152+
remainder = number_of_instances % len(labels)
153+
instance_pools = {label:{'indices':[]} for label in labels}
154+
for i, inst in enumerate(data["Instances"]):
155+
label = inst['output'][0]
156+
instance_pools[label]['indices'].append(i)
157+
remaining = 0
158+
test_pools = {}
159+
160+
for l, label in enumerate(labels):
161+
162+
if len(instance_pools[label]['indices']) >= 4 + instances_per_label: #see comment in function above
163+
num = instances_per_label
164+
if l < remainder: num += 1
165+
166+
test_pools[label] = random.sample(instance_pools[label]['indices'], num)
167+
instance_pools[label]['indices'] = [i for i in instance_pools[label]['indices'] if i not in test_pools[label]]
168+
169+
170+
else:
171+
172+
num = len(instance_pools[label]['indices']) - 4
173+
remaining += instances_per_label - num
174+
175+
test_pools[label] = random.sample(instance_pools[label]['indices'], num)
176+
instance_pools[label]['indices'] = [i for i in instance_pools[label]['indices'] if i not in test_pools[label]]
177+
178+
179+
all_remaining_indices = []
180+
remaining = number_of_instances - sum([len(t) for t in test_pools.values()])
181+
for label in labels: all_remaining_indices.extend(instance_pools[label]['indices'])
182+
remaining_test = random.sample(all_remaining_indices, remaining)
183+
184+
for t in remaining_test:
185+
label = data['Instances'][t]['output'][0]
186+
test_pools[label].append(t)
187+
instance_pools[label]['indices'].remove(t)
188+
189+
indexlist = []
190+
for label in labels: indexlist.extend(test_pools[label])
191+
assert len(indexlist) == number_of_instances, pdb.set_trace()
192+
193+
random.seed(seed)
194+
if number_of_examples == -1: total_num_examples = 1
195+
else: total_num_examples = number_of_examples * len(labels)
196+
pos_examples = {label:[] for label in labels}
197+
for eg in data["Positive Examples"]:
198+
label = eg['output']
199+
pos_examples[label].append(eg)
200+
for label in labels:
201+
for id in instance_pools[label]['indices']:
202+
inst = data["Instances"][id]
203+
inst['output'] = inst['output'][0]
204+
pos_examples[label].append(inst)
205+
206+
chosen_examples = []
207+
if number_of_examples > 0 :
208+
for label in labels: chosen_examples.extend(random.sample(pos_examples[label], number_of_examples))
209+
elif number_of_examples == -1:
210+
label = random.sample(labels, 1)
211+
chosen_examples.extend(random.sample(pos_examples[label], number_of_examples))
212+
assert len(chosen_examples) == total_num_examples
213+
random.shuffle(chosen_examples)
214+
215+
train_indexlist = list(range(len(data['Instances'])))
216+
train_indexlist = [i for i in train_indexlist if i not in indexlist and data['Instances'][i] not in chosen_examples]
217+
218+
dev_len = round(0.1*len(train_indexlist))
219+
dev_indexlist = random.sample(train_indexlist, dev_len)
220+
train_indexlist = [i for i in train_indexlist if i not in dev_indexlist]
221+
222+
generic_instruction=''
223+
for i in instruction_structure:
224+
if i!='Positive Examples Full Only' and i!='Positive Examples Full Explanations' and i!='Negative Examples Full Explanations':
225+
if data[i]!='-':
226+
if i in modified.keys():
227+
data[i] = modified[i]
228+
data[i] = data[i].replace('\n' + 'Things to avoid: -', '')
229+
data[i] = data[i].replace('\n' + 'Emphasis & Caution: -', '')
230+
# pdb.set_trace()
231+
if generic_instruction=='':
232+
generic_instruction=generic_instruction+i+': '+data[i].strip()
233+
else:
234+
generic_instruction=generic_instruction+"\n"+i+': '+data[i].strip()
235+
elif i=='Positive Examples Full Only' :
236+
for j in range(total_num_examples):
237+
if generic_instruction!='':
238+
generic_instruction=generic_instruction+"\n"+'input: '+chosen_examples[j]['input'] + "\n"+ 'output: '+ one_token(chosen_examples[j]['output'])
239+
else:
240+
generic_instruction=generic_instruction+'input: '+chosen_examples[j]['input'] + "\n"+ 'output: '+one_token(chosen_examples[j]['output'])
241+
242+
243+
elif i=='Positive Examples Full Explanations' : #This mode of Natural Instructions not supported
244+
assert False
245+
246+
elif i=='Negative Examples Full Explanations' : #This mode of Natural Instructions not supported
247+
assert False
248+
249+
promptlist=[]
250+
answerlist=[]
251+
252+
for i in range(number_of_instances):
253+
if null_word is None:
254+
if generic_instruction!= '': prompt=generic_instruction+"\n"+'input: '+data['Instances'][indexlist[i]]['input']+"\n"+"output:"
255+
else: prompt='input: '+data['Instances'][indexlist[i]]['input']+"\n"+"output:"
256+
else:
257+
if generic_instruction!='': prompt=generic_instruction+"\n"+'input: '+null_word+"\n"+"output:"
258+
else: prompt='input: '+null_word+"\n"+"output:"
259+
if 'Completion' in labels[0]:
260+
prompt = prompt + ' Completion'
261+
promptlist.append(prompt)
262+
answer = data['Instances'][indexlist[i]]['output'][0].strip(".").replace('Completion ', '')
263+
answer = one_token(answer)
264+
answerlist.append(answer)
265+
266+
train_promptlist=[]
267+
train_answerlist=[]
268+
269+
for i in range(len(train_indexlist)):
270+
if null_word is None:
271+
if generic_instruction!= '': prompt=generic_instruction+"\n"+'input: '+data['Instances'][train_indexlist[i]]['input']+"\n"+"output:"
272+
else: prompt='input: '+data['Instances'][train_indexlist[i]]['input']+"\n"+"output:"
273+
else:
274+
if generic_instruction!='': prompt=generic_instruction+"\n"+'input: '+null_word+"\n"+"output:"
275+
else: prompt='input: '+null_word+"\n"+"output:"
276+
if 'Completion' in labels[0]:
277+
prompt = prompt + ' Completion'
278+
train_promptlist.append(prompt)
279+
train_answer = data['Instances'][train_indexlist[i]]['output'].strip(".").replace('Completion ', '')
280+
train_answer = one_token(train_answer)
281+
train_answerlist.append(train_answer)
282+
283+
dev_promptlist=[]
284+
dev_answerlist=[]
285+
286+
for i in range(len(dev_indexlist)):
287+
if null_word is None:
288+
if generic_instruction!= '': prompt=generic_instruction+"\n"+'input: '+data['Instances'][dev_indexlist[i]]['input']+"\n"+"output:"
289+
else: prompt='input: '+data['Instances'][dev_indexlist[i]]['input']+"\n"+"output:"
290+
else:
291+
if generic_instruction!='': prompt=generic_instruction+"\n"+'input: '+null_word+"\n"+"output:"
292+
else: prompt='input: '+null_word+"\n"+"output:"
293+
if 'Completion' in labels[0]:
294+
prompt = prompt + ' Completion'
295+
dev_promptlist.append(prompt)
296+
dev_answer = data['Instances'][dev_indexlist[i]]['output'].strip(".").replace('Completion ', '')
297+
dev_answer = one_token(dev_answer)
298+
dev_answerlist.append(dev_answer)
299+
return promptlist, answerlist, indexlist, train_promptlist, train_answerlist, train_indexlist, dev_promptlist, dev_answerlist, dev_indexlist

0 commit comments

Comments
 (0)