Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

349: Implemented round() function #359

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions temporian/core/event_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,6 +1501,38 @@ def abs(

return abs(self)

def __round__(self):
from temporian.core.operators.unary import round

return round(input=self)

def round(
self: EventSetOrNode,
) -> EventSetOrNode:
"""Rounds the values of an [`EventSet`][temporian.EventSet]'s features to the nearest integer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add another line specifying that only float types are allowed, and the output type will always be the same as the input's

Example:
```python
>>> a = tp.event_set(
... timestamps=[1, 2, 3],
... features={"M": [1.4, 2.6, 3.1], "N": [-1.9, -3.5, 5.8]},
... )
>>> a.round()
indexes: ...
'M': [1, 3, 3]
'N': [-2, -4, 6]
...

```

Returns:
EventSet with rounded feature values.
"""
from temporian.core.operators.unary import round

return round(self)


def add_index(
self: EventSetOrNode, indexes: Union[str, List[str]]
) -> EventSetOrNode:
Expand Down Expand Up @@ -2895,6 +2927,7 @@ def log(self: EventSetOrNode) -> EventSetOrNode:

return log(self)


def moving_count(
self: EventSetOrNode,
window_length: WindowLength,
Expand Down
30 changes: 30 additions & 0 deletions temporian/core/operators/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,30 @@ def get_output_dtype(cls, feature_dtype: DType) -> DType:
return feature_dtype


class RoundOperator(BaseUnaryOperator):
@classmethod
def op_key_definition(cls) -> str:
return "ROUND"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary empty line

@classmethod
def allowed_dtypes(cls) -> List[DType]:
return [
DType.INT32,
DType.INT64,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ints shouldn't be allowed

]

@classmethod
def get_output_dtype(cls, feature_dtype: DType) -> DType:
return feature_dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These look off:

  • allowed_dtypes specifies the types that the operator can consume - should be DType.FLOAT32 and DType.FLOAT64
  • get_output_dtype returns the output dtype, given the input dtype - in this case it should return the same type as received (I said it should return ints in the issue, edited it, needs to return floats because they have a much larger range than ints)



operator_lib.register_operator(InvertOperator)
operator_lib.register_operator(IsNanOperator)
operator_lib.register_operator(NotNanOperator)
operator_lib.register_operator(AbsOperator)
operator_lib.register_operator(LogOperator)
operator_lib.register_operator(RoundOperator)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra newline here too



@compile
Expand Down Expand Up @@ -242,3 +261,14 @@ def log(
return LogOperator(
input=input,
).outputs["output"]


@compile
def round(
input: EventSetOrNode,
) -> EventSetOrNode:
assert isinstance(input, EventSetNode)

return RoundOperator(
input=input,
).outputs["output"]
9 changes: 9 additions & 0 deletions temporian/implementation/numpy/operators/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
NotNanOperator,
IsNanOperator,
LogOperator,
RoundOperator,
)
from temporian.implementation.numpy import implementation_lib
from temporian.implementation.numpy.data.event_set import IndexData
Expand Down Expand Up @@ -77,6 +78,11 @@ def _do_operation(self, feature: np.ndarray) -> np.ndarray:
return np.log(feature)


class RoundNumpyImplementation(BaseUnaryNumpyImplementation):
def _do_operation(self, feature: np.ndarray) -> np.ndarray:
return np.round(feature)


implementation_lib.register_operator_implementation(
AbsOperator, AbsNumpyImplementation
)
Expand All @@ -92,3 +98,6 @@ def _do_operation(self, feature: np.ndarray) -> np.ndarray:
implementation_lib.register_operator_implementation(
LogOperator, LogNumpyImplementation
)
implementation_lib.register_operator_implementation(
RoundOperator, RoundNumpyImplementation
)