Skip to content

Commit 9e8eecb

Browse files
authored
feat/pipeop-nn-ft-cls (#381)
1 parent b5ad5a3 commit 9e8eecb

File tree

77 files changed

+383
-2
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+383
-2
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ Collate:
117117
'PipeOpTorchConv.R'
118118
'PipeOpTorchConvTranspose.R'
119119
'PipeOpTorchDropout.R'
120+
'PipeOpTorchFTCLS.R'
120121
'PipeOpTorchFn.R'
121122
'PipeOpTorchHead.R'
122123
'PipeOpTorchIdentity.R'

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ export(PipeOpTorchConvTranspose2D)
102102
export(PipeOpTorchConvTranspose3D)
103103
export(PipeOpTorchDropout)
104104
export(PipeOpTorchELU)
105+
export(PipeOpTorchFTCLS)
105106
export(PipeOpTorchFlatten)
106107
export(PipeOpTorchFn)
107108
export(PipeOpTorchGELU)
@@ -181,6 +182,7 @@ export(model_descriptor_to_learner)
181182
export(model_descriptor_to_module)
182183
export(model_descriptor_union)
183184
export(nn)
185+
export(nn_ft_cls)
184186
export(nn_geglu)
185187
export(nn_graph)
186188
export(nn_merge_cat)

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
* feat: `TorchIngressToken` now also can take a `Selector` as argument `features`.
66
* feat: Added encoders for numericals and categoricals
77
* feat: Added `po("nn_fn")` for calling custom functions in a network.
8+
* feat: Added `po("nn_ft_cls")` for concatenating a CLS token to a tokenized input.
89

910
# mlr3torch 0.2.1
1011

R/PipeOpTorchFTCLS.R

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#' @title CLS Token for FT-Transformer
2+
#' @description
3+
#' Concatenates a CLS token to the input as the last feature.
4+
#' The input shape is expected to be `(batch, n_features, d_token)` and the output shape is
5+
#' `(batch, n_features + 1, d_token)`.
6+
#'
7+
#' This is used in the FT-Transformer.
8+
#'
9+
#' @param d_token (`integer(1)`)\cr
10+
#' The dimension of the embedding.
11+
#' @param initialization (`character(1)`)\cr
12+
#' The initialization method for the embedding weights. Possible values are `"uniform"`
13+
#' and `"normal"`.
14+
#'
15+
#' @references
16+
#' `r format_bib("devlin2018bert")`
17+
#'
18+
#' @export
19+
nn_ft_cls = nn_module(
20+
"nn_ft_cls",
21+
initialize = function(d_token, initialization) {
22+
self$d_token = d_token
23+
# an individual CLS token
24+
self$weight = nn_parameter(torch_empty(d_token))
25+
self$initialization = initialization
26+
self$reset_parameters()
27+
},
28+
reset_parameters = function() {
29+
initialize_token_(self$weight, d = self$d_token, self$initialization)
30+
},
31+
# Repeats the underlying CLS token to create a tensor with the given leading dimensions.
32+
# Used for creating a batch of CLS tokens
33+
expand = function(...) {
34+
leading_dimensions = list(...)
35+
if (length(leading_dimensions) == 0) {
36+
return(self$weight)
37+
}
38+
new_dims = rep(1, length(leading_dimensions) - 1)
39+
return(self$weight$view(c(new_dims, -1))$expand(c(leading_dimensions, -1)))
40+
},
41+
forward = function(input) {
42+
return(torch_cat(list(input, self$expand(input$shape[1], 1)), dim = 2))
43+
}
44+
)
45+
46+
#' @title CLS Token for FT-Transformer
47+
#' @inherit nn_ft_cls description
48+
#' @section nn_module:
49+
#' Calls [`nn_ft_cls()`] when trained.
50+
#' @templateVar id nn_ft_cls
51+
#' @template pipeop_torch
52+
#' @template pipeop_torch_example
53+
#' @export
54+
PipeOpTorchFTCLS = R6::R6Class("PipeOpTorchFTCLS",
55+
inherit = PipeOpTorch,
56+
public = list(
57+
#' @description Creates a new instance of this [R6][R6::R6Class] class.
58+
#' @template params_pipelines
59+
initialize = function(id = "nn_ft_cls", param_vals = list()) {
60+
param_set = ps(
61+
initialization = p_fct(tags = c("train"), levels = c("uniform", "normal"), default = "uniform")
62+
)
63+
64+
super$initialize(
65+
id = id,
66+
module_generator = nn_ft_cls,
67+
param_vals = param_vals,
68+
param_set = param_set
69+
)
70+
}
71+
),
72+
private = list(
73+
.shapes_out = function(shapes_in, param_vals, task) {
74+
if (length(shapes_in$input) != 3) {
75+
stop("Input tensor must have 3 dimensions.")
76+
}
77+
shapes_in[[1]][2] = shapes_in[[1]][2] + 1
78+
return(shapes_in)
79+
},
80+
.shape_dependent_params = function(shapes_in, param_vals, task) {
81+
param_vals$d_token = shapes_in$input[3]
82+
return(param_vals)
83+
}
84+
)
85+
)
86+
87+
#' @include aaa.R
88+
register_po("nn_ft_cls", PipeOpTorchFTCLS)

R/PipeOpTorchTokenizer.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ PipeOpTorchTokenizerNum = R6Class("PipeOpTorchTokenizerNum",
6868
#' @name nn_tokenizer_num
6969
#' @description
7070
#' Tokenizes numeric features into a dense embedding.
71+
#' For an input of shape `(batch, n_features)` the output shape is `(batch, n_features, d_token)`.
7172
#' @param n_features (`integer(1)`)\cr
7273
#' The number of features.
7374
#' @param d_token (`integer(1)`)\cr
@@ -119,6 +120,7 @@ nn_tokenizer_num = nn_module(
119120
#' @name nn_tokenizer_categ
120121
#' @description
121122
#' Tokenizes categorical features into a dense embedding.
123+
#' For an input of shape `(batch, n_features)` the output shape is `(batch, n_features, d_token)`.
122124
#' @param cardinalities (`integer()`)\cr
123125
#' The number of categories for each feature.
124126
#' @param d_token (`integer(1)`)\cr

man/mlr_pipeops_nn_adaptive_avg_pool1d.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_pipeops_nn_adaptive_avg_pool2d.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_pipeops_nn_adaptive_avg_pool3d.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_pipeops_nn_avg_pool1d.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_pipeops_nn_avg_pool2d.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)