-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Providing a costume collate_fn to DataLoader has no affect #9263
Comments
Yes, PyG |
Thank you. |
If we would allow overriding Note that you can also customize concatenation by overriding |
Thanks. Nonetheless, the standard DataLoader fails to add a dimension to the edge index as the edges are different sizes for different graphs. So let's say I am not interested in the batching of the edge indexes in one huge graph, and I just want to wrap multiple graphs together, i.e., to stack the keys of the graphs in the batch, but the tensors of each key can be of different shapes (as in edge indexes). So the gradient computation will be done on the loss over the whole batch, but the forward pass will be done on each graph in the batch separately anyway (so GPU-wise it's not the most efficient it could be, but that's ok). |
Do you mean you simply want to "batch" tensors together by stacking them in a list? I am not yet sure I understand, sorry. |
Yes. So I have some costum keys in my Data object, that have different dimensions and I cannot stack them, I just want to put them in a list. |
I see, that's indeed currently not possible. What we could do is to provide an option in |
@mayabechlerspeicher You can utilize the from typing import List, Optional, Union, Sequence
import random
import torch
from typing_extensions import Self
from torch_geometric.data import Data, Batch, Dataset
from torch_geometric.data.collate import collate
from torch_geometric.data.data import BaseData
from torch_geometric.data.datapipes import DatasetAdapter
from torch_geometric.utils import from_smiles
class CustomBatch(Batch):
@classmethod
def from_data_list(
cls,
data_list: List[BaseData],
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
) -> Self:
batch, slice_dict, inc_dict = collate(
cls,
data_list=data_list,
increment=True,
add_batch=not isinstance(data_list[0], Batch),
follow_batch=follow_batch,
exclude_keys=exclude_keys,
)
batch._num_graphs = len(data_list) # type: ignore
batch._slice_dict = slice_dict # type: ignore
batch._inc_dict = inc_dict # type: ignore
if exclude_keys:
for key in exclude_keys:
setattr(batch, key, [getattr(d, key) for d in data_list])
return batch
class Collate:
def __init__(
self,
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None
) -> None:
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
def __call__(self, batch):
elem = batch[0]
if isinstance(elem, Data):
return CustomBatch.from_data_list(batch, self.follow_batch, self.exclude_keys)
class CustomDataLoader(torch.utils.data.DataLoader):
def __init__(
self,
dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
batch_size: int = 1,
shuffle: bool = False,
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
**kwargs,
):
# Remove for PyTorch Lightning:
kwargs.pop('collate_fn', None)
# Save for PyTorch Lightning < 1.6:
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
super().__init__(
dataset,
batch_size,
shuffle,
collate_fn=Collate(follow_batch, exclude_keys),
**kwargs,
)
if __name__ == '__main__':
smiles_list = [
'F/C=C/F',
'COC(=O)[C@@]1(Cc2ccccc2)[C@H]2C(=O)N(C)C(=O)[C@H]2[C@H]2CN=C(SC)N21',
'CC1=C(CCN2CCC(CC2)C2=NOC3=C2C=CC(F)=C3)C(=O)N2CCCCC2=N1',
'[H][C@@]12[C@H]3CC[C@H](C3)[C@]1([H])C(=O)N(C[C@@H]1CCCC[C@H]1CN1CCN(CC1)C1=NSC3=CC=CC=C13)C2=O',
]
data_list = []
for smiles in smiles_list:
data = from_smiles(smiles)
data.mol_features = [1] * random.randint(2, 15)
data_list.append(data)
print(data)
print()
dl = CustomDataLoader(dataset=data_list, batch_size=2, exclude_keys=['mol_features'])
for batch in dl:
print(batch)
print(batch.mol_features) |
Oh, you are right. Thanks for pointing this out. Completely forgot about this option :) |
🐛 Describe the bug
Pyg DataLoader can receive a custom collate_fn as it extends the torch DataLoader, but in its constructor, it doesn't use the given collate_fn; instead, it always uses Collater.
I'm not sure if this is a bug or if the documentation is wrong, but the Pyg documentation states that any parameter used in torch's DataLoader can be used with Pyg's DataLoader. Still, this collate_fn parameter cannot be used.
So, to actually use a custom collate_fn, do I have to Extend DataLoader to use the given collate_fn?
Thanks.
Versions
2.5.3
The text was updated successfully, but these errors were encountered: