|
1 | 1 | """ |
2 | 2 | Wrapper for OWLv2 text-conditioned object detection models from HuggingFace Transformers. |
3 | 3 | """ |
4 | | -from typing import Union, List |
| 4 | +from typing import Union, List, Dict, Any |
5 | 5 | from transformers import Owlv2Processor, Owlv2ForObjectDetection |
6 | 6 | import torch |
7 | 7 |
|
8 | | -class OWLV2Wrapper: # pylint: disable=too-few-public-methods |
9 | | - """Wrapper for OWLv2 text-conditioned object detection.""" |
10 | | - def __init__(self, model_name="google/owlv2-base-patch16-ensemble", device="cpu"): |
| 8 | +# It's a focused wrapper, so R0903 (too-few-public-methods) might be flagged |
| 9 | +# but is acceptable for this type of class. We can add the disable if Pylint complains. |
| 10 | +# pylint: disable=R0903 |
| 11 | + |
| 12 | +class OWLV2Wrapper: |
| 13 | + """ |
| 14 | + Wrapper for OWLv2 text-conditioned object detection. |
| 15 | +
|
| 16 | + This class handles the loading of OWLv2 models and processors, |
| 17 | + and provides a method to detect objects based on text prompts. |
| 18 | + It formats the output to include both the full label matched by OWLv2 |
| 19 | + and the original "core" prompt term provided by the user. |
| 20 | + """ |
| 21 | + def __init__(self, model_name: str ="google/owlv2-base-patch16-ensemble", device: str = "cpu"): |
| 22 | + """ |
| 23 | + Initialize the OWLV2Wrapper. |
| 24 | +
|
| 25 | + Args: |
| 26 | + model_name (str): The Hugging Face model identifier for OWLv2. |
| 27 | + device (str): The device to run the model on (e.g., "cpu", "cuda"). |
| 28 | + """ |
11 | 29 | self.device = device |
12 | 30 | self.processor: Owlv2Processor = Owlv2Processor.from_pretrained(model_name) |
13 | | - self.model = Owlv2ForObjectDetection.from_pretrained(model_name).to(device) |
| 31 | + self.model: Owlv2ForObjectDetection = Owlv2ForObjectDetection.from_pretrained( |
| 32 | + model_name).to( |
| 33 | + self.device |
| 34 | + ) |
14 | 35 |
|
15 | | - def detect(self, *, image, prompt: Union[str, List[str]], threshold=0.1): |
| 36 | + def detect(self, *, image: Any, prompt: Union[str, List[str]], threshold: float = 0.1 |
| 37 | + ) -> List[Dict[str, Any]]: |
16 | 38 | """ |
17 | | - Detect objects in the image matching the text prompt. |
18 | | - Returns a list of dict with keys: box, score, label. |
| 39 | + Detect objects in the image matching the text prompt(s). |
| 40 | +
|
| 41 | + Args: |
| 42 | + image (Any): The input image (e.g., a PIL Image). |
| 43 | + prompt (Union[str, List[str]]): A single text prompt or a list of text prompts. |
| 44 | + threshold (float): The confidence threshold for detections. |
| 45 | +
|
| 46 | + Returns: |
| 47 | + List[Dict[str, Any]]: A list of dictionaries, where each dictionary |
| 48 | + represents a detected object and contains 'box', 'score', 'label' |
| 49 | + (the full text matched by OWLv2), and 'core_prompt' (the original |
| 50 | + user-provided term that led to this detection). |
19 | 51 | """ |
20 | 52 | if isinstance(prompt, str): |
21 | | - processed_prompts = [f"a photo of {prompt}"] |
22 | | - else: # prompt is a list of strings |
23 | | - processed_prompts = [f"a photo of {p}" for p in prompt] |
| 53 | + original_prompt_terms: List[str] = [prompt] |
| 54 | + else: |
| 55 | + original_prompt_terms: List[str] = prompt |
24 | 56 |
|
25 | | - text_labels = [processed_prompts] # Batch size of 1, with potentially multiple queries |
| 57 | + # OWLv2 typically expects prompts like "a photo of <object>" |
| 58 | + processed_prompts_for_owl: List[str] = [ |
| 59 | + f"a photo of {p}" for p in original_prompt_terms |
| 60 | + ] |
| 61 | + # The 'text' argument to the processor for multiple queries on a single image |
| 62 | + # should be List[List[str]], where the outer list is for batch items. |
| 63 | + text_labels_for_owl: List[List[str]] = [processed_prompts_for_owl] |
26 | 64 |
|
27 | 65 | inputs = self.processor( |
28 | | - text=text_labels, images=image, return_tensors="pt" |
| 66 | + text=text_labels_for_owl, images=image, return_tensors="pt" |
29 | 67 | ).to(self.device) |
30 | 68 |
|
31 | 69 | with torch.no_grad(): |
32 | 70 | outputs = self.model(**inputs) |
33 | 71 |
|
34 | | - target_sizes = torch.tensor([(image.height, image.width)]).to(self.device) |
| 72 | + # target_sizes should be a tensor of shape (batch_size, 2) |
| 73 | + target_sizes = torch.tensor([image.size[::-1]], device=self.device) |
| 74 | + |
| 75 | + # Pass text_labels_for_owl to post_process for correct label association |
35 | 76 | results = self.processor.post_process_grounded_object_detection( |
36 | | - outputs=outputs, target_sizes=target_sizes, |
37 | | - threshold=threshold, text_labels=text_labels |
| 77 | + outputs=outputs, |
| 78 | + target_sizes=target_sizes, |
| 79 | + threshold=threshold, |
| 80 | + text_labels=text_labels_for_owl |
38 | 81 | ) |
39 | 82 |
|
40 | | - # Determine the list of original prompt terms for fallback in _format_detections |
41 | | - if isinstance(prompt, str): |
42 | | - fallback_labels = [prompt] |
43 | | - else: |
44 | | - fallback_labels = prompt |
45 | | - |
46 | | - return self._format_detections(results, fallback_labels) |
| 83 | + return self._format_detections(results, original_prompt_terms) |
47 | 84 |
|
48 | | - def _format_detections(self, results, fallback_prompts: List[str]): |
| 85 | + def _format_detections(self, results: List[Dict[str, Any]], |
| 86 | + original_prompt_terms: List[str]) -> List[Dict[str, Any]]: |
49 | 87 | """ |
50 | | - Helper to format detection results into a list of dicts. |
51 | | - fallback_prompts: The list of original prompt terms used for searching. |
| 88 | + Helper to format raw detection results into a structured list of dictionaries. |
| 89 | +
|
| 90 | + Args: |
| 91 | + results (List[Dict[str, Any]]): Raw results from the OWLv2 processor's |
| 92 | + post_process_grounded_object_detection method. |
| 93 | + original_prompt_terms (List[str]): The list of original, user-provided |
| 94 | + prompt terms (e.g., ["cat", "dog"]). |
| 95 | +
|
| 96 | + Returns: |
| 97 | + List[Dict[str, Any]]: Formatted list of detections. |
52 | 98 | """ |
53 | | - detections = [] |
54 | | - if results and results[0]: |
55 | | - result = results[0] |
56 | | - boxes = result["boxes"].cpu().numpy() |
57 | | - scores = result["scores"].cpu().numpy() |
58 | | - # 'text_labels' in result should be populated by post_process_object_detection |
59 | | - # with the specific query that matched each box (e.g., "a photo of cat"). |
60 | | - # The processor.post_process_grounded_object_detection |
61 | | - # returns the text_labels as they were passed in. |
62 | | - returned_labels = result.get("text_labels", fallback_prompts * len(boxes)) # Fallback |
63 | | - for box, score, label_text in zip(boxes, scores, returned_labels): |
64 | | - detections.append({ |
65 | | - "box": [float(coord) for coord in box], |
66 | | - "score": float(score), |
67 | | - "label": label_text |
68 | | - }) |
| 99 | + detections: List[Dict[str, Any]] = [] |
| 100 | + if not results or not results[0]: |
| 101 | + return detections |
| 102 | + |
| 103 | + # Results is a list (batch), we typically process one image at a time here. |
| 104 | + first_image_results = results[0] |
| 105 | + boxes = first_image_results["boxes"].cpu().numpy() |
| 106 | + scores = first_image_results["scores"].cpu().numpy() |
| 107 | + |
| 108 | + # 'labels' are integer indices into the list of queries *for the current image* |
| 109 | + # that were passed to post_process_grounded_object_detection |
| 110 | + # (i.e., text_labels_for_owl[0]). |
| 111 | + prompt_indices = first_image_results.get( |
| 112 | + "labels", torch.zeros(len(boxes), dtype=torch.long) |
| 113 | + ).cpu().numpy() |
| 114 | + |
| 115 | + # 'text_labels' from results should be the actual prompt strings that matched. |
| 116 | + owl_matched_full_labels = first_image_results.get("text_labels", []) |
| 117 | + |
| 118 | + for i, (current_box, current_score) in enumerate(zip(boxes, scores)): |
| 119 | + try: |
| 120 | + core_prompt = original_prompt_terms[prompt_indices[i]] |
| 121 | + except IndexError: |
| 122 | + core_prompt = "unknown_prompt_term" |
| 123 | + print( |
| 124 | + f"Warning: Index {prompt_indices[i]} out of bounds" |
| 125 | + ) |
| 126 | + |
| 127 | + full_owl_label = ( |
| 128 | + owl_matched_full_labels[i] |
| 129 | + if i < len(owl_matched_full_labels) |
| 130 | + else f"a photo of {core_prompt}" |
| 131 | + ) |
| 132 | + |
| 133 | + detections.append({ |
| 134 | + "box": [float(coord) for coord in current_box], |
| 135 | + "score": float(current_score), |
| 136 | + "label": full_owl_label, |
| 137 | + "core_prompt": core_prompt |
| 138 | + }) |
69 | 139 | return detections |
0 commit comments