Coverage for src/pythia/applications/annotation.py: 73%

144 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-10-07 19:27 +0000

1"""Annotation applications.""" 

2from __future__ import annotations 

3 

4import abc 

5import json 

6import logging 

7import sys 

8from functools import partial 

9from pathlib import Path 

10from typing import Any 

11from typing import Callable 

12from typing import List 

13from typing import Literal 

14from typing import Optional 

15from typing import Union 

16 

17import pyds 

18 

19try: 

20 import cv2 

21except ImportError: 

22 cv2 = None # type: ignore[assignment] 

23import numpy as np 

24 

25from pythia.applications.base import Application 

26from pythia.applications.base import BoundSupportedCb 

27from pythia.iterators import analytics_per_obj 

28from pythia.iterators import objects_per_batch 

29from pythia.models.base import Analytics 

30from pythia.models.base import InferenceEngine 

31from pythia.models.base import Tracker 

32from pythia.pipelines.base import Pipeline 

33from pythia.types import SourceUri 

34from pythia.types import SupportedCb 

35from pythia.utils.gst import Gst 

36from pythia.utils.gst import gst_init 

37from pythia.utils.gst import gst_iter 

38from pythia.utils.maskrcnn import extract_maskrcnn_mask 

39from pythia.utils.message_handlers import on_message_error 

40 

41 

42class _DumpLogger(logging.Logger): # noqa: C0115 

43 def json(self, msg, *args, **kwargs): # noqa: C0116 

44 logging.Logger._log( # noqa: W0212 

45 self, logging.INFO, json.dumps(msg), args, **kwargs 

46 ) 

47 

48 

49def _make_handler( 

50 dst: Path | Literal["stdout", "stderr"] 

51) -> logging.StreamHandler: 

52 if not isinstance(dst, Path): 52 ↛ 53line 52 didn't jump to line 53, because the condition on line 52 was never true

53 return logging.StreamHandler(getattr(sys, dst)) 

54 if dst.is_dir(): 54 ↛ 58line 54 didn't jump to line 58, because the condition on line 54 was never false

55 logfile = dst / "detections.jsonl" 

56 logfile.unlink(missing_ok=True) 

57 else: 

58 logfile = dst 

59 return logging.FileHandler(str(logfile)) 

60 

61 

62def _make_logger( 

63 name: str, dst: Path | Literal["stdout", "stderr"] 

64) -> _DumpLogger: 

65 logging.setLoggerClass(_DumpLogger) 

66 try: 

67 logger = logging.getLogger(name) 

68 logger.setLevel(logging.DEBUG) 

69 

70 handler = _make_handler(dst) 

71 handler.setLevel(logging.DEBUG) 

72 formatter = logging.Formatter("%(message)s") 

73 handler.setFormatter(formatter) 

74 logger.addHandler(handler) 

75 finally: 

76 logging.setLoggerClass(logging.Logger) 

77 return logger # type: ignore[return-value] 

78 

79 

80class AnnotateFramesBase(Application, abc.ABC): 

81 """Base class for creating dataset / annotations.""" 

82 

83 nvds_frame_meta_parser: Optional[Callable[[pyds.NvDsFrameMeta], Any]] 

84 

85 on_message_error = on_message_error 

86 

87 @abc.abstractmethod 

88 def annotator_probe( 

89 self, 

90 pad: Gst.Pad, 

91 info: Gst.PadProbeInfo, 

92 batch_meta: pyds.NvDsBatchMeta, 

93 ) -> Gst.PadProbeReturn: 

94 """Implement this to process incoming batch metadata. 

95 

96 Args: 

97 pad: gstreamer pad where the probe was attached. 

98 info: gstreamer pad probe info. 

99 batch_meta: deepstream metadata (batched!). 

100 

101 """ 

102 

103 def __init__(self, pipeline, dst_folder: Path, *args, **kwargs) -> None: 

104 """Construct a Frame annotator. 

105 

106 Args: 

107 pipeline: forwarded to pythia application constructor. 

108 dst_folder: location for the annotations. 

109 args: forwarded to pythia application constructor. 

110 kwargs: forwarded to pythia application constructor. 

111 

112 """ 

113 super().__init__(pipeline, *args, **kwargs) 

114 self._dst_folder = dst_folder 

115 self.logger = _make_logger(type(self).__name__, dst_folder) 

116 

117 @staticmethod 

118 def _extract_common( 

119 pad, frame, detection, *, extract_analytics: bool = False 

120 ): 

121 frame_num = frame.frame_num 

122 box = detection.rect_params 

123 base = { 

124 "frame_num": frame_num, 

125 "id": detection.object_id, 

126 "engine_id": detection.unique_component_id, 

127 "engine": pad.parent.name, 

128 "pad_index": frame.pad_index, 

129 "label": detection.obj_label, 

130 "left": box.left, 

131 "top": box.top, 

132 "width": box.width, 

133 "height": box.height, 

134 "confidence": detection.confidence, 

135 } 

136 if not extract_analytics: 

137 return base 

138 try: 

139 analytics = next(iter(analytics_per_obj(detection))) 

140 except StopIteration: 

141 return base 

142 base["analytics"] = { 

143 attr: getattr(analytics, attr) 

144 for attr in ( 

145 "dirStatus", 

146 "lcStatus", 

147 "ocStatus", 

148 "roiStatus", 

149 "unique_id", 

150 ) 

151 } 

152 return base 

153 

154 @classmethod 

155 def run( # noqa: R0913 

156 cls, 

157 src: SourceUri, 

158 model: Union[str, Path, InferenceEngine], 

159 dst_folder: str | Path, 

160 *args, 

161 suffix: str = ".jpg", 

162 analytics: Union[Path, Analytics] | None = None, 

163 tracker: Union[Path, Tracker] | None = None, 

164 **kwargs, 

165 ) -> None: 

166 """Run an annotation application. 

167 

168 Args: 

169 src: Source uri used for frames/video input. 

170 model: Deepstream inference model. 

171 dst_folder: Path to store annotations and frames. 

172 args: Forwarded to class constructor. 

173 suffix: output frames suffix. 

174 analytics: optional analytics for the pipline. 

175 tracker: optional tracker for the pipline. 

176 kwargs: Forwarded to class constructor. 

177 

178 Raises: 

179 FileExistsError: Non-empty dst_folder 

180 

181 """ 

182 gst_init() 

183 dst_folder = Path(dst_folder) 

184 if not dst_folder.exists(): 184 ↛ 185line 184 didn't jump to line 185, because the condition on line 184 was never true

185 dst_folder.mkdir(parents=True, exist_ok=False) 

186 frames = dst_folder / f"frames/%012d{suffix}" 

187 frames_folder = frames.parent 

188 

189 frames_folder.mkdir(parents=True, exist_ok=True) 

190 try: 

191 next(iter(frames_folder.glob("**/*"))) 

192 except StopIteration: 

193 pass 

194 else: 

195 raise FileExistsError(frames_folder) 

196 

197 if isinstance(model, str): 197 ↛ 198line 197 didn't jump to line 198, because the condition on line 197 was never true

198 model = Path(model) 

199 pipeline = Pipeline( 

200 sources=[src], 

201 models=[model], 

202 sink=str(frames), 

203 analytics=analytics, 

204 tracker=tracker, 

205 ) 

206 app = cls(pipeline, dst_folder, *args, **kwargs) 

207 

208 target = cls.default_probe_target(pipeline) 

209 app.probe(target, "src")(app.annotator_probe) 

210 

211 app() 

212 

213 @staticmethod 

214 def default_probe_target(pipeline: Pipeline) -> str: 

215 """Retreive an element name to attach the probe to. 

216 

217 Args: 

218 pipeline: Pythia Pipeline containing the elements. 

219 

220 Returns: 

221 The name of the most downstream element contained in the 

222 pipeline. The precedence order is: analytics, tracker, 

223 nvinfer. 

224 

225 Raises: 

226 LookupError: none of the required deepstream elements was 

227 found. 

228 

229 """ 

230 for element in gst_iter(pipeline.pipeline.iterate_sorted()): 230 ↛ 238line 230 didn't jump to line 238, because the loop on line 230 didn't complete

231 factory_name = element.get_factory().name 

232 if factory_name == "nvdsanalytics": 

233 return element.name 

234 if factory_name == "nvtracker": 

235 return element.name 

236 if factory_name == "nvinfer": 

237 return element.name 

238 raise LookupError("Unable to find analyrtics, tracker, nvinfer") 

239 

240 @classmethod 

241 def run_with_probe( 

242 cls, 

243 *args, 

244 probe: SupportedCb | BoundSupportedCb | None = None, 

245 **kwargs, 

246 ) -> None: 

247 """Run an annotator with a custom probe. 

248 

249 Args: 

250 args: forwarded to :meth:`run`. 

251 probe: annotator_probe to use. 

252 kwargs: forwarded to :meth:`run`. 

253 

254 """ 

255 import inspect 

256 

257 signature = inspect.getfullargspec(probe) 

258 if signature.args[0] != "self": 258 ↛ 260line 258 didn't jump to line 260, because the condition on line 258 was never false

259 probe = staticmethod(probe) # type: ignore[arg-type, assignment] 

260 klass = type("Annotator", (cls,), {"annotator_probe": probe}) 

261 klass.run(*args, **kwargs) # type: ignore[attr-defined] 

262 

263 

264class AnnotateFramesBbox(AnnotateFramesBase): 

265 """Annotate frames with boundingboxes.""" 

266 

267 def annotator_probe( 

268 self, 

269 pad: Gst.Pad, 

270 info: Gst.PadProbeInfo, 

271 batch_meta: pyds.NvDsBatchMeta, 

272 ) -> Gst.PadProbeReturn: 

273 for frame, detection in objects_per_batch(batch_meta): 

274 data = self._extract_common(pad, frame, detection) 

275 self.logger.json(data) 

276 return Gst.PadProbeReturn.OK 

277 

278 

279class AnnotateFramesMaskRcnn(AnnotateFramesBase): 

280 """Annotate frames with maskrcnn.""" 

281 

282 def __init__( 

283 self, 

284 pipeline, 

285 dst_folder: Path, 

286 *args, 

287 contour_kw: Optional[dict] = None, 

288 **kwargs, 

289 ) -> None: 

290 """Run an maskrcnn annotation application. 

291 

292 Args: 

293 pipeline: forwarded to the annotator constructor. 

294 dst_folder: forwarded to the annotator constructor. 

295 args: forwarded to the annotator constructor. 

296 contour_kw: arbitrary dict containing kwargs for 'cv2.findContours' 

297 kwargs: forwarded to the annotator constructor. 

298 

299 Raises: 

300 ImportError: opencv missing 

301 

302 See Also: 

303 https://docs.opencv.org/4.x/d3/dc0/group__imgproc__shape.html#gadf1ad6a0b82947fa1fe3c3d497f260e0 

304 

305 """ 

306 if cv2 is None: 306 ↛ 307line 306 didn't jump to line 307, because the condition on line 306 was never true

307 raise ImportError( 

308 "Unable to initialize MaskRcnn annotator." 

309 " Reason: opencv-python not installed." 

310 " Reinstall with 'opencv' extra," 

311 " eg 'pip install pythia[opencv]'." 

312 ) 

313 self._countour_kw = contour_kw or {} 

314 self._countour_kw.setdefault("mode", cv2.RETR_TREE) # noqa: E1101 

315 self._countour_kw.setdefault( 

316 "method", cv2.CHAIN_APPROX_SIMPLE # noqa: E1101 

317 ) # noqa: E1101 

318 self.find_contours = partial( 

319 cv2.findContours, # noqa: E1101 

320 **self._countour_kw, 

321 ) 

322 super().__init__(pipeline, dst_folder, *args, **kwargs) 

323 

324 def generate_mask_polygon(self, mask: np.ndarray) -> List[List[int]]: 

325 """Convert 2d numpy array mask into coco-"segmentation". 

326 

327 Args: 

328 mask: used to convert to polygon as a matrix. 

329 

330 Returns: 

331 mask in polygon form. 

332 

333 See Also: 

334 https://learnopencv.com/deep-learning-based-object-detection-and-instance-segmentation-using-mask-rcnn-in-opencv-python-c/ 

335 

336 """ 

337 

338 contours, _ = self.find_contours(mask) 

339 return [[int(i) for i in c.flatten()] for c in contours] 

340 

341 def annotator_probe( 

342 self, 

343 pad: Gst.Pad, 

344 info: Gst.PadProbeInfo, 

345 batch_meta: pyds.NvDsBatchMeta, 

346 ) -> Gst.PadProbeReturn: 

347 for frame, detection in objects_per_batch(batch_meta): 

348 bbox_data = self._extract_common( 

349 pad, 

350 frame, 

351 detection, 

352 extract_analytics=self.pipeline.analytics is not None, 

353 ) 

354 mask_mtx = extract_maskrcnn_mask(detection) 

355 mask_poly = self.generate_mask_polygon(mask_mtx) 

356 self.logger.json( 

357 { 

358 "mask": mask_poly, 

359 **bbox_data, 

360 } 

361 ) 

362 return Gst.PadProbeReturn.OK