136 lines
3.2 KiB
Python
136 lines
3.2 KiB
Python
# Copyright (C) 2022 Intel Corporation
|
|
# Copyright (C) CVAT.ai Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
from collections.abc import Generator, Iterable
|
|
from typing import Optional, TypeVar
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class ProgressReporter:
|
|
"""
|
|
Use as follows:
|
|
|
|
with r.task(...):
|
|
r.report_status(...)
|
|
r.advance(...)
|
|
|
|
for x in r.iter(...):
|
|
...
|
|
|
|
Implementations must override start2, finish, report_status and advance.
|
|
"""
|
|
|
|
@contextlib.contextmanager
|
|
def task(self, **kwargs) -> Generator[None, None, None]:
|
|
"""
|
|
Returns a context manager that represents a long-running task
|
|
for which progress can be reported.
|
|
|
|
Entering it creates a progress bar, and exiting it destroys it.
|
|
|
|
kwargs will be passed to `start()`.
|
|
"""
|
|
self.start2(**kwargs)
|
|
|
|
try:
|
|
yield None
|
|
finally:
|
|
self.finish()
|
|
|
|
def start(self, total: int, *, desc: Optional[str] = None) -> None:
|
|
"""
|
|
This is a compatibility method. Override start2 instead.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def start2(
|
|
self,
|
|
total: int,
|
|
*,
|
|
desc: Optional[str] = None,
|
|
unit: str = "it",
|
|
unit_scale: bool = False,
|
|
unit_divisor: int = 1000,
|
|
**kwargs,
|
|
) -> None:
|
|
"""
|
|
Initializes the progress bar.
|
|
|
|
total, desc, unit, unit_scale, unit_divisor have the same meaning as in tqdm.
|
|
|
|
kwargs is included for future extension; implementations of this method
|
|
must ignore it.
|
|
"""
|
|
self.start(total=total, desc=desc)
|
|
|
|
def report_status(self, progress: int):
|
|
"""Updates the progress bar"""
|
|
raise NotImplementedError
|
|
|
|
def advance(self, delta: int):
|
|
"""Updates the progress bar"""
|
|
raise NotImplementedError
|
|
|
|
def finish(self):
|
|
"""Finishes the progress bar"""
|
|
pass # pylint: disable=unnecessary-pass
|
|
|
|
def iter(
|
|
self,
|
|
iterable: Iterable[T],
|
|
) -> Iterable[T]:
|
|
"""
|
|
Traverses the iterable and reports progress simultaneously.
|
|
|
|
Args:
|
|
iterable: An iterable to be traversed
|
|
|
|
Returns:
|
|
An iterable over elements of the input sequence
|
|
"""
|
|
|
|
for elem in iterable:
|
|
yield elem
|
|
self.advance(1)
|
|
|
|
|
|
class BaseProgressReporter(ProgressReporter):
|
|
def __init__(self) -> None:
|
|
self._in_progress = False
|
|
|
|
def start2(
|
|
self,
|
|
total: int,
|
|
*,
|
|
desc: Optional[str] = None,
|
|
unit: str = "it",
|
|
unit_scale: bool = False,
|
|
unit_divisor: int = 1000,
|
|
**kwargs,
|
|
) -> None:
|
|
assert not self._in_progress
|
|
self._in_progress = True
|
|
|
|
def report_status(self, progress: int):
|
|
assert self._in_progress
|
|
|
|
def advance(self, delta: int):
|
|
assert self._in_progress
|
|
|
|
def finish(self) -> None:
|
|
assert self._in_progress
|
|
self._in_progress = False
|
|
|
|
def __del__(self):
|
|
assert not self._in_progress, "Unfinished task!"
|
|
|
|
|
|
class NullProgressReporter(BaseProgressReporter):
|
|
pass
|