Skip to content

Commit a2830dc

Browse files
committed
Fix Tensor type cast
This also avoids having to import Tensor
1 parent 424198c commit a2830dc

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

baybe/recommenders/naive.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Naive recommender for hybrid spaces."""
22

33
import warnings
4-
from typing import TYPE_CHECKING, ClassVar, Optional, cast
4+
from typing import ClassVar, Optional
55

66
import pandas as pd
77
from attrs import define, evolve, field, fields
@@ -15,9 +15,6 @@
1515
from baybe.searchspace import SearchSpace, SearchSpaceType
1616
from baybe.utils.dataframe import to_tensor
1717

18-
if TYPE_CHECKING:
19-
from torch import Tensor
20-
2118

2219
@define
2320
class NaiveHybridSpaceRecommender(PureRecommender):
@@ -119,7 +116,7 @@ def recommend( # noqa: D102
119116
# will then be attached to every discrete point when the acquisition function
120117
# is evaluated.
121118
cont_part = searchspace.continuous.samples_random(1)
122-
cont_part_tensor = cast(Tensor, to_tensor(cont_part)).unsqueeze(-2)
119+
cont_part_tensor = to_tensor(cont_part).unsqueeze(-2)
123120

124121
# Get discrete candidates. The metadata flags are ignored since the search space
125122
# is hybrid
@@ -154,7 +151,7 @@ def recommend( # noqa: D102
154151
# Get one random discrete point that will be attached when evaluating the
155152
# acquisition function in the discrete space.
156153
disc_part = searchspace.discrete.comp_rep.loc[disc_rec_idx].sample(1)
157-
disc_part_tensor = cast(Tensor, to_tensor(disc_part)).unsqueeze(-2)
154+
disc_part_tensor = to_tensor(disc_part).unsqueeze(-2)
158155

159156
# Setup a fresh acquisition function for the continuous recommender
160157
self.cont_recommender._setup_botorch_acqf(searchspace, train_x, train_y)

baybe/utils/dataframe.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
from __future__ import annotations
44

55
import logging
6-
from collections.abc import Iterable, Sequence
6+
from collections.abc import Iterable, Iterator, Sequence
77
from typing import (
88
TYPE_CHECKING,
99
Literal,
1010
Optional,
1111
Union,
12+
overload,
1213
)
1314

1415
import numpy as np
@@ -28,7 +29,17 @@
2829
_logger = logging.getLogger(__name__)
2930

3031

31-
def to_tensor(*dfs: pd.DataFrame) -> Union[Tensor, Iterable[Tensor]]:
32+
@overload
33+
def to_tensor(df: pd.DataFrame) -> Tensor:
34+
...
35+
36+
37+
@overload
38+
def to_tensor(*dfs: pd.DataFrame) -> Iterator[Tensor]:
39+
...
40+
41+
42+
def to_tensor(*dfs: pd.DataFrame) -> Union[Tensor, Iterator[Tensor]]:
3243
"""Convert a given set of dataframes into tensors (dropping all indices).
3344
3445
Args:

0 commit comments

Comments
 (0)