Skip to content

Commit b51f79a

Browse files
committed
Remove conline::output()
1 parent f9762f2 commit b51f79a

File tree

4 files changed

+118
-76
lines changed

4 files changed

+118
-76
lines changed

R/online.R

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,15 +456,61 @@ online <- function(y, experts, tau,
456456
model_instance$learn()
457457

458458
# Generate output
459-
model <- model_instance$output()
459+
model <- list(
460+
predictions = model_instance$predictions,
461+
predictions_got_sorted = model_instance$predictions_got_sorted,
462+
weights = model_instance$weights,
463+
forecaster_loss = model_instance$loss_for,
464+
experts_loss = model_instance$loss_exp,
465+
past_performance = model_instance$past_performance,
466+
opt_index = model_instance$opt_index + 1, # Respect one-based indexing
467+
parametergrid = model_instance$params,
468+
params_basis_pr = model_instance$params_basis_pr,
469+
params_basis_mv = model_instance$params_basis_mv,
470+
params_hat_pr = model_instance$params_hat_pr,
471+
params_hat_mv = model_instance$params_hat_mv
472+
)
460473

461-
model$specification[["data"]] <-
474+
model[["specification"]] <-
462475
list(
463-
y = model_instance$y,
464-
experts = model_instance$experts,
465-
tau = model_instance$tau
476+
data =
477+
list(
478+
y = model_instance$y,
479+
experts = model_instance$experts,
480+
tau = model_instance$tau
481+
),
482+
objects =
483+
list(
484+
weights_tmp = model_instance$weights_tmp,
485+
predictions_grid = model_instance$predictions_grid,
486+
cum_performance = model_instance$cum_performance,
487+
hat_pr = model_instance$hat_pr,
488+
hat_mv = model_instance$hat_mv,
489+
basis_pr = model_instance$basis_pr,
490+
basis_mv = model_instance$basis_mv,
491+
V = model_instance$V,
492+
E = model_instance$E,
493+
eta = model_instance$eta,
494+
R = model_instance$R,
495+
beta = model_instance$beta,
496+
beta0field = model_instance$beta0field
497+
),
498+
parameters =
499+
list(
500+
lead_time = model_instance$lead_time,
501+
loss_function = model_instance$loss_function,
502+
loss_parameter = model_instance$loss_parameter,
503+
loss_gradient = model_instance$loss_gradient,
504+
method = model_instance$method,
505+
forget_past_performance = model_instance$forget_past_performance,
506+
allow_quantile_crossing = model_instance$allow_quantile_crossing,
507+
save_past_performance = model_instance$save_past_performance,
508+
save_predictions_grid = model_instance$save_predictions_grid
509+
)
466510
)
467511

512+
attr(model, "class") <- c("online", "list")
513+
468514
model_instance$teardown()
469515
rm(model_instance)
470516
model <- post_process_model(model, names)

R/online_update.R

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,62 @@ update.online <- function(object,
7171
new_experts
7272
)
7373
model_instance$learn()
74-
object <- model_instance$output()
7574

76-
object$specification[["data"]] <-
75+
object <- list(
76+
predictions = model_instance$predictions,
77+
predictions_got_sorted = model_instance$predictions_got_sorted,
78+
weights = model_instance$weights,
79+
forecaster_loss = model_instance$loss_for,
80+
experts_loss = model_instance$loss_exp,
81+
past_performance = model_instance$past_performance,
82+
opt_index = model_instance$opt_index + 1, # Respect one-based indexing
83+
parametergrid = model_instance$params,
84+
params_basis_pr = model_instance$params_basis_pr,
85+
params_basis_mv = model_instance$params_basis_mv,
86+
params_hat_pr = model_instance$params_hat_pr,
87+
params_hat_mv = model_instance$params_hat_mv
88+
)
89+
90+
object[["specification"]] <-
7791
list(
78-
y = model_instance$y,
79-
experts = model_instance$experts,
80-
tau = model_instance$tau
92+
data =
93+
list(
94+
y = model_instance$y,
95+
experts = model_instance$experts,
96+
tau = model_instance$tau
97+
),
98+
objects =
99+
list(
100+
weights_tmp = model_instance$weights_tmp,
101+
predictions_grid = model_instance$predictions_grid,
102+
cum_performance = model_instance$cum_performance,
103+
hat_pr = model_instance$hat_pr,
104+
hat_mv = model_instance$hat_mv,
105+
basis_pr = model_instance$basis_pr,
106+
basis_mv = model_instance$basis_mv,
107+
V = model_instance$V,
108+
E = model_instance$E,
109+
eta = model_instance$eta,
110+
R = model_instance$R,
111+
beta = model_instance$beta,
112+
beta0field = model_instance$beta0field
113+
),
114+
parameters =
115+
list(
116+
lead_time = model_instance$lead_time,
117+
loss_function = model_instance$loss_function,
118+
loss_parameter = model_instance$loss_parameter,
119+
loss_gradient = model_instance$loss_gradient,
120+
method = model_instance$method,
121+
forget_past_performance = model_instance$forget_past_performance,
122+
allow_quantile_crossing = model_instance$allow_quantile_crossing,
123+
save_past_performance = model_instance$save_past_performance,
124+
save_predictions_grid = model_instance$save_predictions_grid
125+
)
81126
)
82127

128+
attr(object, "class") <- c("online", "list")
129+
83130
model_instance$teardown()
84131
rm(model_instance)
85132
object <- post_process_model(model = object, names = names)

src/conline.cpp

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -551,71 +551,6 @@ void conline::learn()
551551
clock.tock("core");
552552
}
553553

554-
Rcpp::List conline::output()
555-
{
556-
clock.tick("wrangle");
557-
558-
// 1-Indexing for R-Output
559-
opt_index += 1;
560-
561-
Rcpp::List model_data = Rcpp::List::create(
562-
Rcpp::Named("y") = y,
563-
Rcpp::Named("experts") = experts,
564-
Rcpp::Named("tau") = tau);
565-
566-
Rcpp::List model_parameters = Rcpp::List::create(
567-
Rcpp::Named("lead_time") = lead_time,
568-
Rcpp::Named("loss_function") = loss_function,
569-
Rcpp::Named("loss_parameter") = loss_parameter,
570-
Rcpp::Named("loss_gradient") = loss_gradient,
571-
Rcpp::Named("method") = method,
572-
Rcpp::Named("forget_past_performance") = forget_past_performance,
573-
Rcpp::Named("allow_quantile_crossing") = allow_quantile_crossing,
574-
Rcpp::Named("save_past_performance") = save_past_performance,
575-
Rcpp::Named("save_predictions_grid") = save_predictions_grid);
576-
577-
Rcpp::List model_objects = Rcpp::List::create(
578-
Rcpp::Named("weights_tmp") = weights_tmp,
579-
Rcpp::Named("predictions_grid") = predictions_grid,
580-
Rcpp::Named("cum_performance") = cum_performance,
581-
Rcpp::Named("hat_pr") = hat_pr,
582-
Rcpp::Named("hat_mv") = hat_mv,
583-
Rcpp::Named("basis_pr") = basis_pr,
584-
Rcpp::Named("basis_mv") = basis_mv,
585-
Rcpp::Named("V") = V,
586-
Rcpp::Named("E") = E,
587-
Rcpp::Named("eta") = eta,
588-
Rcpp::Named("R") = R,
589-
Rcpp::Named("beta") = beta,
590-
Rcpp::Named("beta0field") = beta0field);
591-
592-
Rcpp::List model_spec = Rcpp::List::create(
593-
// Rcpp::Named("data") = model_data,
594-
Rcpp::Named("parameters") = model_parameters,
595-
Rcpp::Named("objects") = model_objects);
596-
597-
Rcpp::List out = Rcpp::List::create(
598-
Rcpp::Named("predictions") = predictions,
599-
Rcpp::Named("predictions_got_sorted") = predictions_got_sorted,
600-
Rcpp::Named("weights") = weights,
601-
Rcpp::Named("forecaster_loss") = loss_for,
602-
Rcpp::Named("experts_loss") = loss_exp,
603-
Rcpp::Named("past_performance") = past_performance,
604-
Rcpp::Named("opt_index") = opt_index,
605-
Rcpp::Named("parametergrid") = params,
606-
Rcpp::Named("params_basis_pr") = params_basis_pr,
607-
Rcpp::Named("params_basis_mv") = params_basis_mv,
608-
Rcpp::Named("params_hat_pr") = params_hat_pr,
609-
Rcpp::Named("params_hat_mv") = params_hat_mv,
610-
Rcpp::Named("specification") = model_spec);
611-
612-
out.attr("class") = "online";
613-
614-
clock.tock("wrangle");
615-
// Rcpp::List out;
616-
return out;
617-
}
618-
619554
void conline::init_update(
620555
Rcpp::List &object,
621556
arma::mat &new_y,

src/conline_exports.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@ RCPP_MODULE(conlineEx)
1818
.field("loss_gradient", &conline::loss_gradient)
1919
.field("method", &conline::method)
2020
.field("save_past_performance", &conline::save_past_performance)
21+
.field("past_performance", &conline::past_performance)
22+
.field("cum_performance", &conline::cum_performance)
2123
.field("save_predictions_grid", &conline::save_predictions_grid)
24+
.field("predictions_grid", &conline::predictions_grid)
25+
.field("predictions", &conline::predictions)
26+
.field("predictions_got_sorted", &conline::predictions_got_sorted)
2227
.field("forget_past_performance", &conline::forget_past_performance)
2328
.field("allow_quantile_crossing", &conline::allow_quantile_crossing)
2429
.field("trace", &conline::trace)
@@ -27,19 +32,28 @@ RCPP_MODULE(conlineEx)
2732
.field("hat_pr", &conline::hat_pr)
2833
.field("hat_mv", &conline::hat_mv)
2934
.field("w0", &conline::w0)
35+
.field("beta0field", &conline::beta0field)
36+
.field("beta", &conline::beta)
3037
.field("weights", &conline::weights)
38+
.field("weights_tmp", &conline::weights_tmp)
3139
.field("R0", &conline::R0)
40+
.field("V", &conline::V)
41+
.field("E", &conline::E)
42+
.field("R", &conline::R)
43+
.field("loss_exp", &conline::loss_exp)
44+
.field("loss_for", &conline::loss_for)
45+
.field("eta", &conline::eta)
3246
.field("params", &conline::params)
3347
.field("params_basis_pr", &conline::params_basis_pr)
3448
.field("params_basis_mv", &conline::params_basis_mv)
3549
.field("params_hat_pr", &conline::params_hat_pr)
3650
.field("params_hat_mv", &conline::params_hat_mv)
51+
.field("opt_index", &conline::opt_index)
3752
.field("loss_array", &conline::loss_array)
3853
.field("regret_array", &conline::regret_array)
3954
.method("set_defaults", &conline::set_defaults)
4055
.method("set_grid_objects", &conline::set_grid_objects)
4156
.method("learn", &conline::learn)
42-
.method("output", &conline::output)
4357
.method("init_update", &conline::init_update)
4458
.method("getT", &conline::getT)
4559
.method("getD", &conline::getD)

0 commit comments

Comments
 (0)