Skip to content

Commit 33946b5

Browse files
Merge pull request #125 from tidymodels/faster-rf
2 parents 0dde432 + e74164f commit 33946b5

File tree

6 files changed

+65
-74
lines changed

6 files changed

+65
-74
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# tidypredict (development version)
22

3+
- Speed up `tidypredict_fit()` for partykit and ranger packages. (#125)
4+
35
# tidypredict 0.5.1
46

57
- Exported a number of internal functions to be used in {orbital} package

R/model-partykit.R

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ partykit_tree_info <- function(model) {
6969
get_pk_tree <- function(model) {
7070
tree <- partykit_tree_info(model)
7171
paths <- tree$nodeID[tree[, "terminal"]]
72+
73+
child_info <- get_child_info(tree)
74+
7275
map(
7376
paths,
7477
~ {
@@ -77,7 +80,7 @@ get_pk_tree <- function(model) {
7780
if (is.factor(prediction)) prediction <- as.character(prediction)
7881
list(
7982
prediction = prediction,
80-
path = get_ra_path(.x, tree, FALSE)
83+
path = get_ra_path(.x, tree, child_info, FALSE)
8184
)
8285
}
8386
)
@@ -132,6 +135,9 @@ tidypredict_fit.party <- function(model) {
132135

133136
generate_one_tree <- function(tree_info) {
134137
paths <- tree_info$nodeID[tree_info[, "terminal"]]
138+
139+
child_info <- get_child_info(tree_info)
140+
135141
paths <- map(
136142
paths,
137143
~ {
@@ -140,7 +146,7 @@ tidypredict_fit.party <- function(model) {
140146
if (is.factor(prediction)) prediction <- as.character(prediction)
141147
list(
142148
prediction = prediction,
143-
path = get_ra_path(.x, tree_info, FALSE)
149+
path = get_ra_path(.x, tree_info, child_info, FALSE)
144150
)
145151
}
146152
)

R/model-ranger.R

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,39 @@
11
# Model parser ------------------------------------
2-
get_ra_path <- function(node_id, tree, default_op = TRUE) {
2+
get_ra_path <- function(node_id, tree, child_info, default_op = TRUE) {
33
find <- node_id
44
path <- node_id
5-
for (j in node_id:0) {
6-
row <- tree[tree$nodeID == j, ]
7-
lc <- row["leftChild"][[1]] == find
8-
lr <- row["rightChild"][[1]] == find
9-
if (is.na(lc)) lc <- FALSE
10-
if (is.na(lr)) lr <- FALSE
11-
dir <- NULL
12-
if (lc | lr) {
13-
find <- j
14-
path <- c(path, j)
5+
6+
leftChild <- tree$leftChild
7+
rightChild <- tree$rightChild
8+
splitval <- tree$splitval
9+
splitclass <- tree$splitclass
10+
splitvarName <- tree$splitvarName
11+
12+
new <- child_info[[find]]
13+
path <- find
14+
repeat {
15+
if (new == 0) {
16+
path <- c(path, 0)
17+
break
1518
}
19+
path <- c(path, new)
20+
find <- new
21+
new <- child_info[[find]]
1622
}
23+
1724
map2(
1825
path[1:length(path) - 1],
1926
path[2:length(path)],
2027
~ {
21-
rb <- tree[tree$nodeID == .y, ]
22-
lc <- rb["leftChild"] == .x
23-
lr <- rb["rightChild"] == .x
24-
if (is.na(rb["splitval"][[1]])) {
28+
lc <- leftChild[.y+1] == .x
29+
lr <- rightChild[.y+1] == .x
30+
if (is.na(splitval[.y+1])) {
2531
if (lc) op <- "in"
2632
if (lr) op <- "not-in"
27-
vals <- strsplit(as.character(rb["splitclass"][[1]]), ", ")[[1]]
33+
vals <- strsplit(as.character(splitclass[.y+1]), ", ")[[1]]
2834
list(
2935
type = "set",
30-
col = as.character(rb["splitvarName"][[1]]),
36+
col = as.character(splitvarName[.y+1]),
3137
vals = map(vals, ~.x),
3238
op = op
3339
)
@@ -41,18 +47,44 @@ get_ra_path <- function(node_id, tree, default_op = TRUE) {
4147
}
4248
list(
4349
type = "conditional",
44-
col = as.character(rb["splitvarName"][[1]]),
45-
val = rb["splitval"][[1]],
50+
col = as.character(splitvarName[.y+1]),
51+
val = splitval[.y+1],
4652
op = op
4753
)
4854
}
4955
}
5056
)
5157
}
5258

59+
get_child_info <- function(tree) {
60+
child_info <- numeric(max(tree$nodeID))
61+
left_child <- tree$leftChild
62+
right_child <- tree$rightChild
63+
node_id <- tree$nodeID
64+
65+
for (i in seq_len(nrow(tree))) {
66+
node <- node_id[[i]]
67+
68+
child <- left_child[[i]]
69+
if (!is.na(child)) {
70+
child_info[child] <- node
71+
}
72+
73+
child <- right_child[[i]]
74+
if (!is.na(child)) {
75+
child_info[child] <- node
76+
}
77+
}
78+
79+
child_info
80+
}
81+
5382
get_ra_tree <- function(tree_no, model) {
5483
tree <- ranger::treeInfo(model, tree_no)
5584
paths <- tree$nodeID[tree[, "terminal"]]
85+
86+
child_info <- get_child_info(tree)
87+
5688
map(
5789
paths,
5890
~ {
@@ -61,7 +93,7 @@ get_ra_tree <- function(tree_no, model) {
6193
if (is.factor(prediction)) prediction <- as.character(prediction)
6294
list(
6395
prediction = prediction,
64-
path = get_ra_path(.x, tree, TRUE)
96+
path = get_ra_path(.x, tree, child_info, TRUE)
6597
)
6698
} else {
6799
preds <- map_lgl(colnames(tree), ~ "pred." == substr(.x, 1, 5))
@@ -79,7 +111,7 @@ get_ra_tree <- function(tree_no, model) {
79111
prediction = prediction,
80112
prob = prob,
81113
probs = predictions,
82-
path = get_ra_path(.x, tree, TRUE)
114+
path = get_ra_path(.x, tree, child_info, TRUE)
83115
)
84116
}
85117
}

tests/testthat/_snaps/model-cubist.md

Lines changed: 0 additions & 50 deletions
This file was deleted.

tests/testthat/test-model-ranger.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ run_test <- function(model, test_formula = TRUE) {
1717
}
1818
}
1919

20-
run_test(
20+
tidypredict_fit(
2121
ranger::ranger(Species ~ ., data = iris, num.trees = num_trees, seed = 100, num.threads = 2)
2222
)
2323

tidypredict.Rproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
Version: 1.0
2+
ProjectId: 0cda5e71-1965-4603-b39a-db0e57a2cde6
23

34
RestoreWorkspace: Default
45
SaveWorkspace: Default

0 commit comments

Comments
 (0)