October 2014

Volume 29 Number 10


Test Run : Probit Classification Using C#

James McCaffrey

James McCaffreyProbit (“probability unit”) classification is a machine learning (ML) technique that can be used to make predictions in situations where the dependent variable to predict is binary—that is, it can take one of just two possible values. Probit classification is also called probit regression and probit modeling.

Probit classification is quite similar to logistic regression (LR) classification. The two techniques apply to the same types of problems and tend to give similar results, and the choice of using probit or LR classification usually depends on the discipline in which you’re working. Probit is often used in economics and finance, while LR is more common in other fields.

To get an understanding of what probit classification is, take a look at the demo program in Figure 1.

Probit Classification in Action
Figure 1 Probit Classification in Action

The demo uses probit classification to create a model that predicts whether a hospital patient will die, based on age, sex and the results of a kidney test. The data is completely artificial. The first raw data item is:

48.00   +1.00   4.40   0

The raw data consists of 30 items. Sex is encoded as -1 for male and +1 for female. The variable to predict, Died, is in the last column and is encoded as 0 = false (therefore the person survived) and 1 = true. So, the first data item indicates a 48-year-old female with a kidney score of 4.40 who survived. The demo begins by normalizing the age and kidney data so that all values have roughly the same magnitude. The first data item, after normalization, is:

-0.74   +1.00   -0.61   0.00

Normalized values less than 0.0 (here, both age and kidney score) are below average, and values greater than 0.0 are above average.

The source data is then randomly split into a 24-item training set to create the model, and a six-item test set to estimate the accuracy of the model when applied to new data with unknown results.

The demo program then creates a probit model. Behind the scenes, training is performed using a technique called simplex optimization, with the maximum number of iterations set to 100. After training, the weights that define the model are displayed { -4.26, 2.30, -1.29, 3.45 }.

The first weight value, -4.26, is a general constant and doesn’t apply to any one specific predictor variable. The second weight, 2.30, applies to age; the third weight, -1.29, applies to sex; and the fourth weight, 3.45, applies to kidney score. Positive weights, such as the ones associated with age and kidney score, mean larger values of the predictor indicate the dependent variable, Died, will be larger—that is, closer to true.

The demo computes the accuracy of the model on the 24-item training set (100 percent correct) and on the test set (83.33 percent, or five correct and one wrong). The more significant of these two values is the test accuracy. It’s a rough estimate of the overall accuracy of the probit model.

This article assumes you have at least intermediate programming skills and a basic understanding of ML classification, but doesn’t assume you know anything about probit classification. The demo program is coded using C#, but you should be able to refactor the demo to other .NET languages without too much trouble. The demo code is too long to present in its entirety, but all the code is available in the download that accompanies this article at msdn.microsoft.com/magazine/msdnmag1014. All normal-error checking has been removed to keep the main ideas clear.

Understanding Probit Classification

A simple way to predict death from age, sex and kidney score would be to form a linear combination along the lines of:

died = b0 + (b1)(age) + (b2)(sex) + (b3)(kidney)

where the b0, b1, b2, b3 are weights that must somehow be determined so the computed output values on the training data closely match the known dependent variable values. Logistic regression extends this idea with a more complicated prediction function:

z = b0 + (b1)(age) + (b2)(sex) + (b3)(kidney)
died = 1.0 / (1.0 + e-z)

The math is very deep, but the prediction function, called the logistic sigmoid function, conveniently always returns a value between 0.0 and 1.0, which can be interpreted as a probability. Probit classification uses a different prediction function:

z = b0 + (b1)(age) + (b2)(sex) + (b3)(kidney)
died = Phi(z)

The Phi(z) function is called the standard normal cumulative density function (which is usually abbreviated CDF) and it always returns a value between 0.0 and 1.0. The CDF is tricky because there’s no simple equation for it. The CDF for a value z is the area under the famous bell-shaped curve function (the Gaussian function) from negative infinity to z.

This sounds a lot more complicated than it really is. Take a look at the graph in Figure 2. The graph shows the logistic sigmoid function and the CDF function plotted side by side. The important point is that for any value z, even though the underlying functions are very different, both functions return a value between 0.0 and 1.0 that can be interpreted as a probability.

The Graph of the Cumulative Density Function
Figure 2 The Graph of the Cumulative Density Function

So, from a developer’s point of view, the first challenge is to write a function that computes the CDF for a value z. There’s no simple equation to compute CDF, but there are dozens of exotic-looking approximations. One of the most common ways to approximate the CDF function is to compute something called the erf function (short for Error Function) using an equation called A&S 7.1.26, and then use erf to determine CDF. Code for the CDF function is presented in Figure 3.

Figure 3 The CDF Function in C#

static double CumDensity(double z)
{
  double p = 0.3275911;
  double a1 = 0.254829592;
  double a2 = -0.284496736;
  double a3 = 1.421413741;
  double a4 = -1.453152027;
  double a5 = 1.061405429;
  int sign;
  if (z < 0.0)
    sign = -1;
  else
    sign = 1;
  double x = Math.Abs(z) / Math.Sqrt(2.0);
  double t = 1.0 / (1.0 + p * x);
  double erf = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) *
    t * Math.Exp(-x * x);
  return 0.5 * (1.0 + (sign * erf));
}

The second challenge when writing probit classification code is to determine the values for the weights so when presented with training data, the computed output values closely match the known output values. Another way of looking at the problem is that the goal is to minimize the error between computed and known output values. This is called training the model using numerical optimization.

There’s no easy way to train most ML classifiers, including probit classifiers. There are roughly a dozen major techniques you can use, and each technique has dozens of variations. Common training techniques include simple gradient descent, back-propagation, Newton-Raphson, particle swarm optimization, evolutionary optimization and L-BFGS. The demo program uses one of the oldest and simplest training techniques—simplex optimization.

Understanding Simplex Optimization

Loosely speaking, a simplex is a triangle. The idea behind simplex optimization is to start with three possible solutions (hence, “simplex”). One solution will be “best” (have the smallest error), one will be “worst” (largest error), and the third is called “other.” Next, simplex optimization creates three new potential solutions called “expanded,” “reflected,” and “contracted.” Each of these is compared against the current worst solution, and if any of the new candidates is better (smaller error), the worst solution is replaced.

Simplex optimization is illustrated in Figure 4. In a simple case where a solution consists of two values, such as (1.23, 4.56), you can think of a solution as a point on the (x, y) plane. The left side of Figure 4 shows how three new candidate solutions are generated from the current best, worst and “other” solutions.

Simplex Optimization
Figure 4 Simplex Optimization

First, a centroid is computed. The centroid is the average of the best and other solutions. In two dimensions, this is a point halfway between the best and other points. Next, an imaginary line is created, which starts at the worst point and extends through the centroid. The contracted candidate is between the worst and centroid points. The reflected candidate is on the imaginary line, past the centroid. And the expanded candidate is past the reflected point.

In each iteration of simplex optimization, if one of the expanded, reflected or contracted candidates is better than the current worst solution, worst is replaced by that candidate. If none of the three candidates generated are better than the worst solution, the current worst and other solutions are moved toward the best solution to points somewhere between their current position and the best solution, as shown in the right-hand side of Figure 4.

After each iteration, a new virtual “best-other-worst” triangle is formed, getting closer and closer to an optimal solution. If a snapshot of each triangle is taken, when looked at sequentially, the shifting triangles resemble a pointy blob moving across the plane in a way that resembles a single-celled amoeba. For this reason, simplex optimization is sometimes called amoeba method optimization.

There are many variations of simplex optimization, which differ in how far the contracted, reflected, and expanded candidate solutions are from the current centroid, and the order in which the candidate solutions are checked to see if they’re better than the current worst solution. The most common form of simplex optimization is called the Nelder-Mead algorithm. The demo program uses a simpler variation that doesn’t have a specific name.

For probit classification, each potential solution is a set of weight values. Figure 5 shows, in pseudocode, the variation of simplex optimization used in the demo program.

Figure 5 Pseudocode for the Simplex Optimization Used in the Demo Program

randomly initialize best, worst other solutions
loop maxEpochs times
  create centroid from worst and other
  create expanded
  if expanded is better than worst, replace worst with expanded,
    continue loop
  create reflected
  if reflected  is better than worst, replace worst with reflected,
    continue loop
  create contracted
  if contracted  is better than worst, replace worst with contracted,
    continue loop
  create a random solution
  if  random solution is better than worst, replace worst,
    continue loop
  shrink worst and other toward best
end loop
return best solution found

Simplex optimization, like all other ML optimization algorithms, has pros and cons. However, it’s (relatively) simple to implement and usually—though not always—works well in practice.

The Demo Program

To create the demo program, I launched Visual Studio and selected the C# console application program template and named it ProbitClassification. The demo has no significant Microsoft .NET Framework version dependencies, so any relatively recent version of Visual Studio should work. After the template code loaded, in the Solution Explorer window I renamed file Program.cs to ProbitProgram.cs and Visual Studio automatically renamed class Program.

Figure 6 Beginning of the Demo Code

using System;
namespace ProbitClassification
{
  class ProbitProgram
  {
    static void Main(string[] args)
    {
      Console.WriteLine("\nBegin Probit Binary Classification demo");
      Console.WriteLine("Goal is to predict death (0 = false, 1 = true)");
      double[][] data = new double[30][];
      data[0] = new double[] { 48, +1, 4.40, 0 };
      data[1] = new double[] { 60, -1, 7.89, 1 };
      // Etc.
      data[29] = new double[] { 68, -1, 8.38, 1 };
...

The beginning of the demo code is shown in Figure 6. The dummy data is hardcoded into the program. In a non-demo scenario, your data would be stored in a text file and you’d have to write a utility method to load the data into memory. Next, the source data is displayed using program-defined helper method ShowData:

Console.WriteLine("\nRaw data: \n");
Console.WriteLine("       Age       Sex      Kidney   Died");
Console.WriteLine("=======================================");
ShowData(data, 5, 2, true);

Next, columns 0 and 2 of the source data are normalized:

Console.WriteLine("Normalizing age and kidney data");
int[] columns = new int[] { 0, 2 };
double[][] means = Normalize(data, columns); // Normalize, save means & stdDevs
Console.WriteLine("Done");
Console.WriteLine("\nNormalized data: \n");
ShowData(data, 5, 2, true);

The Normalize method saves and returns the means and standard deviations of all columns so that when new data is encountered, it can be normalized using the same parameters used to train the model. Next, the normalized data is split into a training set (80 percent) and test set (20 percent):

Console.WriteLine("Creating train (80%) and test (20%) matrices");
double[][] trainData;
double[][] testData;
MakeTrainTest(data, 0, out trainData, out testData);
Console.WriteLine("Done");
Console.WriteLine("\nNormalized training data: \n");
ShowData(trainData, 3, 2, true);

You might want to parameterize method MakeTrainTest to accept the percentage of items to place in the training set. Next, a program-defined probit classifier object is instantiated:

int numFeatures = 3; // Age, sex, kidney
Console.WriteLine("Creating probit binary classifier");
ProbitClassifier pc = new ProbitClassifier(numFeatures);

And then the probit classifier is trained, using simplex optimization to find values for the weights so that computed output values closely match the known output values:

int maxEpochs = 100; // 100 gives a representative demo
Console.WriteLine("Setting maxEpochs = " + maxEpochs);
Console.WriteLine("Starting training");
double[] bestWeights = pc.Train(trainData, maxEpochs, 0);
Console.WriteLine("Training complete");
Console.WriteLine("\nBest weights found:");
ShowVector(bestWeights, 4, true);

The demo program concludes by computing the classification accuracy of the model on the training data and on the test data:

...
  double testAccuracy = pc.Accuracy(testData, bestWeights);
  Console.WriteLine("Prediction accuracy on test data = 
    " + testAccuracy.ToString("F4"));
  Console.WriteLine("\nEnd probit binary classification demo\n");
  Console.ReadLine();
} // Main

The demo doesn’t make a prediction for previously unseen data. Making a prediction could look like:

// Slightly older, male, higher kidney score
double[] unknownNormalized = new double[] { 0.25, -1.0, 0.50 };
int died = pc.ComputeDependent(unknownNormalized, bestWeights);
if (died == 0)
  Console.WriteLine("Predict survive");
else if (died == 1)
  Console.WriteLine("Predict die");

This code assumes that the independent x variables—age, sex and kidney score—have been normalized using the means and standard deviations from the training data normalization process.

The ProbitClassifier Class

The overall structure of the ProbitClassifier class is presented in Figure 7. The ProbitClassifier definition contains a nested class named Solution. That sub-class derives from the IComparable interface so that an array of three Solution objects can be automatically sorted to give the best, other and worst solutions. Normally I don’t like fancy coding techniques, but in this situation the benefit slightly outweighs the added complexity.

Figure 7 The ProbitClassifier Class

public class ProbitClassifier
{
  private int numFeatures; // Number of independent variables
  private double[] weights; // b0 = constant
  private Random rnd;
  public ProbitClassifier(int numFeatures) { . . }
  public double[] Train(double[][] trainData, int maxEpochs, int seed) { . . }
  private double[] Expanded(double[] centroid, double[] worst) { . . }
  private double[] Contracted(double[] centroid, double[] worst) { . . }
  private double[] RandomSolution() { . . }
  private double Error(double[][] trainData, double[] weights) { . . }
  public void SetWeights(double[] weights) { . . }
  public double[] GetWeights() { . . }
  public double ComputeOutput(double[] dataItem, double[] weights) { . . }
  private static double CumDensity(double z) { . . }
  public int ComputeDependent(double[] dataItem, double[] weights) { . . }
  public double Accuracy(double[][] trainData, double[] weights) { . . }
  private class Solution : IComparable<Solution>
  {
    // Defined here
  }
}

The ProbitClassifier has two output methods. Method Compute­Output returns a value between 0.0 and 1.0 and is used during training to compute an error value. Method ComputeDependent is a wrapper around ComputeOutput and returns 0 if the output is less than or equal to 0.5, or 1 if the output is greater than 0.5. These return values are used to compute accuracy.

Wrapping Up

Probit classification is one of the oldest ML techniques. Because probit classification is so similar to logistic regression classification, common wisdom is to use either one technique or the other. Because LR is slightly easier to implement than probit, probit classification is used less often, and over time has become somewhat of a second-class ML citizen. However, probit classification is often very effective and can be a valuable addition to your ML toolkit.


Dr. James McCaffrey works for Microsoft Research in Redmond, Wash. He has worked on several Microsoft products including Internet Explorer and Bing. Dr. McCaffrey can be reached at jammc@microsoft.com.

Thanks to the following Microsoft Research technical experts for reviewing this article: Nathan Brown and Kirk Olynyk.