Coverage for src/pythia/utils/str2pythia.py: 94%

24 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-10-07 19:27 +0000

1"""Utilities to convert strings to pythia wrappers.""" 

2 

3from __future__ import annotations 

4 

5from pythia.models.base import Analytics 

6from pythia.models.base import InferenceEngine 

7from pythia.models.base import Tracker 

8from pythia.utils.gst import Gst 

9from pythia.utils.gst import gst_iter 

10 

11 

12def is_inference(element: Gst.Element) -> bool: 

13 """Check wether a `Gst.Element` is a `nvinfer`. 

14 

15 Args: 

16 element: the gstreamer element to check. 

17 

18 Returns: 

19 `True` iff the element is a `nvinfer`. False otherwise. 

20 

21 """ 

22 return "nvinfer" in element.__class__.__name__.lower() 

23 

24 

25def is_tracker(element: Gst.Element) -> bool: 

26 """Check wether a `Gst.Element` is a `nvtracker`. 

27 

28 Args: 

29 element: the gstreamer element to check. 

30 

31 Returns: 

32 `True` iff the element is a `nvtracker`. False otherwise. 

33 

34 """ 

35 return "nvtracker" in element.__class__.__name__.lower() 

36 

37 

38def is_analytics(element: Gst.Element) -> bool: 

39 """Check wether a `Gst.Element` is a `nvdsanalytics`. 

40 

41 Args: 

42 element: the gstreamer element to check. 

43 

44 Returns: 

45 `True` iff the element is a `nvdsanalytics`. False otherwise. 

46 

47 """ 

48 return "nvdsanalytics" in element.__class__.__name__.lower() 

49 

50 

51def find_models(pipeline: Gst.Pipeline) -> list[InferenceEngine]: 

52 """Extract `nvifer` s from parsed pipeline. 

53 

54 Args: 

55 pipeline: The root bin where to look for ninfer elements. 

56 

57 Returns: 

58 List of all the nvinfer wrappers wrapped as 

59 :class:`InferenceEngine`. 

60 

61 """ 

62 return [ 

63 InferenceEngine.from_element(element) 

64 for element in gst_iter(pipeline.iterate_elements()) 

65 if is_inference(element) 

66 ] 

67 

68 

69def find_analytics(pipeline: Gst.Pipeline) -> Analytics | None: 

70 """Extract analytics from parsed pipeline. 

71 

72 Args: 

73 pipeline: The root bin where to look for `nvdsanalytics` 

74 elements. 

75 

76 Returns: 

77 First `nvdsanalytics` found, wrapped as :class:`Analytics`. 

78 

79 """ 

80 for element in gst_iter(pipeline.iterate_elements()): 

81 if is_analytics(element): 81 ↛ 82line 81 didn't jump to line 82, because the condition on line 81 was never true

82 return Analytics.from_element(element) 

83 return None 

84 

85 

86def find_tracker(pipeline: Gst.Pipeline) -> Tracker | None: 

87 """Extract tracker from parsed pipeline. 

88 

89 Args: 

90 pipeline: The root bin where to look for `nvtracker` elements. 

91 

92 Returns: 

93 First `nvtracker` found, wrapped as :class:`Tracker`. 

94 

95 """ 

96 

97 for element in gst_iter(pipeline.iterate_elements()): 

98 if is_tracker(element): 

99 return Tracker.from_element(element) 

100 

101 return None