Skip to content

Commit de8d778

Browse files
authored
Merge pull request #12 from bladeszasza/chore/add-multi-color-segmentation-layers
Chore/add multi color segmentation layers
2 parents 34f0a2b + ef243f3 commit de8d778

8 files changed

Lines changed: 580 additions & 230 deletions

File tree

.github/workflows/pylint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
runs-on: ubuntu-latest
1111
strategy:
1212
matrix:
13-
python-version: ["3.10"]
13+
python-version: ["3.10", "3.11"]
1414
steps:
1515
- uses: actions/checkout@v4
1616
- name: Set up Python ${{ matrix.python-version }}

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ Given one or more text prompts (e.g., `"a red bicycle"`, or `"cat" "dog"`) and a
3030
2. Employ **SAM 2** to generate detailed segmentation masks for each detected object, leveraging techniques from the paper [SAM 2: Segment Anything in Images and Videos](https://arxiv.org/abs/2408.00714).
3131
3. Save both **binary segmentation masks** (foreground vs. background) and **overlay images** (original image with masks visually overlaid) to a specified output directory.
3232

33+
![Multilabel output showcase](./assets/SOWLv2Multilabel.png "Multilabel Output Showcase")
34+
3335
## ✨ Key Features
3436

3537
* **Text-Prompted Segmentation:** Identify and segment objects using free-form text descriptions.
@@ -130,6 +132,8 @@ The tool saves results in the specified output directory. For each detected obje
130132
131133
Objects are numbered sequentially (e.g., `object0`, `object1`) in the order they are detected by OWLv2, regardless of which text prompt they matched. For video inputs, output filenames will also include frame identifiers, and separate videos for each object's masks and overlays will be generated (e.g., `obj0_mask_video.mp4`, `obj0_overlay_video.mp4`).
132134

135+
SOWLv2 automatically assigns a unique color to each detected OWLv2 label, making it easy to visually distinguish different object classes in the output overlays and merged results.
136+
133137
### <a name="configuration"></a>Configuration File (Optional):
134138

135139
You can use a YAML configuration file to specify arguments, which is useful for managing complex settings or reproducing experiments. The `prompt` field in the YAML file can also be a list of strings.

assets/SOWLv2Multilabel.png

13.1 MB
Loading

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "sowlv2"
7-
version = "1.0.0"
7+
version = "0.2.0"
88
authors = [
99
{ name="Csaba Bolyos", email="bladeszasza@gmail.com" },
1010
]
@@ -35,4 +35,4 @@ Homepage = "https://github.com/bladeszasza/SOWLv2"
3535
Issues = "https://github.com/bladeszasza/SOWLv2/issues"
3636

3737
[project.scripts]
38-
sowlv2-detect = "sowlv2.cli:main"
38+
sowlv2-detect = "sowlv2.cli:main"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
setup(
99
name="sowlv2",
10-
version="1.0.0",
10+
version="0.2.0",
1111
description="SOWLv2: Text-prompted object segmentation using OWLv2 and SAM 2",
1212
author="Bolyos Csaba",
1313
author_email="bladeszasza@gmail.com",

sowlv2/data/config.py

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""
22
Dataclasses for configuring the SOWLv2 object detection and segmentation pipeline.
33
"""
4-
54
from dataclasses import dataclass
6-
from typing import Any
5+
from typing import Any, Tuple, List, Dict
6+
import numpy as np
7+
from PIL import Image
78

89
@dataclass
910
class PipelineBaseData:
@@ -20,9 +21,131 @@ class PipelineBaseData:
2021
class SaveMaskOverlayConfig:
2122
"""
2223
Configuration for saving masks and overlays for a frame.
24+
25+
Attributes:
26+
pil_img (Any): The PIL image for the frame.
27+
frame_idx (int): The frame index.
28+
obj_ids (Any): Object IDs for the masks.
29+
masks (Any): Masks for the objects.
30+
out_dir (str): Output directory for saving results.
2331
"""
2432
pil_img: Any
2533
frame_idx: int
2634
obj_ids: Any
2735
masks: Any
2836
out_dir: str
37+
38+
@dataclass
39+
class MaskObject:
40+
"""
41+
Stores the mask and its properties for a detected object.
42+
43+
Attributes:
44+
mask_np (np.ndarray): The mask as a NumPy array.
45+
mask_img_pil (Image.Image): The mask as a PIL image.
46+
mask_file (str): Path to the saved mask file.
47+
"""
48+
mask_np: np.ndarray
49+
mask_img_pil: Image.Image
50+
mask_file: str
51+
52+
@dataclass
53+
class DetectionResult:
54+
"""
55+
Stores the result of a single detection and segmentation.
56+
57+
Attributes:
58+
box (Any): The bounding box for the detected object.
59+
core_prompt (str): The core prompt/label for the object.
60+
object_color (Tuple[int, int, int]): The assigned color for the object.
61+
mask_np (np.ndarray): The segmentation mask as a NumPy array.
62+
mask_img_pil (Image.Image): The mask as a PIL image.
63+
mask_file (str): Path to the saved mask file.
64+
individual_overlay_pil (Image.Image): The overlay as a PIL image.
65+
overlay_file (str): Path to the saved overlay file.
66+
"""
67+
box: Any
68+
core_prompt: str
69+
object_color: Tuple[int, int, int]
70+
mask : MaskObject
71+
individual_overlay_pil: Image.Image
72+
overlay_file: str
73+
74+
@dataclass
75+
class MergedOverlayItem:
76+
"""
77+
Stores information for a single mask/color/label used in a merged overlay.
78+
79+
Attributes:
80+
mask (np.ndarray): Boolean mask for the object.
81+
color (Tuple[int, int, int]): Color assigned to the object.
82+
label (str): The core prompt/label for the object.
83+
"""
84+
mask: np.ndarray
85+
color: Tuple[int, int, int]
86+
label: str
87+
88+
@dataclass
89+
class VideoDetectionDetail:
90+
"""
91+
Stores details for each detected object in a video, including its SAM object ID,
92+
the core prompt/label, and the assigned color for consistent visualization.
93+
94+
Attributes:
95+
sam_id (int): The unique SAM object ID assigned for tracking in the video.
96+
core_prompt (str): The core prompt/label for the detected object.
97+
color (Tuple[int, int, int]): The RGB color assigned to this object for overlays.
98+
"""
99+
sam_id: int
100+
core_prompt: str
101+
color: Tuple[int, int, int]
102+
103+
@dataclass
104+
class PropagatedFrameOutput:
105+
"""
106+
Stores all relevant data for a single propagated frame in video segmentation.
107+
108+
Attributes:
109+
current_pil_img (Image.Image): The current frame as a PIL image.
110+
frame_num (int): The frame number (1-based).
111+
sam_obj_ids_tensor (Any): Tensor or list of SAM object IDs for this frame.
112+
mask_logits_tensor (Any): Tensor of mask logits for this frame.
113+
detection_details_map (List[Dict[str, Any]]): List of detection details for mapping IDs.
114+
output_dir (str): Output directory for saving results.
115+
"""
116+
current_pil_img: Image.Image
117+
frame_num: int
118+
sam_obj_ids_tensor: Any
119+
mask_logits_tensor: Any
120+
detection_details_map: List[Dict[str, Any]]
121+
output_dir: str
122+
123+
@dataclass
124+
class SingleDetectionInput:
125+
"""
126+
Stores all relevant data for processing a single detection in an image.
127+
128+
Attributes:
129+
pil_image (Image.Image): The PIL image being processed.
130+
detection_detail (dict): The detection dictionary for the object.
131+
obj_idx (int): The index of the object in the detections list.
132+
base_name (str): The base name of the input image file.
133+
output_dir (str): Output directory for saving results.
134+
"""
135+
pil_image: Image.Image
136+
detection_detail: dict
137+
obj_idx: int
138+
base_name: str
139+
output_dir: str
140+
141+
@dataclass
142+
class VideoProcessContext:
143+
"""
144+
Holds all context/state for processing a video in SOWLv2Pipeline.
145+
"""
146+
tmp_frames_dir: str
147+
initial_sam_state: Any
148+
first_img_path: str
149+
first_pil_img: Image.Image
150+
detection_details_for_video: List[Dict[str, Any]]
151+
updated_sam_state: Any

sowlv2/owl.py

Lines changed: 112 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,139 @@
11
"""
22
Wrapper for OWLv2 text-conditioned object detection models from HuggingFace Transformers.
33
"""
4-
from typing import Union, List
4+
from typing import Union, List, Dict, Any
55
from transformers import Owlv2Processor, Owlv2ForObjectDetection
66
import torch
77

8-
class OWLV2Wrapper: # pylint: disable=too-few-public-methods
9-
"""Wrapper for OWLv2 text-conditioned object detection."""
10-
def __init__(self, model_name="google/owlv2-base-patch16-ensemble", device="cpu"):
8+
# It's a focused wrapper, so R0903 (too-few-public-methods) might be flagged
9+
# but is acceptable for this type of class. We can add the disable if Pylint complains.
10+
# pylint: disable=R0903
11+
12+
class OWLV2Wrapper:
13+
"""
14+
Wrapper for OWLv2 text-conditioned object detection.
15+
16+
This class handles the loading of OWLv2 models and processors,
17+
and provides a method to detect objects based on text prompts.
18+
It formats the output to include both the full label matched by OWLv2
19+
and the original "core" prompt term provided by the user.
20+
"""
21+
def __init__(self, model_name: str ="google/owlv2-base-patch16-ensemble", device: str = "cpu"):
22+
"""
23+
Initialize the OWLV2Wrapper.
24+
25+
Args:
26+
model_name (str): The Hugging Face model identifier for OWLv2.
27+
device (str): The device to run the model on (e.g., "cpu", "cuda").
28+
"""
1129
self.device = device
1230
self.processor: Owlv2Processor = Owlv2Processor.from_pretrained(model_name)
13-
self.model = Owlv2ForObjectDetection.from_pretrained(model_name).to(device)
31+
self.model: Owlv2ForObjectDetection = Owlv2ForObjectDetection.from_pretrained(
32+
model_name).to(
33+
self.device
34+
)
1435

15-
def detect(self, *, image, prompt: Union[str, List[str]], threshold=0.1):
36+
def detect(self, *, image: Any, prompt: Union[str, List[str]], threshold: float = 0.1
37+
) -> List[Dict[str, Any]]:
1638
"""
17-
Detect objects in the image matching the text prompt.
18-
Returns a list of dict with keys: box, score, label.
39+
Detect objects in the image matching the text prompt(s).
40+
41+
Args:
42+
image (Any): The input image (e.g., a PIL Image).
43+
prompt (Union[str, List[str]]): A single text prompt or a list of text prompts.
44+
threshold (float): The confidence threshold for detections.
45+
46+
Returns:
47+
List[Dict[str, Any]]: A list of dictionaries, where each dictionary
48+
represents a detected object and contains 'box', 'score', 'label'
49+
(the full text matched by OWLv2), and 'core_prompt' (the original
50+
user-provided term that led to this detection).
1951
"""
2052
if isinstance(prompt, str):
21-
processed_prompts = [f"a photo of {prompt}"]
22-
else: # prompt is a list of strings
23-
processed_prompts = [f"a photo of {p}" for p in prompt]
53+
original_prompt_terms: List[str] = [prompt]
54+
else:
55+
original_prompt_terms: List[str] = prompt
2456

25-
text_labels = [processed_prompts] # Batch size of 1, with potentially multiple queries
57+
# OWLv2 typically expects prompts like "a photo of <object>"
58+
processed_prompts_for_owl: List[str] = [
59+
f"a photo of {p}" for p in original_prompt_terms
60+
]
61+
# The 'text' argument to the processor for multiple queries on a single image
62+
# should be List[List[str]], where the outer list is for batch items.
63+
text_labels_for_owl: List[List[str]] = [processed_prompts_for_owl]
2664

2765
inputs = self.processor(
28-
text=text_labels, images=image, return_tensors="pt"
66+
text=text_labels_for_owl, images=image, return_tensors="pt"
2967
).to(self.device)
3068

3169
with torch.no_grad():
3270
outputs = self.model(**inputs)
3371

34-
target_sizes = torch.tensor([(image.height, image.width)]).to(self.device)
72+
# target_sizes should be a tensor of shape (batch_size, 2)
73+
target_sizes = torch.tensor([image.size[::-1]], device=self.device)
74+
75+
# Pass text_labels_for_owl to post_process for correct label association
3576
results = self.processor.post_process_grounded_object_detection(
36-
outputs=outputs, target_sizes=target_sizes,
37-
threshold=threshold, text_labels=text_labels
77+
outputs=outputs,
78+
target_sizes=target_sizes,
79+
threshold=threshold,
80+
text_labels=text_labels_for_owl
3881
)
3982

40-
# Determine the list of original prompt terms for fallback in _format_detections
41-
if isinstance(prompt, str):
42-
fallback_labels = [prompt]
43-
else:
44-
fallback_labels = prompt
45-
46-
return self._format_detections(results, fallback_labels)
83+
return self._format_detections(results, original_prompt_terms)
4784

48-
def _format_detections(self, results, fallback_prompts: List[str]):
85+
def _format_detections(self, results: List[Dict[str, Any]],
86+
original_prompt_terms: List[str]) -> List[Dict[str, Any]]:
4987
"""
50-
Helper to format detection results into a list of dicts.
51-
fallback_prompts: The list of original prompt terms used for searching.
88+
Helper to format raw detection results into a structured list of dictionaries.
89+
90+
Args:
91+
results (List[Dict[str, Any]]): Raw results from the OWLv2 processor's
92+
post_process_grounded_object_detection method.
93+
original_prompt_terms (List[str]): The list of original, user-provided
94+
prompt terms (e.g., ["cat", "dog"]).
95+
96+
Returns:
97+
List[Dict[str, Any]]: Formatted list of detections.
5298
"""
53-
detections = []
54-
if results and results[0]:
55-
result = results[0]
56-
boxes = result["boxes"].cpu().numpy()
57-
scores = result["scores"].cpu().numpy()
58-
# 'text_labels' in result should be populated by post_process_object_detection
59-
# with the specific query that matched each box (e.g., "a photo of cat").
60-
# The processor.post_process_grounded_object_detection
61-
# returns the text_labels as they were passed in.
62-
returned_labels = result.get("text_labels", fallback_prompts * len(boxes)) # Fallback
63-
for box, score, label_text in zip(boxes, scores, returned_labels):
64-
detections.append({
65-
"box": [float(coord) for coord in box],
66-
"score": float(score),
67-
"label": label_text
68-
})
99+
detections: List[Dict[str, Any]] = []
100+
if not results or not results[0]:
101+
return detections
102+
103+
# Results is a list (batch), we typically process one image at a time here.
104+
first_image_results = results[0]
105+
boxes = first_image_results["boxes"].cpu().numpy()
106+
scores = first_image_results["scores"].cpu().numpy()
107+
108+
# 'labels' are integer indices into the list of queries *for the current image*
109+
# that were passed to post_process_grounded_object_detection
110+
# (i.e., text_labels_for_owl[0]).
111+
prompt_indices = first_image_results.get(
112+
"labels", torch.zeros(len(boxes), dtype=torch.long)
113+
).cpu().numpy()
114+
115+
# 'text_labels' from results should be the actual prompt strings that matched.
116+
owl_matched_full_labels = first_image_results.get("text_labels", [])
117+
118+
for i, (current_box, current_score) in enumerate(zip(boxes, scores)):
119+
try:
120+
core_prompt = original_prompt_terms[prompt_indices[i]]
121+
except IndexError:
122+
core_prompt = "unknown_prompt_term"
123+
print(
124+
f"Warning: Index {prompt_indices[i]} out of bounds"
125+
)
126+
127+
full_owl_label = (
128+
owl_matched_full_labels[i]
129+
if i < len(owl_matched_full_labels)
130+
else f"a photo of {core_prompt}"
131+
)
132+
133+
detections.append({
134+
"box": [float(coord) for coord in current_box],
135+
"score": float(current_score),
136+
"label": full_owl_label,
137+
"core_prompt": core_prompt
138+
})
69139
return detections

0 commit comments

Comments
 (0)