Skip to content

Commit

Permalink
Bring back PR#52 (#198)
Browse files Browse the repository at this point in the history
* Draft of localization eval, using RefCOCO as first target

* Other RefCOCO instances for grounding/REC

---------

Co-authored-by: Hunter Heidenreich <[email protected]>
  • Loading branch information
kcz358 and hunterheiden authored Aug 20, 2024
1 parent 6195d44 commit c2f73de
Show file tree
Hide file tree
Showing 16 changed files with 810 additions and 13 deletions.
22 changes: 9 additions & 13 deletions lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,13 +874,15 @@ def concat_tar_parts(tar_parts, output_tar):
download_config=download_config,
**dataset_kwargs if dataset_kwargs is not None else {},
)
self.dataset_no_image = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS,
download_config=download_config,
**dataset_kwargs if dataset_kwargs is not None else {},
)
if self.config.process_docs is not None:
for split in self.dataset:
if split in [
self.config.training_split, self.config.validation_split, self.config.test_split, self.config.fewshot_split
]:
self.dataset[split] = self.config.process_docs(self.dataset[split])

# copy dataset, remove image features
self.dataset_no_image = self.dataset.copy()
for doc_name in self.dataset_no_image:
remove_cols = []
features = self.dataset_no_image[doc_name].features
Expand Down Expand Up @@ -913,20 +915,14 @@ def has_test_docs(self) -> bool:

def training_docs(self) -> datasets.Dataset:
if self.has_training_docs():
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.training_split])
return self.dataset[self.config.training_split]

def validation_docs(self) -> datasets.Dataset:
if self.has_validation_docs():
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.validation_split])
return self.dataset[self.config.validation_split]

def test_docs(self) -> datasets.Dataset:
if self.has_test_docs():
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.test_split])
return self.dataset[self.config.test_split]

def fewshot_docs(self):
Expand Down
34 changes: 34 additions & 0 deletions lmms_eval/tasks/refcoco+/_default_template_bbox_rec_yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
dataset_path: lmms-lab/RefCOCOPlus
output_type: generate_until
process_docs: !function utils_rec.refcoco_bbox_rec_preprocess_dataset
doc_to_visual: !function utils_rec.refcoco_bbox_rec_doc_to_visual
doc_to_text: !function utils_rec.refcoco_bbox_rec_doc_to_text
doc_to_target: "bbox"
generation_kwargs:
until:
- "ASSISTANT:"
process_results: !function utils_rec.refcoco_bbox_rec_process_result
metric_list:
- metric: refcoco_IoU
aggregation : !function utils_rec.refcoco_bbox_rec_iou
higher_is_better : true
- metric: [email protected]
aggregation : !function utils_rec.refcoco_bbox_rec_acc01
higher_is_better : true
- metric: [email protected]
aggregation : !function utils_rec.refcoco_bbox_rec_acc03
higher_is_better : true
- metric: [email protected]
aggregation : !function utils_rec.refcoco_bbox_rec_acc05
higher_is_better : true
- metric: [email protected]
aggregation : !function utils_rec.refcoco_bbox_rec_acc07
higher_is_better : true
- metric: [email protected]
aggregation : !function utils_rec.refcoco_bbox_rec_acc09
higher_is_better : true
- metric: refcoco_Center_ACC
aggregation : !function utils_rec.refcoco_bbox_rec_center_acc
higher_is_better : true
metadata:
version: '0.0'
4 changes: 4 additions & 0 deletions lmms_eval/tasks/refcoco+/refcoco+_bbox_rec_testA.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
group: refcoco+_bbox_rec
task: refcoco+_bbox_rec_testA
include: _default_template_bbox_rec_yaml
test_split: testA
4 changes: 4 additions & 0 deletions lmms_eval/tasks/refcoco+/refcoco+_bbox_rec_testB.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
group: refcoco+_bbox_rec
task: refcoco+_bbox_rec_testB
include: _default_template_bbox_rec_yaml
test_split: testB
4 changes: 4 additions & 0 deletions lmms_eval/tasks/refcoco+/refcoco+_bbox_rec_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
group: refcoco+_bbox_rec
task: refcoco+_bbox_rec_val
include: _default_template_bbox_rec_yaml
test_split: val
221 changes: 221 additions & 0 deletions lmms_eval/tasks/refcoco+/utils_rec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import re
import logging
from datasets import Dataset

eval_logger = logging.getLogger("lmms-eval")

COCO_REC_METRICS = ["IoU", "[email protected]", "[email protected]", "[email protected]", "[email protected]", "[email protected]", "Center_ACC"]


def refcoco_bbox_rec_preprocess_dataset(dataset: Dataset):
# PIL image stored in dataset['image']
# add `image_width` and `image_height` to the dataset
dataset = dataset.map(lambda x: {"image_width": x["image"].width, "image_height": x["image"].height})

# Original bbox format (top x, top y, width, height)
# Convert to (top-left x, top-left y, bottom-right x, bottom-right y)
# Normalize the bounding box coordinates to be between 0 and 1
# using the image width and height
dataset = dataset.map(
lambda x: {"bbox": [x["bbox"][0] / x["image_width"],
x["bbox"][1] / x["image_height"],
(x["bbox"][0] + x["bbox"][2]) / x["image_width"],
(x["bbox"][1] + x["bbox"][3]) / x["image_height"]]}
)

# currently, the dataset has `answer` as a list of strings
# each answer should be its own row
# we will explode the dataset to have one row per answer
# duplicate the other columns
def explode_answers(example):
answers = example.pop('answer')
return [{'answer': answer, **example} for answer in answers]

# Apply the function to each element, collecting the results
exploded_rows = []
for example in dataset:
exploded_rows.extend(explode_answers(example))

# Create a new dataset from the exploded rows
new_dataset = Dataset.from_list(exploded_rows)
print(f"Exploded dataset from {len(dataset)} to {len(new_dataset)} rows")

return new_dataset


def refcoco_bbox_rec_doc_to_visual(doc):
# Image is presented as is
image = doc["image"].convert("RGB")
return [image.convert("RGB")]


def refcoco_bbox_rec_doc_to_text(doc):
assert isinstance(doc['answer'], str), "Answer must be a string"
return "Bounding box coordinates are specified in the format (top-left x, top-left y, bottom-right x, bottom-right y). All values are floating point numbers bounded between 0 and 1. Please provide the bounding box coordinate of the region this sentence describes: " + doc['answer']


def parse_float_sequence_within(input_str):
"""
Extract the first sequence of four floating-point numbers within square brackets from a string.
Args:
input_str (str): A string that may contain a sequence of four floats within square brackets.
Returns:
list: A list of four floats if the pattern is found, or a list of four zeros if the pattern is not found.
"""
# Define the regex pattern to find the first instance of four floats within square brackets
pattern = r'\[\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?)\s*\]'

# Use re.search to find the first match of the pattern in the input string
match = re.search(pattern, input_str)

# If a match is found, convert the captured groups into a list of floats
if match:
return [float(match.group(i)) for i in range(1, 5)]

# If the input does not contain the pattern, return the null float sequence
return [0, 0, 0, 0]


def refcoco_bbox_rec_process_result(doc, result):
"""
Args:
doc: a instance of the eval dataset
results: [pred]
Returns:
a dictionary with key: metric name, value: metric value
"""
pred = result[0] if len(result) > 0 else ""
pred = parse_float_sequence_within(pred)
ann_id = doc["question_id"]
data_dict = {"answer": doc["answer"], "pred": pred, "ann_id": ann_id, 'bbox': doc['bbox']}
return {f"refcoco_{metric}": data_dict for metric in COCO_REC_METRICS}


def compute_iou(box1, box2):
"""
Compute the Intersection over Union (IoU) of two bounding boxes.
Parameters:
- box1 (list of float): Bounding box [x_min, y_min, x_max, y_max].
- box2 (list of float): Bounding box [x_min, y_min, x_max, y_max].
Returns:
- float: IoU of box1 and box2.
"""
# Determine the coordinates of the intersection rectangle
x_left = max(box1[0], box2[0])
y_top = max(box1[1], box2[1])
x_right = min(box1[2], box2[2])
y_bottom = min(box1[3], box2[3])

# Compute the area of intersection
intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)

# Compute the area of both bounding boxes
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])

# Compute the area of the union
union_area = box1_area + box2_area - intersection_area

# Compute the Intersection over Union
iou = intersection_area / union_area

return iou


def compute_accuracy(box1, box2, threshold=0.5):
"""
Compute the accuracy of two bounding boxes based on a specified threshold.
Parameters:
- box1 (list of float): Bounding box [x_min, y_min, x_max, y_max].
- box2 (list of float): Bounding box [x_min, y_min, x_max, y_max].
- threshold (float): Threshold for the IoU to consider the prediction correct.
Returns:
- float: Accuracy of the prediction based on the IoU threshold.
"""
iou = compute_iou(box1, box2)
return iou >= threshold


def compute_center_accuracy(box1, box2):
"""
Compute if the center point of box 2 is within box 1.
Parameters:
- box1 (list of float): Bounding box [x_min, y_min, x_max, y_max].
- box2 (list of float): Bounding box [x_min, y_min, x_max, y_max].
Returns:
- bool: True if the center point of box 2 is within box 1, False otherwise.
"""
# Compute the center point of box 2
center_x = (box2[0] + box2[2]) / 2
center_y = (box2[1] + box2[3]) / 2

# Check if the center point is within box 1
return box1[0] <= center_x <= box1[2] and box1[1] <= center_y <= box1[3]


def refcoco_bbox_rec_aggregation_result(results, metric):
"""
Aggregate the results of the RefCOCO evaluation task using the specified metric.
Args:
- results (list of dict): List of result dictionaries.
- metric (str): Metric to use for aggregation.
Returns:
- dict: Dictionary containing the aggregated results for the specified metric.
"""
scorers = {
'IoU': compute_iou,
'[email protected]': lambda x, y: compute_accuracy(x, y, 0.1),
'[email protected]': lambda x, y: compute_accuracy(x, y, 0.3),
'[email protected]': lambda x, y: compute_accuracy(x, y, 0.5),
'[email protected]': lambda x, y: compute_accuracy(x, y, 0.7),
'[email protected]': lambda x, y: compute_accuracy(x, y, 0.9),
'Center_ACC': compute_center_accuracy
}
results_dict = {metric: []}
for result in results:
# Extract the ground truth and predicted bounding boxes
gt_bbox = result['bbox']
pred_bbox = result['pred']
# Compute the specified metric between the ground truth and predicted bounding boxes
score = scorers[metric](gt_bbox, pred_bbox)
results_dict[metric].append(score)
results_dict[metric] = sum(results_dict[metric]) / len(results_dict[metric])
print(f"Aggregated {metric} score: {results_dict[metric]}")
return results_dict[metric]


def refcoco_bbox_rec_iou(results):
return refcoco_bbox_rec_aggregation_result(results, "IoU")


def refcoco_bbox_rec_acc01(results):
return refcoco_bbox_rec_aggregation_result(results, "[email protected]")

def refcoco_bbox_rec_acc03(results):
return refcoco_bbox_rec_aggregation_result(results, "[email protected]")


def refcoco_bbox_rec_acc05(results):
return refcoco_bbox_rec_aggregation_result(results, "[email protected]")


def refcoco_bbox_rec_acc07(results):
return refcoco_bbox_rec_aggregation_result(results, "[email protected]")


def refcoco_bbox_rec_acc09(results):
return refcoco_bbox_rec_aggregation_result(results, "[email protected]")


def refcoco_bbox_rec_center_acc(results):
return refcoco_bbox_rec_aggregation_result(results, "Center_ACC")
34 changes: 34 additions & 0 deletions lmms_eval/tasks/refcoco/_default_template_bbox_rec_yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
dataset_path: lmms-lab/RefCOCO
output_type: generate_until
process_docs: !function utils_rec.refcoco_bbox_rec_preprocess_dataset
doc_to_visual: !function utils_rec.refcoco_bbox_rec_doc_to_visual
doc_to_text: !function utils_rec.refcoco_bbox_rec_doc_to_text
doc_to_target: "bbox"
generation_kwargs:
until:
- "ASSISTANT:"
process_results: !function utils_rec.refcoco_bbox_rec_process_result
metric_list:
- metric: refcoco_IoU
aggregation : !function utils_rec.refcoco_bbox_rec_iou
higher_is_better : true
- metric: [email protected]
aggregation : !function utils_rec.refcoco_bbox_rec_acc01
higher_is_better : true
- metric: [email protected]
aggregation : !function utils_rec.refcoco_bbox_rec_acc03
higher_is_better : true
- metric: [email protected]
aggregation : !function utils_rec.refcoco_bbox_rec_acc05
higher_is_better : true
- metric: [email protected]
aggregation : !function utils_rec.refcoco_bbox_rec_acc07
higher_is_better : true
- metric: [email protected]
aggregation : !function utils_rec.refcoco_bbox_rec_acc09
higher_is_better : true
- metric: refcoco_Center_ACC
aggregation : !function utils_rec.refcoco_bbox_rec_center_acc
higher_is_better : true
metadata:
version: '0.0'
4 changes: 4 additions & 0 deletions lmms_eval/tasks/refcoco/refcoco_bbox_rec_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
group: refcoco_bbox_rec
task: refcoco_bbox_rec_test
test_split: test
include: _default_template_bbox_rec_yaml
4 changes: 4 additions & 0 deletions lmms_eval/tasks/refcoco/refcoco_bbox_rec_testA.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
group: refcoco_bbox_rec
task: refcoco_bbox_rec_testA
test_split: testA
include: _default_template_bbox_rec_yaml
4 changes: 4 additions & 0 deletions lmms_eval/tasks/refcoco/refcoco_bbox_rec_testB.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
group: refcoco_bbox_rec
task: refcoco_bbox_rec_testB
test_split: testB
include: _default_template_bbox_rec_yaml
4 changes: 4 additions & 0 deletions lmms_eval/tasks/refcoco/refcoco_bbox_rec_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
group: refcoco_bbox_rec
task: refcoco_bbox_rec_val
test_split: val
include: _default_template_bbox_rec_yaml
Loading

0 comments on commit c2f73de

Please sign in to comment.