77
77
# ' @useDynLib snpnet, .registration=TRUE
78
78
# ' @export
79
79
snpnet <- function (genotype.pfile , phenotype.file , phenotype , status.col = NULL , covariates = NULL ,
80
- split.col = NULL , family = NULL , configs = NULL ) {
81
- print(" Ruilin's version here!" )
80
+ split.col = NULL , family = NULL , p.factor = NULL , configs = NULL ) {
81
+
82
+ need.rank <- configs [[' rank' ]]
82
83
validation <- (! is.null(split.col ))
83
84
if (configs [[' prevIter' ]] > = configs [[' niter' ]]) stop(" prevIter is greater or equal to the total number of iterations." )
84
85
time.start <- Sys.time()
85
86
snpnetLogger(' Start snpnet' , log.time = time.start )
86
87
87
88
snpnetLogger(' Preprocessing start..' )
88
-
89
+
89
90
# ## --- Read genotype IDs --- ###
90
91
ids <- list (); phe <- list ()
91
92
ids [[' all' ]] <- readIDsFromPsam(paste0(genotype.pfile , ' .psam' ))
@@ -175,6 +176,14 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
175
176
176
177
stats <- computeStats(genotype.pfile , phe [[' train' ]]$ ID , configs = configs )
177
178
179
+ # ## --- Keep track of the lambda index at which each variant is first added to the model, if required --- ###
180
+ if (need.rank ){
181
+ var.rank <- rep(configs [[' nlambda' ]]+ 1 , length(vars ))
182
+ names(var.rank ) <- vars
183
+ } else {
184
+ var.rank = NULL
185
+ }
186
+
178
187
# ## --- End --- ###
179
188
snpnetLoggerTimeDiff(" Preprocessing end." , time.start , indent = 1 )
180
189
@@ -198,6 +207,9 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
198
207
199
208
prod.full <- computeProduct(residual , genotype.pfile , vars , stats , configs , iter = 0 ) / nrow(phe [[' train' ]])
200
209
score <- abs(prod.full [, 1 ])
210
+
211
+ if (! is.null(p.factor )){score <- score / p.factor [names(score )]} # Divide the score by the penalty factor
212
+
201
213
if (configs [[' verbose' ]]) snpnetLoggerTimeDiff(" End computing inner product for initialization." , time.prod.init.start )
202
214
203
215
if (is.null(configs [[' lambda.min.ratio' ]])) {
@@ -208,6 +220,7 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
208
220
lambda.idx <- 1
209
221
num.lams <- configs [[" nlams.init" ]]
210
222
features.to.keep <- names(glmmod $ coefficients [- 1 ])
223
+
211
224
prev.beta <- NULL
212
225
num.new.valid <- NULL # track number of new valid solutions every iteration, to adjust length of current lambda seq or size of additional variables
213
226
@@ -261,6 +274,7 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
261
274
which.in.model <- which(names(score ) %in% colnames(features [[' train' ]]))
262
275
score [which.in.model ] <- NA
263
276
}
277
+ if (! is.null(p.factor )) {score <- score / p.factor [names(score )]}
264
278
sorted.score <- sort(score , decreasing = T , na.last = NA )
265
279
if (length(sorted.score ) > 0 ) {
266
280
features.to.add <- names(sorted.score )[1 : min(configs [[' num.snps.batch' ]], length(sorted.score ))]
@@ -284,6 +298,7 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
284
298
snpnetLogger(paste0(" - # newly added variables: " , length(features.to.add ), " ." ), indent = 2 )
285
299
snpnetLogger(paste0(" - Total # variables in the strong set: " , ncol(features [[' train' ]]), " ." ), indent = 2 )
286
300
}
301
+
287
302
# ## --- Fit glmnet --- ###
288
303
if (configs [[' verbose' ]]){
289
304
if (configs [[' use.glmnetPlus' ]]){
@@ -292,8 +307,12 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
292
307
snpnetLogger(" Start fitting Glmnet ..." , indent = 1 )
293
308
}
294
309
}
295
- penalty.factor <- rep(1 , ncol(features [[' train' ]]))
296
- penalty.factor [seq_len(length(covariates ))] <- 0
310
+ if (is.null(p.factor )){
311
+ penalty.factor <- rep(1 , ncol(features [[' train' ]]))
312
+ penalty.factor [seq_len(length(covariates ))] <- 0
313
+ } else {
314
+ penalty.factor <- c(rep(0 , length(covariates )), p.factor [colnames(features [[' train' ]])[- (1 : length(covariates ))]])
315
+ }
297
316
current.lams <- full.lams [1 : num.lams ]
298
317
current.lams.adjusted <- full.lams [1 : num.lams ] * sum(penalty.factor ) / length(penalty.factor ) # adjustment to counteract penalty factor normalization in glmnet
299
318
time.glmnet.start <- Sys.time()
@@ -306,14 +325,24 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
306
325
} else {
307
326
beta0 <- prev.beta
308
327
}
309
- glmfit <- glmnetPlus :: glmnet(
310
- features [[' train' ]], response [[' train' ]], family = family ,
311
- lambda = current.lams.adjusted [start.lams : num.lams ], penalty.factor = penalty.factor ,
312
- standardize = configs [[' standardize.variant' ]], thresh = configs [[' glmnet.thresh' ]],
313
- type.gaussian = " naive" , beta0 = beta0
328
+ if (family == " cox" ){
329
+ glmfit <- glmnetPlus :: glmnet(
330
+ features [[' train' ]], surv [[' train' ]], family = family ,
331
+ lambda = current.lams.adjusted [start.lams : num.lams ], penalty.factor = penalty.factor ,
332
+ standardize = configs [[' standardize.variant' ]], thresh = configs [[' glmnet.thresh' ]], beta0 = beta0
333
+ )
334
+ pred.train <- stats :: predict(glmfit , newx = features [[' train' ]])
335
+ residual <- computeCoxgrad(pred.train , response [[' train' ]], status [[' train' ]])
336
+ } else {
337
+ glmfit <- glmnetPlus :: glmnet(
338
+ features [[' train' ]], response [[' train' ]], family = family ,
339
+ lambda = current.lams.adjusted [start.lams : num.lams ], penalty.factor = penalty.factor ,
340
+ standardize = configs [[' standardize.variant' ]], thresh = configs [[' glmnet.thresh' ]],
341
+ type.gaussian = " naive" , beta0 = beta0
314
342
)
315
343
residual <- glmfit $ residuals
316
344
pred.train <- response [[' train' ]] - residual
345
+ }
317
346
318
347
} else {
319
348
start.lams <- 1
@@ -349,20 +378,32 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
349
378
350
379
check.obj <- KKT.check(
351
380
residual , genotype.pfile , vars , nrow(phe [[' train' ]]),
352
- current.lams [start.lams : num.lams ],
353
- ifelse(configs [[' use.glmnetPlus' ]], 1 , lambda.idx ),
354
- stats , glmfit , configs , iter
381
+ current.lams [start.lams : num.lams ], ifelse(configs [[' use.glmnetPlus' ]], 1 , lambda.idx ),
382
+ stats , glmfit , configs , iter , p.factor
355
383
)
356
384
snpnetLogger(" KKT check obj done ..." , indent = 1 )
357
385
max.valid.idx <- check.obj [[" max.valid.idx" ]] + (start.lams - 1 ) # max valid index in the whole lambda sequence
358
- if (lambda.idx < max.valid.idx ) {
386
+
387
+ # Update the lambda index of variants added
388
+ if (need.rank && check.obj [[" max.valid.idx" ]] > 0 ){
389
+ tmp <- 1
390
+ for (lam.idx in start.lams : max.valid.idx ){
391
+ current_active <- setdiff(names(which(glmfit $ beta [, tmp ] != 0 )), covariates )
392
+ tmp <- tmp + 1
393
+ var.rank [current_active ] = pmin(var.rank [current_active ], lam.idx )
394
+ }
395
+ }
396
+
397
+
398
+ if (lambda.idx < max.valid.idx ) {
359
399
is.KKT.valid.for.at.least.one <- TRUE
360
400
} else {
361
401
is.KKT.valid.for.at.least.one <- FALSE
362
402
}
363
403
lambda.idx <- check.obj [[" next.lambda.idx" ]] + (start.lams - 1 )
364
404
365
- if (configs [[' use.glmnetPlus' ]] && check.obj [[" max.valid.idx" ]] > 0 ) {
405
+
406
+ if (configs [[' use.glmnetPlus' ]] && check.obj [[" max.valid.idx" ]] > 0 ) {
366
407
prev.beta <- glmfit $ beta [, check.obj [[" max.valid.idx" ]]]
367
408
prev.beta <- prev.beta [prev.beta != 0 ]
368
409
}
@@ -442,6 +483,6 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
442
483
if (configs [[' verbose' ]]) print(gc())
443
484
444
485
out <- list (metric.train = metric.train , metric.val = metric.val , glmnet.results = glmnet.results ,
445
- full.lams = full.lams , a0 = a0 , beta = beta , configs = configs )
486
+ full.lams = full.lams , a0 = a0 , beta = beta , configs = configs , var.rank = var.rank )
446
487
out
447
488
}
0 commit comments