Skip to content

Commit 940eb06

Browse files
authored
Merge pull request #13 from bladeszasza/bugfix/generate_per_object_videos
Bugfix/generate per object videos
2 parents de8d778 + caaaeee commit 940eb06

1 file changed

Lines changed: 78 additions & 19 deletions

File tree

sowlv2/video_utils.py

Lines changed: 78 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
"""
55
import os
66
from glob import glob
7+
import re
78
import cv2 # pylint: disable=import-error
89
from PIL import Image
910
import numpy as np
1011

12+
1113
# Disable no-member for cv2 (OpenCV) for the whole file
1214
# pylint: disable=no-member
1315

@@ -65,31 +67,88 @@ def images_to_video(image_files, video_path, fps=30):
6567
video_writer.release()
6668
print(f"Saved video {video_path}")
6769

70+
def _parse_mask_filename(fname):
71+
"""
72+
Parse a mask filename to extract sam_id_token and core_prompt_slug.
73+
Returns (sam_id_token, core_prompt_slug) or (None, None) if not matched.
74+
"""
75+
# Example: 000001_obj1_dog_mask.png
76+
match = re.match(r"^\d+_(obj\d+)_([a-zA-Z0-9_]+)_mask\.png$", fname)
77+
if match:
78+
return match.group(1), match.group(2)
79+
# Fallback: 000001_obj1_mask.png (no prompt)
80+
match_simple = re.match(r"^\d+_(obj\d+)_mask\.png$", fname)
81+
if match_simple:
82+
return match_simple.group(1), None
83+
return None, None
84+
85+
def _collect_unique_tracked_objects(mask_files):
86+
"""
87+
Collect unique (sam_id_token, core_prompt_slug) pairs from mask filenames.
88+
Returns a dict with keys as (sam_id_token, core_prompt_slug).
89+
"""
90+
unique_tracked_objects = {}
91+
for f_path in mask_files:
92+
fname = os.path.basename(f_path)
93+
sam_id_token, core_prompt_slug = _parse_mask_filename(fname)
94+
if sam_id_token is not None:
95+
key = (sam_id_token, core_prompt_slug)
96+
if key not in unique_tracked_objects:
97+
unique_tracked_objects[key] = {
98+
"sam_id_token": sam_id_token,
99+
"core_prompt_slug": core_prompt_slug
100+
}
101+
else:
102+
print(f"Warning: Filename {fname} did not match expected pattern.")
103+
return unique_tracked_objects
104+
105+
def _get_obj_files(output_dir, sam_id_token, core_prompt_slug):
106+
"""
107+
Get sorted mask and overlay files for a given object.
108+
"""
109+
if core_prompt_slug:
110+
mask_pattern = os.path.join(
111+
output_dir, f"*_{sam_id_token}_{core_prompt_slug}_mask.png")
112+
overlay_pattern = os.path.join(
113+
output_dir, f"*_{sam_id_token}_{core_prompt_slug}_overlay.png")
114+
video_prefix = f"{sam_id_token}_{core_prompt_slug}"
115+
else:
116+
mask_pattern = os.path.join(output_dir, f"*_{sam_id_token}_mask.png")
117+
overlay_pattern = os.path.join(output_dir, f"*_{sam_id_token}_overlay.png")
118+
video_prefix = sam_id_token
119+
mask_files = sorted(glob(mask_pattern))
120+
overlay_files = sorted(glob(overlay_pattern))
121+
return mask_files, overlay_files, video_prefix
122+
68123
def generate_per_object_videos(output_dir, fps=30):
69124
"""
70125
Generate per-object videos from mask and overlay images.
71-
Each object will have its own video for masks and overlays.
126+
Each object (identified by sam_id and core_prompt) will have its own
127+
video for masks and overlays.
72128
"""
73-
mask_pattern = os.path.join(output_dir, "*_obj*_mask.png")
74-
mask_files = sorted(glob(mask_pattern))
129+
all_mask_files_pattern = os.path.join(output_dir, "*_mask.png")
130+
all_mask_files = sorted(glob(all_mask_files_pattern))
75131

76-
objects = set()
77-
for f in mask_files:
78-
try:
79-
# Assuming filename format like '000001_obj1_mask.png'
80-
obj_id = os.path.basename(f).split('_')[1] # Extracts 'obj1'
81-
objects.add(obj_id)
82-
except IndexError:
83-
print(f"Warning: Could not parse object ID from filename {f}. Skipping.")
84-
continue
132+
if not all_mask_files:
133+
print(f"No mask files found in {output_dir} matching pattern.")
134+
return
135+
136+
unique_tracked_objects = _collect_unique_tracked_objects(all_mask_files)
137+
if not unique_tracked_objects:
138+
print(f"No objects successfully parsed from filenames in {output_dir}.")
139+
return
85140

141+
for key in sorted(unique_tracked_objects.keys()):
142+
obj_info = unique_tracked_objects[key]
143+
sam_id_token = obj_info["sam_id_token"]
144+
core_prompt_slug = obj_info["core_prompt_slug"]
86145

87-
for obj in sorted(list(objects)): # Convert set to sorted list for deterministic order
88-
obj_mask_files = sorted(glob(os.path.join(output_dir, f"*_{obj}_mask.png")))
89-
obj_overlay_files = sorted(glob(os.path.join(output_dir, f"*_{obj}_overlay.png")))
146+
mask_files, overlay_files, video_file_prefix = _get_obj_files(
147+
output_dir, sam_id_token, core_prompt_slug
148+
)
90149

91-
mask_video_path = os.path.join(output_dir, f"{obj}_mask_video.mp4")
92-
overlay_video_path = os.path.join(output_dir, f"{obj}_overlay_video.mp4")
150+
mask_video_path = os.path.join(output_dir, f"{video_file_prefix}_mask_video.mp4")
151+
overlay_video_path = os.path.join(output_dir, f"{video_file_prefix}_overlay_video.mp4")
93152

94-
images_to_video(obj_mask_files, mask_video_path, fps)
95-
images_to_video(obj_overlay_files, overlay_video_path, fps)
153+
images_to_video(mask_files, mask_video_path, fps)
154+
images_to_video(overlay_files, overlay_video_path, fps)

0 commit comments

Comments
 (0)