Coverage for src/pythia/applications/annotation.py: 73%
144 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-10-07 19:27 +0000
« prev ^ index » next coverage.py v6.4.4, created at 2022-10-07 19:27 +0000
1"""Annotation applications."""
2from __future__ import annotations
4import abc
5import json
6import logging
7import sys
8from functools import partial
9from pathlib import Path
10from typing import Any
11from typing import Callable
12from typing import List
13from typing import Literal
14from typing import Optional
15from typing import Union
17import pyds
19try:
20 import cv2
21except ImportError:
22 cv2 = None # type: ignore[assignment]
23import numpy as np
25from pythia.applications.base import Application
26from pythia.applications.base import BoundSupportedCb
27from pythia.iterators import analytics_per_obj
28from pythia.iterators import objects_per_batch
29from pythia.models.base import Analytics
30from pythia.models.base import InferenceEngine
31from pythia.models.base import Tracker
32from pythia.pipelines.base import Pipeline
33from pythia.types import SourceUri
34from pythia.types import SupportedCb
35from pythia.utils.gst import Gst
36from pythia.utils.gst import gst_init
37from pythia.utils.gst import gst_iter
38from pythia.utils.maskrcnn import extract_maskrcnn_mask
39from pythia.utils.message_handlers import on_message_error
42class _DumpLogger(logging.Logger): # noqa: C0115
43 def json(self, msg, *args, **kwargs): # noqa: C0116
44 logging.Logger._log( # noqa: W0212
45 self, logging.INFO, json.dumps(msg), args, **kwargs
46 )
49def _make_handler(
50 dst: Path | Literal["stdout", "stderr"]
51) -> logging.StreamHandler:
52 if not isinstance(dst, Path): 52 ↛ 53line 52 didn't jump to line 53, because the condition on line 52 was never true
53 return logging.StreamHandler(getattr(sys, dst))
54 if dst.is_dir(): 54 ↛ 58line 54 didn't jump to line 58, because the condition on line 54 was never false
55 logfile = dst / "detections.jsonl"
56 logfile.unlink(missing_ok=True)
57 else:
58 logfile = dst
59 return logging.FileHandler(str(logfile))
62def _make_logger(
63 name: str, dst: Path | Literal["stdout", "stderr"]
64) -> _DumpLogger:
65 logging.setLoggerClass(_DumpLogger)
66 try:
67 logger = logging.getLogger(name)
68 logger.setLevel(logging.DEBUG)
70 handler = _make_handler(dst)
71 handler.setLevel(logging.DEBUG)
72 formatter = logging.Formatter("%(message)s")
73 handler.setFormatter(formatter)
74 logger.addHandler(handler)
75 finally:
76 logging.setLoggerClass(logging.Logger)
77 return logger # type: ignore[return-value]
80class AnnotateFramesBase(Application, abc.ABC):
81 """Base class for creating dataset / annotations."""
83 nvds_frame_meta_parser: Optional[Callable[[pyds.NvDsFrameMeta], Any]]
85 on_message_error = on_message_error
87 @abc.abstractmethod
88 def annotator_probe(
89 self,
90 pad: Gst.Pad,
91 info: Gst.PadProbeInfo,
92 batch_meta: pyds.NvDsBatchMeta,
93 ) -> Gst.PadProbeReturn:
94 """Implement this to process incoming batch metadata.
96 Args:
97 pad: gstreamer pad where the probe was attached.
98 info: gstreamer pad probe info.
99 batch_meta: deepstream metadata (batched!).
101 """
103 def __init__(self, pipeline, dst_folder: Path, *args, **kwargs) -> None:
104 """Construct a Frame annotator.
106 Args:
107 pipeline: forwarded to pythia application constructor.
108 dst_folder: location for the annotations.
109 args: forwarded to pythia application constructor.
110 kwargs: forwarded to pythia application constructor.
112 """
113 super().__init__(pipeline, *args, **kwargs)
114 self._dst_folder = dst_folder
115 self.logger = _make_logger(type(self).__name__, dst_folder)
117 @staticmethod
118 def _extract_common(
119 pad, frame, detection, *, extract_analytics: bool = False
120 ):
121 frame_num = frame.frame_num
122 box = detection.rect_params
123 base = {
124 "frame_num": frame_num,
125 "id": detection.object_id,
126 "engine_id": detection.unique_component_id,
127 "engine": pad.parent.name,
128 "pad_index": frame.pad_index,
129 "label": detection.obj_label,
130 "left": box.left,
131 "top": box.top,
132 "width": box.width,
133 "height": box.height,
134 "confidence": detection.confidence,
135 }
136 if not extract_analytics:
137 return base
138 try:
139 analytics = next(iter(analytics_per_obj(detection)))
140 except StopIteration:
141 return base
142 base["analytics"] = {
143 attr: getattr(analytics, attr)
144 for attr in (
145 "dirStatus",
146 "lcStatus",
147 "ocStatus",
148 "roiStatus",
149 "unique_id",
150 )
151 }
152 return base
154 @classmethod
155 def run( # noqa: R0913
156 cls,
157 src: SourceUri,
158 model: Union[str, Path, InferenceEngine],
159 dst_folder: str | Path,
160 *args,
161 suffix: str = ".jpg",
162 analytics: Union[Path, Analytics] | None = None,
163 tracker: Union[Path, Tracker] | None = None,
164 **kwargs,
165 ) -> None:
166 """Run an annotation application.
168 Args:
169 src: Source uri used for frames/video input.
170 model: Deepstream inference model.
171 dst_folder: Path to store annotations and frames.
172 args: Forwarded to class constructor.
173 suffix: output frames suffix.
174 analytics: optional analytics for the pipline.
175 tracker: optional tracker for the pipline.
176 kwargs: Forwarded to class constructor.
178 Raises:
179 FileExistsError: Non-empty dst_folder
181 """
182 gst_init()
183 dst_folder = Path(dst_folder)
184 if not dst_folder.exists(): 184 ↛ 185line 184 didn't jump to line 185, because the condition on line 184 was never true
185 dst_folder.mkdir(parents=True, exist_ok=False)
186 frames = dst_folder / f"frames/%012d{suffix}"
187 frames_folder = frames.parent
189 frames_folder.mkdir(parents=True, exist_ok=True)
190 try:
191 next(iter(frames_folder.glob("**/*")))
192 except StopIteration:
193 pass
194 else:
195 raise FileExistsError(frames_folder)
197 if isinstance(model, str): 197 ↛ 198line 197 didn't jump to line 198, because the condition on line 197 was never true
198 model = Path(model)
199 pipeline = Pipeline(
200 sources=[src],
201 models=[model],
202 sink=str(frames),
203 analytics=analytics,
204 tracker=tracker,
205 )
206 app = cls(pipeline, dst_folder, *args, **kwargs)
208 target = cls.default_probe_target(pipeline)
209 app.probe(target, "src")(app.annotator_probe)
211 app()
213 @staticmethod
214 def default_probe_target(pipeline: Pipeline) -> str:
215 """Retreive an element name to attach the probe to.
217 Args:
218 pipeline: Pythia Pipeline containing the elements.
220 Returns:
221 The name of the most downstream element contained in the
222 pipeline. The precedence order is: analytics, tracker,
223 nvinfer.
225 Raises:
226 LookupError: none of the required deepstream elements was
227 found.
229 """
230 for element in gst_iter(pipeline.pipeline.iterate_sorted()): 230 ↛ 238line 230 didn't jump to line 238, because the loop on line 230 didn't complete
231 factory_name = element.get_factory().name
232 if factory_name == "nvdsanalytics":
233 return element.name
234 if factory_name == "nvtracker":
235 return element.name
236 if factory_name == "nvinfer":
237 return element.name
238 raise LookupError("Unable to find analyrtics, tracker, nvinfer")
240 @classmethod
241 def run_with_probe(
242 cls,
243 *args,
244 probe: SupportedCb | BoundSupportedCb | None = None,
245 **kwargs,
246 ) -> None:
247 """Run an annotator with a custom probe.
249 Args:
250 args: forwarded to :meth:`run`.
251 probe: annotator_probe to use.
252 kwargs: forwarded to :meth:`run`.
254 """
255 import inspect
257 signature = inspect.getfullargspec(probe)
258 if signature.args[0] != "self": 258 ↛ 260line 258 didn't jump to line 260, because the condition on line 258 was never false
259 probe = staticmethod(probe) # type: ignore[arg-type, assignment]
260 klass = type("Annotator", (cls,), {"annotator_probe": probe})
261 klass.run(*args, **kwargs) # type: ignore[attr-defined]
264class AnnotateFramesBbox(AnnotateFramesBase):
265 """Annotate frames with boundingboxes."""
267 def annotator_probe(
268 self,
269 pad: Gst.Pad,
270 info: Gst.PadProbeInfo,
271 batch_meta: pyds.NvDsBatchMeta,
272 ) -> Gst.PadProbeReturn:
273 for frame, detection in objects_per_batch(batch_meta):
274 data = self._extract_common(pad, frame, detection)
275 self.logger.json(data)
276 return Gst.PadProbeReturn.OK
279class AnnotateFramesMaskRcnn(AnnotateFramesBase):
280 """Annotate frames with maskrcnn."""
282 def __init__(
283 self,
284 pipeline,
285 dst_folder: Path,
286 *args,
287 contour_kw: Optional[dict] = None,
288 **kwargs,
289 ) -> None:
290 """Run an maskrcnn annotation application.
292 Args:
293 pipeline: forwarded to the annotator constructor.
294 dst_folder: forwarded to the annotator constructor.
295 args: forwarded to the annotator constructor.
296 contour_kw: arbitrary dict containing kwargs for 'cv2.findContours'
297 kwargs: forwarded to the annotator constructor.
299 Raises:
300 ImportError: opencv missing
302 See Also:
303 https://docs.opencv.org/4.x/d3/dc0/group__imgproc__shape.html#gadf1ad6a0b82947fa1fe3c3d497f260e0
305 """
306 if cv2 is None: 306 ↛ 307line 306 didn't jump to line 307, because the condition on line 306 was never true
307 raise ImportError(
308 "Unable to initialize MaskRcnn annotator."
309 " Reason: opencv-python not installed."
310 " Reinstall with 'opencv' extra,"
311 " eg 'pip install pythia[opencv]'."
312 )
313 self._countour_kw = contour_kw or {}
314 self._countour_kw.setdefault("mode", cv2.RETR_TREE) # noqa: E1101
315 self._countour_kw.setdefault(
316 "method", cv2.CHAIN_APPROX_SIMPLE # noqa: E1101
317 ) # noqa: E1101
318 self.find_contours = partial(
319 cv2.findContours, # noqa: E1101
320 **self._countour_kw,
321 )
322 super().__init__(pipeline, dst_folder, *args, **kwargs)
324 def generate_mask_polygon(self, mask: np.ndarray) -> List[List[int]]:
325 """Convert 2d numpy array mask into coco-"segmentation".
327 Args:
328 mask: used to convert to polygon as a matrix.
330 Returns:
331 mask in polygon form.
333 See Also:
334 https://learnopencv.com/deep-learning-based-object-detection-and-instance-segmentation-using-mask-rcnn-in-opencv-python-c/
336 """
338 contours, _ = self.find_contours(mask)
339 return [[int(i) for i in c.flatten()] for c in contours]
341 def annotator_probe(
342 self,
343 pad: Gst.Pad,
344 info: Gst.PadProbeInfo,
345 batch_meta: pyds.NvDsBatchMeta,
346 ) -> Gst.PadProbeReturn:
347 for frame, detection in objects_per_batch(batch_meta):
348 bbox_data = self._extract_common(
349 pad,
350 frame,
351 detection,
352 extract_analytics=self.pipeline.analytics is not None,
353 )
354 mask_mtx = extract_maskrcnn_mask(detection)
355 mask_poly = self.generate_mask_polygon(mask_mtx)
356 self.logger.json(
357 {
358 "mask": mask_poly,
359 **bbox_data,
360 }
361 )
362 return Gst.PadProbeReturn.OK