Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve progress bar display #545

Merged
merged 7 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lenskit/lenskit/logging/progress/_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from typing import Any


class Progress:
"""
Base class for progress reporting. The default implementations do nothing.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass

def update(self, advance: int = 0, **kwargs: float | int | str):
Expand Down
80 changes: 60 additions & 20 deletions lenskit/lenskit/logging/progress/_rich.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,31 @@
# pyright: strict
from __future__ import annotations

from threading import Lock
from uuid import UUID, uuid4

import structlog
from rich.console import Group
from rich.progress import (
BarColumn,
MofNCompleteColumn,
ProgressColumn,
SpinnerColumn,
Task,
TaskID,
TaskProgressColumn,
TextColumn,
TimeRemainingColumn,
)
from rich.progress import Progress as ProgressImpl
from rich.progress import TaskID
from rich.text import Text
from typing_extensions import override

from .._console import console, get_live
from ._base import Progress

_log = structlog.stdlib.get_logger("lenskit.logging.progress")
_pb_lock = Lock()
_progress: ProgressImpl | None = None
_active_bars: dict[UUID, RichProgress] = {}


Expand All @@ -22,59 +35,86 @@
"""

uuid: UUID
label: str
total: int | None
fields: dict[str, str | None]
logger: structlog.stdlib.BoundLogger
_bar: ProgressImpl
_task: TaskID
_task: TaskID | None = None

def __init__(self, label: str, total: int | None, fields: dict[str, str | None]):
super().__init__()
self.uuid = uuid4()
self.label = label

Check warning on line 47 in lenskit/lenskit/logging/progress/_rich.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/logging/progress/_rich.py#L47

Added line #L47 was not covered by tests
self.total = total
self.fields = fields

self.logger = _log.bind(label=label, uuid=str(self.uuid))

self._bar = ProgressImpl(console=console)

_install_bar(self)

self._task = self._bar.add_task(label, total=total)
self._task = _install_bar(self)

Check warning on line 53 in lenskit/lenskit/logging/progress/_rich.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/logging/progress/_rich.py#L53

Added line #L53 was not covered by tests

def update(self, advance: int = 1, **kwargs: float | int | str):
self._bar.update(self._task, advance=advance, **kwargs) # type: ignore
if _progress is not None:
_progress.update(self._task, advance=advance, **kwargs) # type: ignore

Check warning on line 57 in lenskit/lenskit/logging/progress/_rich.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/logging/progress/_rich.py#L56-L57

Added lines #L56 - L57 were not covered by tests

def finish(self):
self._bar.stop()
_remove_bar(self)


def _install_bar(bar: RichProgress):
def _install_bar(bar: RichProgress) -> TaskID | None:
global _progress
bar.logger.debug("installing progress bar")
live = get_live()
if live is None:
bar._bar.disable = True
return
return None

Check warning on line 68 in lenskit/lenskit/logging/progress/_rich.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/logging/progress/_rich.py#L68

Added line #L68 was not covered by tests

with _pb_lock:
if _progress is None:
_progress = ProgressImpl(

Check warning on line 72 in lenskit/lenskit/logging/progress/_rich.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/logging/progress/_rich.py#L71-L72

Added lines #L71 - L72 were not covered by tests
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
MofNCompleteColumn(),
RateColumn(),
TaskProgressColumn(),
TimeRemainingColumn(),
console=console,
)
live.update(_progress)

Check warning on line 82 in lenskit/lenskit/logging/progress/_rich.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/logging/progress/_rich.py#L82

Added line #L82 was not covered by tests

_active_bars[bar.uuid] = bar
rbs = [b._bar for b in _active_bars.values()]
group = Group(*rbs)
live.update(group)
return _progress.add_task(bar.label, total=bar.total)

Check warning on line 85 in lenskit/lenskit/logging/progress/_rich.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/logging/progress/_rich.py#L85

Added line #L85 was not covered by tests


def _remove_bar(bar: RichProgress):
live = get_live()
if live is None:
if live is None or _progress is None:

Check warning on line 90 in lenskit/lenskit/logging/progress/_rich.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/logging/progress/_rich.py#L90

Added line #L90 was not covered by tests
return
if bar.uuid not in _active_bars:
return
if bar._task is None:
return

Check warning on line 95 in lenskit/lenskit/logging/progress/_rich.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/logging/progress/_rich.py#L94-L95

Added lines #L94 - L95 were not covered by tests

bar.logger.debug("uninstalling progress bar")

with _pb_lock:
_progress.remove_task(bar._task)

Check warning on line 100 in lenskit/lenskit/logging/progress/_rich.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/logging/progress/_rich.py#L100

Added line #L100 was not covered by tests
del _active_bars[bar.uuid]

live.update(Group(*[b._bar for b in _active_bars.values()]))
live.refresh()

class RateColumn(ProgressColumn):
def __init__(self):
super().__init__()

Check warning on line 106 in lenskit/lenskit/logging/progress/_rich.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/logging/progress/_rich.py#L106

Added line #L106 was not covered by tests

@override
def render(self, task: Task):
speed = task.finished_speed or task.speed
if speed is None:
disp = "?"
elif speed > 1000:
disp = "{:d} it/s".format(int(speed))
elif speed > 1:
disp = "{:.3g} it/s".format(speed)

Check warning on line 116 in lenskit/lenskit/logging/progress/_rich.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/logging/progress/_rich.py#L110-L116

Added lines #L110 - L116 were not covered by tests
else:
disp = "{:.3g} s/it"

Check warning on line 118 in lenskit/lenskit/logging/progress/_rich.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/logging/progress/_rich.py#L118

Added line #L118 was not covered by tests

return Text(disp, "progress.percentage")

Check warning on line 120 in lenskit/lenskit/logging/progress/_rich.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/logging/progress/_rich.py#L120

Added line #L120 was not covered by tests
6 changes: 3 additions & 3 deletions lenskit/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ license = { file = "LICENSE.md" }
dynamic = ["version"]
dependencies = [
"pandas ~=2.0",
"structlog >= 23.2",
"numpy >= 1.24",
"scipy >= 1.10.0",
"torch ~=2.1",
"rich ~=13.5",
"threadpoolctl >=3.0",
"pydantic ~=2.7",
"structlog >= 23.2",
"rich ~=13.5",
"pyzmq >=24",
"pydantic ~=2.7",
]

[project.optional-dependencies]
Expand Down
Loading
Loading