@@ -149,3 +149,143 @@ impl LassoOutput<f64> {
149149 Ok ( predictions)
150150 }
151151}
152+
153+
154+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
155+ // UNIT TESTS
156+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
157+
158+ #[ cfg( test) ]
159+ mod tests_lasso_regression {
160+ use super :: * ;
161+ use RustQuant_utils :: assert_approx_equal;
162+
163+ struct DataForTests {
164+ training_set : DMatrix < f64 > ,
165+ testing_set : DMatrix < f64 > ,
166+ response : DVector < f64 > ,
167+ }
168+
169+ fn setup_test ( ) -> DataForTests {
170+ DataForTests {
171+ training_set : DMatrix :: from_row_slice (
172+ 4 ,
173+ 3 ,
174+ & [
175+ -0.083_784_355 , -0.633_485_70 , -0.399_266_60 ,
176+ -0.982_943_745 , 1.090_797_46 , -0.468_123_05 ,
177+ -1.875_067_321 , -0.913_727_27 , 0.326_962_08 ,
178+ -0.186_144_661 , 1.001_639_71 , -0.412_746_90 ] ,
179+ ) ,
180+
181+ testing_set : DMatrix :: from_row_slice (
182+ 4 ,
183+ 3 ,
184+ & [
185+ 0.562_036_47 , 0.595_846_45 , -0.411_653_01 ,
186+ 0.663_358_26 , 0.452_091_83 , -0.294_327_15 ,
187+ -0.602_897_28 , 0.896_743_96 , 1.218_573_96 ,
188+ 0.698_377_69 , 0.572_216_51 , 0.244_111_43 ] ,
189+ ) ,
190+
191+ response : DVector :: from_row_slice (
192+ & [
193+ -0.445_151_96 ,
194+ -1.847_803_64 ,
195+ -0.628_825_31 ,
196+ -0.861_080_69
197+ ]
198+ ) ,
199+ }
200+ }
201+
202+ #[ test]
203+ fn test_lasso_without_intercept ( ) -> Result < ( ) , RustQuantError > {
204+
205+ let data: DataForTests = setup_test ( ) ;
206+
207+ let input: LassoInput < f64 > = LassoInput {
208+ x : data. training_set ,
209+ y : data. response ,
210+ lambda : 0.01 ,
211+ fit_intercept : false ,
212+ max_iter : 1000 ,
213+ tolerance : 1e-4 ,
214+ } ;
215+
216+ let output: LassoOutput < f64 > = input. fit ( ) ?;
217+ let predictions = output. predict ( data. testing_set ) ?;
218+
219+ for ( i, coefficient) in output. coefficients . iter ( ) . enumerate ( ) {
220+ assert_approx_equal ! (
221+ coefficient,
222+ & [
223+ 0.0 ,
224+ 0.743_965_706_491_596_7 ,
225+ -0.304_713_846_510_641_43 ,
226+ 1.355_162_653_724_116_22 ,
227+ ] [ i] ,
228+ f64 :: EPSILON
229+ ) ;
230+ }
231+
232+ for ( i, pred) in predictions. iter ( ) . enumerate ( ) {
233+ assert_approx_equal ! (
234+ pred,
235+ & [
236+ -0.321_283_589_676_737_6 ,
237+ -0.04310400559445471 ,
238+ 0.9295807191488583 ,
239+ 0.6760174510230131
240+ ] [ i] ,
241+ f64 :: EPSILON
242+ ) ;
243+ }
244+ Ok ( ( ) )
245+ }
246+
247+ #[ test]
248+ fn test_lasso_with_intercept ( ) -> Result < ( ) , RustQuantError > {
249+
250+ let data: DataForTests = setup_test ( ) ;
251+
252+ let input: LassoInput < f64 > = LassoInput {
253+ x : data. training_set ,
254+ y : data. response ,
255+ lambda : 0.01 ,
256+ fit_intercept : true ,
257+ max_iter : 1000 ,
258+ tolerance : 1e-4 ,
259+ } ;
260+
261+ let output: LassoOutput < f64 > = input. fit ( ) ?;
262+ let predictions = output. predict ( data. testing_set ) ?;
263+
264+ for ( i, coefficient) in output. coefficients . iter ( ) . enumerate ( ) {
265+ assert_approx_equal ! (
266+ coefficient,
267+ & [
268+ 0.009_633_706_736_496_328 ,
269+ 0.750_479_303_541_854_1 ,
270+ -0.301_997_087_876_784_5 ,
271+ 1.373_605_833_196_545_3 ,
272+ ] [ i] ,
273+ f64 :: EPSILON
274+ ) ;
275+ }
276+
277+ for ( i, pred) in predictions. iter ( ) . enumerate ( ) {
278+ assert_approx_equal ! (
279+ pred,
280+ & [
281+ -0.313_962_423_203_417_3 ,
282+ -0.033_349_554_520_968_38 ,
283+ 0.960_198_011_081_136_2 ,
284+ 0.696_256_873_679_798_4 ,
285+ ] [ i] ,
286+ f64 :: EPSILON
287+ ) ;
288+ }
289+ Ok ( ( ) )
290+ }
291+ }
0 commit comments