Skip to content

Commit

Permalink
Fix distributed sampler initialization and exceeded sampler warning…
Browse files Browse the repository at this point in the history
… false positives (#1270)

* Fix false positives about 'exceeded' warning

* Fix _maybe_init_distributed
  • Loading branch information
pzelasko authored Jan 24, 2024
1 parent c678849 commit 9f4bfa1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
25 changes: 15 additions & 10 deletions lhotse/dataset/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from lhotse.cut import Cut, CutSet
from lhotse.lazy import Dillable
from lhotse.utils import Seconds, is_none_or_gt
from lhotse.utils import Seconds, ifnone, is_none_or_gt


class CutSampler(Sampler, Dillable):
Expand Down Expand Up @@ -101,17 +101,22 @@ def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]
assert world_size >= 1
if rank is not None:
assert rank >= 0

# Order of precedence:
# 1. When world size or rank are explicitly provided, we will use them.
# 2. Next, check WORLD_SIZE and RANK env variables; yes? use them.
# 3. Next, check if torch.distributed is initialized and has them set; yes? use them.
# 4. If none of those are available, rank=0 and world_size=1.
if "WORLD_SIZE" in os.environ and "RANK" in os.environ:
# If deepspeed launcher is being used, it will set the env variables automatically.
self.world_size = int(os.environ["WORLD_SIZE"])
self.rank = int(os.environ["RANK"])
return
if not dist.is_available() or not dist.is_initialized():
self.world_size = 1 if world_size is None else world_size
self.rank = 0 if rank is None else rank
return
self.world_size = dist.get_world_size() if world_size is None else world_size
self.rank = dist.get_rank() if rank is None else rank
self.world_size = ifnone(world_size, int(os.environ["WORLD_SIZE"]))
self.rank = ifnone(rank, int(os.environ["RANK"]))
elif dist.is_available() and dist.is_initialized():
self.world_size = ifnone(world_size, dist.get_world_size())
self.rank = ifnone(rank, dist.get_rank())
else:
self.world_size = ifnone(world_size, 1)
self.rank = ifnone(rank, 0)
assert self.rank < self.world_size

def set_epoch(self, epoch: int) -> None:
Expand Down
4 changes: 2 additions & 2 deletions lhotse/dataset/sampling/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,9 @@ def detuplify(
# Did we exceed the max_frames and max_cuts constraints?
if self.time_constraint.close_to_exceeding():
# Yes. Finish sampling this batch.
if self.time_constraint.exceeded():
if self.time_constraint.exceeded() and len(cuts) == 1:
warnings.warn(
"We have exceeded the max_duration constraint during sampling. "
"We have exceeded the max_duration constraint during sampling but have only 1 cut. "
"This is likely because max_duration was set to a very low value ~10s, "
"or you're using a CutSet with very long cuts (e.g. 100s of seconds long)."
)
Expand Down

0 comments on commit 9f4bfa1

Please sign in to comment.