|
1 | 1 | """Naive recommender for hybrid spaces."""
|
2 | 2 |
|
3 | 3 | import warnings
|
4 |
| -from typing import TYPE_CHECKING, ClassVar, Optional, cast |
| 4 | +from typing import ClassVar, Optional |
5 | 5 |
|
6 | 6 | import pandas as pd
|
7 | 7 | from attrs import define, evolve, field, fields
|
|
15 | 15 | from baybe.searchspace import SearchSpace, SearchSpaceType
|
16 | 16 | from baybe.utils.dataframe import to_tensor
|
17 | 17 |
|
18 |
| -if TYPE_CHECKING: |
19 |
| - from torch import Tensor |
20 |
| - |
21 | 18 |
|
22 | 19 | @define
|
23 | 20 | class NaiveHybridSpaceRecommender(PureRecommender):
|
@@ -119,7 +116,7 @@ def recommend( # noqa: D102
|
119 | 116 | # will then be attached to every discrete point when the acquisition function
|
120 | 117 | # is evaluated.
|
121 | 118 | cont_part = searchspace.continuous.samples_random(1)
|
122 |
| - cont_part_tensor = cast(Tensor, to_tensor(cont_part)).unsqueeze(-2) |
| 119 | + cont_part_tensor = to_tensor(cont_part).unsqueeze(-2) |
123 | 120 |
|
124 | 121 | # Get discrete candidates. The metadata flags are ignored since the search space
|
125 | 122 | # is hybrid
|
@@ -154,7 +151,7 @@ def recommend( # noqa: D102
|
154 | 151 | # Get one random discrete point that will be attached when evaluating the
|
155 | 152 | # acquisition function in the discrete space.
|
156 | 153 | disc_part = searchspace.discrete.comp_rep.loc[disc_rec_idx].sample(1)
|
157 |
| - disc_part_tensor = cast(Tensor, to_tensor(disc_part)).unsqueeze(-2) |
| 154 | + disc_part_tensor = to_tensor(disc_part).unsqueeze(-2) |
158 | 155 |
|
159 | 156 | # Setup a fresh acquisition function for the continuous recommender
|
160 | 157 | self.cont_recommender._setup_botorch_acqf(searchspace, train_x, train_y)
|
|
0 commit comments