Skip to content

Commit 2faec6d

Browse files
authored
Merge branch 'fabiofelix:main' into main
2 parents 5ca9ad6 + ac9b358 commit 2faec6d

File tree

3 files changed

+109
-2
lines changed

3 files changed

+109
-2
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ This is the code for training and evaluation of the preception models built on t
55
It can process videos and predict task (skill) steps such as the ones related to [tactical field care](https://www.ncbi.nlm.nih.gov/books/NBK532260/).
66

77
> [!NOTE]
8-
> These are the used skills: Trauma Assessment (M1), Apply tourniquet (M2), Pressure Dressing (M3), X-Stat (M5), and Apply Chest seal (R18)
8+
> These are the used skills:
9+
> (June/2024 demo) Apply tourniquet (M2), Pressure Dressing (M3), X-Stat (M5), and Apply Chest seal (R18)
10+
> (December/2024 demo) Nasopharyngeal Airway (NPA) (A8), Wound Packing (M4), Ventilate (BVM) (R16), Needle Chest Decompression (R19)
911
1012
## **Install**
1113

step_recog/datasets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .milly import collate_fn, Milly_multifeature_v4
1+
from .milly import collate_fn, Milly_multifeature_v4, Milly_multifeature_v5

step_recog/datasets/milly.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,3 +726,108 @@ def __getitem__(self, index):
726726

727727

728728
return video_act, video_obj, video_frame, video_sound, window_step_label, window_position_label, window_stop_frame, video_id
729+
730+
class Milly_multifeature_v5(Milly_multifeature):
731+
def _construct_loader(self, split):
732+
video_ids = []
733+
734+
for v in ["alabama", "bbnlabs"]:
735+
video_ids.extend( glob.glob(os.path.join(self.cfg.DATASET.LOCATION, v, "*")) )
736+
737+
if self.data_filter is None:
738+
annotations = pd.read_csv(self.annotations_file, usecols=['video_id','start_frame','stop_frame','narration','verb_class','video_fps'])
739+
video_ids = [f for f in video_ids if os.path.basename(f) in annotations.video_id.unique()]
740+
else:
741+
video_ids = [f for f in video_ids if os.path.basename(f) in self.data_filter]
742+
743+
self.datapoints = {}
744+
self.class_histogram = [0] * (self.cfg.MODEL.OUTPUT_DIM + 1)
745+
ipoint = 0
746+
total_window = 0
747+
video_ids = sorted(video_ids)
748+
749+
if split == "train":
750+
self.rng.shuffle(video_ids)
751+
752+
win_size_sec = [1, 2, 4] if self.time_augs else [2]
753+
hop_size_perc = [0.125, 0.25, 0.5] if self.time_augs else [0.5]
754+
755+
progress = tqdm.tqdm(video_ids, total=len(video_ids), desc = "Video")
756+
757+
label2idx = {step: idx for skill in self.cfg.SKILLS for idx, step in enumerate(skill['STEPS'])}
758+
label2idx["No step"] = len(label2idx)
759+
760+
for v in video_ids:
761+
progress.update(1)
762+
763+
embeddings = glob.glob(os.path.join(v, "features", "*.npy"))
764+
embeddings = np.load(embeddings[0], allow_pickle=True)
765+
video_frames = [f[-1] for f in embeddings[()]["frames"]]
766+
767+
label = np.load(os.path.join(v, "normalized_frame_labels.npy"), allow_pickle=True)
768+
label = label[video_frames]
769+
770+
win_size = self.rng.integers(len(win_size_sec))
771+
hop_size = self.rng.integers(len(hop_size_perc))
772+
773+
for l in label:
774+
self.class_histogram[label2idx[l]] += 1
775+
776+
self.datapoints[ipoint] = {
777+
'video_id': v,
778+
'win_size': win_size_sec[win_size],
779+
'hop_size': int(hop_size_perc[hop_size] * win_size_sec[win_size]),
780+
'label': [ label2idx[l] for l in label ]
781+
}
782+
ipoint += 1
783+
total_window += label.shape[0]
784+
progress.set_postfix({"window total": total_window, "padded videos": 0})
785+
786+
def __getitem__(self, index):
787+
video = self.datapoints[index]
788+
video_obj = []
789+
video_frame = []
790+
video_act = []
791+
video_sound = []
792+
window_step_label = video["label"]
793+
window_position_label = [[0, 0]] * len(window_step_label)
794+
window_stop_frame = []
795+
796+
apply_img_aug = False
797+
img_aug_idx = None
798+
799+
if self.image_augs and self.rng.choice([True, False], p = [self.cfg.DATASET.IMAGE_AUGMENTATION_PERCENTAGE, 1.0 - self.cfg.DATASET.IMAGE_AUGMENTATION_PERCENTAGE]):
800+
apply_img_aug = True
801+
802+
if self.cfg.MODEL.USE_OBJECTS:
803+
image_embeddings = glob.glob(os.path.join(video["video_id"], "features", "*_{}_{}.npy".format(video["win_size"], video["hop_size"])))
804+
# obj_embeddings, frame_embeddings
805+
806+
if self.cfg.MODEL.USE_ACTION:
807+
action_embeddings = glob.glob(os.path.join(video["video_id"], "features", "*_{}_{}.npy".format(video["win_size"], video["hop_size"])))
808+
809+
if apply_img_aug:
810+
action_embeddings = [act for act in action_embeddings if "original" not in act]
811+
img_aug_idx = self.rng.integers(len(action_embeddings)) if img_aug_idx is None else img_aug_idx
812+
action_embeddings = action_embeddings[img_aug_idx]
813+
else:
814+
action_embeddings = [act for act in action_embeddings if "original" in act]
815+
action_embeddings = action_embeddings[0]
816+
817+
action_embeddings = np.load(action_embeddings, allow_pickle=True)
818+
video_act.extend(action_embeddings[()]["embeddings"])
819+
820+
if self.cfg.MODEL.USE_AUDIO:
821+
audio_embeddings = glob.glob(os.path.join(video["video_id"], "features", "*_{}_{}.npy".format(video["win_size"], video["hop_size"])))
822+
823+
video_obj = torch.from_numpy(np.array(video_obj))
824+
video_frame = torch.from_numpy(np.array(video_frame))
825+
video_act = torch.from_numpy(np.array(video_act))
826+
video_sound = torch.from_numpy(np.array(video_sound))
827+
window_step_label = torch.from_numpy(np.array(window_step_label))
828+
window_position_label = torch.from_numpy(np.array(window_position_label))
829+
window_stop_frame = torch.from_numpy(np.array(window_stop_frame))
830+
video_id = np.array([ os.path.basename(video["video_id"]) ])
831+
832+
return video_act, video_obj, video_frame, video_sound, window_step_label, window_position_label, window_stop_frame, video_id
833+

0 commit comments

Comments
 (0)