Skip to content

Commit

Permalink
add vignette on fine-tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
MMenchero committed Dec 19, 2024
1 parent 1386791 commit c2ee9cc
Show file tree
Hide file tree
Showing 13 changed files with 917 additions and 18 deletions.
6 changes: 3 additions & 3 deletions R/nixtla_client_cross_validation.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
#' @param quantiles Quantiles to forecast. Should be between 0 and 1.
#' @param n_windows Number of windows to evaluate.
#' @param step_size Step size between each cross validation window. If NULL, it will equal the forecast horizon (h).
#' @param finetune_steps Number of steps used to finetune 'TimeGPT' in the new data.
#' @param finetune_depth The depth of the fine-tuning. Accepts an integer value from 1 to 5, where 1 (default) means little fine-tuning and 5 means that the entire model is fine-tuned to your dataset.
#' @param finetune_loss Loss function to use for finetuning. Options are: "default", "mae", "mse", "rmse", "mape", and "smape".
#' @param finetune_steps Number of steps used to fine-tune 'TimeGPT' in the new data.
#' @param finetune_depth The depth of the fine-tuning. Uses a scale from 1 to 5, where 1 means little fine-tuning and 5 means that the entire model is fine-tuned.
#' @param finetune_loss Loss function to use for fine-tuning. Options are: "default", "mae", "mse", "rmse", "mape", and "smape".
#' @param clean_ex_first Clean exogenous signal before making the forecasts using 'TimeGPT'.
#' @param model Model to use, either "timegpt-1" or "timegpt-1-long-horizon". Use "timegpt-1-long-horizon" if you want to forecast more than one seasonal period given the frequency of the data.
#'
Expand Down
6 changes: 3 additions & 3 deletions R/nixtla_client_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
#' @param X_df A tsibble or a data frame with future exogenous variables.
#' @param level The confidence levels (0-100) for the prediction intervals.
#' @param quantiles Quantiles to forecast. Should be between 0 and 1.
#' @param finetune_steps Number of steps used to finetune 'TimeGPT' in the new data.
#' @param finetune_depth The depth of the fine-tuning. Accepts an integer value from 1 to 5, where 1 (default) means little fine-tuning and 5 means that the entire model is fine-tuned to your dataset.
#' @param finetune_loss Loss function to use for finetuning. Options are: "default", "mae", "mse", "rmse", "mape", and "smape".
#' @param finetune_steps Number of steps used to fine-tune 'TimeGPT' in the new data.
#' @param finetune_depth The depth of the fine-tuning. Uses a scale from 1 to 5, where 1 means little fine-tuning and 5 means that the entire model is fine-tuned.
#' @param finetune_loss Loss function to use for fine-tuning. Options are: "default", "mae", "mse", "rmse", "mape", and "smape".
#' @param clean_ex_first Clean exogenous signal before making the forecasts using 'TimeGPT'.
#' @param add_history Return fitted values of the model.
#' @param model Model to use, either "timegpt-1" or "timegpt-1-long-horizon". Use "timegpt-1-long-horizon" if you want to forecast more than one seasonal period given the frequency of the data.
Expand Down
6 changes: 3 additions & 3 deletions R/nixtla_client_historic.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
#' @param target_col Column that contains the target variable.
#' @param level The confidence levels (0-100) for the prediction intervals.
#' @param quantiles Quantiles to forecast. Should be between 0 and 1.
#' @param finetune_steps Number of steps used to finetune 'TimeGPT' in the new data.
#' @param finetune_depth The depth of the fine-tuning. Accepts an integer value from 1 to 5, where 1 (default) means little fine-tuning and 5 means that the entire model is fine-tuned to your dataset.
#' @param finetune_loss Loss function to use for finetuning. Options are: "default", "mae", "mse", "rmse", "mape", and "smape".
#' @param finetune_steps Number of steps used to fine-tune 'TimeGPT' in the new data.
#' @param finetune_depth The depth of the fine-tuning. Uses a scale from 1 to 5, where 1 means little fine-tuning and 5 means that the entire model is fine-tuned.
#' @param finetune_loss Loss function to use for fine-tuning. Options are: "default", "mae", "mse", "rmse", "mape", and "smape".
#' @param clean_ex_first Clean exogenous signal before making the forecasts using 'TimeGPT'.
#' @param model Model to use, either "timegpt-1" or "timegpt-1-long-horizon". Use "timegpt-1-long-horizon" if you want to forecast more than one seasonal period given the frequency of the data.
#'
Expand Down
6 changes: 3 additions & 3 deletions man/nixtla_client_cross_validation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/nixtla_client_forecast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/nixtla_client_historic.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

131 changes: 131 additions & 0 deletions tests/mocks/api.nixtla.io/v2/forecast-038639-POST.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
{
"input_tokens": 600,
"output_tokens": 120,
"finetune_tokens": 0,
"mean": [
38.820404,
36.293564,
34.978855,
32.99633,
31.583813,
33.274483,
36.975334,
42.980255,
47.18004,
49.38613,
50.58629,
50.16604,
49.604767,
48.710793,
47.50801,
49.641068,
52.359497,
58.20144,
58.857323,
55.131165,
49.567757,
45.881218,
44.496113,
43.748726,
10.04649,
6.3477135,
3.8844414,
2.3206596,
2.5391273,
4.168751,
4.5570107,
7.487301,
9.367214,
11.309727,
11.275341,
11.971142,
11.837162,
12.7635765,
13.369236,
14.58382,
15.6118355,
15.449764,
15.033131,
14.167515,
13.30592,
12.907833,
10.776638,
7.633236,
52.749832,
49.590576,
46.05893,
41.110207,
39.65557,
42.497383,
48.12867,
54.479225,
58.289307,
60.86104,
61.56459,
60.438133,
59.821243,
56.35112,
54.18071,
55.2239,
56.647575,
64.53344,
67.05695,
65.62052,
61.02072,
56.786224,
56.03811,
54.446175,
49.7386,
49.122772,
48.87745,
48.59867,
49.343666,
51.27717,
53.082943,
56.1737,
57.441006,
57.917213,
58.155327,
58.4386,
58.29986,
58.309822,
58.473557,
59.172974,
59.360233,
59.138382,
57.66301,
55.740982,
54.20776,
53.051807,
51.7741,
50.458515,
22.853998,
22.351372,
21.971369,
22.012585,
22.484016,
22.69992,
26.459377,
31.47617,
31.518444,
30.856718,
30.465942,
29.211628,
27.852182,
27.181736,
26.28147,
26.776318,
28.727142,
35.49711,
35.459,
33.863174,
32.520737,
29.762478,
27.170156,
24.734856
],
"intervals": null,
"weights_x": null,
"feature_contributions": null,
"request_id": "JEAB9GZ9KA"
}
131 changes: 131 additions & 0 deletions tests/mocks/api.nixtla.io/v2/forecast-dd4308-POST.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
{
"input_tokens": 8280,
"output_tokens": 120,
"finetune_tokens": 500,
"mean": [
39.33689,
36.81118,
35.862137,
34.178474,
33.005707,
34.602592,
38.382442,
45.29547,
49.116364,
51.709076,
52.31858,
51.4066,
50.5456,
49.365852,
48.270195,
50.595474,
53.19593,
59.678688,
60.51818,
57.22879,
51.752975,
47.76425,
46.002213,
44.784706,
10.026592,
6.836117,
4.517971,
3.9740334,
3.70512,
5.399822,
6.174061,
9.923893,
12.239056,
14.2149315,
13.17947,
14.263432,
13.346569,
14.199467,
14.601391,
16.600262,
19.36082,
20.515034,
21.571346,
21.480137,
21.020634,
20.655666,
18.897038,
16.328136,
52.39077,
49.435287,
46.650024,
43.17102,
42.517525,
45.289276,
51.680283,
58.49974,
62.696686,
64.67974,
63.97027,
62.087547,
60.991127,
57.563072,
55.60578,
56.84146,
58.360245,
66.88493,
69.96444,
69.03388,
64.37969,
59.964928,
58.9301,
56.773148,
49.328724,
48.39244,
47.72139,
47.343994,
47.672806,
48.544514,
49.517147,
51.288704,
52.273098,
53.232426,
53.966618,
54.523796,
54.51413,
54.70507,
55.0904,
55.68209,
56.167664,
56.07667,
55.04934,
53.626564,
52.309814,
51.268036,
50.12602,
48.925766,
23.074203,
22.41476,
21.85689,
22.067892,
22.626034,
22.7155,
26.250298,
31.046175,
31.589384,
31.64193,
31.067324,
29.933369,
28.635342,
27.827301,
26.68113,
27.277428,
29.669315,
35.96617,
36.442406,
35.202442,
33.239815,
30.23969,
27.701052,
25.199667
],
"intervals": null,
"weights_x": null,
"feature_contributions": null,
"request_id": "AHD4U6ADNU"
}
Loading

0 comments on commit c2ee9cc

Please sign in to comment.