Skip to content

Commit dee7949

Browse files
Update dimensionality_reduction_api
1 parent c280946 commit dee7949

File tree

5 files changed

+33
-31
lines changed

5 files changed

+33
-31
lines changed

dimensionality_reduction_api/Dockerfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
FROM python:3.10-slim
2-
LABEL author="Jaime Céspedes Sisniega <cespedes@ifca.unican.es>"
2+
LABEL author="Jaime Céspedes Sisniega <jaime.cespedes@alumnos.unican.es>"
33

44
ARG DATASET
55
ARG PARENT_DIR
@@ -15,8 +15,8 @@ RUN apt-get -y update && \
1515
rm requirements.txt
1616

1717
COPY ${PARENT_DIR}/app ./
18-
COPY ./ml/${DATASET}/encoder.pt ./objects/encoder.pt
19-
COPY ./ml/${DATASET}/transformer.pt ./objects/transformer.pt
18+
COPY ./ml/objects/${DATASET}/encoder.pt ./objects/encoder.pt
19+
COPY ./ml/objects/${DATASET}/transform.pt ./objects/transform.pt
2020

2121
USER app
2222

dimensionality_reduction_api/app/api.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
from settings import (
2020
api_settings,
2121
encoder_settings,
22-
transformer_settings,
22+
transform_settings,
2323
)
2424

2525
dr = DimensionalityReduction(
2626
settings_encoder=encoder_settings,
27-
settings_transformer=transformer_settings,
27+
settings_transform=transform_settings,
2828
)
2929

3030

@@ -33,9 +33,9 @@
3333
status_code=HTTP_200_OK,
3434
)
3535
async def dimensionality_reduction(
36-
data: DimensionalityReductionInputData = Body(
37-
media_type=RequestEncodingType.MULTI_PART,
38-
),
36+
data: DimensionalityReductionInputData = Body(
37+
media_type=RequestEncodingType.MULTI_PART,
38+
),
3939
) -> DimensionalityReductionResponse:
4040
"""Reduce image.
4141
@@ -46,7 +46,7 @@ async def dimensionality_reduction(
4646
"""
4747
image = await data.image
4848
logging.info("Transforming image...")
49-
transformed_image = dr.transform(
49+
transformed_image = dr.apply_transform(
5050
data=image,
5151
)
5252
logging.info("Image transformed.")

dimensionality_reduction_api/app/dr.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torchvision
99
from torch import nn
1010

11-
from settings import EncoderSettings, TransformerSettings
11+
from settings import EncoderSettings, TransformSettings
1212
from utils import SingletonMeta
1313

1414

@@ -70,7 +70,9 @@ def __init__(self, input_size: Tuple[int, int, int], latent_dim: int) -> None:
7070
),
7171
)
7272

73-
def _get_feature_size(self, input_channels: int, input_size: Tuple[int, int]) -> Tuple[int, int, int]:
73+
def _get_feature_size(
74+
self, input_channels: int, input_size: Tuple[int, int]
75+
) -> Tuple[int, int, int]:
7476
# Function to compute the size of the feature maps after convolutional layers
7577
x = torch.randn(1, input_channels, *input_size)
7678
x = self.encoder_conv(x)
@@ -89,19 +91,19 @@ class DimensionalityReduction(metaclass=SingletonMeta):
8991
def __init__(
9092
self: "DimensionalityReduction",
9193
settings_encoder: EncoderSettings,
92-
settings_transformer: TransformerSettings,
94+
settings_transform: TransformSettings,
9395
) -> None:
9496
"""Init method."""
9597
logging.info("Loading image encoder...")
9698
self.encoder = self.load_encoder(
9799
settings=settings_encoder,
98100
)
99101
logging.info("Image encoder loaded.")
100-
logging.info("Loading transformer...")
101-
self.transformer = self.load_transformer(
102-
settings=settings_transformer,
102+
logging.info("Loading transform...")
103+
self.transform = self.load_transform(
104+
settings=settings_transform,
103105
)
104-
logging.info("Transformer loaded.")
106+
logging.info("Transform loaded.")
105107

106108
def load_encoder(
107109
self: "DimensionalityReduction",
@@ -117,16 +119,16 @@ def load_encoder(
117119
)
118120
return encoder
119121

120-
def load_transformer(
122+
def load_transform(
121123
self: "DimensionalityReduction",
122-
settings: EncoderSettings,
124+
settings: TransformSettings
123125
) -> torchvision.transforms.Compose:
124126
"""Load image encoder.
125127
126128
:return encoder
127129
:rtype: dict[str, str | None]
128130
"""
129-
transformer = self._load_transformer(
131+
transformer = self._load_transform(
130132
settings=settings,
131133
)
132134
return transformer
@@ -143,7 +145,7 @@ def encode(self, data: np.ndarray) -> np.ndarray:
143145
encoded = self.encoder(data).numpy()
144146
return encoded
145147

146-
def transform(self, data: np.ndarray) -> torch.Tensor:
148+
def apply_transform(self, data: np.ndarray) -> torch.Tensor:
147149
"""Transform data.
148150
149151
:param data: data
@@ -152,7 +154,7 @@ def transform(self, data: np.ndarray) -> torch.Tensor:
152154
:rtype: np.ndarray
153155
"""
154156

155-
transformed = self.transformer(data).unsqueeze(0)
157+
transformed = self.transform(data).unsqueeze(0)
156158
return transformed
157159

158160
@staticmethod
@@ -173,8 +175,8 @@ def _load_encoder(settings: EncoderSettings) -> Encoder:
173175
return encoder
174176

175177
@staticmethod
176-
def _load_transformer(
177-
settings: TransformerSettings,
178+
def _load_transform(
179+
settings: TransformSettings,
178180
) -> torchvision.transforms.Compose:
179181
transformer = torch.load(
180182
f=settings.FILE_PATH,

dimensionality_reduction_api/app/schemas/dr.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ async def parse_image(cls, data: UploadFile) -> Image:
2323
"""
2424
data = await data.read()
2525
image = Image.open(
26-
fp=BytesIO(
27-
initial_bytes=data,
28-
),
29-
)
26+
fp=BytesIO(
27+
initial_bytes=data,
28+
),
29+
)
3030
return image
3131

3232
class Config(BaseConfig):

dimensionality_reduction_api/app/settings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ class EncoderSettings(BaseSettings):
2222
FILE_PATH: Path = Path("objects/encoder.pt")
2323

2424

25-
class TransformerSettings(BaseSettings):
26-
"""Transformer settings class."""
25+
class TransformSettings(BaseSettings):
26+
"""Transform settings class."""
2727

28-
FILE_PATH: Path = Path("objects/transformer.pt")
28+
FILE_PATH: Path = Path("objects/transform.pt")
2929

3030

3131
api_settings = APISettings()
3232
encoder_settings = EncoderSettings()
33-
transformer_settings = TransformerSettings()
33+
transform_settings = TransformSettings()

0 commit comments

Comments
 (0)