Clustering data using k-means in ML.NET

Microsoft recently released a preview of a machine learning framework for .NET developers—ML.NET.

I needed to perform a clustering analysis from existing data in one of my applications. This is a pretty common machine learning task, so I decided to document the basic approach in this article.

We’ll use the well-worn iris data set from the UCI Machine Learning Repository to demonstrate how to perform a cluster analysis using ML.NET. The iris data set contains three fairly distinct clusters of different types of iris—with a few difficult to classify specimens. This will make it easy to check the output of the analysis.

If you wish to review/run the completed application, you can clone/download it from GitHub.

Data preparation

The ML.NET framework allows you to import your data directly as part of the analysis pipeline. However, I have a number of libraries I like to use to load and transform my data, so I used CsvHelper to load the data. This resulted in an IEnumerable of Iris objects.

public class Iris
    public double PetalLength { get; set; }
    public double PetalWidth { get; set; }
    public double SepalLength { get; set; }
    public double SepalWidth { get; set; }
    public string Type { get; set; }

At present, ML.NET only seems to work with public fields rather than properties (which I understand will be addressed as the framework matures). Rather than expose public fields in my domain classes (e.g. Iris) I wrote an Observation data transfer class and converted my Iris data to Observation data.

public class Observation
    public float[] Features;

    public static Observation Create(Iris iris)
        return new Observation
            Features = new[]

IEnumerable observations = data.Select(Observation.Create).ToList();

ML.NET looks for a Features field during training, so naming is important. This must be a vector of floats. The VectorType attribute specifies that the feature data is four dimensional.

  1. Sepal length
  2. Sepal width
  3. Petal length
  4. Petal width

Creating the predication class

We also need a class that defines a prediction—i.e. the cluster that contains a given iris.

public class ClusterPrediction
    public unit PredictedLabel;
    public float[] Score;

Again, the type and name of these fields (not properties) is important.

Building the pipeline

We are now ready to construct the learning pipeline.

var pipeline = new LearningPipeline
    new KMeansPlusPlusClusterer
        K = 3,
        NormalizeFeatures = NormalizeOption.Yes,
        MaxIterations = 100

We want to identify 3 clusters, normalize the training data and stop the analysis after 100 iterations (if it still hasn’t converged).

The model can now be trained.

PredictionModel model = pipeline.Train();

Determining the cluster assignments

Finally, we can predict the clusters containing each observation (iris).

data.ToList().ForEach(x =>
    ClusterPrediction prediction = model.Predict(Observation.Create(x));

    Console.WriteLine($"Type {x.Type} was assigned to cluster {prediction.PredictedLabel}");

prediction.Score is a array of k (i.e. one for each cluster) numbers specifying the distance between each observation and the respective cluster centroid.

I obtained the following results on a test run

  • Cluster 1—40 versicolor and 11 virginica
  • Cluster 2—39 virginica and 10 versicolor
  • Cluster 3—50 setosa

Your results will differ, as k-means is non-deterministic.


ML.NET is currently a first preview release, so clearly there are things that will improve over the next few iterations (e.g. support for properties), but it’s already looking useful if you have a need to include machine learning in your .NET applications.

Learning Tree training

If you are interested in the topics covered in this blog post, Learning Tree has a number of courses that may help advance your skills in these areas…and avoid any traps.

Type to search

Do you mean "" ?

Sorry, no results were found for your query.

Please check your spelling and try your search again.