party from the partykit package to a data.tree structure.R/node_conversion_party.R
as.Node.party.RdConvert a a party from the partykit package to a data.tree structure.
# S3 method for class 'party'
as.Node(x, ...)library(partykit)
#> Loading required package: libcoin
#>
#> Attaching package: ‘partykit’
#> The following objects are masked from ‘package:party’:
#>
#> cforest, ctree, ctree_control, edge_simple, mob, mob_control,
#> node_barplot, node_bivplot, node_boxplot, node_inner, node_surv,
#> node_terminal, varimp
data("WeatherPlay", package = "partykit")
### splits ###
# split in overcast, humidity, and windy
sp_o <- partysplit(1L, index = 1:3)
sp_h <- partysplit(3L, breaks = 75)
sp_w <- partysplit(4L, index = 1:2)
## query labels
character_split(sp_o)
#> $name
#> [1] "V1"
#>
#> $levels
#> [1] "(-Inf,1]" "(1,2]" "(2, Inf]"
#>
### nodes ###
## set up partynode structure
pn <- partynode(1L, split = sp_o, kids = list(
partynode(2L, split = sp_h, kids = list(
partynode(3L, info = "yes"),
partynode(4L, info = "no"))),
partynode(5L, info = "yes"),
partynode(6L, split = sp_w, kids = list(
partynode(7L, info = "yes"),
partynode(8L, info = "no")))))
pn
#> [1] root
#> | [2] V1 in (-Inf,1]
#> | | [3] V3 <= 75 *
#> | | [4] V3 > 75 *
#> | [5] V1 in (1,2] *
#> | [6] V1 in (2, Inf]
#> | | [7] V4 <= 1 *
#> | | [8] V4 > 1 *
### tree ###
## party: associate recursive partynode structure with data
py <- party(pn, WeatherPlay)
tree <- as.Node(py)
print(tree,
"splitname",
count = function(node) nrow(node$data),
"splitLevel")
#> levelName splitname count splitLevel
#> 1 1 outlook 14
#> 2 ¦--2 humidity 5 sunny
#> 3 ¦ ¦--3 yes 2 <= 75
#> 4 ¦ °--4 no 3 > 75
#> 5 ¦--5 yes 4 overcast
#> 6 °--6 windy 5 rainy
#> 7 ¦--7 yes 3 false
#> 8 °--8 no 2 true
SetNodeStyle(tree,
label = function(node) paste0(node$name, ": ", node$splitname),
tooltip = function(node) paste0(nrow(node$data), " observations"),
fontname = "helvetica")
SetEdgeStyle(tree,
arrowhead = "none",
label = function(node) node$splitLevel,
fontname = "helvetica",
penwidth = function(node) 12 * nrow(node$data)/nrow(node$root$data),
color = function(node) {
paste0("grey",
100 - as.integer( 100 * nrow(node$data)/nrow(node$root$data))
)
}
)
Do(tree$leaves,
function(node) {
SetNodeStyle(node,
shape = "box",
color = ifelse(node$splitname == "yes", "darkolivegreen4", "lightsalmon4"),
fillcolor = ifelse(node$splitname == "yes", "darkolivegreen1", "lightsalmon"),
style = "filled,rounded",
penwidth = 2
)
}
)
plot(tree)