A set of functions to calculate performance metrics for prediction models. Also see the Spark ML Documentation https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.ml.evaluation.package
Usage
ml_binary_classification_evaluator(
x,
label_col = "label",
raw_prediction_col = "rawPrediction",
metric_name = "areaUnderROC",
uid = random_string("binary_classification_evaluator_"),
...
)
ml_binary_classification_eval(
x,
label_col = "label",
prediction_col = "prediction",
metric_name = "areaUnderROC"
)
ml_multiclass_classification_evaluator(
x,
label_col = "label",
prediction_col = "prediction",
metric_name = "f1",
uid = random_string("multiclass_classification_evaluator_"),
...
)
ml_classification_eval(
x,
label_col = "label",
prediction_col = "prediction",
metric_name = "f1"
)
ml_regression_evaluator(
x,
label_col = "label",
prediction_col = "prediction",
metric_name = "rmse",
uid = random_string("regression_evaluator_"),
...
)Arguments
- x
A
spark_connectionobject or atbl_sparkcontaining label and prediction columns. The latter should be the output ofsdf_predict.- label_col
Name of column string specifying which column contains the true labels or values.
- raw_prediction_col
Raw prediction (a.k.a. confidence) column name.
- metric_name
The performance metric. See details.
- uid
A character string used to uniquely identify the ML estimator.
- ...
Optional arguments; currently unused.
- prediction_col
Name of the column that contains the predicted label or value NOT the scored probability. Column should be of type
Double.
Details
The following metrics are supported
Binary Classification:
areaUnderROC(default) orareaUnderPR(not available in Spark 2.X.)Multiclass Classification:
f1(default),precision,recall,weightedPrecision,weightedRecalloraccuracy; for Spark 2.X:f1(default),weightedPrecision,weightedRecalloraccuracy.Regression:
rmse(root mean squared error, default),mse(mean squared error),r2, ormae(mean absolute error.)
ml_binary_classification_eval() is an alias for ml_binary_classification_evaluator() for backwards compatibility.
ml_classification_eval() is an alias for ml_multiclass_classification_evaluator() for backwards compatibility.
Examples
if (FALSE) { # \dontrun{
sc <- spark_connect(master = "local")
mtcars_tbl <- sdf_copy_to(sc, mtcars, name = "mtcars_tbl", overwrite = TRUE)
partitions <- mtcars_tbl %>%
sdf_random_split(training = 0.7, test = 0.3, seed = 1111)
mtcars_training <- partitions$training
mtcars_test <- partitions$test
# for multiclass classification
rf_model <- mtcars_training %>%
ml_random_forest(cyl ~ ., type = "classification")
pred <- ml_predict(rf_model, mtcars_test)
ml_multiclass_classification_evaluator(pred)
# for regression
rf_model <- mtcars_training %>%
ml_random_forest(cyl ~ ., type = "regression")
pred <- ml_predict(rf_model, mtcars_test)
ml_regression_evaluator(pred, label_col = "cyl")
# for binary classification
rf_model <- mtcars_training %>%
ml_random_forest(am ~ gear + carb, type = "classification")
pred <- ml_predict(rf_model, mtcars_test)
ml_binary_classification_evaluator(pred)
} # }