Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Providing a costume collate_fn to DataLoader has no affect #9263

Closed
mayabechlerspeicher opened this issue Apr 30, 2024 · 9 comments
Closed
Labels

Comments

@mayabechlerspeicher
Copy link

🐛 Describe the bug

Pyg DataLoader can receive a custom collate_fn as it extends the torch DataLoader, but in its constructor, it doesn't use the given collate_fn; instead, it always uses Collater.
I'm not sure if this is a bug or if the documentation is wrong, but the Pyg documentation states that any parameter used in torch's DataLoader can be used with Pyg's DataLoader. Still, this collate_fn parameter cannot be used.

So, to actually use a custom collate_fn, do I have to Extend DataLoader to use the given collate_fn?

Thanks.

Versions

2.5.3

@rusty1s
Copy link
Member

rusty1s commented May 2, 2024

Yes, PyG DataLoader is just a wrapper around torch.utils.data.DataLoader with a custom collate_fn. As such, this is the only argument that cannot be overridden. Let me clarify this in the documentation.

@mayabechlerspeicher
Copy link
Author

Thank you.
Could you please clarify why it should be restricted from being overridden?
I believe it is a typical case where a data object has custom keys that one wants to batch differently than concatenation (e.g., when their dimensions do not allow concatenation).

@rusty1s
Copy link
Member

rusty1s commented May 3, 2024

If we would allow overriding collate_fn in PyG's data loader, then this would mean it boils down to torch.utils.data.DataLoader. In this case, I don't see a good reason why you shouldn't use the vanilla PyTorch DataLoader in the first place.

Note that you can also customize concatenation by overriding Data.__cat_dim__ (see the advanced mini-batch tutorial in our documentation).

@mayabechlerspeicher
Copy link
Author

mayabechlerspeicher commented May 12, 2024

Thanks. Nonetheless, the standard DataLoader fails to add a dimension to the edge index as the edges are different sizes for different graphs.

So let's say I am not interested in the batching of the edge indexes in one huge graph, and I just want to wrap multiple graphs together, i.e., to stack the keys of the graphs in the batch, but the tensors of each key can be of different shapes (as in edge indexes). So the gradient computation will be done on the loss over the whole batch, but the forward pass will be done on each graph in the batch separately anyway (so GPU-wise it's not the most efficient it could be, but that's ok).
Because the tensors are not of the same dimensions, you cannot contact them, so Data.cat_dim would not help. what should I do in that case?

@rusty1s
Copy link
Member

rusty1s commented May 13, 2024

Do you mean you simply want to "batch" tensors together by stacking them in a list? I am not yet sure I understand, sorry.

@mayabechlerspeicher
Copy link
Author

Yes. So I have some costum keys in my Data object, that have different dimensions and I cannot stack them, I just want to put them in a list.

@rusty1s
Copy link
Member

rusty1s commented May 22, 2024

I see, that's indeed currently not possible. What we could do is to provide an option in Data to restrict concatenation of certain attributes. Would this work for your use-case?

@devanshamin
Copy link
Contributor

devanshamin commented May 30, 2024

@mayabechlerspeicher You can utilize the exclude_keys in the CustomBatch.from_data_list(...) and add it to the batch however you want.

from typing import List, Optional, Union, Sequence
import random

import torch
from typing_extensions import Self
from torch_geometric.data import Data, Batch, Dataset
from torch_geometric.data.collate import collate
from torch_geometric.data.data import BaseData
from torch_geometric.data.datapipes import DatasetAdapter
from torch_geometric.utils import from_smiles

class CustomBatch(Batch):
    @classmethod
    def from_data_list(
        cls,
        data_list: List[BaseData],
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
    ) -> Self:
        batch, slice_dict, inc_dict = collate(
            cls,
            data_list=data_list,
            increment=True,
            add_batch=not isinstance(data_list[0], Batch),
            follow_batch=follow_batch,
            exclude_keys=exclude_keys,
        )

        batch._num_graphs = len(data_list)  # type: ignore
        batch._slice_dict = slice_dict  # type: ignore
        batch._inc_dict = inc_dict  # type: ignore

        if exclude_keys:
            for key in exclude_keys:
                setattr(batch, key, [getattr(d, key) for d in data_list])

        return batch

class Collate:
    def __init__(
        self, 
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None
    ) -> None:
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

    def __call__(self, batch):
        elem = batch[0]
        if isinstance(elem, Data):
            return CustomBatch.from_data_list(batch, self.follow_batch, self.exclude_keys)

class CustomDataLoader(torch.utils.data.DataLoader):
    def __init__(
        self,
        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
        batch_size: int = 1,
        shuffle: bool = False,
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
        **kwargs,
    ):
        # Remove for PyTorch Lightning:
        kwargs.pop('collate_fn', None)

        # Save for PyTorch Lightning < 1.6:
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

        super().__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=Collate(follow_batch, exclude_keys),
            **kwargs,
        )

if __name__ == '__main__':
    
    smiles_list = [
        'F/C=C/F',
        'COC(=O)[C@@]1(Cc2ccccc2)[C@H]2C(=O)N(C)C(=O)[C@H]2[C@H]2CN=C(SC)N21',
        'CC1=C(CCN2CCC(CC2)C2=NOC3=C2C=CC(F)=C3)C(=O)N2CCCCC2=N1',
        '[H][C@@]12[C@H]3CC[C@H](C3)[C@]1([H])C(=O)N(C[C@@H]1CCCC[C@H]1CN1CCN(CC1)C1=NSC3=CC=CC=C13)C2=O',
    ]
    data_list = []
    for smiles in smiles_list:
        data = from_smiles(smiles)
        data.mol_features = [1] * random.randint(2, 15)
        data_list.append(data)
        print(data)

    print()
    dl = CustomDataLoader(dataset=data_list, batch_size=2, exclude_keys=['mol_features'])
    for batch in dl:
        print(batch)
        print(batch.mol_features)

@rusty1s
Copy link
Member

rusty1s commented Jun 5, 2024

Oh, you are right. Thanks for pointing this out. Completely forgot about this option :)

@rusty1s rusty1s closed this as completed Jun 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants