diff --git a/src/lighteval/data.py b/src/lighteval/data.py index 7cb105e6..24ad529b 100644 --- a/src/lighteval/data.py +++ b/src/lighteval/data.py @@ -26,7 +26,15 @@ import torch from torch.utils.data import Dataset -from torch.utils.data.distributed import DistributedSampler, T_co +from packaging import version + +torch_version = torch.__version__ + +if version.parse(torch_version) >= version.parse("2.5.0"): + from torch.utils.data.distributed import DistributedSampler, _T_co +else: + from torch.utils.data.distributed import DistributedSampler + from torch.utils.data.distributed import T_co as _T_co from lighteval.tasks.requests import ( GreedyUntilRequest, @@ -318,7 +326,7 @@ class GenDistributedSampler(DistributedSampler): as our samples are sorted by length. """ - def __iter__(self) -> Iterator[T_co]: + def __iter__(self) -> Iterator[_T_co]: if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator()