@@ -78,6 +78,7 @@ def list_processes(cls):
78
78
"""
79
79
return [cls .name ]
80
80
81
+
81
82
# ------------------------------ Specific models ---------------------------- #
82
83
83
84
@@ -381,7 +382,7 @@ def forward(self, image: torch.Tensor, text: List[str], return_labels: bool = Fa
381
382
text = [text ]
382
383
text_original = text
383
384
text = ['a photo of a ' + t for t in text ]
384
- inputs = self .processor (text = text , images = image , return_tensors = "pt" ) # padding="longest",
385
+ inputs = self .processor (text = text , images = image , return_tensors = "pt" ) # padding="longest",
385
386
inputs = {k : v .to (self .dev ) for k , v in inputs .items ()}
386
387
outputs = self .model (** inputs )
387
388
@@ -512,7 +513,7 @@ def compute_prediction(self, original_image, original_caption, custom_entity=Non
512
513
tic = timeit .time .perf_counter ()
513
514
514
515
# compute predictions
515
- with HiddenPrints (): # Hide some deprecated notices
516
+ with HiddenPrints (): # Hide some deprecated notices
516
517
predictions = self .model (image_list , captions = [original_caption ],
517
518
positive_map = positive_map_label_to_token )
518
519
predictions = [o .to (self .cpu_device ) for o in predictions ]
@@ -779,6 +780,8 @@ def __init__(self, gpu_number=0):
779
780
super ().__init__ (gpu_number = gpu_number )
780
781
with open (config .gpt3 .qa_prompt ) as f :
781
782
self .qa_prompt = f .read ().strip ()
783
+ with open (config .gpt3 .guess_prompt ) as f :
784
+ self .guess_prompt = f .read ().strip ()
782
785
self .temperature = config .gpt3 .temperature
783
786
self .n_votes = config .gpt3 .n_votes
784
787
self .model = config .gpt3 .model
@@ -802,7 +805,40 @@ def most_frequent(answers):
802
805
answer_counts = Counter (answers )
803
806
return answer_counts .most_common (1 )[0 ][0 ]
804
807
805
- def get_qa (self , prompts , prompt_base : str = None ) -> list [str ]:
808
+ def process_guesses (self , prompts ):
809
+ prompt_base = self .guess_prompt
810
+ prompts_total = []
811
+ for p in prompts :
812
+ question , guess1 , _ = p
813
+ if len (guess1 ) == 1 :
814
+ # In case only one option is given as a guess
815
+ guess1 = [guess1 [0 ], guess1 [0 ]]
816
+ prompts_total .append (prompt_base .format (question , guess1 [0 ], guess1 [1 ]))
817
+ response = self .process_guesses_fn (prompts_total )
818
+ if self .n_votes > 1 :
819
+ response_ = []
820
+ for i in range (len (prompts )):
821
+ if self .model == 'chatgpt' :
822
+ resp_i = [r ['message' ]['content' ] for r in
823
+ response ['choices' ][i * self .n_votes :(i + 1 ) * self .n_votes ]]
824
+ else :
825
+ resp_i = [r ['text' ] for r in response ['choices' ][i * self .n_votes :(i + 1 ) * self .n_votes ]]
826
+ response_ .append (self .most_frequent (resp_i ).lstrip ())
827
+ response = response_
828
+ else :
829
+ if self .model == 'chatgpt' :
830
+ response = [r ['message' ]['content' ].lstrip () for r in response ['choices' ]]
831
+ else :
832
+ response = [r ['text' ].lstrip () for r in response ['choices' ]]
833
+ return response
834
+
835
+ def process_guesses_fn (self , prompt ):
836
+ # The code is the same as get_qa_fn, but we separate in case we want to modify it later
837
+ response = self .query_gpt3 (prompt , model = self .model , max_tokens = 5 , logprobs = 1 , stream = False ,
838
+ stop = ["\n " , "<|endoftext|>" ])
839
+ return response
840
+
841
+ def get_qa (self , prompts , prompt_base : str = None ) -> list [str ]:
806
842
if prompt_base is None :
807
843
prompt_base = self .qa_prompt
808
844
prompts_total = []
@@ -814,8 +850,8 @@ def get_qa(self, prompts, prompt_base: str=None) -> list[str]:
814
850
response_ = []
815
851
for i in range (len (prompts )):
816
852
if self .model == 'chatgpt' :
817
- resp_i = [r ['message' ]['content' ]
818
- for r in response ['choices' ][i * self .n_votes :(i + 1 ) * self .n_votes ]]
853
+ resp_i = [r ['message' ]['content' ] for r in
854
+ response ['choices' ][i * self .n_votes :(i + 1 ) * self .n_votes ]]
819
855
else :
820
856
resp_i = [r ['text' ] for r in response ['choices' ][i * self .n_votes :(i + 1 ) * self .n_votes ]]
821
857
response_ .append (self .most_frequent (resp_i ))
@@ -891,6 +927,8 @@ def forward(self, prompt, process_name):
891
927
if len (prompt ) > 0 :
892
928
if process_name == 'gpt3_qa' :
893
929
response = self .get_qa (prompt )
930
+ elif process_name == 'gpt3_guess' :
931
+ response = self .process_guesses (prompt )
894
932
else : # 'gpt3_general', general prompt, has to be given all of it
895
933
response = self .get_general (prompt )
896
934
else :
@@ -911,7 +949,7 @@ def forward(self, prompt, process_name):
911
949
912
950
@classmethod
913
951
def list_processes (cls ):
914
- return ['gpt3_' + n for n in ['qa' , 'general' ]]
952
+ return ['gpt3_' + n for n in ['qa' , 'guess' , ' general' ]]
915
953
916
954
917
955
# @cache.cache
@@ -924,24 +962,26 @@ def codex_helper(extended_prompt):
924
962
if not isinstance (extended_prompt , list ):
925
963
extended_prompt = [extended_prompt ]
926
964
responses = [openai .ChatCompletion .create (
927
- model = config .codex .model ,
928
- messages = [
929
- # {"role": "system", "content": "You are a helpful assistant."},
930
- {"role" : "system" , "content" : "Only answer with a function starting def execute_command." },
931
- {"role" : "user" , "content" : prompt }
932
- ],
933
- temperature = config .codex .temperature ,
934
- max_tokens = config .codex .max_tokens ,
935
- top_p = 1. ,
936
- frequency_penalty = 0 ,
937
- presence_penalty = 0 ,
938
- # best_of=config.codex.best_of,
939
- stop = ["\n \n " ],
940
- )
941
- for prompt in extended_prompt ]
942
- resp = [r ['choices' ][0 ]['message' ]['content' ].replace ("execute_command(image)" , "execute_command(image, my_fig, time_wait_between_lines, syntax)" ) for r in responses ]
943
- # if len(resp) == 1:
944
- # resp = resp[0]
965
+ model = config .codex .model ,
966
+ messages = [
967
+ # {"role": "system", "content": "You are a helpful assistant."},
968
+ {"role" : "system" , "content" : "Only answer with a function starting def execute_command." },
969
+ {"role" : "user" , "content" : prompt }
970
+ ],
971
+ temperature = config .codex .temperature ,
972
+ max_tokens = config .codex .max_tokens ,
973
+ top_p = 1. ,
974
+ frequency_penalty = 0 ,
975
+ presence_penalty = 0 ,
976
+ # best_of=config.codex.best_of,
977
+ stop = ["\n \n " ],
978
+ )
979
+ for prompt in extended_prompt ]
980
+ resp = [r ['choices' ][0 ]['message' ]['content' ].replace ("execute_command(image)" ,
981
+ "execute_command(image, my_fig, time_wait_between_lines, syntax)" )
982
+ for r in responses ]
983
+ # if len(resp) == 1:
984
+ # resp = resp[0]
945
985
else :
946
986
warnings .warn ('OpenAI Codex is deprecated. Please use GPT-4 or GPT-3.5-turbo.' )
947
987
response = openai .Completion .create (
@@ -1161,7 +1201,7 @@ def caption(self, image, prompt=None):
1161
1201
generated_text = [cap .strip () for cap in
1162
1202
self .processor .batch_decode (generated_ids , skip_special_tokens = True )]
1163
1203
return generated_text
1164
-
1204
+
1165
1205
def pre_question (self , question ):
1166
1206
# from LAVIS blip_processors
1167
1207
question = re .sub (
@@ -1223,7 +1263,6 @@ class SaliencyModel(BaseModel):
1223
1263
1224
1264
def __init__ (self , gpu_number = 0 ,
1225
1265
path_checkpoint = f'{ config .path_pretrained_models } /saliency_inspyrenet_plus_ultra' ):
1226
-
1227
1266
from base_models .inspyrenet .saliency_transforms import get_transform
1228
1267
from base_models .inspyrenet .InSPyReNet import InSPyReNet
1229
1268
from base_models .inspyrenet .backbones .SwinTransformer import SwinB
0 commit comments