|
4 | 4 | """ |
5 | 5 | import os |
6 | 6 | from glob import glob |
| 7 | +import re |
7 | 8 | import cv2 # pylint: disable=import-error |
8 | 9 | from PIL import Image |
9 | 10 | import numpy as np |
10 | 11 |
|
| 12 | + |
11 | 13 | # Disable no-member for cv2 (OpenCV) for the whole file |
12 | 14 | # pylint: disable=no-member |
13 | 15 |
|
@@ -65,31 +67,88 @@ def images_to_video(image_files, video_path, fps=30): |
65 | 67 | video_writer.release() |
66 | 68 | print(f"Saved video {video_path}") |
67 | 69 |
|
| 70 | +def _parse_mask_filename(fname): |
| 71 | + """ |
| 72 | + Parse a mask filename to extract sam_id_token and core_prompt_slug. |
| 73 | + Returns (sam_id_token, core_prompt_slug) or (None, None) if not matched. |
| 74 | + """ |
| 75 | + # Example: 000001_obj1_dog_mask.png |
| 76 | + match = re.match(r"^\d+_(obj\d+)_([a-zA-Z0-9_]+)_mask\.png$", fname) |
| 77 | + if match: |
| 78 | + return match.group(1), match.group(2) |
| 79 | + # Fallback: 000001_obj1_mask.png (no prompt) |
| 80 | + match_simple = re.match(r"^\d+_(obj\d+)_mask\.png$", fname) |
| 81 | + if match_simple: |
| 82 | + return match_simple.group(1), None |
| 83 | + return None, None |
| 84 | + |
| 85 | +def _collect_unique_tracked_objects(mask_files): |
| 86 | + """ |
| 87 | + Collect unique (sam_id_token, core_prompt_slug) pairs from mask filenames. |
| 88 | + Returns a dict with keys as (sam_id_token, core_prompt_slug). |
| 89 | + """ |
| 90 | + unique_tracked_objects = {} |
| 91 | + for f_path in mask_files: |
| 92 | + fname = os.path.basename(f_path) |
| 93 | + sam_id_token, core_prompt_slug = _parse_mask_filename(fname) |
| 94 | + if sam_id_token is not None: |
| 95 | + key = (sam_id_token, core_prompt_slug) |
| 96 | + if key not in unique_tracked_objects: |
| 97 | + unique_tracked_objects[key] = { |
| 98 | + "sam_id_token": sam_id_token, |
| 99 | + "core_prompt_slug": core_prompt_slug |
| 100 | + } |
| 101 | + else: |
| 102 | + print(f"Warning: Filename {fname} did not match expected pattern.") |
| 103 | + return unique_tracked_objects |
| 104 | + |
| 105 | +def _get_obj_files(output_dir, sam_id_token, core_prompt_slug): |
| 106 | + """ |
| 107 | + Get sorted mask and overlay files for a given object. |
| 108 | + """ |
| 109 | + if core_prompt_slug: |
| 110 | + mask_pattern = os.path.join( |
| 111 | + output_dir, f"*_{sam_id_token}_{core_prompt_slug}_mask.png") |
| 112 | + overlay_pattern = os.path.join( |
| 113 | + output_dir, f"*_{sam_id_token}_{core_prompt_slug}_overlay.png") |
| 114 | + video_prefix = f"{sam_id_token}_{core_prompt_slug}" |
| 115 | + else: |
| 116 | + mask_pattern = os.path.join(output_dir, f"*_{sam_id_token}_mask.png") |
| 117 | + overlay_pattern = os.path.join(output_dir, f"*_{sam_id_token}_overlay.png") |
| 118 | + video_prefix = sam_id_token |
| 119 | + mask_files = sorted(glob(mask_pattern)) |
| 120 | + overlay_files = sorted(glob(overlay_pattern)) |
| 121 | + return mask_files, overlay_files, video_prefix |
| 122 | + |
68 | 123 | def generate_per_object_videos(output_dir, fps=30): |
69 | 124 | """ |
70 | 125 | Generate per-object videos from mask and overlay images. |
71 | | - Each object will have its own video for masks and overlays. |
| 126 | + Each object (identified by sam_id and core_prompt) will have its own |
| 127 | + video for masks and overlays. |
72 | 128 | """ |
73 | | - mask_pattern = os.path.join(output_dir, "*_obj*_mask.png") |
74 | | - mask_files = sorted(glob(mask_pattern)) |
| 129 | + all_mask_files_pattern = os.path.join(output_dir, "*_mask.png") |
| 130 | + all_mask_files = sorted(glob(all_mask_files_pattern)) |
75 | 131 |
|
76 | | - objects = set() |
77 | | - for f in mask_files: |
78 | | - try: |
79 | | - # Assuming filename format like '000001_obj1_mask.png' |
80 | | - obj_id = os.path.basename(f).split('_')[1] # Extracts 'obj1' |
81 | | - objects.add(obj_id) |
82 | | - except IndexError: |
83 | | - print(f"Warning: Could not parse object ID from filename {f}. Skipping.") |
84 | | - continue |
| 132 | + if not all_mask_files: |
| 133 | + print(f"No mask files found in {output_dir} matching pattern.") |
| 134 | + return |
| 135 | + |
| 136 | + unique_tracked_objects = _collect_unique_tracked_objects(all_mask_files) |
| 137 | + if not unique_tracked_objects: |
| 138 | + print(f"No objects successfully parsed from filenames in {output_dir}.") |
| 139 | + return |
85 | 140 |
|
| 141 | + for key in sorted(unique_tracked_objects.keys()): |
| 142 | + obj_info = unique_tracked_objects[key] |
| 143 | + sam_id_token = obj_info["sam_id_token"] |
| 144 | + core_prompt_slug = obj_info["core_prompt_slug"] |
86 | 145 |
|
87 | | - for obj in sorted(list(objects)): # Convert set to sorted list for deterministic order |
88 | | - obj_mask_files = sorted(glob(os.path.join(output_dir, f"*_{obj}_mask.png"))) |
89 | | - obj_overlay_files = sorted(glob(os.path.join(output_dir, f"*_{obj}_overlay.png"))) |
| 146 | + mask_files, overlay_files, video_file_prefix = _get_obj_files( |
| 147 | + output_dir, sam_id_token, core_prompt_slug |
| 148 | + ) |
90 | 149 |
|
91 | | - mask_video_path = os.path.join(output_dir, f"{obj}_mask_video.mp4") |
92 | | - overlay_video_path = os.path.join(output_dir, f"{obj}_overlay_video.mp4") |
| 150 | + mask_video_path = os.path.join(output_dir, f"{video_file_prefix}_mask_video.mp4") |
| 151 | + overlay_video_path = os.path.join(output_dir, f"{video_file_prefix}_overlay_video.mp4") |
93 | 152 |
|
94 | | - images_to_video(obj_mask_files, mask_video_path, fps) |
95 | | - images_to_video(obj_overlay_files, overlay_video_path, fps) |
| 153 | + images_to_video(mask_files, mask_video_path, fps) |
| 154 | + images_to_video(overlay_files, overlay_video_path, fps) |
0 commit comments