Skip to content

Commit f86d781

Browse files
committed
Got this thing disassembled, I think
1 parent 1a8b034 commit f86d781

File tree

2 files changed

+29
-36
lines changed

2 files changed

+29
-36
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ PatterNet/
44
results/
55
models/
66
metrics.csv
7-
*.csv
7+
*.csv
8+
.vscode/launch.json

evaluate.py

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212

1313
from utils import *
1414

15+
MODEL_NAME = "clip"
16+
DATASET_PATH = "/home/ryan/rscir/PatterNet"
17+
METHODS = ["Weighted Similarities Norm"]
18+
LAMBDAS = [0.5]
19+
1520

1621
# Function to read features from a pickle file
1722
def read_dataset_features(pickle_dir):
@@ -140,51 +145,56 @@ def calculate_rankings(method, query_features, text_features, database_features,
140145
args = parser.parse_args()
141146

142147
# Convert lambdas argument to a list of floats
143-
lambdas = list(map(float, args.lambdas.split(',')))
148+
# lambdas = list(map(float, args.lambdas.split(',')))
144149
# For lambda ablation, uncomment the line:
145150
# lambdas = [x/10 for x in range(0, 11, 1)]
151+
lambdas = LAMBDAS
146152

147153
# Load model and tokenizer
148-
model, _, tokenizer = load_model(args.model_name, args.model_type)
154+
model, _, tokenizer = load_model(MODEL_NAME, args.model_type)
149155

150156
# Read features from the specified dataset
151157
if args.dataset == 'patternnet':
152158
print('Reading features...')
153-
features, labels, paths = read_dataset_features(os.path.join(args.dataset_path, 'features', f'patternnet_{args.model_name}.pkl'))
159+
features, labels, paths = read_dataset_features(os.path.join(DATASET_PATH, 'features', f'patternnet_{MODEL_NAME}.pkl'))
154160
print('Features are loaded!')
155161
at = [5, 10, 15, 20]
156162

157163
# Initialize metrics storage
158-
metrics_final = create_metrics_final(at, args.methods)
164+
metrics_final = create_metrics_final(at, METHODS)
159165

160166
if args.dataset == 'patternnet':
161167
for lam in lambdas:
162168
for attribute in args.attributes:
163-
metrics_final = create_metrics_final(at, args.methods)
169+
metrics_final = create_metrics_final(at, METHODS)
164170
start = time.time()
165171

166172
# Read query data from CSV file
167-
query_filenames, attributes, attribute_values = read_csv(os.path.join(args.dataset_path, 'PatternCom', f'{attribute}.csv'))
173+
query_filenames, attributes, attribute_values = read_csv(os.path.join(DATASET_PATH, 'PatternCom', f'{attribute}.csv'))
168174
query_labels = [re.split(r'\d', path)[0] for path in query_filenames] # or something like labels[relative_indices], should give the same
169175

170176
# Fix query attribute labels
171177
query_attributelabels = [x + query_labels[ii] for ii, x in enumerate(attributes)]
172178
query_attributelabels = fix_query_attributelabels(attribute, query_attributelabels)
173179

174-
# Pair attribute labels with attribute values
180+
# Pair attribute labels with attribute values | 0000 = ('colortenniscourt', 'blue')...
175181
paired = list(zip(query_attributelabels, attribute_values))
176182

177183
# Create prompts based on paired data
178-
prompts = create_prompts(paired)
179-
relative_indices = find_relative_indices(query_filenames, paths)
180-
filename_to_index_map = {filename: i for i, filename in enumerate(query_filenames)}
184+
prompts = create_prompts(paired) # 0000 = ['brown', 'green', 'gray', 'red']
185+
relative_indices = find_relative_indices(query_filenames, paths) # 0000 = 1106
186+
filename_to_index_map = {filename: i for i, filename in enumerate(query_filenames)} # 'tenniscourt723.jpg' = 0
187+
index_to_filename_map = filename_to_index_map = {i: filename for i, filename in enumerate(query_filenames)} # 0 = 'tenniscourt723.jpg'
181188

182189
# Cache text features
183190
text_feature_cache = {}
184191
for i, idx in enumerate(tqdm(relative_indices, desc="Processing queries")):
185192
query_feature = features[idx]
193+
186194
query_class = query_labels[i] # Get the original class of the query image
195+
187196
for prompt in tqdm(prompts[i], desc="Processing prompts", leave=False):
197+
188198
# Check if the text feature for this prompt is already computed
189199
if prompt not in text_feature_cache:
190200
# If not, compute and cache it
@@ -195,29 +205,11 @@ def calculate_rankings(method, query_features, text_features, database_features,
195205
else:
196206
# If already computed, retrieve from cache
197207
text_feature = text_feature_cache[prompt]
198-
for method in args.methods:
208+
209+
210+
for method in METHODS:
211+
print(f"Querying image: {index_to_filename_map[i]} | Querying text: {prompt}\n")
199212
rankings = calculate_rankings(method, query_feature, text_feature, features, lam)
200-
temp_metrics = metrics_calc(rankings, prompt, paths, filename_to_index_map, attribute_values, at, query_class, query_labels)
201-
202-
# Accumulate metrics for each method
203-
for k in at:
204-
metrics_final[method][f"R@{k}"].append(temp_metrics[f"R@{k}"])
205-
metrics_final[method][f"P@{k}"].append(temp_metrics[f"P@{k}"])
206-
metrics_final[method]["AP"].append(temp_metrics["AP"])
207-
208-
# Calculate average metrics
209-
for method in metrics_final:
210-
for metric in metrics_final[method]:
211-
metrics_final[method][metric] = round(sum(metrics_final[method][metric]) / len(metrics_final[method][metric]) if metrics_final[method][metric] else 0, 2)
212-
213-
print(metrics_final)
214-
end = time.time()
215-
timer(start, end)
216-
217-
# Save metrics to CSV file
218-
print('Writing results to CSV file...')
219-
results_dir = 'results'
220-
if not os.path.exists(results_dir):
221-
os.makedirs(results_dir)
222-
results_file_path = os.path.join(results_dir, f'{args.dataset}_metrics_{args.model_name}_lambda{lam}_{attribute}.csv')
223-
dict_to_csv(metrics_final, results_file_path)
213+
best_match = os.path.basename(paths[rankings[0].item()])
214+
print(f"Best match: {best_match}")
215+
print()

0 commit comments

Comments
 (0)