Skip to content

Commit 6560c75

Browse files
authored
Fix T_co import bug (#484)
* Fix T_co import bug * Fix styling
1 parent fdb12f4 commit 6560c75

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

src/lighteval/data.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,15 @@
2525
from typing import Iterator, Tuple
2626

2727
import torch
28+
from packaging import version
2829
from torch.utils.data import Dataset
29-
from torch.utils.data.distributed import DistributedSampler, T_co
30+
31+
32+
if version.parse(torch.__version__) >= version.parse("2.5.0"):
33+
from torch.utils.data.distributed import DistributedSampler, _T_co
34+
else:
35+
from torch.utils.data.distributed import DistributedSampler
36+
from torch.utils.data.distributed import T_co as _T_co
3037

3138
from lighteval.tasks.requests import (
3239
GreedyUntilRequest,
@@ -318,7 +325,7 @@ class GenDistributedSampler(DistributedSampler):
318325
as our samples are sorted by length.
319326
"""
320327

321-
def __iter__(self) -> Iterator[T_co]:
328+
def __iter__(self) -> Iterator[_T_co]:
322329
if self.shuffle:
323330
# deterministically shuffle based on epoch and seed
324331
g = torch.Generator()

0 commit comments

Comments
 (0)