Coverage for src/pythia/models/base.py: 81%

201 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-10-07 19:27 +0000

1"""Model. 

2 

3A deep learning model used to extract metadata from a video source. 

4Contains: 

5 

6 A labels.txt file, containing the list of model labels. 

7 A pgie.conf file. 

8 A model.etlt,model.engine or any Deepstream-nvinfer compatible engine file. 

9 

10""" 

11from __future__ import annotations 

12 

13import configparser 

14import json 

15from dataclasses import dataclass 

16from dataclasses import field 

17from pathlib import Path 

18from textwrap import dedent as _ 

19from typing import Collection 

20from typing import Dict 

21from typing import Optional 

22from typing import Type 

23from typing import TypeVar 

24 

25from pythia.types import Con 

26from pythia.types import HasConnections 

27from pythia.utils.ext import not_empty 

28from pythia.utils.ext import not_none 

29from pythia.utils.gst import Gst 

30 

31IE = TypeVar("IE", bound="InferenceEngine") 

32T = TypeVar("T", bound="Tracker") 

33A = TypeVar("A", bound="Analytics") 

34 

35 

36@dataclass 

37class InferenceEngine(HasConnections): 

38 """Pythia wrapper around nvinfer gst element.""" 

39 

40 MODEL_SUFFIXES = { 

41 ".etlt": "tlt-encoded-model", 

42 ".caffe": "model-file", 

43 ".caffemodel": "model-file", 

44 ".prototxt": "proto-file", 

45 ".onnx": "onnx-file", 

46 ".uff": "uff-file", 

47 # ".engine": "model-engine-file", 

48 } 

49 """Supported model extensions (prioritized order). 

50 

51 Used when an inference engine is to be instantiated by a directory, 

52 to locate supported models from their extension. 

53 

54 See Also: :meth:`locate_source_model`. 

55 

56 """ 

57 

58 MODEL_CONFIG_SUFFIXES = ( 

59 ".conf", 

60 ".ini", 

61 ".yml", 

62 ".yaml", 

63 ) 

64 """Ordered collection of supported model config file extensions. 

65 

66 Used when an inference engine is to be instantiated by a directory, 

67 to locate `config-file-path` from their extension. 

68 

69 See Also: :meth:`locate_config_file`. 

70 

71 """ 

72 

73 labels_file: Path 

74 config_file: Path 

75 _string: Optional[str] = None 

76 source_model: Optional[Path] = None 

77 compiled_model: Optional[Path] = None 

78 _default_props: Dict[str, str] = field(default_factory=dict) 

79 

80 CONNECTIONS: Con = field(default_factory=dict) # noqa: C0103 

81 

82 def gst(self, name: str, **kw) -> str: 

83 """Render nvinfer with `gst-launch`-like syntax. 

84 

85 Args: 

86 name: nvinfer gstreamer element name property. 

87 kw: nvinfer gstreamer property name and value. 

88 

89 Returns: 

90 Rendered string 

91 

92 See Also: 

93 https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_plugin_gst-nvinfer.html#gst-properties 

94 

95 """ 

96 props = "\n".join(f"{k.replace('_', '-')}={v}" for k, v in kw.items()) 

97 self._string = _( 

98 f"""\ 

99 nvinfer 

100 config-file-path={self.config_file} 

101 name={name} 

102 {props} 

103 """ 

104 ) 

105 return self._string 

106 

107 @classmethod 

108 def locate_source_model(cls, folder: Path) -> Path | None: 

109 """Find the first deepstream model file in a folder. 

110 

111 It iterates over the known nvinfer-compatible model file 

112 extensions, and returns at the first success. 

113 

114 Args: 

115 folder: Directory to search the model. 

116 

117 Returns: 

118 Found model, or `None` if not found. 

119 

120 """ 

121 for suffix in cls.MODEL_SUFFIXES: 121 ↛ 126line 121 didn't jump to line 126, because the loop on line 121 didn't complete

122 try: 

123 return not_empty(next(iter(folder.glob(f"*{suffix}")))) 

124 except StopIteration: 

125 pass 

126 return None 

127 

128 @staticmethod 

129 def locate_labels_file(folder: Path) -> Path: 

130 """Find labels file from a directory. 

131 

132 Args: 

133 folder: directory to search labels file. 

134 

135 Returns: 

136 The first file matching the `*label*` pattern inside the 

137 directory. 

138 

139 Raises: 

140 FileNotFoundError: no file matches the expected labels 

141 pattern. 

142 

143 """ 

144 try: 

145 return not_empty(next(iter(folder.glob("*label*")))) 

146 except StopIteration as exc: 

147 raise FileNotFoundError( 

148 f"No labels file found at {folder}" 

149 ) from exc 

150 

151 @classmethod 

152 def locate_config_file(cls, folder: Path) -> Path: 

153 """Find the first model config file in a folder. 

154 

155 Iterate over the known nvinfer-compatible config-file-path file 

156 extensions, and returns at the first success. 

157 

158 Args: 

159 folder: Directory to search the model. 

160 

161 Returns: 

162 path to the found configuration file. 

163 

164 Raises: 

165 FileNotFoundError: No configuration file found. 

166 

167 """ 

168 for suffix in cls.MODEL_CONFIG_SUFFIXES: 168 ↛ 173line 168 didn't jump to line 173, because the loop on line 168 didn't complete

169 try: 

170 return not_empty(next(iter(folder.glob(f"*{suffix}")))) 

171 except StopIteration: 

172 pass 

173 raise FileNotFoundError(f"No config file found at {folder}") 

174 

175 @staticmethod 

176 def locate_compiled_model( 

177 folder: Path, source_model: Path | None 

178 ) -> Path | None: 

179 """Find the first model engine file in a folder. 

180 

181 Returns a file matching the `*.engine` pattern 

182 

183 Args: 

184 folder: Directory to search the model. 

185 source_model: If set, use this path's stem to try to locate 

186 the `.engine` file. Otherwise, finds any `*.engine`. 

187 

188 Returns: 

189 path to the found configuration file. 

190 

191 Raises: 

192 FileNotFoundError: No configuration file found using any of 

193 the strategies. 

194 

195 """ 

196 if source_model: 196 ↛ 209line 196 didn't jump to line 209, because the condition on line 196 was never false

197 try: 

198 return not_empty( 

199 next( 

200 iter( 

201 source_model.parent.glob( 

202 f"*{source_model.stem}*.engine" 

203 ) 

204 ) 

205 ) 

206 ) 

207 except StopIteration: 

208 pass 

209 try: 

210 return next(iter(folder.glob("*.engine"))) 

211 except StopIteration as exc: 

212 if not source_model: 212 ↛ 213line 212 didn't jump to line 213, because the condition on line 212 was never true

213 raise FileNotFoundError( 

214 f"Neither {source_model=} nor its compiled version exist." 

215 ) from exc 

216 return None 

217 

218 @classmethod 

219 def from_folder(cls: Type[IE], folder: str | Path) -> IE: 

220 """Factory to instantiate from directories. 

221 

222 Args: 

223 folder: Directory where the model files are located. 

224 

225 Returns: 

226 Instantiated model. 

227 

228 Raises: 

229 FileNotFoundError: empty folder received. 

230 

231 """ 

232 folder = Path(folder).resolve() 

233 if not folder.exists(): 233 ↛ 234line 233 didn't jump to line 234, because the condition on line 233 was never true

234 raise FileNotFoundError(f"No directory not found at {folder}.") 

235 

236 source_model = cls.locate_source_model(folder) 

237 labels_file = cls.locate_labels_file(folder) 

238 config_file = cls.locate_config_file(folder) 

239 compiled_model = cls.locate_compiled_model(folder, source_model) 

240 

241 return cls( 

242 labels_file=labels_file, 

243 config_file=config_file, 

244 source_model=source_model, 

245 compiled_model=compiled_model, 

246 ) 

247 

248 @classmethod 

249 def from_element(cls: Type[IE], element: Gst.Element) -> IE: 

250 """Factory from nvinfer. 

251 

252 Args: 

253 element: The nvinfer to use as source. 

254 

255 Returns: 

256 The instantiated nvinfer wrapper. 

257 

258 """ 

259 skip = ("parent",) 

260 props = {} 

261 for prop in element.list_properties(): 

262 name = prop.name 

263 if name in skip: 

264 continue 

265 

266 raw = element.get_property(name) 

267 try: 

268 value = raw.value_nick # enums 

269 except AttributeError: 

270 value = str(raw) 

271 if isinstance(raw, bool): 

272 value = value.lower() # False -> false 

273 props[name] = value 

274 

275 config_file = Path(props.pop("config-file-path")).resolve() 

276 return cls( 

277 config_file=config_file, 

278 labels_file=not_none( 

279 cls.pop_property_or_get_from_nvinfer_conf( # noqa: C0301 

280 config_file, 

281 props, 

282 property_names=("labelfile-path",), 

283 ) 

284 ), 

285 source_model=cls.pop_property_or_get_from_nvinfer_conf( 

286 config_file, 

287 props, 

288 property_names=cls.MODEL_SUFFIXES.values(), 

289 ), 

290 compiled_model=cls.pop_property_or_get_from_nvinfer_conf( 

291 config_file, 

292 props, 

293 property_names=("model-engine-file",), 

294 ), 

295 _default_props=props, 

296 ) 

297 

298 @staticmethod 

299 def pop_property_or_get_from_nvinfer_conf( 

300 config_file: Path, 

301 props: dict[str, str], 

302 *, 

303 property_names: Collection[str], 

304 ) -> Path | None: 

305 """Pop nvinfer property, or get from config_file if not found. 

306 

307 Args: 

308 config_file: `nvinfer.conf` ini file. Used to compute 

309 absolute paths, and default source for property values 

310 if not found in the props arg. 

311 props: element properties where to look for the desired 

312 properties. Note: If the property is found, its popped 

313 from this dict. 

314 property_names: possible property names to look for. 

315 

316 Returns: 

317 First occurence of the property_names, as found either in 

318 the props dict or in the nvinfer.conf `[property]` 

319 section. 

320 

321 Raises: 

322 FileNotFoundError: None of the requested names is available 

323 in the properties, and the config file does not exist. 

324 

325 """ 

326 # extract from nvinfer's properties 

327 for prop_name in property_names: 

328 value = props.get(prop_name, None) 

329 if value is None: 

330 continue 

331 value_path = Path(value) 

332 if not value_path.is_absolute(): 332 ↛ 333line 332 didn't jump to line 333, because the condition on line 332 was never true

333 value_path = value_path.relative_to(config_file).resolve() 

334 return value_path 

335 

336 # extract from nvinfer's config file 

337 if not config_file.exists(): 337 ↛ 338line 337 didn't jump to line 338, because the condition on line 337 was never true

338 raise FileNotFoundError(config_file) 

339 config = configparser.ConfigParser() 

340 config.read(str(config_file)) 

341 for prop_name in property_names: 341 ↛ 350line 341 didn't jump to line 350, because the loop on line 341 didn't complete

342 value = config["property"].get(prop_name, None) 

343 if value is None: 

344 continue 

345 value_path = Path(value) 

346 if not value_path.is_absolute(): 346 ↛ 348line 346 didn't jump to line 348, because the condition on line 346 was never false

347 value_path = (config_file.parent / value_path).resolve() 

348 return value_path 

349 

350 return None 

351 

352 

353@dataclass 

354class Tracker(HasConnections): 

355 """Pythia wrapper around nvtracker gst element.""" 

356 

357 config_file: Path 

358 low_level_library: Path = Path( 

359 "/opt/nvidia/deepstream/deepstream/lib/libnvds_nvmultiobjecttracker.so" 

360 ).resolve() 

361 _string: Optional[str] = None 

362 _default_props: Dict[str, str] = field(default_factory=dict) 

363 CONNECTIONS: Con = field(default_factory=dict) # noqa: C0103 

364 

365 @classmethod 

366 def from_file( 

367 cls: Type[T], 

368 config_file: Path, 

369 low_level_library: Path = low_level_library, 

370 ) -> T: 

371 """Factory to create `Tracker` s from configuration file. 

372 

373 Args: 

374 config_file: path for the `nvtracker` gstreamer element 

375 'll-config-file' property. 

376 low_level_library: path for the `nvtracker` gstreamer element 

377 'll-lib-file' property (shared object). 

378 

379 Returns: 

380 Instantiated `Tracker`. 

381 

382 Raises: 

383 FileNotFoundError: Tracker config file does not exist. 

384 

385 """ 

386 config_file = Path(config_file).resolve() 

387 if not config_file.exists(): 387 ↛ 388line 387 didn't jump to line 388, because the condition on line 387 was never true

388 raise FileNotFoundError( 

389 f"No Tracker configuration file found at {config_file}." 

390 ) 

391 return cls( 

392 config_file=config_file, 

393 low_level_library=low_level_library, 

394 ) 

395 

396 @classmethod 

397 def from_element(cls: Type[T], element: Gst.Element) -> T: 

398 """Factory from nvtracker. 

399 

400 Args: 

401 element: The nvtracker to use as source. 

402 

403 Returns: 

404 The instantiated nvtracker wrapper. 

405 

406 """ 

407 skip = ("parent",) 

408 props = {} 

409 for prop in element.list_properties(): 

410 name = prop.name 

411 if name in skip: 

412 continue 

413 

414 raw = element.get_property(name) 

415 try: 

416 value = raw.value_nick # enums 

417 except AttributeError: 

418 value = str(raw) 

419 if isinstance(raw, bool): 

420 value = value.lower() # False -> false 

421 props[name] = value 

422 

423 return cls( 

424 config_file=props.pop("ll-config-file"), 

425 low_level_library=props.pop("ll-lib-file"), 

426 _default_props=props, 

427 ) 

428 

429 def gst(self, **kwargs: str) -> str: 

430 """Render nvtracker element with `gst-launch`-like syntax. 

431 

432 Args: 

433 kwargs: Additional gst element properties. 

434 

435 Returns: 

436 Rendered string. 

437 

438 Raises: 

439 FileNotFoundError: Tracker `ll-config-file` not found. 

440 

441 See Also: 

442 https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_plugin_gst-nvtracker.html#gst-properties 

443 

444 """ 

445 inline_props = json.loads(json.dumps(self._default_props)) 

446 inline_props.update(kwargs) 

447 props = "\n".join( 

448 f"{k.replace('_', '-')}={v}" for k, v in inline_props.items() 

449 ) 

450 

451 if not self.low_level_library.exists(): 451 ↛ 452line 451 didn't jump to line 452, because the condition on line 451 was never true

452 raise FileNotFoundError( 

453 "Could not find Tracker implementation" 

454 f" at {self.low_level_library}" 

455 ) 

456 self._string = _( 

457 f"""\ 

458 nvtracker 

459 ll-config-file={self.config_file} 

460 ll-lib-file={self.low_level_library} 

461 {props} 

462 """ 

463 ) 

464 return self._string 

465 

466 

467@dataclass 

468class Analytics(HasConnections): 

469 """Pythia wrapper around nvdsanalytics gst element.""" 

470 

471 config_file: Path 

472 _string: Optional[str] = None 

473 _default_props: Dict[str, str] = field(default_factory=dict) 

474 CONNECTIONS: Con = field(default_factory=dict) # noqa: C0103 

475 

476 def gst(self, **kwargs: str) -> str: 

477 """Render string as `gst-launch`-like parseable string. 

478 

479 Args: 

480 kwargs: Additional gst element properties. 

481 

482 Returns: 

483 Rendered `nvdsanalytics`. 

484 

485 See Also: 

486 https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_plugin_gst-nvdsanalytics.html#gst-properties 

487 

488 """ 

489 inline_props = json.loads(json.dumps(self._default_props)) 

490 inline_props.update(kwargs) 

491 props = "\n".join( 

492 f"{k.replace('_', '-')}={v}" for k, v in inline_props.items() 

493 ) 

494 self._string = _( 

495 f"""\ 

496 nvdsanalytics 

497 config-file={self.config_file} 

498 {props} 

499 """ 

500 ) 

501 

502 return self._string 

503 

504 def requires_tracker(self) -> bool: 

505 """Return `True` if its `nvdsanalytics` requires `nvtracker`. 

506 

507 Returns: 

508 `True` if its `nvdsanalytics` contains line crossing or 

509 direction andata. 

510 

511 """ 

512 config = configparser.ConfigParser() 

513 config.read(str(self.config_file)) 

514 for section_name in config.sections(): 514 ↛ 523line 514 didn't jump to line 523, because the loop on line 514 didn't complete

515 if any( 

516 section_name.startswith(pattern) 

517 for pattern in ( 

518 "line-crossing", 

519 "direction-detection", 

520 ) 

521 ): 

522 return True 

523 return False 

524 

525 @classmethod 

526 def from_file(cls: Type[A], config_file: Path) -> A: 

527 """Factory from configuration file. 

528 

529 Args: 

530 config_file: location of the nvdsanalytics `config-file` 

531 property. 

532 

533 Returns: 

534 The instantiated `nvdsanalytics` wrapper class. 

535 

536 Raises: 

537 FileNotFoundError: The `nvdsanalytics` `config-file` 

538 property is not found. 

539 

540 """ 

541 config_file = Path(config_file).resolve() 

542 if not config_file.exists(): 542 ↛ 543line 542 didn't jump to line 543, because the condition on line 542 was never true

543 raise FileNotFoundError( 

544 f"No Analytics configuration file found at {config_file}." 

545 ) 

546 return cls(config_file=config_file) 

547 

548 @classmethod 

549 def from_element(cls: Type[A], element: Gst.Element) -> A: 

550 """Factory from nvdsanalytics. 

551 

552 Args: 

553 element: The nvdsanalytics to use as source. 

554 

555 Returns: 

556 The instantiated nvdsanalytics wrapper. 

557 

558 """ 

559 skip = ("parent",) 

560 props = {} 

561 for prop in element.list_properties(): 

562 name = prop.name 

563 if name in skip: 

564 continue 

565 

566 raw = element.get_property(name) 

567 try: 

568 value = raw.value_nick # enums 

569 except AttributeError: 

570 value = str(raw) 

571 if isinstance(raw, bool): 

572 value = value.lower() # False -> false 

573 props[name] = value 

574 

575 return cls( 

576 config_file=props.pop("config-file"), 

577 _default_props=props, 

578 )