Coverage for src/pythia/pipelines/base.py: 77%
248 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-10-07 19:27 +0000
« prev ^ index » next coverage.py v6.4.4, created at 2022-10-07 19:27 +0000
1"""Pipeline.
3A Gstreamer pipeline, used to process video/image input.
4Contains:
6 A source: video/image.
7 A sink: display/file.
8 At least one PythIA model.
10Note:
11 Although a one-shot uridecodebin usage seems to work, there seems to
12 be an issue with quickly subsequent runs, producing segfaults.
14"""
16from __future__ import annotations
18import abc
19import re
20from collections import defaultdict
21from pathlib import Path
22from textwrap import dedent as _
23from typing import Collection
24from typing import Dict
25from typing import Iterator
26from typing import List
27from typing import Optional
28from typing import Tuple
29from typing import Union
30from urllib.parse import parse_qs
31from urllib.parse import urlparse
32from urllib.parse import urlunparse
34from pythia.exceptions import IncompatiblePipelineError
35from pythia.exceptions import InvalidPipelineError
36from pythia.models.base import Analytics
37from pythia.models.base import InferenceEngine
38from pythia.models.base import Tracker
39from pythia.types import Con
40from pythia.types import HasConnections
41from pythia.types import SinkUri
42from pythia.types import SourceUri
43from pythia.utils.ext import get_arch
44from pythia.utils.gst import GLib
45from pythia.utils.gst import Gst
46from pythia.utils.gst import gst_init
47from pythia.utils.str2pythia import find_analytics
48from pythia.utils.str2pythia import find_models
49from pythia.utils.str2pythia import find_tracker
51PSB = Union["PythiaTestSource", "PythiaSource", "PythiaMultiSource"]
52PS = Union[
53 "PythiaFakesink", "PythiaFilesink", "PythiaMultifileSink", "PythiaLiveSink"
54]
56UNABLE_TO_PLAY_PIPELINE = "Unable to play the pipeline."
59class PythiaSourceBase(abc.ABC, HasConnections):
60 """Base class wrapper for Gstreamer sources.
62 The main goal is to define a skeleton for quickly building sources,
63 and subclasses must implement their rendering logic in the `gst`
64 method.
66 """
68 CONNECTIONS: Con = {}
70 def __init__(self, *uris: SourceUri) -> None:
71 """Construct an instance from `SourceUri` s.
73 Args:
74 uris: Collection of `SourceUri` s.
76 """
77 self.pythia_params, self.uris = self.pop_pythia_args_from_uris(uris)
79 def __iter__(self) -> Iterator[SourceUri]:
80 """Iterate over the configured uris.
82 Yields:
83 Own `SourceUri`s.
85 """
86 yield from self.uris
88 def __len__(self) -> int:
89 """Get the number of sources.
91 Returns:
92 The number of configured uris.
94 """
96 return len(self.uris)
98 @abc.abstractmethod
99 def gst(self) -> str:
100 """Render as string with `gst-launch`-like syntax."""
102 @classmethod
103 def from_uris(cls, *uris: SourceUri) -> PSB:
104 """Factory to build a concrete source from a collection of uris.
106 Depending on the received uris, instantiates a concrete
107 :class:`PythiaSourceBase`.
109 Args:
110 uris: Collection of uris to build the source from.
112 Returns:
113 The instantiated source object.
115 Raises:
116 ValueError: No source uris received
118 """
119 num_uris = len(uris)
120 if num_uris == 1: 120 ↛ 124line 120 didn't jump to line 124, because the condition on line 120 was never false
121 if uris[0].startswith("test"):
122 return PythiaTestSource(*uris)
123 return PythiaSource(*uris)
124 if num_uris >= 1:
125 return PythiaMultiSource(*uris)
126 raise ValueError("No source uris")
128 @staticmethod
129 @abc.abstractmethod
130 def pop_pythia_args_from_uris(
131 uris: Tuple[SourceUri, ...],
132 ) -> Tuple[dict, List[SourceUri]]:
133 """Pop pythia-related query params from source uri.
135 Args:
136 uris: input uris to filter
138 """
141def clean_single_uri(uri: SourceUri) -> Tuple[dict, SourceUri]:
142 """Extract muxer width and height.
144 Args:
145 uri: input uris to parse.
147 Returns:
148 extracted: dictionary containing popped params.
149 list containing the single uri wihtout its pythia query params.
151 Examples:
152 >>> clean_single_uri("file://video.mp4?muxer_width=1280&muxer_height=720")
153 ({'muxer_width': 1280, 'muxer_height': 720}, ['file://video.mp4'])
155 """ # noqa: C0301
156 parsed = urlparse(uri)
157 data = parsed._asdict()
158 query = parse_qs(data["query"], strict_parsing=False)
159 extracted = {
160 "muxer_width": int(query["muxer_width"][0]),
161 "muxer_height": int(query["muxer_height"][0]),
162 "num_buffers": int(query.get("num_buffers", ["-1"])[0]),
163 }
164 clean_query = parsed.query
165 for name, value in extracted.items():
166 clean_query = clean_query.replace(f"{name}={value}", "")
167 clean_query = re.sub(r"\&+", "&", clean_query).strip("&").strip("?")
168 parts = [*parsed[:4], clean_query, *parsed[5:]]
169 clean_uri = urlunparse(parts)
170 return extracted, clean_uri
173class PythiaSource(PythiaSourceBase):
174 """Uridecodebin wrapper building block for a single source."""
176 @staticmethod
177 def pop_pythia_args_from_uris(
178 uris: Tuple[SourceUri, ...],
179 ) -> Tuple[dict, List[SourceUri]]:
180 """Extract muxer width and height.
182 Args:
183 uris: input uris to filter
185 Returns:
186 extracted: dictionary containing popped params.
187 list containing the single uri wihtout its pythia query params.
189 Examples:
190 >>> uris = ["file://video.mp4?muxer_width=1280&muxer_height=720"]
191 >>> PythiaSource.pop_pythia_args_from_uris(uris)
192 ({'muxer_width': 1280, 'muxer_height': 720}, ['file://video.mp4'])
194 """
195 extracted, clean_uri = clean_single_uri(uris[0])
196 return extracted, [clean_uri]
198 CONNECTIONS: Con = {}
200 def gst(self) -> str:
201 """Render from single uridecodebin up to nvmuxer.
203 Returns:
204 Rendered string
206 """
208 return _(
209 f"""\
210 uridecodebin
211 uri={self.uris[0]}
212 ! queue
213 ! nvvideoconvert
214 ! video/x-raw(memory:NVMM)
215 ! m.sink_0
216 nvstreammux
217 name=m
218 batch-size={len(self)}
219 width={self.pythia_params['muxer_width']}
220 height={self.pythia_params['muxer_height']}
221 """
222 )
225class PythiaMultiSource(PythiaSourceBase):
226 """Uridecodebin wrapper building block for multiple sources."""
228 @staticmethod
229 def pop_pythia_args_from_uris(
230 uris: Tuple[SourceUri, ...],
231 ) -> Tuple[dict, List[SourceUri]]:
232 """Extract muxer width and height.
234 Args:
235 uris: input uris to filter
237 Returns:
238 extracted: dictionary containing popped params.
239 list containing the single uri wihtout its pythia query params.
241 Examples:
242 >>> uris = [
243 ... "./frames/%04d.jpg?muxer_width=320&muxer_height=240",
244 ... "./annotations/%04d.jpg?muxer_width=1280&muxer_height=100",
245 ... ]
246 >>> PythiaMultiSource.pop_pythia_args_from_uris(uris)
247 ({'muxer_width': 1280, 'muxer_height': 240}, ['./frames/%04d.jpg', './annotations/%04d.jpg'])
249 """ # noqa: C0301
250 extrema = {
251 "muxer_width": 0,
252 "muxer_height": 0,
253 }
254 uris_out = []
255 for uri in uris:
256 extracted, clean_uri = clean_single_uri(uri)
257 uris_out.append(clean_uri)
258 for key in extrema:
259 extrema[key] = max(extrema[key], extracted[key])
260 return extrema, uris_out
262 def gst(self) -> str:
263 """Render from several uridecodebin up to nvmuxer.
265 Returns:
266 Rendered string
268 """
269 suffix = _(
270 f"""\
271 nvstreammux
272 name=m
273 batch-size={len(self.uris)}
274 """
275 )
276 text = "\n".join(
277 f"""\
278 uridecodebin
279 uri={self.uris[idx]}
280 ! queue
281 ! nvvideoconvert
282 ! video/x-raw(memory:NVMM)
283 ! m.sink_{idx}
284 nvstreammux
285 name=m
286 batch-size=1
287 """
288 for idx in range(len(self.uris))
289 )
290 return f"{text}\n{suffix}"
293class PythiaTestSource(PythiaSourceBase):
294 """videotestsrc wrapper building block."""
296 @staticmethod
297 def pop_pythia_args_from_uris(
298 uris: Tuple[SourceUri, ...],
299 ) -> Tuple[dict, List[SourceUri]]:
300 """Extract muxer width and height.
302 Args:
303 uris: input uris to filter
305 Returns:
306 extracted: dictionary containing popped params.
307 list containing the single uri wihtout its pythia query params.
309 Examples:
310 >>> uris = ["test://?muxer_width=320&muxer_height=240"]
311 >>> PythiaTestSource.pop_pythia_args_from_uris(uris)
312 ({'muxer_width': 320, 'muxer_height': 240}, ['test:'])
314 """
315 extracted, clean_uri = clean_single_uri(uris[0])
316 return extracted, [clean_uri]
318 def gst(self) -> str:
319 """Render from single videotestsrc up to nvmuxer.
321 Returns:
322 Rendered string.
324 """
325 return _(
326 f"""
327 videotestsrc
328 num-buffers={self.pythia_params['num_buffers']}
329 ! queue
330 ! nvvideoconvert
331 ! video/x-raw(memory:NVMM)
332 ! m.sink_0
333 nvstreammux
334 name=m
335 batch-size={len(self)}
336 nvbuf-memory-type=0
337 width={self.pythia_params['muxer_width']}
338 height={self.pythia_params['muxer_height']}
339 """
340 )
343class PythiaSink(abc.ABC, HasConnections):
344 """Class used to construct sink from uris."""
346 CONNECTIONS: Con = {}
347 VIDEO_EXTENSIONS = [
348 ".mp4",
349 ".avi",
350 ".mov",
351 ".mkv",
352 ".webm",
353 ".flv",
354 ".wmv",
355 ".mpg",
356 ".mpeg",
357 ".m4v",
358 ]
360 def __init__(self, uri: SinkUri) -> None:
361 """Instantiate sink wrapper with one of the available uris.
363 Args:
364 uri: the uri to build a gst sink and finish the pipeline.
366 """
367 self.uri = uri
369 @classmethod
370 def from_uri(cls, uri: SinkUri) -> PS:
371 """Factory constructor from `SinkUri` .
373 Args:
374 uri: the uri to use. Must fulfill one of the following
375 conditions:
377 * be one of ('live', 'fakesink'). If set to 'live', the
378 output will be the screen. If set to 'fakesink', use
379 the fakesing `Gst.Element` .
380 * If a string containing a `%` , the underlying element
381 will be a `multifilesink` .
382 * Otherwise, it mus be a string pointing to a path, and
383 have a valid and supported video extension.
385 Returns:
386 The instantiated `PythiaSink` , depending on its uri.
388 Raises:
389 ValueError: unsupported sink uri.
391 """
392 if uri == "live": 392 ↛ 393line 392 didn't jump to line 393, because the condition on line 392 was never true
393 return PythiaLiveSink(uri)
395 if uri == "fakesink":
396 return PythiaFakesink(uri)
398 if "%" in Path(uri).stem: 398 ↛ 401line 398 didn't jump to line 401, because the condition on line 398 was never false
399 return PythiaMultifileSink(uri)
401 if Path(uri).suffix in cls.VIDEO_EXTENSIONS:
402 return PythiaFilesink(uri)
404 raise ValueError(f"Unknown sink uri: {uri}")
406 @abc.abstractmethod
407 def gst(self) -> str:
408 """Render as string with `gst-launch`-like syntax."""
411class PythiaFakesink(PythiaSink):
412 """fakesink wrapper building block for a single sink."""
414 def gst(self) -> str:
415 """Simple fakesink.
417 Returns:
418 Rendered string
420 """
421 return "fakesink"
424class PythiaFilesink(PythiaSink):
425 """filesink wrapper building block for a single sink.
427 Uses `encodebin` to attempt to properly parse upstream buffers.
429 """
431 def gst(self) -> str:
432 """Render from single encodebin up to filesink.
434 Returns:
435 Rendered string
437 """
439 return _(
440 f"""\
441 encodebin
442 ! filesink
443 location="{self.uri}"
444 """
445 )
448class PythiaMultifileSink(PythiaSink):
449 """multifilesink building block for a single multioutput sink.
451 Uses `encodebin` to attempt to properly parse upstream buffers.
453 """
455 SUPPORTED_FORMATS = {
456 ".jpg": """
457 nvvideoconvert
458 ! jpegenc
459 quality=100
460 idct-method=float
461 """,
462 ".png": """
463 nvvideoconvert
464 ! avenc_png
465 """,
466 ".webp": """
467 nvvideoconvert
468 ! webpenc
469 lossless=true
470 quality=100
471 speed=6
472 """,
473 }
475 def gst(self) -> str:
476 """Render from single encodebin up to multifilesink.
478 Returns:
479 Rendered string
481 """
482 encode = self.SUPPORTED_FORMATS[Path(self.uri).suffix]
483 return _(
484 f"""\
485 {encode}
486 ! multifilesink
487 location="{self.uri}"
488 """
489 )
492class PythiaLiveSink(PythiaSink):
493 """nveglglessink wrapper."""
495 def __init__(self, uri: SinkUri, arch: str = "") -> None:
496 """Construct nveglglessink wrapper.
498 Args:
499 uri: uri for `PythiaSink`'s constructor.
500 arch: platform architecture, to differentiate GPU and
501 jetson devices. If not set, automatically computed by
502 :func:`get_arch`. In jetson devices, injects an
503 additional `nvegltransform`.
505 See Also:
506 https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_FAQ.html#why-is-a-gst-nvegltransform-plugin-required-on-a-jetson-platform-upstream-from-gst-nveglglessink
508 """
509 super().__init__(uri)
510 self.arch = arch or get_arch()
511 self.transform = "! nvegltransform" if get_arch() == "aarch64" else ""
513 def gst(self) -> str:
514 """Render from nvvideoconvert to nveglglessink.
516 Returns:
517 Rendered string
519 """
520 return _(
521 f"""\
522 nvvideoconvert
523 {self.transform}
524 ! nveglglessink
525 """
526 )
529class BasePipeline(HasConnections, abc.ABC):
530 """Common abstraction wrapper for pythia pipelines."""
532 _pipeline: Optional[Gst.Pipeline] = None
533 models: Collection[InferenceEngine]
534 analytics: Optional[Analytics]
535 tracker: Optional[Tracker]
537 @abc.abstractmethod
538 def gst(self) -> str:
539 """Render its string for to use in `gst-launch`-like syntax."""
541 @property
542 @abc.abstractmethod
543 def CONNECTIONS(self) -> Con: # type: ignore[override] # noqa: C0103,C0116
544 ...
546 def validate(
547 self,
548 ) -> None:
549 """Checks for internal compliance of specified elements.
551 > Tracker requires at least one InferenceEngine > Analytics
552 requires at least one InferenceEngine, and Tracker if it has
553 Direction Detection or Line Crossing. > SecondaryInference
554 Engine requires at least a PrimaryInferenceEngine
556 Raises:
557 IncompatiblePipelineError: `Analytics` requires `Tracker`
558 but none supplied.
559 IncompatiblePipelineError: `Tracker` requires `Model` but
560 none supplied.
562 """
563 if self.analytics:
564 if len(self.models) < 1: 564 ↛ 565line 564 didn't jump to line 565, because the condition on line 564 was never true
565 raise IncompatiblePipelineError(
566 f"Analytics requires at least 1 InferenceEngine."
567 f" Found {len(self.models)}."
568 )
569 if self.analytics.requires_tracker() and (self.tracker is None): 569 ↛ 570line 569 didn't jump to line 570, because the condition on line 569 was never true
570 raise IncompatiblePipelineError(
571 "Current Analytics spec requires at least Tracker, "
572 "but none found."
573 )
574 if self.tracker:
575 if len(self.models) < 1: 575 ↛ 576line 575 didn't jump to line 576, because the condition on line 575 was never true
576 raise IncompatiblePipelineError(
577 "Tracker requires at least 1 InferenceEngine."
578 f" Found {len(self.models)}."
579 )
581 @property
582 def pipeline(self) -> Gst.Pipeline:
583 """Gstreamer pipeline lazy property.
585 Returns:
586 The only Gstremaer pipeline on this app, instantiated.
588 """
589 if not self._pipeline:
590 self._pipeline = self.parse_launch()
591 return self._pipeline
593 def parse_launch(self) -> Gst.Pipeline:
594 """Instantiate the internal `Gst.Pipeline`.
596 Returns:
597 The instantiated :class:`Gst.Pipeline`.
599 Raises:
600 NotImplementedError: pipeline already instantiated.
601 InvalidPipelineError: Unable to parse pipeline because of a
602 syntax error in the pipeline string.
603 GLib.Error: Syntax unrelated error - unable to parse
604 pipeline.
606 """
607 gst_init()
608 if self._pipeline: 608 ↛ 609line 608 didn't jump to line 609, because the condition on line 608 was never true
609 raise NotImplementedError(
610 "TODO: make a copy of the pipeline,"
611 " this one is already in use"
612 )
613 try:
614 return Gst.parse_launch(self.gst())
615 except GLib.Error as exc:
616 if "syntax error" in str(exc):
617 raise InvalidPipelineError from exc
618 raise
620 def start(self) -> Gst.StateChangeReturn:
621 """Start the pipeline by setting it to PLAYING state.
623 Returns:
624 The state change result enum.
626 Raises:
627 RuntimeError: Unable to play the pipeline.
629 """
630 self.validate()
631 result = self.pipeline.set_state(Gst.State.PLAYING)
632 if result is Gst.StateChangeReturn.FAILURE: 632 ↛ 633line 632 didn't jump to line 633, because the condition on line 632 was never true
633 self.stop()
634 raise RuntimeError(f"ERROR: {UNABLE_TO_PLAY_PIPELINE}")
635 return result
637 def stop(self) -> None:
638 """Set the pipeline to null state."""
639 self.pipeline.set_state(Gst.State.NULL)
641 def send_eos(self) -> None:
642 """Send a gstreamer 'end of stream' signal."""
644 self.pipeline.send_event(Gst.Event.new_eos())
647ModelType = Union[
648 Collection[Union[Path, InferenceEngine]], Path, InferenceEngine, None
649]
652class Pipeline(BasePipeline):
653 """Wrapper to ease pipeline creation from simple building blocks."""
655 def __init__( # noqa: R0913
656 self,
657 sources: SourceUri | list[SourceUri] | tuple[SourceUri],
658 models: ModelType = None,
659 sink: SinkUri = "fakesink",
660 analytics: Union[Path, Analytics] | None = None,
661 tracker: Union[Path, Tracker] | None = None,
662 ) -> None:
663 """Initialize pipeline wrapper to incrementally build pipeline.
665 Args:
666 sources: Collection of uri sources to join in `nvstreammux`.
667 models: Collection of models to insert in the pipeline.
668 sink: Final element of the pipeline.
669 analytics: Optional `nvdsanalytics`.
670 tracker: Optional `nvtracker`.
672 Raises:
673 ValueError: invalid analytics or tracker object.
675 """
676 super().__init__()
677 if isinstance(sources, SourceUri):
678 sources = [sources]
679 self.source = PythiaSourceBase.from_uris(*sources)
681 if isinstance(models, (Path, InferenceEngine)):
682 models = [models]
683 self.models = (
684 [
685 model
686 if isinstance(model, InferenceEngine)
687 else InferenceEngine.from_folder(model)
688 for model in models
689 ]
690 if models
691 else []
692 )
693 self._model_map: dict[str, InferenceEngine] = {}
695 if analytics is None:
696 self.analytics = analytics
697 elif isinstance(analytics, Analytics): 697 ↛ 698line 697 didn't jump to line 698, because the condition on line 697 was never true
698 self.analytics = analytics
699 elif isinstance(analytics, Path): 699 ↛ 702line 699 didn't jump to line 702, because the condition on line 699 was never false
700 self.analytics = Analytics.from_file(analytics)
701 else:
702 raise ValueError(f"Unhandled {analytics=}")
704 if tracker is None:
705 self.tracker = tracker
706 elif isinstance(tracker, Tracker): 706 ↛ 707line 706 didn't jump to line 707, because the condition on line 706 was never true
707 self.tracker = tracker
708 elif isinstance(tracker, Path): 708 ↛ 711line 708 didn't jump to line 711, because the condition on line 708 was never false
709 self.tracker = Tracker.from_file(tracker)
710 else:
711 raise ValueError(f"Unhandled {tracker=}")
713 self.sink = PythiaSink.from_uri(sink)
715 @property
716 def CONNECTIONS(self) -> Con: # type: ignore[override] # noqa: C0103
717 cons: Con = defaultdict(dict)
718 for connectable in (self.source, *self.models, self.sink):
719 for element_name, connections in connectable.CONNECTIONS.items(): 719 ↛ 720line 719 didn't jump to line 720, because the loop on line 719 never started
720 for signal, callback in connections.items():
721 cons[element_name][signal] = callback
723 return cons
725 @property
726 def model_map(self) -> Dict[str, InferenceEngine]:
727 """Lazyproperty mapping from model names to inference engines.
729 Returns:
730 A dictionary whose keys are nvinfer names and their values
731 are their respective :class:`InferenceEngine` wrappers.
733 """
734 if not self._model_map:
735 self.gst()
736 return self._model_map
738 def gst(self) -> str:
739 """Render its string for to use in `gst-launch`-like syntax.
741 Returns:
742 The pipeline as it would be used when calling `gst-launch`.
744 """
745 source = self.source.gst()
746 models = ""
747 for idx, model in enumerate(self.models):
748 name = f"model_{idx}"
749 self._model_map[name] = model
750 models += model.gst(
751 name=name,
752 unique_id=idx + 1,
753 )
755 sink = self.sink.gst()
756 tracker = self.tracker.gst() if self.tracker else None
757 analytics = self.analytics.gst() if self.analytics else None
758 return _(
759 f"""
760 {source}
761 {'! ' + models if models else ''}
762 {'! ' + tracker if tracker else ''}
763 {'! ' + analytics if analytics else ''}
764 ! {sink}
765 """
766 )
769class StringPipeline(BasePipeline):
770 """Pythia pipeline wrapper to construct from pipeline strings."""
772 CONNECTIONS: Con = {}
774 def __init__(self, pipeline_string: str) -> None:
775 """Initialize pipeline wrapper using a pipeline string.
777 Args:
778 pipeline_string: A `gst-launch`-like pipeline string.
780 Raises:
781 InvalidPipelineError: Unable to parse pipeline because of a
782 syntax error in the pipeline string.
783 GLib.Error: Syntax unrelated error - unable to parse
784 pipeline.
786 """
787 super().__init__()
788 self.pipeline_string = pipeline_string
789 try:
790 self.pipeline
791 except GLib.Error as exc:
792 if "gst_parse_error" not in str(exc):
793 raise
794 raise InvalidPipelineError(
795 f"Unable to parse pipeline:\n```gst\n{pipeline_string}\n```"
796 ) from exc
797 self.models = find_models(self.pipeline)
798 self.analytics = find_analytics(self.pipeline)
799 self.tracker = find_tracker(self.pipeline)
801 def gst(self) -> str:
802 return self.pipeline_string