Coverage for src/pythia/models/base.py: 81%
201 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-10-07 19:27 +0000
« prev ^ index » next coverage.py v6.4.4, created at 2022-10-07 19:27 +0000
1"""Model.
3A deep learning model used to extract metadata from a video source.
4Contains:
6 A labels.txt file, containing the list of model labels.
7 A pgie.conf file.
8 A model.etlt,model.engine or any Deepstream-nvinfer compatible engine file.
10"""
11from __future__ import annotations
13import configparser
14import json
15from dataclasses import dataclass
16from dataclasses import field
17from pathlib import Path
18from textwrap import dedent as _
19from typing import Collection
20from typing import Dict
21from typing import Optional
22from typing import Type
23from typing import TypeVar
25from pythia.types import Con
26from pythia.types import HasConnections
27from pythia.utils.ext import not_empty
28from pythia.utils.ext import not_none
29from pythia.utils.gst import Gst
31IE = TypeVar("IE", bound="InferenceEngine")
32T = TypeVar("T", bound="Tracker")
33A = TypeVar("A", bound="Analytics")
36@dataclass
37class InferenceEngine(HasConnections):
38 """Pythia wrapper around nvinfer gst element."""
40 MODEL_SUFFIXES = {
41 ".etlt": "tlt-encoded-model",
42 ".caffe": "model-file",
43 ".caffemodel": "model-file",
44 ".prototxt": "proto-file",
45 ".onnx": "onnx-file",
46 ".uff": "uff-file",
47 # ".engine": "model-engine-file",
48 }
49 """Supported model extensions (prioritized order).
51 Used when an inference engine is to be instantiated by a directory,
52 to locate supported models from their extension.
54 See Also: :meth:`locate_source_model`.
56 """
58 MODEL_CONFIG_SUFFIXES = (
59 ".conf",
60 ".ini",
61 ".yml",
62 ".yaml",
63 )
64 """Ordered collection of supported model config file extensions.
66 Used when an inference engine is to be instantiated by a directory,
67 to locate `config-file-path` from their extension.
69 See Also: :meth:`locate_config_file`.
71 """
73 labels_file: Path
74 config_file: Path
75 _string: Optional[str] = None
76 source_model: Optional[Path] = None
77 compiled_model: Optional[Path] = None
78 _default_props: Dict[str, str] = field(default_factory=dict)
80 CONNECTIONS: Con = field(default_factory=dict) # noqa: C0103
82 def gst(self, name: str, **kw) -> str:
83 """Render nvinfer with `gst-launch`-like syntax.
85 Args:
86 name: nvinfer gstreamer element name property.
87 kw: nvinfer gstreamer property name and value.
89 Returns:
90 Rendered string
92 See Also:
93 https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_plugin_gst-nvinfer.html#gst-properties
95 """
96 props = "\n".join(f"{k.replace('_', '-')}={v}" for k, v in kw.items())
97 self._string = _(
98 f"""\
99 nvinfer
100 config-file-path={self.config_file}
101 name={name}
102 {props}
103 """
104 )
105 return self._string
107 @classmethod
108 def locate_source_model(cls, folder: Path) -> Path | None:
109 """Find the first deepstream model file in a folder.
111 It iterates over the known nvinfer-compatible model file
112 extensions, and returns at the first success.
114 Args:
115 folder: Directory to search the model.
117 Returns:
118 Found model, or `None` if not found.
120 """
121 for suffix in cls.MODEL_SUFFIXES: 121 ↛ 126line 121 didn't jump to line 126, because the loop on line 121 didn't complete
122 try:
123 return not_empty(next(iter(folder.glob(f"*{suffix}"))))
124 except StopIteration:
125 pass
126 return None
128 @staticmethod
129 def locate_labels_file(folder: Path) -> Path:
130 """Find labels file from a directory.
132 Args:
133 folder: directory to search labels file.
135 Returns:
136 The first file matching the `*label*` pattern inside the
137 directory.
139 Raises:
140 FileNotFoundError: no file matches the expected labels
141 pattern.
143 """
144 try:
145 return not_empty(next(iter(folder.glob("*label*"))))
146 except StopIteration as exc:
147 raise FileNotFoundError(
148 f"No labels file found at {folder}"
149 ) from exc
151 @classmethod
152 def locate_config_file(cls, folder: Path) -> Path:
153 """Find the first model config file in a folder.
155 Iterate over the known nvinfer-compatible config-file-path file
156 extensions, and returns at the first success.
158 Args:
159 folder: Directory to search the model.
161 Returns:
162 path to the found configuration file.
164 Raises:
165 FileNotFoundError: No configuration file found.
167 """
168 for suffix in cls.MODEL_CONFIG_SUFFIXES: 168 ↛ 173line 168 didn't jump to line 173, because the loop on line 168 didn't complete
169 try:
170 return not_empty(next(iter(folder.glob(f"*{suffix}"))))
171 except StopIteration:
172 pass
173 raise FileNotFoundError(f"No config file found at {folder}")
175 @staticmethod
176 def locate_compiled_model(
177 folder: Path, source_model: Path | None
178 ) -> Path | None:
179 """Find the first model engine file in a folder.
181 Returns a file matching the `*.engine` pattern
183 Args:
184 folder: Directory to search the model.
185 source_model: If set, use this path's stem to try to locate
186 the `.engine` file. Otherwise, finds any `*.engine`.
188 Returns:
189 path to the found configuration file.
191 Raises:
192 FileNotFoundError: No configuration file found using any of
193 the strategies.
195 """
196 if source_model: 196 ↛ 209line 196 didn't jump to line 209, because the condition on line 196 was never false
197 try:
198 return not_empty(
199 next(
200 iter(
201 source_model.parent.glob(
202 f"*{source_model.stem}*.engine"
203 )
204 )
205 )
206 )
207 except StopIteration:
208 pass
209 try:
210 return next(iter(folder.glob("*.engine")))
211 except StopIteration as exc:
212 if not source_model: 212 ↛ 213line 212 didn't jump to line 213, because the condition on line 212 was never true
213 raise FileNotFoundError(
214 f"Neither {source_model=} nor its compiled version exist."
215 ) from exc
216 return None
218 @classmethod
219 def from_folder(cls: Type[IE], folder: str | Path) -> IE:
220 """Factory to instantiate from directories.
222 Args:
223 folder: Directory where the model files are located.
225 Returns:
226 Instantiated model.
228 Raises:
229 FileNotFoundError: empty folder received.
231 """
232 folder = Path(folder).resolve()
233 if not folder.exists(): 233 ↛ 234line 233 didn't jump to line 234, because the condition on line 233 was never true
234 raise FileNotFoundError(f"No directory not found at {folder}.")
236 source_model = cls.locate_source_model(folder)
237 labels_file = cls.locate_labels_file(folder)
238 config_file = cls.locate_config_file(folder)
239 compiled_model = cls.locate_compiled_model(folder, source_model)
241 return cls(
242 labels_file=labels_file,
243 config_file=config_file,
244 source_model=source_model,
245 compiled_model=compiled_model,
246 )
248 @classmethod
249 def from_element(cls: Type[IE], element: Gst.Element) -> IE:
250 """Factory from nvinfer.
252 Args:
253 element: The nvinfer to use as source.
255 Returns:
256 The instantiated nvinfer wrapper.
258 """
259 skip = ("parent",)
260 props = {}
261 for prop in element.list_properties():
262 name = prop.name
263 if name in skip:
264 continue
266 raw = element.get_property(name)
267 try:
268 value = raw.value_nick # enums
269 except AttributeError:
270 value = str(raw)
271 if isinstance(raw, bool):
272 value = value.lower() # False -> false
273 props[name] = value
275 config_file = Path(props.pop("config-file-path")).resolve()
276 return cls(
277 config_file=config_file,
278 labels_file=not_none(
279 cls.pop_property_or_get_from_nvinfer_conf( # noqa: C0301
280 config_file,
281 props,
282 property_names=("labelfile-path",),
283 )
284 ),
285 source_model=cls.pop_property_or_get_from_nvinfer_conf(
286 config_file,
287 props,
288 property_names=cls.MODEL_SUFFIXES.values(),
289 ),
290 compiled_model=cls.pop_property_or_get_from_nvinfer_conf(
291 config_file,
292 props,
293 property_names=("model-engine-file",),
294 ),
295 _default_props=props,
296 )
298 @staticmethod
299 def pop_property_or_get_from_nvinfer_conf(
300 config_file: Path,
301 props: dict[str, str],
302 *,
303 property_names: Collection[str],
304 ) -> Path | None:
305 """Pop nvinfer property, or get from config_file if not found.
307 Args:
308 config_file: `nvinfer.conf` ini file. Used to compute
309 absolute paths, and default source for property values
310 if not found in the props arg.
311 props: element properties where to look for the desired
312 properties. Note: If the property is found, its popped
313 from this dict.
314 property_names: possible property names to look for.
316 Returns:
317 First occurence of the property_names, as found either in
318 the props dict or in the nvinfer.conf `[property]`
319 section.
321 Raises:
322 FileNotFoundError: None of the requested names is available
323 in the properties, and the config file does not exist.
325 """
326 # extract from nvinfer's properties
327 for prop_name in property_names:
328 value = props.get(prop_name, None)
329 if value is None:
330 continue
331 value_path = Path(value)
332 if not value_path.is_absolute(): 332 ↛ 333line 332 didn't jump to line 333, because the condition on line 332 was never true
333 value_path = value_path.relative_to(config_file).resolve()
334 return value_path
336 # extract from nvinfer's config file
337 if not config_file.exists(): 337 ↛ 338line 337 didn't jump to line 338, because the condition on line 337 was never true
338 raise FileNotFoundError(config_file)
339 config = configparser.ConfigParser()
340 config.read(str(config_file))
341 for prop_name in property_names: 341 ↛ 350line 341 didn't jump to line 350, because the loop on line 341 didn't complete
342 value = config["property"].get(prop_name, None)
343 if value is None:
344 continue
345 value_path = Path(value)
346 if not value_path.is_absolute(): 346 ↛ 348line 346 didn't jump to line 348, because the condition on line 346 was never false
347 value_path = (config_file.parent / value_path).resolve()
348 return value_path
350 return None
353@dataclass
354class Tracker(HasConnections):
355 """Pythia wrapper around nvtracker gst element."""
357 config_file: Path
358 low_level_library: Path = Path(
359 "/opt/nvidia/deepstream/deepstream/lib/libnvds_nvmultiobjecttracker.so"
360 ).resolve()
361 _string: Optional[str] = None
362 _default_props: Dict[str, str] = field(default_factory=dict)
363 CONNECTIONS: Con = field(default_factory=dict) # noqa: C0103
365 @classmethod
366 def from_file(
367 cls: Type[T],
368 config_file: Path,
369 low_level_library: Path = low_level_library,
370 ) -> T:
371 """Factory to create `Tracker` s from configuration file.
373 Args:
374 config_file: path for the `nvtracker` gstreamer element
375 'll-config-file' property.
376 low_level_library: path for the `nvtracker` gstreamer element
377 'll-lib-file' property (shared object).
379 Returns:
380 Instantiated `Tracker`.
382 Raises:
383 FileNotFoundError: Tracker config file does not exist.
385 """
386 config_file = Path(config_file).resolve()
387 if not config_file.exists(): 387 ↛ 388line 387 didn't jump to line 388, because the condition on line 387 was never true
388 raise FileNotFoundError(
389 f"No Tracker configuration file found at {config_file}."
390 )
391 return cls(
392 config_file=config_file,
393 low_level_library=low_level_library,
394 )
396 @classmethod
397 def from_element(cls: Type[T], element: Gst.Element) -> T:
398 """Factory from nvtracker.
400 Args:
401 element: The nvtracker to use as source.
403 Returns:
404 The instantiated nvtracker wrapper.
406 """
407 skip = ("parent",)
408 props = {}
409 for prop in element.list_properties():
410 name = prop.name
411 if name in skip:
412 continue
414 raw = element.get_property(name)
415 try:
416 value = raw.value_nick # enums
417 except AttributeError:
418 value = str(raw)
419 if isinstance(raw, bool):
420 value = value.lower() # False -> false
421 props[name] = value
423 return cls(
424 config_file=props.pop("ll-config-file"),
425 low_level_library=props.pop("ll-lib-file"),
426 _default_props=props,
427 )
429 def gst(self, **kwargs: str) -> str:
430 """Render nvtracker element with `gst-launch`-like syntax.
432 Args:
433 kwargs: Additional gst element properties.
435 Returns:
436 Rendered string.
438 Raises:
439 FileNotFoundError: Tracker `ll-config-file` not found.
441 See Also:
442 https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_plugin_gst-nvtracker.html#gst-properties
444 """
445 inline_props = json.loads(json.dumps(self._default_props))
446 inline_props.update(kwargs)
447 props = "\n".join(
448 f"{k.replace('_', '-')}={v}" for k, v in inline_props.items()
449 )
451 if not self.low_level_library.exists(): 451 ↛ 452line 451 didn't jump to line 452, because the condition on line 451 was never true
452 raise FileNotFoundError(
453 "Could not find Tracker implementation"
454 f" at {self.low_level_library}"
455 )
456 self._string = _(
457 f"""\
458 nvtracker
459 ll-config-file={self.config_file}
460 ll-lib-file={self.low_level_library}
461 {props}
462 """
463 )
464 return self._string
467@dataclass
468class Analytics(HasConnections):
469 """Pythia wrapper around nvdsanalytics gst element."""
471 config_file: Path
472 _string: Optional[str] = None
473 _default_props: Dict[str, str] = field(default_factory=dict)
474 CONNECTIONS: Con = field(default_factory=dict) # noqa: C0103
476 def gst(self, **kwargs: str) -> str:
477 """Render string as `gst-launch`-like parseable string.
479 Args:
480 kwargs: Additional gst element properties.
482 Returns:
483 Rendered `nvdsanalytics`.
485 See Also:
486 https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_plugin_gst-nvdsanalytics.html#gst-properties
488 """
489 inline_props = json.loads(json.dumps(self._default_props))
490 inline_props.update(kwargs)
491 props = "\n".join(
492 f"{k.replace('_', '-')}={v}" for k, v in inline_props.items()
493 )
494 self._string = _(
495 f"""\
496 nvdsanalytics
497 config-file={self.config_file}
498 {props}
499 """
500 )
502 return self._string
504 def requires_tracker(self) -> bool:
505 """Return `True` if its `nvdsanalytics` requires `nvtracker`.
507 Returns:
508 `True` if its `nvdsanalytics` contains line crossing or
509 direction andata.
511 """
512 config = configparser.ConfigParser()
513 config.read(str(self.config_file))
514 for section_name in config.sections(): 514 ↛ 523line 514 didn't jump to line 523, because the loop on line 514 didn't complete
515 if any(
516 section_name.startswith(pattern)
517 for pattern in (
518 "line-crossing",
519 "direction-detection",
520 )
521 ):
522 return True
523 return False
525 @classmethod
526 def from_file(cls: Type[A], config_file: Path) -> A:
527 """Factory from configuration file.
529 Args:
530 config_file: location of the nvdsanalytics `config-file`
531 property.
533 Returns:
534 The instantiated `nvdsanalytics` wrapper class.
536 Raises:
537 FileNotFoundError: The `nvdsanalytics` `config-file`
538 property is not found.
540 """
541 config_file = Path(config_file).resolve()
542 if not config_file.exists(): 542 ↛ 543line 542 didn't jump to line 543, because the condition on line 542 was never true
543 raise FileNotFoundError(
544 f"No Analytics configuration file found at {config_file}."
545 )
546 return cls(config_file=config_file)
548 @classmethod
549 def from_element(cls: Type[A], element: Gst.Element) -> A:
550 """Factory from nvdsanalytics.
552 Args:
553 element: The nvdsanalytics to use as source.
555 Returns:
556 The instantiated nvdsanalytics wrapper.
558 """
559 skip = ("parent",)
560 props = {}
561 for prop in element.list_properties():
562 name = prop.name
563 if name in skip:
564 continue
566 raw = element.get_property(name)
567 try:
568 value = raw.value_nick # enums
569 except AttributeError:
570 value = str(raw)
571 if isinstance(raw, bool):
572 value = value.lower() # False -> false
573 props[name] = value
575 return cls(
576 config_file=props.pop("config-file"),
577 _default_props=props,
578 )