|
1 | 1 | """Naive recommender for hybrid spaces.""" |
2 | 2 |
|
3 | 3 | import warnings |
4 | | -from typing import TYPE_CHECKING, ClassVar, Optional, cast |
| 4 | +from typing import ClassVar, Optional |
5 | 5 |
|
6 | 6 | import pandas as pd |
7 | 7 | from attrs import define, evolve, field, fields |
|
15 | 15 | from baybe.searchspace import SearchSpace, SearchSpaceType |
16 | 16 | from baybe.utils.dataframe import to_tensor |
17 | 17 |
|
18 | | -if TYPE_CHECKING: |
19 | | - from torch import Tensor |
20 | | - |
21 | 18 |
|
22 | 19 | @define |
23 | 20 | class NaiveHybridSpaceRecommender(PureRecommender): |
@@ -119,7 +116,7 @@ def recommend( # noqa: D102 |
119 | 116 | # will then be attached to every discrete point when the acquisition function |
120 | 117 | # is evaluated. |
121 | 118 | cont_part = searchspace.continuous.samples_random(1) |
122 | | - cont_part_tensor = cast(Tensor, to_tensor(cont_part)).unsqueeze(-2) |
| 119 | + cont_part_tensor = to_tensor(cont_part).unsqueeze(-2) |
123 | 120 |
|
124 | 121 | # Get discrete candidates. The metadata flags are ignored since the search space |
125 | 122 | # is hybrid |
@@ -154,7 +151,7 @@ def recommend( # noqa: D102 |
154 | 151 | # Get one random discrete point that will be attached when evaluating the |
155 | 152 | # acquisition function in the discrete space. |
156 | 153 | disc_part = searchspace.discrete.comp_rep.loc[disc_rec_idx].sample(1) |
157 | | - disc_part_tensor = cast(Tensor, to_tensor(disc_part)).unsqueeze(-2) |
| 154 | + disc_part_tensor = to_tensor(disc_part).unsqueeze(-2) |
158 | 155 |
|
159 | 156 | # Setup a fresh acquisition function for the continuous recommender |
160 | 157 | self.cont_recommender._setup_botorch_acqf(searchspace, train_x, train_y) |
|
0 commit comments