Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Viame integration #3

Merged
merged 5 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion train_kwcoco_demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,28 @@ LOG_BATCH_VIZ_TO_DISK=1 python -m yolo.lazy \
"image_size=[224,224]"


### TODO: show how to run inference
### show how to run inference

BUNDLE_DPATH=$HOME/demo-yolo-kwcoco-train
TEST_FPATH=$BUNDLE_DPATH/vidshapes_rgb_test/data.kwcoco.json
# Grab a checkpoint
CKPT_FPATH=$(python -c "if 1:
import pathlib
ckpt_dpath = pathlib.Path('$BUNDLE_DPATH') / 'training/train/kwcoco-demo/checkpoints'
checkpoints = sorted(ckpt_dpath.glob('*'))
print(checkpoints[-1])
")
echo "CKPT_FPATH = $CKPT_FPATH"

export DISABLE_RICH_HANDLER=1
export CUDA_VISIBLE_DEVICES="1,"
python yolo/lazy.py \
task.data.source="$TEST_FPATH" \
task=inference \
dataset=kwcoco-demo \
use_wandb=False \
out_path=kwcoco-demo-inference \
name=kwcoco-inference-test \
cpu_num=8 \
weight="\"$CKPT_FPATH\"" \
accelerator=auto
11 changes: 8 additions & 3 deletions yolo/tools/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,13 @@ def process_preloaded_coco(self):
for image_id in coco_dset.images():
if self.stop_event.is_set():
break
classes = coco_dset.object_categories() # todo: cache?
coco_img = coco_dset.coco_image(image_id)
file_path = coco_img.primary_image_filepath()
metadata = coco_img.img
metadata = {
'img': coco_img.img,
'classes': classes,
}
self.process_image(file_path, metadata)

def load_image_folder(self, folder):
Expand Down Expand Up @@ -402,10 +406,11 @@ def process_frame(self, frame, metadata=None):
frame, _, rev_tensor = self.transform(frame, torch.zeros(0, 5))
frame = frame[None]
rev_tensor = rev_tensor[None]
item = (frame, rev_tensor, origin_frame, metadata)
if not self.is_stream:
self.queue.put((frame, rev_tensor, origin_frame, metadata))
self.queue.put(item)
else:
self.current_frame = (frame, rev_tensor, origin_frame, metadata)
self.current_frame = item

def __iter__(self) -> Generator[Tensor, None, None]:
return self
Expand Down
33 changes: 21 additions & 12 deletions yolo/tools/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,24 @@ def configure_optimizers(self):
class InferenceModel(BaseModel):
def __init__(self, cfg: Config):
super().__init__(cfg)
import ubelt as ub
self.cfg = cfg
# TODO: Add FastModel
self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)

print(f'self.predict_loader._is_coco={self.predict_loader._is_coco}')
if self.predict_loader._is_coco:
# hack: write to coco as well
# Setup a kwcoco file to write to if the user requests it.
self.pred_dset = self.predict_loader.coco_dset.copy()
self.pred_dset.reroot(absolute=True)
...
self.pred_dset.fpath = ub.Path(self.pred_dset.fpath).augment(prefix='predict-', ext='.kwcoco.json', multidot=True)

def on_predict_end(self, *args, **kwargs):
print('[InferenceModel] on_predict_end')
dset = self.pred_dset
print(f'dset.fpath={dset.fpath}')
dset.dump()
print('Finished prediction')

def setup(self, stage):
self.vec2box = create_converter(
Expand All @@ -135,27 +144,27 @@ def predict_dataloader(self):

def predict_step(self, batch, batch_idx):

if 0:
# We can access these variables if we need to
self._trainer.predict_dataloaders
self._trainer.predict_dataloaders.coco_dset

images, rev_tensor, origin_frame, metadata = batch

assert metadata is not None
img = metadata['img']
classes = metadata['classes']
image_id = img['id']
predicts = self.post_process(self(images), rev_tensor=rev_tensor)

WRITE_TO_COCO = 1
if WRITE_TO_COCO:
from yolo.utils.kwcoco_utils import tensor_to_kwimage
dset = self.pred_dset
for yolo_annot_tensor in predicts:
pred_dets = tensor_to_kwimage(yolo_annot_tensor).numpy()
pred_dets = tensor_to_kwimage(yolo_annot_tensor, classes=classes).numpy()
pred_dets = pred_dets.non_max_supress(thresh=0.3)
for ann in list(pred_dets.to_coco()):
...
for ann in list(pred_dets.to_coco(dset=dset)):
ann['image_id'] = image_id
dset.add_annotation(**ann)

img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list)

# TODO: handle prediction to kwcoco file.

if getattr(self.predict_loader, "is_stream", None):
fps = self._display_stream(img)
else:
Expand Down
16 changes: 14 additions & 2 deletions yolo/utils/kwcoco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""


def tensor_to_kwimage(yolo_annot_tensor):
def tensor_to_kwimage(yolo_annot_tensor, classes=None):
"""
Convert a raw output tensor to a kwimage Detections object

Expand All @@ -15,6 +15,9 @@ def tensor_to_kwimage(yolo_annot_tensor):
yolo_annot_tensor[:, 5] is the objectness confidence
Other columns are the per-class confidence

classes (kwcoco.CategoryTree):
...

Example:
yolo_annot_tensor = torch.rand(1, 6)
"""
Expand All @@ -23,10 +26,19 @@ def tensor_to_kwimage(yolo_annot_tensor):
boxes = kwimage.Boxes(yolo_annot_tensor[:, 1:5], format='xyxy')
dets = kwimage.Detections(
boxes=boxes,
class_idxs=class_idxs
class_idxs=class_idxs,
classes=classes,
)

if yolo_annot_tensor.shape[1] > 5:
scores = yolo_annot_tensor[:, 5]
dets.data['scores'] = scores

if classes is not None:
if hasattr(classes, 'idx_to_id'):
# Add class-id information if that is available
import torch
idx_to_id = torch.Tensor(classes.idx_to_id).int().to(class_idxs.device)
class_ids = idx_to_id[class_idxs]
dets.data['class_ids'] = class_ids
return dets
11 changes: 11 additions & 0 deletions yolo/utils/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=Tru

save_path = validate_log_directory(cfg, cfg.name)

write_config(cfg, save_path)

progress, loggers = [], []

if hasattr(cfg.task, "ema") and cfg.task.ema.enable:
Expand All @@ -345,6 +347,15 @@ def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=Tru
return progress, loggers, save_path


@rank_zero_only
def write_config(cfg, save_path):
# Dump the config to the disk in the output folder
from omegaconf import OmegaConf
config_text = OmegaConf.to_yaml(cfg)
config_fpath = save_path / f'{cfg.task.task}_config.yaml'
config_fpath.write_text(config_text)


def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
if isinstance(model, YOLO):
model = model.model
Expand Down