How to Build a Random Forest Classifier Using Data Frames in Spark

Decsion Trail in Forest

The release of Spark 1.5 increased support for using data frames with MLLib—Spark’s machine learning library. MLlib now divides into two packages

  • spark.mllib which contains the original API built on top of RDDs
  • which provides a higher-level API built on top of DataFrames for constructing machine learning pipelines

While the spark.mllib package will continue to be supported, the package is recommended because of the flexibility and performance benefits offered through the use of data frames.

Data frames were popularized by R and are similar to SQL tables. They have been adopted by various data science frameworks—such as the Pandas library in Python and Deedle in .NET. Their introduction throughout the Spark API allows a degree of optimization that wasn’t available when using the RDD-based APIs. APIs built on data frames provide developers with a higher-level of abstraction and provide Spark with the freedom to make aggressive improvements to execution plans.

However, as the package is relatively new, most of the available examples make use of the older API. In this article we update a previous analysis that used the spark.mllib package to predict cancer diagnoses using a random forest.

We’ll be using the interactive Spark Shell to build our model.

Importing the required classes

We will need to import a number of classes from the package. Let’s do this all at once to get it out of the way.


Read the data

In this example we’ll load the data from a JSON file. The JSON file format required by Spark is not a typical JSON file. We require each object (i.e. row of the eventual data frame) to be serialized on one line. The objects are not wrapped in an array. The file format is

{"bareNuclei":1.0,"blandChromatin":3.0,"class":2.0,"clumpThickness":5.0, ... }
{"bareNuclei":10.0,"blandChromatin":3.0,"class":2.0,"clumpThickness":5.0, ... }
{"bareNuclei":2.0,"blandChromatin":3.0,"class":2.0,"clumpThickness":3.0, ... }
{"bareNuclei":10.0,"blandChromatin":9.0,"class":4.0,"clumpThickness":8.0, ... }

I’ve created a (zipped) “Spark JSON” version of the data that you can download. This file contains a set of samples taken from patients. Each of the samples has a number of features (e.g. Bare Nuclei) and each feature is on a 1-10 scale. It also has a Class feature which is the eventual diagnosis. This takes values 2 or 4, meaning benign or malignant, respectively. The original data set is available from the UCI Machine Learning Repository.

Let’s read the data.

val df1 ="file:///home/user/cancer.json")

We can now examine the structure of the data frame. First, let’s review the “table” schema


This should display the JSON schema—i.e. the list of columns in the data frame. There are 11 columns.

 |-- bareNuclei: double (nullable = true)
 |-- blandChromatin: double (nullable = true)
 |-- class: double (nullable = true)
 |-- clumpThickness: double (nullable = true)
 |-- marginalAdhesion: double (nullable = true)
 |-- mitoses: double (nullable = true)
 |-- normalNucleoli: double (nullable = true)
 |-- sampleCodeNumber: double (nullable = true)
 |-- singleEpithelialCellSize: double (nullable = true)
 |-- uniformityOfCellShape: double (nullable = true)
 |-- uniformityOfCellSize: double (nullable = true)

We can also look at the contents of the data frame.

This displays, by default, the first 20 rows.

Remove samples with missing values

There are missing values in this data set. Specifically, the bareNuclei field has some null values. Let’s remove those samples.

val df2 = df1.filter("bareNuclei is not null")

Prepare the data for classification

We need to do two things to prepare our data for the random forest classifier

  • Create a column that is a vector of all the features (predictor values)
  • Transform the class field to an index—it needs to contain a few discrete values

First, we create a “feature” column of all the predictor values.

val assembler = new VectorAssembler().setInputCols(Array("bareNuclei", 
  "blandChromatin", "clumpThickness", "marginalAdhesion", "mitoses", 
  "normalNucleoli", "singleEpithelialCellSize", "uniformityOfCellShape", 
val df3 = assembler.transform(df2)

Next, we create discrete target values (labels) from the class field.

val labelIndexer = new StringIndexer().setInputCol("class").setOutputCol("label")
val df4 =

Review the df4 data frame.

You’ll see that we now have a vector field called “features” and a “double” field called “label”. These are the default field names expected by the classifier. If we compare the class and label fields we can see that the label “0” has been assigned to benign samples (2) and “1” to malignant samples (4).

Train the classifier

Before we train the classifier we need to split our data into training and test data sets (data frames).

val splitSeed = 5043 
val Array(trainingData, testData) = df4.randomSplit(Array(0.7, 0.3), splitSeed)

70% of the data is used to training the model. The remaining 30% is held back for testing.

Now let’s train a random forest classifier that has the following hyper-parameter values

  • Gini impurity
  • A maximum tree depth of 3
  • 20 trees in the forest
  • Automatically selects the number of features to consider for splits at each tree node
  • Uses a random number seed of 5043, allowing us to repeat the results

We create the classifier and then use it to train (fit) the model.

val classifier = new RandomForestClassifier()
val model =

Predicting diagnoses using the test data

We can now ask the model to predict diagnoses for the test samples.

val predictions = model.transform(testData)

Let’s examine the first 5 predictions."sampleCodeNumber", "label", "prediction").show(5)

This should show the following

|       1002945.0|  0.0|       1.0|
|       1017023.0|  0.0|       0.0|
|       1033078.0|  0.0|       0.0|
|       1036172.0|  0.0|       0.0|
|       1041801.0|  1.0|       1.0|

As you can see, the previous transform produced a new prediction column.

Evaluate the quality of the model

Spark provides us with tools to evaluate the accuracy of our model. Let’s generate the “precision” metric by comparing the label column with the prediction column.

val evaluator = new MulticlassClassificationEvaluator()
val accuracy = evaluator.evaluate(predictions) 

You should see that the accuracy is approximately 97%. Not bad, given that we’ve made no attempt to tune the model.

In this article, we’ve introduced Spark MLlib’s data frame API and used it to build a random forest classifier for a realistic data set. There are many other machine learning algorithms to explore.

If you’ve enjoyed this article and wish to learn more, Learning Tree has a number of courses that may help, including

Type to search

Do you mean "" ?

Sorry, no results were found for your query.

Please check your spelling and try your search again.