Description
Is your feature request related to a problem? Please describe.
With #558 we now have better control over how a pulsemap is processed. From the Kaggle competition it became apparent that many of the top scoring models simply cropped the number of pulses to some fixed number, to reduce the impact of the n^2
term from Self-Attention components.
While their primary way to select pulses was to simply select the first n pulses, I believe it might be interesting to look into other methods of selecting pulses. (Randomly, sorted by charge, sorted by probability of real signal, farthest point sampling etc.)
Describe the solution you'd like
To avoid having to implement many Node Definitions I think it might make sense to make a common class for all cropped nodes
class CroppedNodes(NodeDefinition):
def __init__(self, max_pulses: int, cropping_method: Callable) -> None:
super().__init__()
self.max_pulses = max_pulses
self._cropping_method = cropping_method
def _construct_nodes(self, x: torch.Tensor) -> Data:
x = self._cropping_method(x, self.max_pulses)
return Data(x=x)
Such a structure would also allow to easier re-use the copping methods in other node definitions. (Maybe you want to crop after calculating summary nodes per dom, to make sure you do not get an event which triggered 5k doms.
Describe alternatives you've considered
We could of course just implement each cropping algorithm as a subclass of a common CroppedNodes
class and have the logic restricted to each subclass. But I think the cropping logic is general enough that there is merit to have it as a separate component.