7777# ' @useDynLib snpnet, .registration=TRUE
7878# ' @export
7979snpnet  <-  function (genotype.pfile , phenotype.file , phenotype , status.col  =  NULL , covariates  =  NULL , 
80-                    split.col = NULL , family  =  NULL , configs = NULL ) {
81-   print(" Ruilin's version here!" 
80+                    split.col = NULL , family  =  NULL , p.factor = NULL , configs = NULL ) {
81+ 
82+   need.rank  <-  configs [[' rank' 
8283  validation  <-  (! is.null(split.col ))
8384  if  (configs [[' prevIter' > =  configs [[' niter' " prevIter is greater or equal to the total number of iterations." 
8485  time.start  <-  Sys.time()
8586  snpnetLogger(' Start snpnet' log.time  =  time.start )
8687
8788  snpnetLogger(' Preprocessing start..' 
88- 
89+      
8990  # ## --- Read genotype IDs --- ###
9091  ids  <-  list (); phe  <-  list ()
9192  ids [[' all' <-  readIDsFromPsam(paste0(genotype.pfile , ' .psam' 
@@ -175,6 +176,14 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
175176
176177  stats  <-  computeStats(genotype.pfile , phe [[' train' $ ID , configs  =  configs )
177178
179+   # ## --- Keep track of the lambda index at which each variant is first added to the model, if required --- ###
180+   if  (need.rank ){
181+     var.rank  <-  rep(configs [[' nlambda' + 1 , length(vars ))
182+     names(var.rank ) <-  vars 
183+   } else {
184+     var.rank  =  NULL 
185+   }
186+ 
178187  # ## --- End --- ###
179188  snpnetLoggerTimeDiff(" Preprocessing end." time.start , indent = 1 )
180189
@@ -198,6 +207,9 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
198207
199208    prod.full  <-  computeProduct(residual , genotype.pfile , vars , stats , configs , iter = 0 ) /  nrow(phe [[' train' 
200209    score  <-  abs(prod.full [, 1 ])
210+ 
211+     if  (! is.null(p.factor )){score  <-  score / p.factor [names(score )]} #  Divide the score by the penalty factor
212+       
201213    if  (configs [[' verbose' "   End computing inner product for initialization." time.prod.init.start )
202214
203215    if  (is.null(configs [[' lambda.min.ratio' 
@@ -208,6 +220,7 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
208220    lambda.idx  <-  1 
209221    num.lams  <-  configs [[" nlams.init" 
210222    features.to.keep  <-  names(glmmod $ coefficients [- 1 ])
223+ 
211224    prev.beta  <-  NULL 
212225    num.new.valid  <-  NULL   #  track number of new valid solutions every iteration, to adjust length of current lambda seq or size of additional variables
213226
@@ -261,6 +274,7 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
261274      which.in.model  <-  which(names(score ) %in%  colnames(features [[' train' 
262275      score [which.in.model ] <-  NA 
263276    }
277+     if  (! is.null(p.factor )) {score  <-  score / p.factor [names(score )]} 
264278    sorted.score  <-  sort(score , decreasing  =  T , na.last  =  NA )
265279    if  (length(sorted.score ) >  0 ) {
266280      features.to.add  <-  names(sorted.score )[1 : min(configs [[' num.snps.batch' sorted.score ))]
@@ -284,6 +298,7 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
284298      snpnetLogger(paste0(" - # newly added variables: " features.to.add ), " ." indent = 2 )
285299      snpnetLogger(paste0(" - Total # variables in the strong set: " features [[' train' " ." indent = 2 )
286300    }
301+       
287302    # ## --- Fit glmnet --- ###
288303    if  (configs [[' verbose' 
289304        if (configs [[' use.glmnetPlus' 
@@ -292,8 +307,12 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
292307            snpnetLogger(" Start fitting Glmnet ..." indent = 1 )
293308        }
294309    }
295-     penalty.factor  <-  rep(1 , ncol(features [[' train' 
296-     penalty.factor [seq_len(length(covariates ))] <-  0 
310+     if  (is.null(p.factor )){
311+       penalty.factor  <-  rep(1 , ncol(features [[' train' 
312+       penalty.factor [seq_len(length(covariates ))] <-  0       
313+     } else  {
314+       penalty.factor  <-  c(rep(0 , length(covariates )), p.factor [colnames(features [[' train' - (1 : length(covariates ))]])
315+     }
297316    current.lams  <-  full.lams [1 : num.lams ]
298317    current.lams.adjusted  <-  full.lams [1 : num.lams ] *  sum(penalty.factor ) /  length(penalty.factor )  #  adjustment to counteract penalty factor normalization in glmnet
299318    time.glmnet.start  <-  Sys.time()
@@ -306,14 +325,24 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
306325      } else  {
307326        beta0  <-  prev.beta 
308327      }
309-       glmfit  <-  glmnetPlus :: glmnet(
310-           features [[' train' response [[' train' family  =  family ,
311-           lambda  =  current.lams.adjusted [start.lams : num.lams ], penalty.factor  =  penalty.factor ,
312-           standardize  =  configs [[' standardize.variant' thresh  =  configs [[' glmnet.thresh' 
313-           type.gaussian  =  " naive" beta0  =  beta0 
328+       if (family  ==  " cox" 
329+         glmfit  <-  glmnetPlus :: glmnet(
330+                 features [[' train' surv [[' train' family  =  family ,
331+                 lambda  =  current.lams.adjusted [start.lams : num.lams ], penalty.factor  =  penalty.factor ,
332+                 standardize  =  configs [[' standardize.variant' thresh  =  configs [[' glmnet.thresh' beta0  =  beta0 
333+             )
334+         pred.train  <-  stats :: predict(glmfit , newx  =  features [[' train' 
335+         residual  <-  computeCoxgrad(pred.train , response [[' train' status [[' train' 
336+       } else  {
337+         glmfit  <-  glmnetPlus :: glmnet(
338+         features [[' train' response [[' train' family  =  family ,
339+         lambda  =  current.lams.adjusted [start.lams : num.lams ], penalty.factor  =  penalty.factor ,
340+         standardize  =  configs [[' standardize.variant' thresh  =  configs [[' glmnet.thresh' 
341+         type.gaussian  =  " naive" beta0  =  beta0 
314342      )
315343      residual  <-  glmfit $ residuals 
316344      pred.train  <-  response [[' train' -  residual 
345+       }
317346
318347    } else  {
319348        start.lams  <-  1 
@@ -349,20 +378,32 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
349378
350379    check.obj  <-  KKT.check(
351380        residual , genotype.pfile , vars , nrow(phe [[' train' 
352-         current.lams [start.lams : num.lams ],
353-         ifelse(configs [[' use.glmnetPlus' 1 , lambda.idx ),
354-         stats , glmfit , configs , iter 
381+         current.lams [start.lams : num.lams ], ifelse(configs [[' use.glmnetPlus' 1 , lambda.idx ),
382+         stats , glmfit , configs , iter , p.factor 
355383    )
356384    snpnetLogger(" KKT check obj done ..." indent = 1 )
357385    max.valid.idx  <-  check.obj [[" max.valid.idx" +  (start.lams  -  1 )  #  max valid index in the whole lambda sequence
358-     if  (lambda.idx  <  max.valid.idx ) {
386+ 
387+     #  Update the lambda index of variants added
388+     if  (need.rank  &&  check.obj [[" max.valid.idx" >  0 ){
389+       tmp  <-  1 
390+       for  (lam.idx  in  start.lams : max.valid.idx ){       
391+        current_active  <-  setdiff(names(which(glmfit $ beta [, tmp ] !=  0 )), covariates )
392+        tmp  <-  tmp  +  1 
393+        var.rank [current_active ] =  pmin(var.rank [current_active ], lam.idx )
394+      } 
395+     }
396+       
397+ 
398+       if  (lambda.idx  <  max.valid.idx ) {
359399        is.KKT.valid.for.at.least.one  <-  TRUE 
360400    } else {
361401        is.KKT.valid.for.at.least.one  <-  FALSE 
362402    }
363403    lambda.idx  <-  check.obj [[" next.lambda.idx" +  (start.lams  -  1 )
364404
365-     if  (configs [[' use.glmnetPlus' &&  check.obj [[" max.valid.idx" >  0 ) {
405+ 
406+       if  (configs [[' use.glmnetPlus' &&  check.obj [[" max.valid.idx" >  0 ) {
366407      prev.beta  <-  glmfit $ beta [, check.obj [[" max.valid.idx" 
367408      prev.beta  <-  prev.beta [prev.beta  !=  0 ]
368409    }
@@ -442,6 +483,6 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
442483  if (configs [[' verbose' 
443484
444485  out  <-  list (metric.train  =  metric.train , metric.val  =  metric.val , glmnet.results  =  glmnet.results ,
445-               full.lams  =  full.lams , a0  =  a0 , beta  =  beta , configs  =  configs )
486+               full.lams  =  full.lams , a0  =  a0 , beta  =  beta , configs  =  configs ,  var.rank = var.rank )
446487  out 
447488}
0 commit comments