Skip to content

Commit fc09111

Browse files
authored
Merge pull request #1 from rivas-lab/Cox
Include Cox model and Chang C-index support for Snpnet.
2 parents 574e10d + 3b8dc05 commit fc09111

File tree

4 files changed

+67
-23
lines changed

4 files changed

+67
-23
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Description: Fitting the entire lasso solution path on large-scale genotype-phen
1111
License: GPL-2
1212
Encoding: UTF-8
1313
LazyData: true
14-
Imports: BGData (>= 2.0.0), ROCR, Rcpp, crochet, utils, stats, data.table, methods
14+
Imports: BGData (>= 2.0.0), ROCR, Rcpp, crochet, utils, stats, data.table, methods, cindex
1515
LinkingTo: Rcpp, BH
1616
Depends: R (>= 3.0.0)
1717
RoxygenNote: 6.0.1

R/coxgrad.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ fid=function(x,index){
6363
}
6464

6565
# Use this shorter version
66-
coxgrad <- function(f, time, d, w){
66+
coxgrad <- function(f, time, d, w, eps=0.00001){
6767
if(missing(w))w=rep(1,length(f))
68+
time = time - d*eps
6869
d = d * w
6970
f=scale(f,TRUE,FALSE)
7071
o = order(time)

R/functions.R

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,15 @@ computeProduct <- function(residual, pfile, vars, stats, configs, iter) {
183183
prod.full
184184
}
185185

186-
KKT.check <- function(residual, pfile, vars, n.train, current.lams, prev.lambda.idx, stats, glmfit, configs, iter) {
186+
KKT.check <- function(residual, pfile, vars, n.train, current.lams, prev.lambda.idx, stats, glmfit, configs, iter, p.factor=NULL) {
187187
time.KKT.check.start <- Sys.time()
188188
if (configs[['KKT.verbose']]) snpnetLogger('Start KKT.check()', indent=1, log.time=time.KKT.check.start)
189189
prod.full <- computeProduct(residual, pfile, vars, stats, configs, iter) / n.train
190+
191+
if(!is.null(p.factor)){
192+
prod.full <- sweep(prod.full, 1, p.factor, FUN="/")
193+
}
194+
190195
if (configs[['KKT.verbose']]) snpnetLoggerTimeDiff('- computeProduct.', indent=2, start.time=time.KKT.check.start)
191196
num.lams <- length(current.lams)
192197
if (length(configs[["covariates"]]) > 0) {
@@ -294,7 +299,7 @@ computeMetric <- function(pred, response, metric.type) {
294299
})
295300
} else if (metric.type == 'C'){
296301
metric <- apply(pred, 2, function(p) {
297-
my.cindex(p, response[,1], response[,2])
302+
cindex::CIndex(p, response[,1], response[,2])
298303
})
299304
}
300305
metric
@@ -388,9 +393,6 @@ checkGlmnetPlus <- function(use.glmnetPlus, family) {
388393
if (!requireNamespace("glmnetPlus")) {
389394
warning("use.glmnetPlus was set to TRUE but glmnetPlus not found... Revert back to glmnet.")
390395
use.glmnetPlus <- FALSE
391-
} else if (family != "gaussian") {
392-
warning("glmnetPlus currently does not support non-gaussian family... Revert back to glmnet.")
393-
use.glmnetPlus <- FALSE
394396
}
395397
}
396398
use.glmnetPlus

R/snpnet.R

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,16 @@
7777
#' @useDynLib snpnet, .registration=TRUE
7878
#' @export
7979
snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL, covariates = NULL,
80-
split.col=NULL, family = NULL, configs=NULL) {
81-
print("Ruilin's version here!")
80+
split.col=NULL, family = NULL, p.factor=NULL, configs=NULL) {
81+
82+
need.rank <- configs[['rank']]
8283
validation <- (!is.null(split.col))
8384
if (configs[['prevIter']] >= configs[['niter']]) stop("prevIter is greater or equal to the total number of iterations.")
8485
time.start <- Sys.time()
8586
snpnetLogger('Start snpnet', log.time = time.start)
8687

8788
snpnetLogger('Preprocessing start..')
88-
89+
8990
### --- Read genotype IDs --- ###
9091
ids <- list(); phe <- list()
9192
ids[['all']] <- readIDsFromPsam(paste0(genotype.pfile, '.psam'))
@@ -175,6 +176,14 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
175176

176177
stats <- computeStats(genotype.pfile, phe[['train']]$ID, configs = configs)
177178

179+
### --- Keep track of the lambda index at which each variant is first added to the model, if required --- ###
180+
if (need.rank){
181+
var.rank <- rep(configs[['nlambda']]+1, length(vars))
182+
names(var.rank) <- vars
183+
} else{
184+
var.rank = NULL
185+
}
186+
178187
### --- End --- ###
179188
snpnetLoggerTimeDiff("Preprocessing end.", time.start, indent=1)
180189

@@ -198,6 +207,9 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
198207

199208
prod.full <- computeProduct(residual, genotype.pfile, vars, stats, configs, iter=0) / nrow(phe[['train']])
200209
score <- abs(prod.full[, 1])
210+
211+
if (!is.null(p.factor)){score <- score/p.factor[names(score)]} # Divide the score by the penalty factor
212+
201213
if (configs[['verbose']]) snpnetLoggerTimeDiff(" End computing inner product for initialization.", time.prod.init.start)
202214

203215
if (is.null(configs[['lambda.min.ratio']])) {
@@ -208,6 +220,7 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
208220
lambda.idx <- 1
209221
num.lams <- configs[["nlams.init"]]
210222
features.to.keep <- names(glmmod$coefficients[-1])
223+
211224
prev.beta <- NULL
212225
num.new.valid <- NULL # track number of new valid solutions every iteration, to adjust length of current lambda seq or size of additional variables
213226

@@ -261,6 +274,7 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
261274
which.in.model <- which(names(score) %in% colnames(features[['train']]))
262275
score[which.in.model] <- NA
263276
}
277+
if (!is.null(p.factor)) {score <- score/p.factor[names(score)]}
264278
sorted.score <- sort(score, decreasing = T, na.last = NA)
265279
if (length(sorted.score) > 0) {
266280
features.to.add <- names(sorted.score)[1:min(configs[['num.snps.batch']], length(sorted.score))]
@@ -284,6 +298,7 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
284298
snpnetLogger(paste0("- # newly added variables: ", length(features.to.add), "."), indent=2)
285299
snpnetLogger(paste0("- Total # variables in the strong set: ", ncol(features[['train']]), "."), indent=2)
286300
}
301+
287302
### --- Fit glmnet --- ###
288303
if (configs[['verbose']]){
289304
if(configs[['use.glmnetPlus']]){
@@ -292,8 +307,12 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
292307
snpnetLogger("Start fitting Glmnet ...", indent=1)
293308
}
294309
}
295-
penalty.factor <- rep(1, ncol(features[['train']]))
296-
penalty.factor[seq_len(length(covariates))] <- 0
310+
if (is.null(p.factor)){
311+
penalty.factor <- rep(1, ncol(features[['train']]))
312+
penalty.factor[seq_len(length(covariates))] <- 0
313+
} else {
314+
penalty.factor <- c(rep(0, length(covariates)), p.factor[colnames(features[['train']])[-(1:length(covariates))]])
315+
}
297316
current.lams <- full.lams[1:num.lams]
298317
current.lams.adjusted <- full.lams[1:num.lams] * sum(penalty.factor) / length(penalty.factor) # adjustment to counteract penalty factor normalization in glmnet
299318
time.glmnet.start <- Sys.time()
@@ -306,14 +325,24 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
306325
} else {
307326
beta0 <- prev.beta
308327
}
309-
glmfit <- glmnetPlus::glmnet(
310-
features[['train']], response[['train']], family = family,
311-
lambda = current.lams.adjusted[start.lams:num.lams], penalty.factor = penalty.factor,
312-
standardize = configs[['standardize.variant']], thresh = configs[['glmnet.thresh']],
313-
type.gaussian = "naive", beta0 = beta0
328+
if(family == "cox"){
329+
glmfit <- glmnetPlus::glmnet(
330+
features[['train']], surv[['train']], family = family,
331+
lambda = current.lams.adjusted[start.lams:num.lams], penalty.factor = penalty.factor,
332+
standardize = configs[['standardize.variant']], thresh = configs[['glmnet.thresh']], beta0 = beta0
333+
)
334+
pred.train <- stats::predict(glmfit, newx = features[['train']])
335+
residual <- computeCoxgrad(pred.train, response[['train']], status[['train']])
336+
} else {
337+
glmfit <- glmnetPlus::glmnet(
338+
features[['train']], response[['train']], family = family,
339+
lambda = current.lams.adjusted[start.lams:num.lams], penalty.factor = penalty.factor,
340+
standardize = configs[['standardize.variant']], thresh = configs[['glmnet.thresh']],
341+
type.gaussian = "naive", beta0 = beta0
314342
)
315343
residual <- glmfit$residuals
316344
pred.train <- response[['train']] - residual
345+
}
317346

318347
} else {
319348
start.lams <- 1
@@ -349,20 +378,32 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
349378

350379
check.obj <- KKT.check(
351380
residual, genotype.pfile, vars, nrow(phe[['train']]),
352-
current.lams[start.lams:num.lams],
353-
ifelse(configs[['use.glmnetPlus']], 1, lambda.idx),
354-
stats, glmfit, configs, iter
381+
current.lams[start.lams:num.lams], ifelse(configs[['use.glmnetPlus']], 1, lambda.idx),
382+
stats, glmfit, configs, iter, p.factor
355383
)
356384
snpnetLogger("KKT check obj done ...", indent=1)
357385
max.valid.idx <- check.obj[["max.valid.idx"]] + (start.lams - 1) # max valid index in the whole lambda sequence
358-
if (lambda.idx < max.valid.idx) {
386+
387+
# Update the lambda index of variants added
388+
if (need.rank && check.obj[["max.valid.idx"]] > 0){
389+
tmp <- 1
390+
for (lam.idx in start.lams:max.valid.idx){
391+
current_active <- setdiff(names(which(glmfit$beta[, tmp] != 0)), covariates)
392+
tmp <- tmp + 1
393+
var.rank[current_active] = pmin(var.rank[current_active], lam.idx)
394+
}
395+
}
396+
397+
398+
if (lambda.idx < max.valid.idx) {
359399
is.KKT.valid.for.at.least.one <- TRUE
360400
} else{
361401
is.KKT.valid.for.at.least.one <- FALSE
362402
}
363403
lambda.idx <- check.obj[["next.lambda.idx"]] + (start.lams - 1)
364404

365-
if (configs[['use.glmnetPlus']] && check.obj[["max.valid.idx"]] > 0) {
405+
406+
if (configs[['use.glmnetPlus']] && check.obj[["max.valid.idx"]] > 0) {
366407
prev.beta <- glmfit$beta[, check.obj[["max.valid.idx"]]]
367408
prev.beta <- prev.beta[prev.beta != 0]
368409
}
@@ -442,6 +483,6 @@ snpnet <- function(genotype.pfile, phenotype.file, phenotype, status.col = NULL,
442483
if(configs[['verbose']]) print(gc())
443484

444485
out <- list(metric.train = metric.train, metric.val = metric.val, glmnet.results = glmnet.results,
445-
full.lams = full.lams, a0 = a0, beta = beta, configs = configs)
486+
full.lams = full.lams, a0 = a0, beta = beta, configs = configs, var.rank=var.rank)
446487
out
447488
}

0 commit comments

Comments
 (0)