Skip to content

Commit ac2fa26

Browse files
committed
process guess
1 parent 12782b6 commit ac2fa26

File tree

4 files changed

+191
-26
lines changed

4 files changed

+191
-26
lines changed

configs/base_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ best_match_model: xvlm # Which model to use for bes
4343
gpt3: # GPT-3 configuration
4444
n_votes: 1 # Number of tries to use for GPT-3. Use with temperature > 0
4545
qa_prompt: ./prompts/gpt3/gpt3_qa.txt
46+
guess_prompt: ./prompts/gpt3/gpt3_process_guess.txt
4647
temperature: 0. # Temperature for GPT-3. Almost deterministic if 0
4748
model: text-davinci-003 # See openai.Model.list() for available models
4849

image_patch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,10 @@ def llm_query(query, context=None, long_answer=True, queues=None):
434434
return forward(model_name='gpt3_qa', prompt=[query, context], queues=queues)
435435

436436

437+
def process_guesses(prompt, guess1=None, guess2=None, queues=None):
438+
return forward(model_name='gpt3_guess', prompt=[prompt, guess1, guess2], queues=queues)
439+
440+
437441
def coerce_to_numeric(string, no_string=False):
438442
"""
439443
This function takes a string as input and returns a numeric value after removing any non-numeric characters.

prompts/gpt3/gpt3_process_guess.txt

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
Please answer the following questions using the given guesses.
2+
If a unique answer cannot be determined, choose only one of the possible answers.
3+
Aim to reply in ONE word (at MOST 2).
4+
5+
Question: What kind of flowers are these?
6+
Guess 1: these flowers are purple, so lavender, lilac, iris, and hyacinth
7+
Guess 2: purple flowers
8+
Answer: lilac
9+
10+
Question: What do these people on the bikes normally write and give out?
11+
Guess 1: the people on bikes are police, so Tickets
12+
Guess 2: tickets
13+
Answer: tickets
14+
15+
Question: What kind of cold meet is this?
16+
Guess 1: what kind of meat is this is beef, so roast beef
17+
Guess 2: beef
18+
Answer: beef
19+
20+
Question: Can you guess the place shown in this picture?
21+
Guess 1: the place is tourist attraction, so the Eiffel Tower in Paris, France
22+
Guess 2: big ben
23+
Answer: big ben
24+
25+
Question: When was this type of vehicle with two equal sized wheels invented?
26+
Guess 1: the vehicle is a bicycle, so 19th century
27+
Guess 2: 1819
28+
Answer: 1800s
29+
30+
Question: What is the flavor of the pink topping on this dessert?
31+
Guess 1: the topping is whipped cream, so strawberry, vanilla, chocolate, and raspberry
32+
Guess 2: strawberry
33+
Answer: strawberry
34+
35+
Question: How are these festive lights held in place?
36+
Guess 1: these festive lights are christmas lights, so with hooks clips
37+
Guess 2: string
38+
Answer: string
39+
40+
Question: Who is famous for allegedly doing this in a lightning storm?
41+
Guess 1: what is being done is flying a kite, so Benjamin Franklin
42+
Guess 2: Charles Manson
43+
Answer: Benjamin Franklin
44+
45+
Question: What is the object atop the skier's head used for?
46+
Guess 1: the object atop the skier's head is helmet, so protection from head injuries
47+
Guess 2: sunglasses
48+
Answer: protection
49+
50+
Question: What rank is the man on the right?
51+
Guess 1: who is the man on the right is sailor, so seaman
52+
Guess 2: captain
53+
Answer: captain
54+
55+
Question: Chemically what kind of water is in the picture?
56+
Guess 1: the water in the picture is waves, so salt water
57+
Guess 2: salt water
58+
Answer: salt
59+
60+
Question: Is the material tweed or canvas?
61+
Guess 1: the material is fabric, so fabric
62+
Guess 2: canvas
63+
Answer: canvas
64+
65+
Question: Which type of meat are in the photo?
66+
Guess 1: the meat in the photo is sausage, so pork
67+
Guess 2: hot dogs
68+
Answer: hotdogs
69+
70+
Question: What sort of predator might there be in an area like this?
71+
Guess 1: this area is mountains, so predators like wolves fox
72+
Guess 2: shark
73+
Answer: shark
74+
75+
Question: Can you name a sport this person could be a part of?
76+
Guess 1: this person is a racer, so racing such as auto
77+
Guess 2: motorcycle racing
78+
Answer: racing
79+
80+
Question: Who makes the yellow top worn in this photograph?
81+
Guess 1: the top is red, so brand is unknown
82+
Guess 2: Burton
83+
Answer: Burton
84+
85+
Question: Is the athlete right or left handed?
86+
Guess 1: what is the athlete doing is playing baseball, so unclear
87+
Guess 2: right handed
88+
Answer: right handed
89+
90+
Question: Is this food high or low on fat?
91+
Guess 1: what kind of food is this is sandwich, so depends on ingredients
92+
Guess 2: high
93+
Answer: high
94+
95+
Question: What wood are those cabinets made of?
96+
Guess 1: what kind of cabinets are these is kitchen cabinets, so typically wood such as oak
97+
Guess 2: maple
98+
99+
Question: Which objects shown are typically associated with small children?
100+
Guess 1: what objects are shown are stuffed animals, so toys
101+
Guess 2: teddy bears
102+
Answer: teddy bears
103+
104+
Question: What small appliance is that stuffed animal inside?
105+
Guess 1: the stuffed animal is a teddy bear, so vacuum cleaner
106+
Guess 2: microwave
107+
Answer: microwave
108+
109+
Question: What is this made with?
110+
Guess 1: what is this is muffin, so flour sugar eggs
111+
Guess 2: oats
112+
Answer: flour
113+
114+
Question: What is the position name of the player squatting down?
115+
Guess 1: who is squatting down is the batter, so hitter
116+
Guess 2: catcher
117+
118+
Question: {}
119+
Guess 1: {}
120+
Guess 2: {}
121+
Answer (remember, only 1-2 words):

vision_models.py

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def list_processes(cls):
7878
"""
7979
return [cls.name]
8080

81+
8182
# ------------------------------ Specific models ---------------------------- #
8283

8384

@@ -381,7 +382,7 @@ def forward(self, image: torch.Tensor, text: List[str], return_labels: bool = Fa
381382
text = [text]
382383
text_original = text
383384
text = ['a photo of a ' + t for t in text]
384-
inputs = self.processor(text=text, images=image, return_tensors="pt") # padding="longest",
385+
inputs = self.processor(text=text, images=image, return_tensors="pt") # padding="longest",
385386
inputs = {k: v.to(self.dev) for k, v in inputs.items()}
386387
outputs = self.model(**inputs)
387388

@@ -512,7 +513,7 @@ def compute_prediction(self, original_image, original_caption, custom_entity=Non
512513
tic = timeit.time.perf_counter()
513514

514515
# compute predictions
515-
with HiddenPrints(): # Hide some deprecated notices
516+
with HiddenPrints(): # Hide some deprecated notices
516517
predictions = self.model(image_list, captions=[original_caption],
517518
positive_map=positive_map_label_to_token)
518519
predictions = [o.to(self.cpu_device) for o in predictions]
@@ -779,6 +780,8 @@ def __init__(self, gpu_number=0):
779780
super().__init__(gpu_number=gpu_number)
780781
with open(config.gpt3.qa_prompt) as f:
781782
self.qa_prompt = f.read().strip()
783+
with open(config.gpt3.guess_prompt) as f:
784+
self.guess_prompt = f.read().strip()
782785
self.temperature = config.gpt3.temperature
783786
self.n_votes = config.gpt3.n_votes
784787
self.model = config.gpt3.model
@@ -802,7 +805,40 @@ def most_frequent(answers):
802805
answer_counts = Counter(answers)
803806
return answer_counts.most_common(1)[0][0]
804807

805-
def get_qa(self, prompts, prompt_base: str=None) -> list[str]:
808+
def process_guesses(self, prompts):
809+
prompt_base = self.guess_prompt
810+
prompts_total = []
811+
for p in prompts:
812+
question, guess1, _ = p
813+
if len(guess1) == 1:
814+
# In case only one option is given as a guess
815+
guess1 = [guess1[0], guess1[0]]
816+
prompts_total.append(prompt_base.format(question, guess1[0], guess1[1]))
817+
response = self.process_guesses_fn(prompts_total)
818+
if self.n_votes > 1:
819+
response_ = []
820+
for i in range(len(prompts)):
821+
if self.model == 'chatgpt':
822+
resp_i = [r['message']['content'] for r in
823+
response['choices'][i * self.n_votes:(i + 1) * self.n_votes]]
824+
else:
825+
resp_i = [r['text'] for r in response['choices'][i * self.n_votes:(i + 1) * self.n_votes]]
826+
response_.append(self.most_frequent(resp_i).lstrip())
827+
response = response_
828+
else:
829+
if self.model == 'chatgpt':
830+
response = [r['message']['content'].lstrip() for r in response['choices']]
831+
else:
832+
response = [r['text'].lstrip() for r in response['choices']]
833+
return response
834+
835+
def process_guesses_fn(self, prompt):
836+
# The code is the same as get_qa_fn, but we separate in case we want to modify it later
837+
response = self.query_gpt3(prompt, model=self.model, max_tokens=5, logprobs=1, stream=False,
838+
stop=["\n", "<|endoftext|>"])
839+
return response
840+
841+
def get_qa(self, prompts, prompt_base: str = None) -> list[str]:
806842
if prompt_base is None:
807843
prompt_base = self.qa_prompt
808844
prompts_total = []
@@ -814,8 +850,8 @@ def get_qa(self, prompts, prompt_base: str=None) -> list[str]:
814850
response_ = []
815851
for i in range(len(prompts)):
816852
if self.model == 'chatgpt':
817-
resp_i = [r['message']['content']
818-
for r in response['choices'][i * self.n_votes:(i + 1) * self.n_votes]]
853+
resp_i = [r['message']['content'] for r in
854+
response['choices'][i * self.n_votes:(i + 1) * self.n_votes]]
819855
else:
820856
resp_i = [r['text'] for r in response['choices'][i * self.n_votes:(i + 1) * self.n_votes]]
821857
response_.append(self.most_frequent(resp_i))
@@ -891,6 +927,8 @@ def forward(self, prompt, process_name):
891927
if len(prompt) > 0:
892928
if process_name == 'gpt3_qa':
893929
response = self.get_qa(prompt)
930+
elif process_name == 'gpt3_guess':
931+
response = self.process_guesses(prompt)
894932
else: # 'gpt3_general', general prompt, has to be given all of it
895933
response = self.get_general(prompt)
896934
else:
@@ -911,7 +949,7 @@ def forward(self, prompt, process_name):
911949

912950
@classmethod
913951
def list_processes(cls):
914-
return ['gpt3_' + n for n in ['qa', 'general']]
952+
return ['gpt3_' + n for n in ['qa', 'guess', 'general']]
915953

916954

917955
# @cache.cache
@@ -924,24 +962,26 @@ def codex_helper(extended_prompt):
924962
if not isinstance(extended_prompt, list):
925963
extended_prompt = [extended_prompt]
926964
responses = [openai.ChatCompletion.create(
927-
model=config.codex.model,
928-
messages=[
929-
# {"role": "system", "content": "You are a helpful assistant."},
930-
{"role": "system", "content": "Only answer with a function starting def execute_command."},
931-
{"role": "user", "content": prompt}
932-
],
933-
temperature=config.codex.temperature,
934-
max_tokens=config.codex.max_tokens,
935-
top_p = 1.,
936-
frequency_penalty=0,
937-
presence_penalty=0,
938-
# best_of=config.codex.best_of,
939-
stop=["\n\n"],
940-
)
941-
for prompt in extended_prompt]
942-
resp = [r['choices'][0]['message']['content'].replace("execute_command(image)", "execute_command(image, my_fig, time_wait_between_lines, syntax)") for r in responses]
943-
# if len(resp) == 1:
944-
# resp = resp[0]
965+
model=config.codex.model,
966+
messages=[
967+
# {"role": "system", "content": "You are a helpful assistant."},
968+
{"role": "system", "content": "Only answer with a function starting def execute_command."},
969+
{"role": "user", "content": prompt}
970+
],
971+
temperature=config.codex.temperature,
972+
max_tokens=config.codex.max_tokens,
973+
top_p=1.,
974+
frequency_penalty=0,
975+
presence_penalty=0,
976+
# best_of=config.codex.best_of,
977+
stop=["\n\n"],
978+
)
979+
for prompt in extended_prompt]
980+
resp = [r['choices'][0]['message']['content'].replace("execute_command(image)",
981+
"execute_command(image, my_fig, time_wait_between_lines, syntax)")
982+
for r in responses]
983+
# if len(resp) == 1:
984+
# resp = resp[0]
945985
else:
946986
warnings.warn('OpenAI Codex is deprecated. Please use GPT-4 or GPT-3.5-turbo.')
947987
response = openai.Completion.create(
@@ -1161,7 +1201,7 @@ def caption(self, image, prompt=None):
11611201
generated_text = [cap.strip() for cap in
11621202
self.processor.batch_decode(generated_ids, skip_special_tokens=True)]
11631203
return generated_text
1164-
1204+
11651205
def pre_question(self, question):
11661206
# from LAVIS blip_processors
11671207
question = re.sub(
@@ -1223,7 +1263,6 @@ class SaliencyModel(BaseModel):
12231263

12241264
def __init__(self, gpu_number=0,
12251265
path_checkpoint=f'{config.path_pretrained_models}/saliency_inspyrenet_plus_ultra'):
1226-
12271266
from base_models.inspyrenet.saliency_transforms import get_transform
12281267
from base_models.inspyrenet.InSPyReNet import InSPyReNet
12291268
from base_models.inspyrenet.backbones.SwinTransformer import SwinB

0 commit comments

Comments
 (0)