diff --git a/proteinflow/data/torch.py b/proteinflow/data/torch.py index a31e6ca..1224d99 100644 --- a/proteinflow/data/torch.py +++ b/proteinflow/data/torch.py @@ -130,6 +130,7 @@ def from_args( cut_edges=False, require_antigen=False, require_light_chain=False, + require_no_light_chain=False, require_heavy_chain=False, *args, **kwargs, @@ -191,6 +192,8 @@ def from_args( if `True`, only entries with an antigen will be included (used if the dataset is SAbDab) require_light_chain : bool, default False if `True`, only entries with a light chain will be included (used if the dataset is SAbDab) + require_no_light_chain : bool, default False + if `True`, only entries without a light chain will be included (used if the dataset is SAbDab) require_heavy_chain : bool, default False if `True`, only entries with a heavy chain will be included (used if the dataset is SAbDab) *args @@ -225,6 +228,7 @@ def from_args( cut_edges=cut_edges, require_antigen=require_antigen, require_light_chain=require_light_chain, + require_no_light_chain=require_no_light_chain, require_heavy_chain=require_heavy_chain, ) return ProteinLoader( @@ -323,6 +327,7 @@ def __init__( antigen_patch_size=128, require_antigen=False, require_light_chain=False, + require_no_light_chain=False, require_heavy_chain=False, ): """Initialize the dataset. @@ -398,6 +403,8 @@ def __init__( if `True`, only entries with an antigen will be included (used if the dataset is SAbDab) require_light_chain : bool, default False if `True`, only entries with a light chain will be included (used if the dataset is SAbDab) + require_no_light_chain : bool, default False + if `True`, only entries without a light chain will be included (used if the dataset is SAbDab) requre_heavy_chain : bool, default False if `True`, only entries with a heavy chain will be included (used if the dataset is SAbDab) @@ -538,7 +545,10 @@ def __init__( if require_antigen or require_light_chain: to_exclude.update( self._exclude_by_chains( - require_antigen, require_light_chain, require_heavy_chain + require_antigen, + require_light_chain, + require_no_light_chain, + require_heavy_chain, ) ) if self.clusters is not None: @@ -604,7 +614,11 @@ def _check_chain_types(self, file): return chain_types def _exclude_by_chains( - self, require_antigen, require_light_chain, require_heavy_chain + self, + require_antigen, + require_light_chain, + require_no_light_chain, + require_heavy_chain, ): """Exclude entries that do not have an antigen or a light chain.""" to_exclude = set() @@ -617,6 +631,8 @@ def _exclude_by_chains( to_exclude.add(id) if require_light_chain and "light" not in chain_types: to_exclude.add(id) + if require_no_light_chain and "light" in chain_types: + to_exclude.add(id) if require_heavy_chain and "heavy" not in chain_types: to_exclude.add(id) return to_exclude