Skip to content

Commit a5974b4

Browse files
committed
Lasso: Create unit tests
1 parent c07bb50 commit a5974b4

File tree

1 file changed

+140
-0
lines changed

1 file changed

+140
-0
lines changed

crates/RustQuant_ml/src/lasso.rs

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,143 @@ impl LassoOutput<f64> {
149149
Ok(predictions)
150150
}
151151
}
152+
153+
154+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
155+
// UNIT TESTS
156+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
157+
158+
#[cfg(test)]
159+
mod tests_lasso_regression {
160+
use super::*;
161+
use RustQuant_utils::assert_approx_equal;
162+
163+
struct DataForTests {
164+
training_set: DMatrix<f64>,
165+
testing_set: DMatrix<f64>,
166+
response: DVector<f64>,
167+
}
168+
169+
fn setup_test() -> DataForTests {
170+
DataForTests {
171+
training_set: DMatrix::from_row_slice(
172+
4,
173+
3,
174+
&[
175+
-0.083_784_355, -0.633_485_70, -0.399_266_60,
176+
-0.982_943_745, 1.090_797_46, -0.468_123_05,
177+
-1.875_067_321, -0.913_727_27, 0.326_962_08,
178+
-0.186_144_661, 1.001_639_71, -0.412_746_90],
179+
),
180+
181+
testing_set: DMatrix::from_row_slice(
182+
4,
183+
3,
184+
&[
185+
0.562_036_47, 0.595_846_45, -0.411_653_01,
186+
0.663_358_26, 0.452_091_83, -0.294_327_15,
187+
-0.602_897_28, 0.896_743_96, 1.218_573_96,
188+
0.698_377_69, 0.572_216_51, 0.244_111_43],
189+
),
190+
191+
response: DVector::from_row_slice(
192+
&[
193+
-0.445_151_96,
194+
-1.847_803_64,
195+
-0.628_825_31,
196+
-0.861_080_69
197+
]
198+
),
199+
}
200+
}
201+
202+
#[test]
203+
fn test_lasso_without_intercept() -> Result<(), RustQuantError> {
204+
205+
let data: DataForTests = setup_test();
206+
207+
let input: LassoInput<f64> = LassoInput {
208+
x: data.training_set,
209+
y: data.response,
210+
lambda: 0.01,
211+
fit_intercept: false,
212+
max_iter: 1000,
213+
tolerance: 1e-4,
214+
};
215+
216+
let output: LassoOutput<f64> = input.fit()?;
217+
let predictions = output.predict(data.testing_set)?;
218+
219+
for (i, coefficient) in output.coefficients.iter().enumerate() {
220+
assert_approx_equal!(
221+
coefficient,
222+
&[
223+
0.0,
224+
0.743_965_706_491_596_7,
225+
-0.304_713_846_510_641_43,
226+
1.355_162_653_724_116_22,
227+
][i],
228+
f64::EPSILON
229+
);
230+
}
231+
232+
for (i, pred) in predictions.iter().enumerate() {
233+
assert_approx_equal!(
234+
pred,
235+
&[
236+
-0.321_283_589_676_737_6,
237+
-0.04310400559445471,
238+
0.9295807191488583,
239+
0.6760174510230131
240+
][i],
241+
f64::EPSILON
242+
);
243+
}
244+
Ok(())
245+
}
246+
247+
#[test]
248+
fn test_lasso_with_intercept() -> Result<(), RustQuantError> {
249+
250+
let data: DataForTests = setup_test();
251+
252+
let input: LassoInput<f64> = LassoInput {
253+
x: data.training_set,
254+
y: data.response,
255+
lambda: 0.01,
256+
fit_intercept: true,
257+
max_iter: 1000,
258+
tolerance: 1e-4,
259+
};
260+
261+
let output: LassoOutput<f64> = input.fit()?;
262+
let predictions = output.predict(data.testing_set)?;
263+
264+
for (i, coefficient) in output.coefficients.iter().enumerate() {
265+
assert_approx_equal!(
266+
coefficient,
267+
&[
268+
0.009_633_706_736_496_328,
269+
0.750_479_303_541_854_1,
270+
-0.301_997_087_876_784_5,
271+
1.373_605_833_196_545_3,
272+
][i],
273+
f64::EPSILON
274+
);
275+
}
276+
277+
for (i, pred) in predictions.iter().enumerate() {
278+
assert_approx_equal!(
279+
pred,
280+
&[
281+
-0.313_962_423_203_417_3,
282+
-0.033_349_554_520_968_38,
283+
0.960_198_011_081_136_2,
284+
0.696_256_873_679_798_4,
285+
][i],
286+
f64::EPSILON
287+
);
288+
}
289+
Ok(())
290+
}
291+
}

0 commit comments

Comments
 (0)