Convert a a party from the partykit package to a data.tree structure.

# S3 method for class 'party'
as.Node(x, ...)

Arguments

x

The party object

...

other arguments (unused)

Examples

library(partykit)
#> Loading required package: libcoin
#> 
#> Attaching package: ‘partykit’
#> The following objects are masked from ‘package:party’:
#> 
#>     cforest, ctree, ctree_control, edge_simple, mob, mob_control,
#>     node_barplot, node_bivplot, node_boxplot, node_inner, node_surv,
#>     node_terminal, varimp
data("WeatherPlay", package = "partykit")
### splits ###
# split in overcast, humidity, and windy
sp_o <- partysplit(1L, index = 1:3)
sp_h <- partysplit(3L, breaks = 75)
sp_w <- partysplit(4L, index = 1:2)

## query labels
character_split(sp_o)
#> $name
#> [1] "V1"
#> 
#> $levels
#> [1] "(-Inf,1]" "(1,2]"    "(2, Inf]"
#> 

### nodes ###
## set up partynode structure
pn <- partynode(1L, split = sp_o, kids = list(
  partynode(2L, split = sp_h, kids = list(
      partynode(3L, info = "yes"),
      partynode(4L, info = "no"))),
  partynode(5L, info = "yes"),
  partynode(6L, split = sp_w, kids = list(
      partynode(7L, info = "yes"),
      partynode(8L, info = "no")))))
pn
#> [1] root
#> |   [2] V1 in (-Inf,1]
#> |   |   [3] V3 <= 75 *
#> |   |   [4] V3 > 75 *
#> |   [5] V1 in (1,2] *
#> |   [6] V1 in (2, Inf]
#> |   |   [7] V4 <= 1 *
#> |   |   [8] V4 > 1 *
### tree ###
## party: associate recursive partynode structure with data
py <- party(pn, WeatherPlay)
tree <- as.Node(py)

print(tree, 
      "splitname",
      count = function(node) nrow(node$data), 
      "splitLevel")
#>   levelName splitname count splitLevel
#> 1 1           outlook    14           
#> 2  ¦--2      humidity     5      sunny
#> 3  ¦   ¦--3       yes     2      <= 75
#> 4  ¦   °--4        no     3       > 75
#> 5  ¦--5           yes     4   overcast
#> 6  °--6         windy     5      rainy
#> 7      ¦--7       yes     3      false
#> 8      °--8        no     2       true

SetNodeStyle(tree, 
             label = function(node) paste0(node$name, ": ", node$splitname), 
             tooltip = function(node) paste0(nrow(node$data), " observations"),
             fontname = "helvetica")
SetEdgeStyle(tree, 
             arrowhead = "none", 
             label = function(node) node$splitLevel,
             fontname = "helvetica",
             penwidth = function(node) 12 * nrow(node$data)/nrow(node$root$data),
             color = function(node) {
               paste0("grey", 
                      100 - as.integer( 100 * nrow(node$data)/nrow(node$root$data))
                      )
             }
             )
Do(tree$leaves, 
   function(node) {
     SetNodeStyle(node, 
                  shape = "box", 
                  color = ifelse(node$splitname == "yes", "darkolivegreen4", "lightsalmon4"),
                  fillcolor = ifelse(node$splitname == "yes", "darkolivegreen1", "lightsalmon"),
                  style = "filled,rounded",
                  penwidth = 2
                  )
   }
   )

plot(tree)