forked from isaaccorley/pytorch-enhance
-
Notifications
You must be signed in to change notification settings - Fork 0
/
poutyne_example.py
29 lines (23 loc) · 696 Bytes
/
poutyne_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from torch.utils.data import DataLoader
from poutyne.framework import Model
from torch_enhance.datasets import BSDS300, Set14, Set5
from torch_enhance.models import SRCNN
from torch_enhance import metrics
scale_factor = 2
train_dataset = BSDS300(scale_factor=scale_factor)
val_dataset = Set14(scale_factor=scale_factor)
train_dataloader = DataLoader(train_dataset, batch_size=8)
val_dataloader = DataLoader(val_dataset, batch_size=2)
channels = 3 if train_dataset.color_space == "RGB" else 1
pytorch_network = SRCNN(scale_factor, channels)
model = Model(
pytorch_network,
"sgd",
"mse"
)
model.fit_generator(
train_dataloader,
val_dataloader,
epochs=1
)