"""Model.
A deep learning model used to extract metadata from a video source.
Contains:
A labels.txt file, containing the list of model labels.
A pgie.conf file.
A model.etlt,model.engine or any Deepstream-nvinfer compatible engine file.
"""
from __future__ import annotations
import configparser
import json
from dataclasses import dataclass
from dataclasses import field
from pathlib import Path
from textwrap import dedent as _
from typing import Collection
from typing import Dict
from typing import Optional
from typing import Type
from typing import TypeVar
from pythia.types import Con
from pythia.types import HasConnections
from pythia.utils.ext import not_empty
from pythia.utils.ext import not_none
from pythia.utils.gst import Gst
IE = TypeVar("IE", bound="InferenceEngine")
T = TypeVar("T", bound="Tracker")
A = TypeVar("A", bound="Analytics")
[docs]@dataclass
class InferenceEngine(HasConnections):
"""Pythia wrapper around nvinfer gst element."""
MODEL_SUFFIXES = {
".etlt": "tlt-encoded-model",
".caffe": "model-file",
".caffemodel": "model-file",
".prototxt": "proto-file",
".onnx": "onnx-file",
".uff": "uff-file",
# ".engine": "model-engine-file",
}
"""Supported model extensions (prioritized order).
Used when an inference engine is to be instantiated by a directory,
to locate supported models from their extension.
See Also: :meth:`locate_source_model`.
"""
MODEL_CONFIG_SUFFIXES = (
".conf",
".ini",
".yml",
".yaml",
)
"""Ordered collection of supported model config file extensions.
Used when an inference engine is to be instantiated by a directory,
to locate `config-file-path` from their extension.
See Also: :meth:`locate_config_file`.
"""
labels_file: Path
config_file: Path
_string: Optional[str] = None
source_model: Optional[Path] = None
compiled_model: Optional[Path] = None
_default_props: Dict[str, str] = field(default_factory=dict)
CONNECTIONS: Con = field(default_factory=dict) # noqa: C0103
[docs] def gst(self, name: str, **kw) -> str:
"""Render nvinfer with `gst-launch`-like syntax.
Args:
name: nvinfer gstreamer element name property.
kw: nvinfer gstreamer property name and value.
Returns:
Rendered string
See Also:
https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_plugin_gst-nvinfer.html#gst-properties
"""
props = "\n".join(f"{k.replace('_', '-')}={v}" for k, v in kw.items())
self._string = _(
f"""\
nvinfer
config-file-path={self.config_file}
name={name}
{props}
"""
)
return self._string
[docs] @classmethod
def locate_source_model(cls, folder: Path) -> Path | None:
"""Find the first deepstream model file in a folder.
It iterates over the known nvinfer-compatible model file
extensions, and returns at the first success.
Args:
folder: Directory to search the model.
Returns:
Found model, or `None` if not found.
"""
for suffix in cls.MODEL_SUFFIXES:
try:
return not_empty(next(iter(folder.glob(f"*{suffix}"))))
except StopIteration:
pass
return None
[docs] @staticmethod
def locate_labels_file(folder: Path) -> Path:
"""Find labels file from a directory.
Args:
folder: directory to search labels file.
Returns:
The first file matching the `*label*` pattern inside the
directory.
Raises:
FileNotFoundError: no file matches the expected labels
pattern.
"""
try:
return not_empty(next(iter(folder.glob("*label*"))))
except StopIteration as exc:
raise FileNotFoundError(
f"No labels file found at {folder}"
) from exc
[docs] @classmethod
def locate_config_file(cls, folder: Path) -> Path:
"""Find the first model config file in a folder.
Iterate over the known nvinfer-compatible config-file-path file
extensions, and returns at the first success.
Args:
folder: Directory to search the model.
Returns:
path to the found configuration file.
Raises:
FileNotFoundError: No configuration file found.
"""
for suffix in cls.MODEL_CONFIG_SUFFIXES:
try:
return not_empty(next(iter(folder.glob(f"*{suffix}"))))
except StopIteration:
pass
raise FileNotFoundError(f"No config file found at {folder}")
[docs] @staticmethod
def locate_compiled_model(
folder: Path, source_model: Path | None
) -> Path | None:
"""Find the first model engine file in a folder.
Returns a file matching the `*.engine` pattern
Args:
folder: Directory to search the model.
source_model: If set, use this path's stem to try to locate
the `.engine` file. Otherwise, finds any `*.engine`.
Returns:
path to the found configuration file.
Raises:
FileNotFoundError: No configuration file found using any of
the strategies.
"""
if source_model:
try:
return not_empty(
next(
iter(
source_model.parent.glob(
f"*{source_model.stem}*.engine"
)
)
)
)
except StopIteration:
pass
try:
return next(iter(folder.glob("*.engine")))
except StopIteration as exc:
if not source_model:
raise FileNotFoundError(
f"Neither {source_model=} nor its compiled version exist."
) from exc
return None
[docs] @classmethod
def from_folder(cls: Type[IE], folder: str | Path) -> IE:
"""Factory to instantiate from directories.
Args:
folder: Directory where the model files are located.
Returns:
Instantiated model.
Raises:
FileNotFoundError: empty folder received.
"""
folder = Path(folder).resolve()
if not folder.exists():
raise FileNotFoundError(f"No directory not found at {folder}.")
source_model = cls.locate_source_model(folder)
labels_file = cls.locate_labels_file(folder)
config_file = cls.locate_config_file(folder)
compiled_model = cls.locate_compiled_model(folder, source_model)
return cls(
labels_file=labels_file,
config_file=config_file,
source_model=source_model,
compiled_model=compiled_model,
)
[docs] @classmethod
def from_element(cls: Type[IE], element: Gst.Element) -> IE:
"""Factory from nvinfer.
Args:
element: The nvinfer to use as source.
Returns:
The instantiated nvinfer wrapper.
"""
skip = ("parent",)
props = {}
for prop in element.list_properties():
name = prop.name
if name in skip:
continue
raw = element.get_property(name)
try:
value = raw.value_nick # enums
except AttributeError:
value = str(raw)
if isinstance(raw, bool):
value = value.lower() # False -> false
props[name] = value
config_file = Path(props.pop("config-file-path")).resolve()
return cls(
config_file=config_file,
labels_file=not_none(
cls.pop_property_or_get_from_nvinfer_conf( # noqa: C0301
config_file,
props,
property_names=("labelfile-path",),
)
),
source_model=cls.pop_property_or_get_from_nvinfer_conf(
config_file,
props,
property_names=cls.MODEL_SUFFIXES.values(),
),
compiled_model=cls.pop_property_or_get_from_nvinfer_conf(
config_file,
props,
property_names=("model-engine-file",),
),
_default_props=props,
)
[docs] @staticmethod
def pop_property_or_get_from_nvinfer_conf(
config_file: Path,
props: dict[str, str],
*,
property_names: Collection[str],
) -> Path | None:
"""Pop nvinfer property, or get from config_file if not found.
Args:
config_file: `nvinfer.conf` ini file. Used to compute
absolute paths, and default source for property values
if not found in the props arg.
props: element properties where to look for the desired
properties. Note: If the property is found, its popped
from this dict.
property_names: possible property names to look for.
Returns:
First occurence of the property_names, as found either in
the props dict or in the nvinfer.conf `[property]`
section.
Raises:
FileNotFoundError: None of the requested names is available
in the properties, and the config file does not exist.
"""
# extract from nvinfer's properties
for prop_name in property_names:
value = props.get(prop_name, None)
if value is None:
continue
value_path = Path(value)
if not value_path.is_absolute():
value_path = value_path.relative_to(config_file).resolve()
return value_path
# extract from nvinfer's config file
if not config_file.exists():
raise FileNotFoundError(config_file)
config = configparser.ConfigParser()
config.read(str(config_file))
for prop_name in property_names:
value = config["property"].get(prop_name, None)
if value is None:
continue
value_path = Path(value)
if not value_path.is_absolute():
value_path = (config_file.parent / value_path).resolve()
return value_path
return None
[docs]@dataclass
class Tracker(HasConnections):
"""Pythia wrapper around nvtracker gst element."""
config_file: Path
low_level_library: Path = Path(
"/opt/nvidia/deepstream/deepstream/lib/libnvds_nvmultiobjecttracker.so"
).resolve()
_string: Optional[str] = None
_default_props: Dict[str, str] = field(default_factory=dict)
CONNECTIONS: Con = field(default_factory=dict) # noqa: C0103
[docs] @classmethod
def from_file(
cls: Type[T],
config_file: Path,
low_level_library: Path = low_level_library,
) -> T:
"""Factory to create `Tracker` s from configuration file.
Args:
config_file: path for the `nvtracker` gstreamer element
'll-config-file' property.
low_level_library: path for the `nvtracker` gstreamer element
'll-lib-file' property (shared object).
Returns:
Instantiated `Tracker`.
Raises:
FileNotFoundError: Tracker config file does not exist.
"""
config_file = Path(config_file).resolve()
if not config_file.exists():
raise FileNotFoundError(
f"No Tracker configuration file found at {config_file}."
)
return cls(
config_file=config_file,
low_level_library=low_level_library,
)
[docs] @classmethod
def from_element(cls: Type[T], element: Gst.Element) -> T:
"""Factory from nvtracker.
Args:
element: The nvtracker to use as source.
Returns:
The instantiated nvtracker wrapper.
"""
skip = ("parent",)
props = {}
for prop in element.list_properties():
name = prop.name
if name in skip:
continue
raw = element.get_property(name)
try:
value = raw.value_nick # enums
except AttributeError:
value = str(raw)
if isinstance(raw, bool):
value = value.lower() # False -> false
props[name] = value
return cls(
config_file=props.pop("ll-config-file"),
low_level_library=props.pop("ll-lib-file"),
_default_props=props,
)
[docs] def gst(self, **kwargs: str) -> str:
"""Render nvtracker element with `gst-launch`-like syntax.
Args:
kwargs: Additional gst element properties.
Returns:
Rendered string.
Raises:
FileNotFoundError: Tracker `ll-config-file` not found.
See Also:
https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_plugin_gst-nvtracker.html#gst-properties
"""
inline_props = json.loads(json.dumps(self._default_props))
inline_props.update(kwargs)
props = "\n".join(
f"{k.replace('_', '-')}={v}" for k, v in inline_props.items()
)
if not self.low_level_library.exists():
raise FileNotFoundError(
"Could not find Tracker implementation"
f" at {self.low_level_library}"
)
self._string = _(
f"""\
nvtracker
ll-config-file={self.config_file}
ll-lib-file={self.low_level_library}
{props}
"""
)
return self._string
[docs]@dataclass
class Analytics(HasConnections):
"""Pythia wrapper around nvdsanalytics gst element."""
config_file: Path
_string: Optional[str] = None
_default_props: Dict[str, str] = field(default_factory=dict)
CONNECTIONS: Con = field(default_factory=dict) # noqa: C0103
[docs] def gst(self, **kwargs: str) -> str:
"""Render string as `gst-launch`-like parseable string.
Args:
kwargs: Additional gst element properties.
Returns:
Rendered `nvdsanalytics`.
See Also:
https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_plugin_gst-nvdsanalytics.html#gst-properties
"""
inline_props = json.loads(json.dumps(self._default_props))
inline_props.update(kwargs)
props = "\n".join(
f"{k.replace('_', '-')}={v}" for k, v in inline_props.items()
)
self._string = _(
f"""\
nvdsanalytics
config-file={self.config_file}
{props}
"""
)
return self._string
[docs] def requires_tracker(self) -> bool:
"""Return `True` if its `nvdsanalytics` requires `nvtracker`.
Returns:
`True` if its `nvdsanalytics` contains line crossing or
direction andata.
"""
config = configparser.ConfigParser()
config.read(str(self.config_file))
for section_name in config.sections():
if any(
section_name.startswith(pattern)
for pattern in (
"line-crossing",
"direction-detection",
)
):
return True
return False
[docs] @classmethod
def from_file(cls: Type[A], config_file: Path) -> A:
"""Factory from configuration file.
Args:
config_file: location of the nvdsanalytics `config-file`
property.
Returns:
The instantiated `nvdsanalytics` wrapper class.
Raises:
FileNotFoundError: The `nvdsanalytics` `config-file`
property is not found.
"""
config_file = Path(config_file).resolve()
if not config_file.exists():
raise FileNotFoundError(
f"No Analytics configuration file found at {config_file}."
)
return cls(config_file=config_file)
[docs] @classmethod
def from_element(cls: Type[A], element: Gst.Element) -> A:
"""Factory from nvdsanalytics.
Args:
element: The nvdsanalytics to use as source.
Returns:
The instantiated nvdsanalytics wrapper.
"""
skip = ("parent",)
props = {}
for prop in element.list_properties():
name = prop.name
if name in skip:
continue
raw = element.get_property(name)
try:
value = raw.value_nick # enums
except AttributeError:
value = str(raw)
if isinstance(raw, bool):
value = value.lower() # False -> false
props[name] = value
return cls(
config_file=props.pop("config-file"),
_default_props=props,
)