Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix distributed sampler initialization and exceeded sampler warning false positives #1270

Merged
merged 2 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions lhotse/dataset/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from lhotse.cut import Cut, CutSet
from lhotse.lazy import Dillable
from lhotse.utils import Seconds, is_none_or_gt
from lhotse.utils import Seconds, ifnone, is_none_or_gt


class CutSampler(Sampler, Dillable):
Expand Down Expand Up @@ -101,17 +101,22 @@
assert world_size >= 1
if rank is not None:
assert rank >= 0

# Order of precedence:
# 1. When world size or rank are explicitly provided, we will use them.
# 2. Next, check WORLD_SIZE and RANK env variables; yes? use them.
# 3. Next, check if torch.distributed is initialized and has them set; yes? use them.
# 4. If none of those are available, rank=0 and world_size=1.
if "WORLD_SIZE" in os.environ and "RANK" in os.environ:
# If deepspeed launcher is being used, it will set the env variables automatically.
self.world_size = int(os.environ["WORLD_SIZE"])
self.rank = int(os.environ["RANK"])
return
if not dist.is_available() or not dist.is_initialized():
self.world_size = 1 if world_size is None else world_size
self.rank = 0 if rank is None else rank
return
self.world_size = dist.get_world_size() if world_size is None else world_size
self.rank = dist.get_rank() if rank is None else rank
self.world_size = ifnone(world_size, int(os.environ["WORLD_SIZE"]))
self.rank = ifnone(rank, int(os.environ["RANK"]))

Check warning on line 113 in lhotse/dataset/sampling/base.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/sampling/base.py#L112-L113

Added lines #L112 - L113 were not covered by tests
elif dist.is_available() and dist.is_initialized():
self.world_size = ifnone(world_size, dist.get_world_size())
self.rank = ifnone(rank, dist.get_rank())

Check warning on line 116 in lhotse/dataset/sampling/base.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/sampling/base.py#L115-L116

Added lines #L115 - L116 were not covered by tests
else:
self.world_size = ifnone(world_size, 1)
self.rank = ifnone(rank, 0)
assert self.rank < self.world_size

def set_epoch(self, epoch: int) -> None:
Expand Down
4 changes: 2 additions & 2 deletions lhotse/dataset/sampling/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,9 @@ def detuplify(
# Did we exceed the max_frames and max_cuts constraints?
if self.time_constraint.close_to_exceeding():
# Yes. Finish sampling this batch.
if self.time_constraint.exceeded():
if self.time_constraint.exceeded() and len(cuts) == 1:
warnings.warn(
"We have exceeded the max_duration constraint during sampling. "
"We have exceeded the max_duration constraint during sampling but have only 1 cut. "
"This is likely because max_duration was set to a very low value ~10s, "
"or you're using a CutSet with very long cuts (e.g. 100s of seconds long)."
)
Expand Down
Loading