12
12
13
13
from utils import *
14
14
15
+ MODEL_NAME = "clip"
16
+ DATASET_PATH = "/home/ryan/rscir/PatterNet"
17
+ METHODS = ["Weighted Similarities Norm" ]
18
+ LAMBDAS = [0.5 ]
19
+
15
20
16
21
# Function to read features from a pickle file
17
22
def read_dataset_features (pickle_dir ):
@@ -140,51 +145,56 @@ def calculate_rankings(method, query_features, text_features, database_features,
140
145
args = parser .parse_args ()
141
146
142
147
# Convert lambdas argument to a list of floats
143
- lambdas = list (map (float , args .lambdas .split (',' )))
148
+ # lambdas = list(map(float, args.lambdas.split(',')))
144
149
# For lambda ablation, uncomment the line:
145
150
# lambdas = [x/10 for x in range(0, 11, 1)]
151
+ lambdas = LAMBDAS
146
152
147
153
# Load model and tokenizer
148
- model , _ , tokenizer = load_model (args . model_name , args .model_type )
154
+ model , _ , tokenizer = load_model (MODEL_NAME , args .model_type )
149
155
150
156
# Read features from the specified dataset
151
157
if args .dataset == 'patternnet' :
152
158
print ('Reading features...' )
153
- features , labels , paths = read_dataset_features (os .path .join (args . dataset_path , 'features' , f'patternnet_{ args . model_name } .pkl' ))
159
+ features , labels , paths = read_dataset_features (os .path .join (DATASET_PATH , 'features' , f'patternnet_{ MODEL_NAME } .pkl' ))
154
160
print ('Features are loaded!' )
155
161
at = [5 , 10 , 15 , 20 ]
156
162
157
163
# Initialize metrics storage
158
- metrics_final = create_metrics_final (at , args . methods )
164
+ metrics_final = create_metrics_final (at , METHODS )
159
165
160
166
if args .dataset == 'patternnet' :
161
167
for lam in lambdas :
162
168
for attribute in args .attributes :
163
- metrics_final = create_metrics_final (at , args . methods )
169
+ metrics_final = create_metrics_final (at , METHODS )
164
170
start = time .time ()
165
171
166
172
# Read query data from CSV file
167
- query_filenames , attributes , attribute_values = read_csv (os .path .join (args . dataset_path , 'PatternCom' , f'{ attribute } .csv' ))
173
+ query_filenames , attributes , attribute_values = read_csv (os .path .join (DATASET_PATH , 'PatternCom' , f'{ attribute } .csv' ))
168
174
query_labels = [re .split (r'\d' , path )[0 ] for path in query_filenames ] # or something like labels[relative_indices], should give the same
169
175
170
176
# Fix query attribute labels
171
177
query_attributelabels = [x + query_labels [ii ] for ii , x in enumerate (attributes )]
172
178
query_attributelabels = fix_query_attributelabels (attribute , query_attributelabels )
173
179
174
- # Pair attribute labels with attribute values
180
+ # Pair attribute labels with attribute values | 0000 = ('colortenniscourt', 'blue')...
175
181
paired = list (zip (query_attributelabels , attribute_values ))
176
182
177
183
# Create prompts based on paired data
178
- prompts = create_prompts (paired )
179
- relative_indices = find_relative_indices (query_filenames , paths )
180
- filename_to_index_map = {filename : i for i , filename in enumerate (query_filenames )}
184
+ prompts = create_prompts (paired ) # 0000 = ['brown', 'green', 'gray', 'red']
185
+ relative_indices = find_relative_indices (query_filenames , paths ) # 0000 = 1106
186
+ filename_to_index_map = {filename : i for i , filename in enumerate (query_filenames )} # 'tenniscourt723.jpg' = 0
187
+ index_to_filename_map = filename_to_index_map = {i : filename for i , filename in enumerate (query_filenames )} # 0 = 'tenniscourt723.jpg'
181
188
182
189
# Cache text features
183
190
text_feature_cache = {}
184
191
for i , idx in enumerate (tqdm (relative_indices , desc = "Processing queries" )):
185
192
query_feature = features [idx ]
193
+
186
194
query_class = query_labels [i ] # Get the original class of the query image
195
+
187
196
for prompt in tqdm (prompts [i ], desc = "Processing prompts" , leave = False ):
197
+
188
198
# Check if the text feature for this prompt is already computed
189
199
if prompt not in text_feature_cache :
190
200
# If not, compute and cache it
@@ -195,29 +205,11 @@ def calculate_rankings(method, query_features, text_features, database_features,
195
205
else :
196
206
# If already computed, retrieve from cache
197
207
text_feature = text_feature_cache [prompt ]
198
- for method in args .methods :
208
+
209
+
210
+ for method in METHODS :
211
+ print (f"Querying image: { index_to_filename_map [i ]} | Querying text: { prompt } \n " )
199
212
rankings = calculate_rankings (method , query_feature , text_feature , features , lam )
200
- temp_metrics = metrics_calc (rankings , prompt , paths , filename_to_index_map , attribute_values , at , query_class , query_labels )
201
-
202
- # Accumulate metrics for each method
203
- for k in at :
204
- metrics_final [method ][f"R@{ k } " ].append (temp_metrics [f"R@{ k } " ])
205
- metrics_final [method ][f"P@{ k } " ].append (temp_metrics [f"P@{ k } " ])
206
- metrics_final [method ]["AP" ].append (temp_metrics ["AP" ])
207
-
208
- # Calculate average metrics
209
- for method in metrics_final :
210
- for metric in metrics_final [method ]:
211
- metrics_final [method ][metric ] = round (sum (metrics_final [method ][metric ]) / len (metrics_final [method ][metric ]) if metrics_final [method ][metric ] else 0 , 2 )
212
-
213
- print (metrics_final )
214
- end = time .time ()
215
- timer (start , end )
216
-
217
- # Save metrics to CSV file
218
- print ('Writing results to CSV file...' )
219
- results_dir = 'results'
220
- if not os .path .exists (results_dir ):
221
- os .makedirs (results_dir )
222
- results_file_path = os .path .join (results_dir , f'{ args .dataset } _metrics_{ args .model_name } _lambda{ lam } _{ attribute } .csv' )
223
- dict_to_csv (metrics_final , results_file_path )
213
+ best_match = os .path .basename (paths [rankings [0 ].item ()])
214
+ print (f"Best match: { best_match } " )
215
+ print ()
0 commit comments