How to Predict Outcomes Using Random Forests and Spark

Spark and Big DataRandom forests are an ensemble, or model of models, machine learning approach. The algorithm builds multiple decision trees, based on different subsets of the features in the data. Outcomes are then predicted by running observations through all the trees and averaging the individual predictions.

Think wisdom of crowds.

Spark’s machine learning library, MLlib, has support for random forest modeling. In this article, we’ll use MLlib to build a model for predicting cancer diagnoses.

We’ll develop the model using Scala, the language in which Spark is written. Scala is a good choice as

  • functional languages allow us to represent mathematical concepts naturally
  • functional languages, by encouraging immutability, are well suited to parallel enviroments (such as cluster computing)
  • features in Spark tend to be implemented in the Scala API first

The problem

We want to build a classification model that will predict whether a tissue sample is malignant or benign based on characteristics of the sample.

Random forests are a supervised learning technique so we need a training set. UCI’s Machine Learning Repository contains a dataset of (nine) breast cancer tissue sample characteristics, along with the eventual diagnoses.

Working with Spark and Big Data

I’ll assume you have access to a Spark cluster, as installing Spark is outside the scope of this article. If you don’t have a cluster consider using Amazon EC2 to get up and running.

Launch the spark shell. If you want to run locally, using all the cores of you machine, use

spark-shell --master local[*]

This puts us in a modified version of Scala’s interactive REPL shell. This is a great way to do data science using Spark and Scala.

The use case for Spark is big data. The model we’ll be building will be pretty small, but will demonstrate the concepts. Spark operates on Resilient Distributed Datasets (RDDs). For example, to load a big HDFS dataset into an RDD you might use

val rdd = sc.textFile("hdfs://datasets/observations.dat")

sc is the Spark context. It’s how we interact with Spark.

We’ll load our (small) dataset directly from the UCI website and convert it to an RDD. From that point on the steps will be identical to working with a big distributed dataset.

Loading and preparing the data

Import a Scala library that allows us to read data from a URL.


Read the (CSV) file in as a sequence of lines.

val csv = Source.fromURL("

Now convert the data to an RDD.

val rdd = sc.parallelize(csv)

OK. With the preliminaries over, let’s get down to some analysis.

First of all we have to prepare the data. This is always the bulk of any real-world data science project.

val data =",")).filter(_(6) != "?").map(_.drop(1))

Got that? Probably not. One of the benefits of Scala is that it’s very concise, but let’s break that last statement down.

  1. Split each line (using “,”) of the CSV file into separate fields
  2. Some of the rows in the dataset contain missing values in the seventh field. Remove those.
  3. The first column contains an ID. Drop it.
  4. Convert all remaining values to floating point numbers

Training the classifier

With the data ready we can set about constructing the classifier.

Import some data types we need.

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint

Create a set of LabeledPoint objects that contain a set of feature values and a label (the diagnosis).

val labeledPoints = => LabeledPoint(if (x.last == 4) 1 else 0, 

The dataset uses 2 to represent a benign diagnosis and 4 to represent a malignant one. We’ll use 0 and 1, respectively. init returns all the values in a sequence except the last (which in this case removes the diagnosis value from the list of features).

Now we split the data into training and test datasets—70% for training and 30% for testing.

val splits = labeledPoints.randomSplit(Array(0.7, 0.3), seed = 5043l)

val trainingData = splits(0)
val testData = splits(1)

Set up the model’s hyperparameters

import org.apache.spark.mllib.tree.configuration.Algo
import org.apache.spark.mllib.tree.impurity.Gini

val algorithm = Algo.Classification
val impurity = Gini
val maximumDepth = 3
val treeCount = 20
val featureSubsetStrategy = "auto"
val seed = 5043

We have a classification problem, and we’ll use 20 trees each having a maximum depth of three. The other parameters are pretty standard.

Now let’s build the model!

import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.RandomForest

val model = RandomForest.trainClassifier(trainingData, new Strategy(algorithm, 
  impurity, maximumDepth), treeCount, featureSubsetStrategy, seed)

That was easy. Time to try it out.

Evaluating the model

Let’s see how well it works. Using the test dataset we’ll generate a number of predictions and cross reference them with the actual diagnoses.

val labeledPredictions = { labeledPoint =>
    val predictions = model.predict(labeledPoint.features)
    (labeledPoint.label, predictions)

MLlib makes it easy to evaluate the effectiveness of our classifier.

import org.apache.spark.mllib.evaluation.MulticlassMetrics

val evaluationMetrics = new MulticlassMetrics( => 
  (x._1, x._2)))

Let’s look at the precision of the classifier—i.e. how many predictions it gets right.


97.2%. Pretty good.

What about the confusion matrix?

148   2   
  4  57

Columns are predictions. Rows and columns are ordered by ascending order of label (diagnosis). So, we have 148 correctly predicted benign cases and 57 correctly predicted malignant cases. There are a couple of false positives, where we predicted cancer when there wasn’t any. More concerning, we give four people the all-clear when they had malignant tumors.

In general, you find yourself trading off false positives for false negatives. Maybe the balance needs a little tuning. Given our problem, false negatives are considerably worse than false postives.

But, that’ll have to be for another time…

Learning More About Spark and Big Data

In this short article we’ve had a quick run through what’s involved in using MLlib to make predictions. Random forests are just one of many algorithms available to us. We’ve also necessarily skipped over some of the checks and balances you’d apply in a real machine learning project.

If this has piqued your interest, Learning Tree has a number of courses where you can dig deeper into similar techniques and technologies.

Type to search

Do you mean "" ?

Sorry, no results were found for your query.

Please check your spelling and try your search again.