Skip to content

Commit 307ab40

Browse files
authored
BREAKING_CHANGE: change format for binary classification heads (#385)
1 parent 9e8eecb commit 307ab40

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+597
-187
lines changed

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ S3method(marshal_model,learner_torch_model)
3636
S3method(materialize,data.frame)
3737
S3method(materialize,lazy_tensor)
3838
S3method(materialize,list)
39+
S3method(output_dim_for,TaskClassif)
40+
S3method(output_dim_for,TaskRegr)
3941
S3method(print,ModelDescriptor)
4042
S3method(print,Select)
4143
S3method(print,TorchIngressToken)
@@ -194,6 +196,7 @@ export(nn_squeeze)
194196
export(nn_tokenizer_categ)
195197
export(nn_tokenizer_num)
196198
export(nn_unsqueeze)
199+
export(output_dim_for)
197200
export(pipeop_preproc_torch)
198201
export(replace_head)
199202
export(select_all)

NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
* feat: Added encoders for numericals and categoricals
77
* feat: Added `po("nn_fn")` for calling custom functions in a network.
88
* feat: Added `po("nn_ft_cls")` for concatenating a CLS token to a tokenized input.
9+
* BREAKING_CHANGE: The output dimension of neural networks for binary classification tasks is now
10+
expected to be 1 and not 2 as before. The behavior of `nn("head")` was also changed to match this.
11+
This means that for binary classification tasks, `t_loss("cross_entropy")` now generates
12+
`nn_bce_with_logits_loss` instead of `nn_cross_entropy_loss`.
13+
This also came with a reparametrization of the `t_loss("cross_entropy")` loss (thanks to @tdhock, #374).
914

1015
# mlr3torch 0.2.1
1116

R/LearnerTorch.R

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,21 @@
2929
#' To do so, you just need to include `epochs = to_tune(upper = <upper>, internal = TRUE)` in the search space,
3030
#' where `<upper>` is the maximally allowed number of epochs, and configure the early stopping.
3131
#'
32+
#' @section Network Head and Target Encoding:
33+
#' Torch learners are expected to have the following output:
34+
#' * binary classification: `c(batch_size, 1)`, representing the logits for the positive class.
35+
#' * multiclass classification: `c(batch_size, n_classes)`, representing the logits for all classes.
36+
#' * regression: `c(batch_size, 1)` representing the response prediction.
37+
#'
38+
#' Furthermore, the target encoding is expected to be as follows:
39+
#' * regression: The `numeric` target variable of a [`TaskRegr`][mlr3::TaskRegr] is encoded as a
40+
#' [`torch_float`][torch::torch_float] with shape `c(batch_size, 1)`.
41+
#' * binary classification: The `factor` target variable of a [`TaskClassif`][mlr3::TaskClassif] is encoded as a
42+
#' [`torch_float`][torch::torch_float] with shape `(batch_size, 1)` where the positive class is `1` and the negative
43+
#' class is `0`.
44+
#' * multi-class classification: The `factor` target variable of a [`TaskClassif`][mlr3::TaskClassif] is a label-encoded
45+
#' [`torch_long`][torch::torch_long] with shape `(batch_size, n_classes)` starting at `1` and ending at `n_classes`.
46+
#'
3247
#' @template param_id
3348
#' @template param_task_type
3449
#' @template param_param_vals
@@ -81,8 +96,8 @@
8196
#' ([`Task`][mlr3::Task], `list()`) -> [`nn_module`][torch::nn_module]\cr
8297
#' Construct a [`torch::nn_module`] object for the given task and parameter values, i.e. the neural network that
8398
#' is trained by the learner.
84-
#' For classification, the output of this network are expected to be the scores before the application of the
85-
#' final softmax layer.
99+
#' Note that a specific output shape is expected from the returned network, see section *Network Head and Target Encoding*.
100+
#' You can use [`output_dim_for()`] to obtain the correct output dimension for a given task.
86101
#' * `.dataset(task, param_vals)`\cr
87102
#' ([`Task`][mlr3::Task], `list()`) -> [`torch::dataset`]\cr
88103
#' Create the dataset for the task.
@@ -92,17 +107,19 @@
92107
#' * `y` is the target tensor.
93108
#' * `.index` are the indices of the batch (`integer()` or a `torch_int()`).
94109
#'
110+
#' For information on the expected target encoding of `y`, see section *Network Head and Target Encoding*.
95111
#' Moreover, one needs to pay attention respect the row ids of the provided task.
112+
#' It is recommended to relu on [`task_dataset`] for creating the [`dataset`][torch::dataset].
96113
#'
97114
#' It is also possible to overwrite the private `.dataloader()` method.
98115
#' This must respect the dataloader parameters from the [`ParamSet`][paradox::ParamSet].
99116
#'
100117
#' * `.dataloader(dataset, param_vals)`\cr
101118
#' ([`Task`][mlr3::Task], `list()`) -> [`torch::dataloader`]\cr
102119
#' Create a dataloader from the task.
103-
#' Needs to respect at least `batch_size` and `shuffle` (otherwise predictions can be permuted).
120+
#' Needs to respect at least `batch_size` and `shuffle` (otherwise predictions will be incorrectly ordered).
104121
#'
105-
#' To change the predict types, the it is possible to overwrite the method below:
122+
#' To change the predict types, it is possible to overwrite the method below:
106123
#'
107124
#' * `.encode_prediction(predict_tensor, task)`\cr
108125
#' ([`torch_tensor`][torch::torch_tensor], [`Task`][mlr3::Task]) -> `list()`\cr

R/LearnerTorchFeatureless.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ LearnerTorchFeatureless = R6Class("LearnerTorchFeatureless",
4141
),
4242
private = list(
4343
.network = function(task, param_vals) {
44-
nn_featureless(nout = get_nout(task))
44+
nn_featureless(nout = output_dim_for(task))
4545
},
4646
.dataset = function(task, dataset) {
4747
dataset_featureless(task)
@@ -52,7 +52,7 @@ LearnerTorchFeatureless = R6Class("LearnerTorchFeatureless",
5252
dataset_featureless = dataset(
5353
initialize = function(task) {
5454
self$task = task
55-
self$target_batchgetter = get_target_batchgetter(task$task_type)
55+
self$target_batchgetter = get_target_batchgetter(task)
5656
},
5757
.getbatch = function(index) {
5858
target = self$task$data(rows = self$task$row_ids[index], cols = self$task$target_names)

R/LearnerTorchMLP.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ LearnerTorchMLP = R6Class("LearnerTorchMLP",
8787
private = list(
8888
.network = function(task, param_vals) {
8989
# verify_train_task was already called beforehand, so we can make some assumptions
90-
d_out = get_nout(task)
90+
d_out = output_dim_for(task)
9191
d_in = if (single_lazy_tensor(task)) {
9292
private$.get_input_shape(task, param_vals$shape)[2L]
9393
} else {

R/LearnerTorchModel.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ LearnerTorchModel = R6Class("LearnerTorchModel",
129129
dataset = task_dataset(
130130
task,
131131
feature_ingress_tokens = ingress_tokens,
132-
target_batchgetter = get_target_batchgetter(self$task_type)
132+
target_batchgetter = get_target_batchgetter(task)
133133
)
134134
},
135135
.network_stored = NULL,

R/LearnerTorchModule.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
#' nn_one_layer = nn_module("nn_one_layer",
3636
#' initialize = function(task, size_hidden) {
3737
#' self$first = nn_linear(task$n_features, size_hidden)
38-
#' self$second = nn_linear(size_hidden, length(task$class_names))
38+
#' self$second = nn_linear(size_hidden, output_dim_for(task))
3939
#' },
4040
#' # argument x corresponds to the ingress token x
4141
#' forward = function(x) {
@@ -117,7 +117,7 @@ LearnerTorchModule = R6Class("LearnerTorchModule",
117117
dataset = task_dataset(
118118
task,
119119
feature_ingress_tokens = ingress_tokens,
120-
target_batchgetter = get_target_batchgetter(self$task_type)
120+
target_batchgetter = get_target_batchgetter(task)
121121
)
122122
},
123123

R/LearnerTorchTabResNet.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ LearnerTorchTabResNet = R6Class("LearnerTorchTabResNet",
4343
)
4444

4545
private$.param_set_base = ps(
46-
n_blocks = p_int(1, tags = c("train", "required")),
46+
n_blocks = p_int(0, tags = c("train", "required")),
4747
d_block = p_int(1, tags = c("train", "required"))
4848
)
4949
param_set = alist(private$.block$param_set, private$.param_set_base)
@@ -65,10 +65,10 @@ LearnerTorchTabResNet = R6Class("LearnerTorchTabResNet",
6565
private = list(
6666
.block = NULL,
6767
.dataset = function(task, param_vals) {
68-
dataset_num(task, param_vals)
68+
dataset_num(task, param_vals, argname = "num.input")
6969
},
7070
.network = function(task, param_vals) {
71-
graph = po("torch_ingress_num") %>>%
71+
graph = po("torch_ingress_num", id = "num") %>>%
7272
po("nn_linear", out_features = param_vals$d_block) %>>%
7373
po("nn_block", private$.block, n_blocks = param_vals$n_blocks) %>>%
7474
po("nn_head")

R/LearnerTorchVision.R

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ LearnerTorchVision = R6Class("LearnerTorchVision",
5858
private = list(
5959
.module_generator = NULL,
6060
.network = function(task, param_vals) {
61-
nout = get_nout(task)
61+
nout = output_dim_for(task)
6262
if (param_vals$pretrained) {
6363
network = replace_head(private$.module_generator(pretrained = TRUE), nout)
6464
return(network)
@@ -107,126 +107,126 @@ replace_head.VGG = function(network, d_out) {
107107
}
108108

109109
#' @include aaa.R
110-
register_learner("classif.alexnet",
110+
register_learner("classif.alexnet",
111111
function(loss = NULL, optimizer = NULL, callbacks = list()) {
112112
LearnerTorchVision$new("alexnet", torchvision::model_alexnet, "AlexNet",
113113
loss = loss, optimizer = optimizer, callbacks = callbacks)
114114
}
115115
)
116116

117-
# register_learner("classif.inception_v3",
117+
# register_learner("classif.inception_v3",
118118
# function(loss = NULL, optimizer = NULL, callbacks = list()) {
119119
# LearnerTorchVision$new("inception_v3", torchvision::model_inception_v3, "Inception V3",
120120
# loss = loss, optimizer = optimizer, callbacks = callbacks)
121121
# }
122122
# )
123123

124-
register_learner("classif.mobilenet_v2",
124+
register_learner("classif.mobilenet_v2",
125125
function(loss = NULL, optimizer = NULL, callbacks = list()) {
126126
LearnerTorchVision$new("mobilenet_v2", torchvision::model_mobilenet_v2, "Mobilenet V2",
127127
loss = loss, optimizer = optimizer, callbacks = callbacks)
128128
}
129129
)
130130

131-
register_learner("classif.resnet18",
131+
register_learner("classif.resnet18",
132132
function(loss = NULL, optimizer = NULL, callbacks = list()) {
133133
LearnerTorchVision$new("resnet18", torchvision::model_resnet18, "ResNet-18",
134134
loss = loss, optimizer = optimizer, callbacks = callbacks)
135135
}
136136
)
137137

138-
register_learner("classif.resnet34",
138+
register_learner("classif.resnet34",
139139
function(loss = NULL, optimizer = NULL, callbacks = list()) {
140140
LearnerTorchVision$new("resnet34", torchvision::model_resnet34, "ResNet-34",
141141
loss = loss, optimizer = optimizer, callbacks = callbacks)
142142
}
143143
)
144144

145-
register_learner("classif.resnet50",
145+
register_learner("classif.resnet50",
146146
function(loss = NULL, optimizer = NULL, callbacks = list()) {
147147
LearnerTorchVision$new("resnet50", torchvision::model_resnet50, "ResNet-50",
148148
loss = loss, optimizer = optimizer, callbacks = callbacks)
149149
}
150150
)
151151

152-
register_learner("classif.resnet101",
152+
register_learner("classif.resnet101",
153153
function(loss = NULL, optimizer = NULL, callbacks = list()) {
154154
LearnerTorchVision$new("resnet101", torchvision::model_resnet101, "ResNet-101",
155155
loss = loss, optimizer = optimizer, callbacks = callbacks)
156156
}
157157
)
158158

159-
register_learner("classif.resnet152",
159+
register_learner("classif.resnet152",
160160
function(loss = NULL, optimizer = NULL, callbacks = list()) {
161161
LearnerTorchVision$new("resnet152", torchvision::model_resnet152, "ResNet-152",
162162
loss = loss, optimizer = optimizer, callbacks = callbacks)
163163
}
164164
)
165165

166-
register_learner("classif.resnext101_32x8d",
166+
register_learner("classif.resnext101_32x8d",
167167
function(loss = NULL, optimizer = NULL, callbacks = list()) {
168168
LearnerTorchVision$new("resnext101_32x8d", torchvision::model_resnext101_32x8d, "ResNeXt-101 32x8d",
169169
loss = loss, optimizer = optimizer, callbacks = callbacks)
170170
}
171171
)
172172

173-
register_learner("classif.resnext50_32x4d",
173+
register_learner("classif.resnext50_32x4d",
174174
function(loss = NULL, optimizer = NULL, callbacks = list()) {
175175
LearnerTorchVision$new("resnext50_32x4d", torchvision::model_resnext50_32x4d, "ResNeXt-50 32x4d",
176176
loss = loss, optimizer = optimizer, callbacks = callbacks)
177177
}
178178
)
179179

180-
register_learner("classif.vgg11",
180+
register_learner("classif.vgg11",
181181
function(loss = NULL, optimizer = NULL, callbacks = list()) {
182182
LearnerTorchVision$new("vgg11", torchvision::model_vgg11, "VGG 11",
183183
loss = loss, optimizer = optimizer, callbacks = callbacks)
184184
}
185185
)
186186

187-
register_learner("classif.vgg11_bn",
187+
register_learner("classif.vgg11_bn",
188188
function(loss = NULL, optimizer = NULL, callbacks = list()) {
189189
LearnerTorchVision$new("vgg11_bn", torchvision::model_vgg11_bn, "VGG 11",
190190
loss = loss, optimizer = optimizer, callbacks = callbacks)
191191
}
192192
)
193193

194-
register_learner("classif.vgg13",
194+
register_learner("classif.vgg13",
195195
function(loss = NULL, optimizer = NULL, callbacks = list()) {
196196
LearnerTorchVision$new("vgg13", torchvision::model_vgg13, "VGG 13",
197197
loss = loss, optimizer = optimizer, callbacks = callbacks)
198198
}
199199
)
200200

201-
register_learner("classif.vgg13_bn",
201+
register_learner("classif.vgg13_bn",
202202
function(loss = NULL, optimizer = NULL, callbacks = list()) {
203203
LearnerTorchVision$new("vgg13_bn", torchvision::model_vgg13_bn, "VGG 13",
204204
loss = loss, optimizer = optimizer, callbacks = callbacks)
205205
}
206206
)
207207

208-
register_learner("classif.vgg16",
208+
register_learner("classif.vgg16",
209209
function(loss = NULL, optimizer = NULL, callbacks = list()) {
210210
LearnerTorchVision$new("vgg16", torchvision::model_vgg16, "VGG 16",
211211
loss = loss, optimizer = optimizer, callbacks = callbacks)
212212
}
213213
)
214214

215-
register_learner("classif.vgg16_bn",
215+
register_learner("classif.vgg16_bn",
216216
function(loss = NULL, optimizer = NULL, callbacks = list()) {
217217
LearnerTorchVision$new("vgg16_bn", torchvision::model_vgg16_bn, "VGG 16",
218218
loss = loss, optimizer = optimizer, callbacks = callbacks)
219219
}
220220
)
221221

222-
register_learner("classif.vgg19",
222+
register_learner("classif.vgg19",
223223
function(loss = NULL, optimizer = NULL, callbacks = list()) {
224224
LearnerTorchVision$new("vgg19", torchvision::model_vgg19, "VGG 19",
225225
loss = loss, optimizer = optimizer, callbacks = callbacks)
226226
}
227227
)
228228

229-
register_learner("classif.vgg19_bn",
229+
register_learner("classif.vgg19_bn",
230230
function(loss = NULL, optimizer = NULL, callbacks = list()) {
231231
LearnerTorchVision$new("vgg19_bn", torchvision::model_vgg19_bn, "VGG 19",
232232
loss = loss, optimizer = optimizer, callbacks = callbacks)

R/PipeOpTorch.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
#' self$output = if (task$task_type == "regr") {
9999
#' torch::nn_linear(d_hidden, 1)
100100
#' } else if (task$task_type == "classif") {
101-
#' torch::nn_linear(d_hidden, length(task$class_names))
101+
#' torch::nn_linear(d_hidden, output_dim_for(task))
102102
#' }
103103
#' },
104104
#' forward = function(x) {

0 commit comments

Comments
 (0)