August 2015

Volume 30 Number 8

Test Run - K-Means++ Data Clustering

By James McCaffrey

James McCaffreyData clustering is the process of grouping data items so that similar items are placed together. Once grouped, the clusters of data can be examined to see if there are relationships that might be useful. For example, if a huge set of sales data was clustered, information about the data in each cluster might reveal patterns that could be used for targeted marketing.

There are several clustering algorithms. One of the most common is called the k-means algorithm. There are several variations of this algorithm. This article explains a relatively recent variation called the k-means++ algorithm.

Take a look at the demo program in Figure 1. The program starts with 20 data items, each consisting of a person’s height in inches and weight in pounds. Next, the number of clusters is set to 3. In most data clustering scenarios, the number of clusters must be specified by the user.

K-Means++ Clustering in Action
Figure 1 K-Means++ Clustering in Action

The demo program then clusters the data using the k-means++ algorithm. Each of the 20 data items is assigned to one cluster with an ID of 0, 1 or 2. The cluster assignments are stored in an array, where the array index corresponds to a data index and the array value is the associated cluster ID. For example, the final clustering of the demo data is:

0 2 1 1 2 . . . 1

This indicates data item 0 (height = 65.0, weight = 220.0) is assigned to cluster 0, data item 1 is assigned to cluster 2, data item 2 is assigned to cluster 1, and so on. The demo concludes by displaying the data, grouped by cluster. Here a very clear pattern is revealed. There are eight people characterized by medium height and heavy weight, seven people with short height and light weight, and five people with tall height and medium weight.

This article assumes you have at least intermediate programming skills but doesn’t assume you know anything about the k-means++ algorithm. The demo program is coded using C# but you shouldn’t have much difficulty refactoring the code to another language, such as Python or JavaScript. The demo code is too long to present in its entirety, but the complete source is available in the code download that accompanies this article.

Understanding the K-Means++ Algorithm

The k-means++ algorithm is a variation of the standard k-means algorithm, so in order to understand k-means++ you must first understand the regular k-means. The k-means algorithm has an interesting history, and is sometimes called Lloyd’s algorithm. The “k” in k-means refers to the number of clusters. In very high-level pseudo-code, the most common form of standard k-means is deceptively simple:

pick k initial means
loop until no change
  assign each data item to closest mean
  compute new means based on new clusters
end loop

In spite of its simple appearance, the standard k-means algorithm is, in fact, very subtle, and implementation is surprisingly tricky. Suppose the data to be clustered consists of the 20 data items shown in Figure 1, with k set to 3. The first step is to select three of the data items to act as initial means. A common approach is to select data items at random. Suppose the three randomly selected data items are item 2 (59.0, 110.0) as the mean of cluster 0, item 4 (75.0, 150.0) as the mean of cluster 1, and item 6 (68.0, 230.0) as the mean of cluster 2.

Inside the main processing loop, each data item is examined and assigned to the cluster with the closest mean. So, data item 0 (65.0, 220.0) would be assigned to cluster 2 because item 0 is closer to (68.0, 230.0) than to the other two means. Each of the remaining 19 data items would be assigned to a cluster. Note that the data items that were initially selected as means would be assigned to the correct cluster because the distance would be 0.

After each data item is assigned to one of the clusters, a new mean for each cluster is calculated. Suppose that cluster 0 currently contains just three data items: (59.0, 110.0), (70.0, 210.0), (61.0, 130.0). The new mean for this cluster would be:

( (59 + 70 + 61)/3, (110 + 210 + 130)/3 ) =
(190/3, 450/3) =
(63.3, 150.0)

New means for clusters 1 and 2 would be calculated similarly. Notice the new means are not necessarily one of the actual data items anymore. Technically, each of the new means is a “centroid,” but the term “mean” is commonly used.

After computing new means, each data item is examined again and assigned to the cluster with the closest new mean. The iterative update-clusters, update-means process continues until there’s no change in cluster assignments.

This all sounds relatively simple, but a lot can go wrong with a naive implementation of the standard k-means algorithm. In particular, a bad selection for the initial means can lead to a very poor clustering of data, or to a very long runtime to stabilization, or both. As it turns out, good initial means are ones that aren’t close to each other. The k-means++ algorithm selects initial means that aren’t close to each other, then uses the standard k-means algorithm for clustering.

The K-Means++ Initialization Mechanism

In high-level pseudo-code the k-means++ initialization mechanism to select means is:

select one data item at random as first mean
loop k-1 times
  compute distance from each item to closest mean
  select an item that has large distance-squared
    as next initial mean
end loop

Again, the pseudo-code is deceptively simple. The k-means++ initialization mechanism is illustrated in Figure 2. There are nine data points, each of which has two components. The number of clusters is set to 3, so 3 data items must be selected as initial means.

K-Means++ Initialization Mechanism
Figure 2 K-Means++ Initialization Mechanism

The diagram in Figure 2 shows the k-means++ initialization process after two of the three initial means have been selected. The first initial mean at (3, 6) was randomly selected. Then the distance-squared from each of the other 8 data items to the first mean was computed, and using that information, the second initial mean at (4, 3) was selected (in a way I’ll explain shortly).

To select a data item as the third initial mean, the squared distance from each data point to its closest mean is computed. The distances are shown as dashed lines. Using these squared distance values, the third mean will be selected so that data items with small squared distance values have a low probability of being selected, and data items with large squared distance values have a high probability of being selected. This technique is sometimes called proportional fitness selection.

Proportional fitness selection is the heart of the k-means++ initialization mechanism. There are several ways to implement proportional fitness selection. The demo program uses a technique called roulette wheel selection. In high-level pseudo-code, one form of roulette wheel selection is:

p = random value between 0.0 and 1.0
create an array of cumulative probabilities
loop each cell of cum prob array
  if cum[i] >= p
    return i
  end if
end loop

A concrete example will help clarify roulette wheel selection. Suppose there are four candidate items (0, 1, 2, 3) with associated values (20.0, 10.0, 40.0, 30.0). The sum of the values is 20.0 + 40.0 + 10.0 + 30.0 = 100.0. Proportional fitness selection will pick item 0 with probability 20.0/100.0 = 0.20; pick item 1 with probability 10.0/100.0 = 0.10; pick item 2 with probability 40.0/100.0 = 0.40; and pick item 3 with probability 30.0/100.0 = 0.30.

If the probabilities of selection are stored in an array as (0.20, 0.10, 0.40, 0.30), the cumulative probabilities can be stored in an array with values (0.20, 0.30, 0.70, 1.00). Now, suppose a random p is generated with value 0.83. If i is an array index into the cumulative probabilities array, when i = 0, cum[i] = 0.20, which isn’t greater than p = 0.83, so i increments to 1. Now cum[i] = 0.30, which is still not greater than p, so i increments to 2. Now cum[i] = 0.70, which is still not greater than p, so i increments to 3. Now cum[i] = 1.00, which is greater than p, so i = 3 is returned as the selected item.

Notice that the distances between the cumulative probabilities differ, with larger differences corresponding to those items with higher probabilities of selection.

To summarize, the k-means++ algorithm selects initial means so the means are dissimilar, then uses the standard k-means algorithm to cluster data. The initialization process uses proportional fitness selection, which can be implemented in several ways. The demo program uses roulette wheel selection.

Overall Program Structure

The overall structure of the demo program, with a few minor edits to save space, is presented in Figure 3. To create the demo program, I launched Visual Studio and created a new C# console application project named KMeansPlus. The demo program has no significant Microsoft .NET Framework dependencies so any relatively recent version of Visual Studio will work.

Figure 3 Overall Program Structure

using System;
using System.Collections.Generic;
namespace KMeansPlus
{
  class KMeansPlusProgram
  {
    static void Main(string[] args)
    {
      Console.WriteLine("Begin k-means++ demo");
      // All program code here
      Console.WriteLine("End k-means++ demo");
      Console.ReadLine();
    }
    public static int[] Cluster(double[][] rawData,
      int numClusters, int seed) { . . }
    public static double[][] InitMeans(int numClusters,
      double[][] data, int seed) { . . }
    private static double[][] Normalized(double[][] rawData) { . . }
    private static double[][] MakeMatrix(int rows, int cols) { . . }
    private static bool UpdateMeans(double[][] data,
      int[] clustering, double[][] means) { . . }
    private static bool UpdateClustering(double[][] data,
      int[] clustering, double[][] means) { . . }
    private static double Distance(double[] tuple,
      double[] mean) { . . }
    private static int MinIndex(double[] distances) { . . }
    static void ShowData(double[][] data, int decimals,
      bool indices, bool newLine) { . . }
    static void ShowVector(int[] vector, bool newLine) { . . }
    static void ShowClustered(double[][] data, int[] clustering,
      int numClusters, int decimals)
  }
} // ns

After the template code loaded into the editor, in the Solution Explorer window I right-clicked on file Program.cs, renamed it to the more descriptive KMeansPlusProgram.cs and allowed Visual Studio to automatically rename class Program. In the editor window, at the top of the template-generated code, I deleted all references to namespaces except the ones to the top-level System namespace and the Collections.Generic namespace.

The Main method begins by setting up 20 raw data items:

double[][] rawData = new double[20][];
rawData[0] = new double[] { 65.0, 220.0 };
...
rawData[19] = new double[] { 61.0, 130.0 };

In a non-demo scenario you’d probably read data from a text file or SQL database. After displaying the raw data using a program-defined helper method named ShowData, the data is clustered:

int numClusters = 3;
int[] clustering = Cluster(rawData,   numClusters, 0);

Although there are some techniques you can use to guess the best number of clusters, in general you must use trial and error. The Cluster method accepts numeric raw data to cluster in an array-of-array style matrix; the number of clusters to use (I could have used “k” but “numClusters” is more readable); and a seed value to use for randomization.

The Main method concludes by displaying the clustering array and showing the raw data grouped by cluster:

ShowVector(clustering, true);
ShowClustered(rawData, clustering, numClusters, 1);

I used a static method approach rather than an OOP approach. Method Cluster calls four helper methods. Helper method Normalized accepts a matrix of raw data and returns a matrix where the data has been normalized so that all values are roughly the same magnitude (typically between -6.0 and +6.0). Method InitMeans implements the k-means++ initialization mechanism. Methods UpdateClustering and UpdateMeans implement the core parts of the standard k-means algorithm. 

Methods InitMeans and UpdateClustering both call helper method Distance, which returns the Euclidean distance between two data items. For example, if one data tuple is (3.0, 9.0, 5.0) and a second tuple is (2.0, 6.0, 1.0), the Euclidean distance between the two items is:

Sqrt( (3-2)^2 + (9-6)^2 + (5-1)^2) ) =
Sqrt( 1 + 9 + 16 ) =
Sqrt(26) = 5.1

Other distance definitions can be used. In general, k-means and k-means++ are used to cluster strictly numeric data rather than categorical data.

Implementing K-Means++

The code for method Cluster is presented in Figure 4. Method Cluster begins by normalizing the raw data so that large components in data items (such as weight values in the demo) don’t dominate smaller components (height values). The demo uses Gaussian normalization. Two common alternatives are min-max normalization, and order of magnitude normalization. A design alternative is to normalize your raw data in a preprocessing step, then pass the normalized data directly to method Cluster.

Figure 4 Method Cluster

public static int[] Cluster(double[][] rawData,
 int numClusters, int seed)
{
  double[][] data = Normalized(rawData);
  bool changed = true;
  bool success = true;
  double[][] means = InitMeans(numClusters, data, seed);
  int[] clustering = new int[data.Length];
  int maxCount = data.Length * 10;
  int ct = 0;
  while (changed == true && success == true &&
    ct < maxCount)
  {
    changed = UpdateClustering(data, clustering, means);
    success = UpdateMeans(data, clustering, means);
    ++ct;
  }
  return clustering;
}

Method InitMeans implements the k-means++ initialization mechanism and returns a set of means that are far apart from each other in terms of Euclidean distance. Inside the main clustering loop, method UpdateClustering iterates through each data item and assigns each item to the cluster associated with the closest current means/centroids. The method returns false if there’s no change to cluster assignments (indicating that clustering is complete) or if the new clustering would result in a cluster that has no data items (indicating something is wrong). An alternative is to throw an exception on a zero-count cluster situation.

Method UpdateMeans iterates through the data assigned to each cluster and computes a new mean/centroid for each cluster. The method returns false if one or more means can’t be calculated because a cluster has no data items.

The main clustering loop uses a sanity count check to prevent an infinite loop. The k-means algorithm typically stabilizes very quickly, but there’s no guarantee the algorithm will stabilize at all. The value of maxCount is set to 10 times the number of data items, which is arbitrary but has worked well for me in practice.

The definition of method InitMeans begins with:

public static double[][] InitMeans(int numClusters,
  double[][] data, int seed)
{
  double[][] means = MakeMatrix(numClusters, data[0].Length);
  List<int> used = new List<int>();
...

The local array-of-arrays style matrix named means holds the method return, where the row index is a cluster ID and each row is an array that holds the components of the associated mean. The List<int> named used holds indices of data items that have been assigned as initial means, so duplicate initial means can be prevented. This approach assumes there are no data items with identical values. When clustering, how you deal with duplicate data items depends on your particular problem scenario. One alternative to removing duplicate items from the source data is to weight duplicate items by their frequency.

Next, the first initial mean is selected and stored:

Random rnd = new Random(seed);
int idx = rnd.Next(0, data.Length);
Array.Copy(data[idx], means[0], data[idx].Length);
used.Add(idx);

The first initial mean is selected at random from all data items. The initial means are existing data items and they are sometimes called medoids.

Next, a for loop is constructed to select the remaining k-1 means:

for (int k = 1; k < numClusters; ++k)
{
  double[] dSquared = new double[data.Length];
  int newMean = -1;
...

Array dSquared holds the squared distances between each data item and the closest existing initial mean. Variable newMean holds the index of a data item that will be the next initial mean. Next, each (normalized) data item is examined and its dSquared value is computed and stored:

for (int i = 0; i < data.Length; ++i)
{
  if (used.Contains(i) == true) continue;
  double[] distances = new double[k];
  for (int j = 0; j < k; ++j)
    distances[j] = Distance(data[i], means[k]);
  int m = MinIndex(distances);
  dSquared[i] = distances[m] * distances[m];
}

The check to determine if a data item has already been used as an initial mean isn’t really necessary because if the item has been used, the distance squared to the closet mean will be the distance to itself, which is 0. The array named distances holds the Euclidean distances from the current data item to each of the existing k-initial means that have been selected so far.

Recall that Euclidean distance in the Distance method takes the square root of the sum of the squared differences between data item components. Because k-means++ uses squared distances, the squaring operation in InitMeans undoes the square root operation in Distance. Therefore, you could simplify the code by defining a method that returns squared distance directly.

Next, the loop to scan through cumulative probabilities for roulette wheel selection is prepared:

double p = rnd.NextDouble();
double sum = 0.0;
for (int i = 0; i < dSquared.Length; ++i)
  sum += dSquared[i];
double cumulative = 0.0;
int ii = 0;
int sanity = 0;

A random value between 0.0 and 1.0 is generated and the sum of the squared distances is calculated as explained in the section describing proportional fitness selection. Instead of explicitly creating an array of cumulative probabilities, it’s more efficient to generate the current cumulative probability on the fly.

Each cumulative probability is computed and examined in a while loop that implements roulette wheel selection:

while (sanity < data.Length * 2)
{
  cumulative += dSquared[ii] / sum;
  if (cumulative >= p && used.Contains(ii) == false)
  {
    newMean = ii; // the chosen index
    used.Add(newMean); // don't pick again
    break;
  }
  ++ii; // next candidate
  if (ii >= dSquared.Length) ii = 0; // past the end
  ++sanity;
}

The while loop advances until the cumulative probability value is greater than or equal to the random p value. However, duplicate initial means can’t be allowed so if the selected mean is in the “used” List<int>, the next available data item is selected. If the ii index runs past the end of the data, it’s reset to 0. Note that if a data item has already been selected as an initial mean, the next available data item will probably not be the next most likely item.

Method InitMeans concludes by saving the selected initial mean, and returning the array of selected means:

...
    Array.Copy(data[newMean], means[k], data[newMean].Length);
  } // k, each remaining mean
  return means;
} // InitMean

The purpose of the InitMeans method is to find k dissimilar data items to be used as initial means. Roulette wheel selection doesn’t guarantee that the selected means are maximally different from each other, and there’s a very small chance the selected means could be quite close to each other. Therefore, you may want to refactor method InitMeans so that roulette wheel selection is used several times to generate candidate sets of means, and then return the set containing means that are most different from each other.

Wrapping Up

This article is based on the 2007 research paper, “K-Means++: The Advantages of Careful Seeding,” by D. Arthur and S. Vassilvitskii. As is usually the case with research papers, virtually no implementation details are given. However, the paper carefully explains how k-means++ initialization works and establishes theoretical boundaries for its behavior.


Dr. James McCaffrey works for Microsoft Research in Redmond, Wash. He has worked on several Microsoft products including Internet Explorer and Bing. Dr. McCaffrey can be reached at jammc@microsoft.com.


Thanks to the following Microsoft Research technical expert for reviewing this article: Kirk Olynyk