May 2016

Volume 31 Number 5

[Test Run]

The Multi-Armed Bandit Problem

By James McCaffrey | May 2016 | Get the Code: C#   VB

James McCaffreyImagine you’re in Las Vegas, standing in front of three slot machines. You have 20 tokens to use, where you drop a token into any of the three machines, pull the handle and are paid a random amount. The machines pay out differently, but you initially have no knowledge of what kind of payout schedules the machines follow. What strategies can you use to try and maximize your gain?

This is an example of what’s called the multi-armed bandit problem, so named because a slot machine is informally called a one-armed bandit. The problem is not as whimsical as it might first seem. There are many important real-life problems, such as drug clinical trials, that are similar to the slot machine example.

It’s unlikely you’ll ever need to code an implementation of the multi-armed bandit problem in most enterprise development scenarios. But you might want to read this article for three reasons. First, several of the programming techniques used in this article can be used in other, more common programming scenarios. Second, a concrete code implementation of the multi-armed bandit problem can serve as a good introduction to an active area of economics and machine learning research. And third, you just might find the topic interesting for its own sake.

The best way to get a feel for where this article is headed is to take a look at the demo program shown in Figure 1. There are many different algorithms that can be used on multi-armed bandit problems. For example, a completely random approach would be to just select a machine at random for each pull, then hope for the best. The demo presented here uses a basic technique called the explore-exploit algorithm.

Figure 1 Using Explore-Exploit on a Multi-Armed Bandit Problem

The demo begins by creating three machines. Each machine pays a random amount on each pull, where the payout follows a Gaussian (bell-shaped) distribution with a specified mean (average) and standard deviation. The third machine is the best in one sense because it has the highest mean payout per pull of 0.1 arbitrary units. In a non-demo scenario, you wouldn’t know the characteristics of the machines.

The total number of pulls available is set to 20. In explore-exploit, you set aside a certain proportion of your allotted pulls and use them to try and find the best machine. Then you use your remaining pulls only on the best machine found during the preliminary explore phase. The key variable for the explore-exploit algorithm is the percentage of pulls you designate for the explore phase. The demo sets the explore-percentage to 0.40, therefore, there are 20 * 0.40 = 8 explore pulls followed by 20 - 8 = 12 exploit pulls. Increasing the percentage of explore pulls increases the probability of finding the best machine, at the expense of having fewer pulls left to take advantage of the best machine in the exploit phase.

During the eight-pull explore phase, the demo displays which machine was randomly selected and the associated payout. Behind the scenes, the demo saves the accumulated payouts of each machine. Machine 0 was selected three times and paid -0.09 + 0.12 + 0.29 = +0.32 units. Machine 1 was selected two times and paid -0.46 + -1.91 = -2.37 units. Machine 2 was selected three times and paid 0.19 + 0.14 + 0.70 = +1.03 units. In the case of the demo, the explore-exploit algorithm correctly identifies machine 2 as the best machine because it has the largest total payout. At this point the algorithm has a net gain (loss) of 0.32 + -2.37 + 1.03 = -1.02 units.

During the 12-pull exploit phase, the demo repeatedly plays only machine 2. The exploit phase payouts are 0.03 + 0.33 + . . + 0.45 = +2.32 units. Therefore, the total payout over all 20 pulls is -1.02 + 2.32 = +1.30 units and the average payout per pull is 1.30 / 20 = 0.065.

There are several different metrics that can be used to evaluate the effectiveness of a multi-armed bandit algorithm. One common measure is called regret. Regret is the difference between a theo­retical baseline total payout and the actual total payout for an algorithm. The baseline theoretical payout is the expected payout if all allotted pulls were used on the best machine. In the case of the three machines in the demo, the best machine has an average payout of 0.10 units, so the expected payout if all 20 pulls were used on that machine is 20 * 0.10 = 2.00 units. Because the explore-exploit algorithm yielded a total payout of only 1.30 units the regret metric is 2.00 - 1.30 = 0.70 units. Algorithms with lower regret values are better than those with higher regret values.

This article assumes you have at least intermediate programming skills, but doesn’t assume you know anything about the multi-armed bandit problem. The complete demo program, with a few minor edits to save space, is presented in Figure 2, and it’s also available in the associated code download. The demo is coded using C#, but you shouldn’t have too much trouble refactoring the demo to another language, such as Python or Java. All normal error checking was removed from the demo in order to keep the main ideas of the multi-armed bandit problem as clear as possible.

Figure 2 Complete Multi-Armed Bandit Demo Code

using System;

namespace MultiBandit

{

  class MultiBanditProgram

  {

    static void Main(string[] args)

    {

      Console.WriteLine("\nBegin multi-armed bandit demo \n");

      Console.WriteLine("Creating 3 Gaussian machines");

      Console.WriteLine("Machine 0 mean =  0.0, sd = 1.0");

      Console.WriteLine("Machine 1 mean = -0.5, sd = 2.0");

      Console.WriteLine("Machine 2 mean =  0.1, sd = 0.5");

      Console.WriteLine("Best machine is [2] mean pay = 0.1");

      int nMachines = 3;

      Machine[] machines = new Machine[nMachines];

      machines[0] = new Machine(0.0, 1.0, 0);

      machines[1] = new Machine(-0.5, 2.0, 1);

      machines[2] = new Machine(0.1, 0.5, 2);

      int nPulls = 20;

      double pctExplore = 0.40;

      Console.WriteLine("Setting nPulls = " + nPulls);

      Console.WriteLine("\nUsing pctExplore = " +

        pctExplore.ToString("F2"));

      double avgPay = ExploreExploit(machines, pctExplore,

        nPulls);

      double totPay = avgPay * nPulls;

      Console.WriteLine("\nAverage pay per pull = " +

        avgPay.ToString("F2"));

      Console.WriteLine("Total payout         = " +

        totPay.ToString("F2"));

      double avgBase = machines[2].mean;

      double totBase = avgBase * nPulls;

      Console.WriteLine("\nBaseline average pay = " +

        avgBase.ToString("F2"));

      Console.WriteLine("Total baseline pay   = " +

        totBase.ToString("F2"));

      double regret = totBase - totPay;

      Console.WriteLine("\nTotal regret = " +

        regret.ToString("F2"));

      Console.WriteLine("\nEnd bandit demo \n");

      Console.ReadLine();

    } // Main

    static double ExploreExploit(Machine[] machines,

      double pctExplore, int nPulls)

    {

      // Use basic explore-exploit algorithm

      // Return the average pay per pull

      int nMachines = machines.Length;

      Random r = new Random(2); // which machine

      double[] explorePays = new double[nMachines];

      double totPay = 0.0;

      int nExplore = (int)(nPulls * pctExplore);

      int nExploit = nPulls - nExplore;

      Console.WriteLine("\nStart explore phase");

      for (int pull = 0; pull < nExplore; ++pull)

      {

        int m = r.Next(0, nMachines); // pick a machine

        double pay = machines[m].Pay(); // play

        Console.Write("[" + pull.ToString().PadLeft(3) + "]  ");

        Console.WriteLine("selected machine " + m + ". pay = " +

          pay.ToString("F2").PadLeft(6));

        explorePays[m] += pay; // update

        totPay += pay;

      } // Explore

      int bestMach = BestIdx(explorePays);

      Console.WriteLine("\nBest machine found = " + bestMach);

      Console.WriteLine("\nStart exploit phase");

      for (int pull = 0; pull < nExploit; ++pull)

      {

        double pay = machines[bestMach].Pay();

        Console.Write("[" + pull.ToString().PadLeft(3) + "] ");

        Console.WriteLine("pay = " +

          pay.ToString("F2").PadLeft(6));

        totPay += pay; // accumulate

      } // Exploit

      return totPay / nPulls; // avg payout per pull

    } // ExploreExploit

    static int BestIdx(double[] pays)

    {

      // Index of array with largest value

      int result = 0;

      double maxVal = pays[0];

      for (int i = 0; i < pays.Length; ++i)

      {

        if (pays[i] > maxVal)

        {

          result = i;

          maxVal = pays[i];

        }

      }

      return result;

    }

  } // Program class

  public class Machine

  {

    public double mean; // Avg payout per pull

    public double sd; // Variability about the mean

    private Gaussian g; // Payout generator

    public Machine(double mean, double sd, int seed)

    {

      this.mean = mean;

      this.sd = sd;

      this.g = new Gaussian(mean, sd, seed);

    }

    public double Pay()

    {

      return this.g.Next();

    }

    // -----

    private class Gaussian

    {

      private Random r;

      private double mean;

      private double sd;

      public Gaussian(double mean, double sd, int seed)

      {

        this.r = new Random(seed);

        this.mean = mean;

        this.sd = sd;

      }

      public double Next()

      {

        double u1 = r.NextDouble();

        double u2 = r.NextDouble();

        double left = Math.Cos(2.0 * Math.PI * u1);

        double right = Math.Sqrt(-2.0 * Math.Log(u2));

        double z = left * right;

        return this.mean + (z * this.sd);

      }

    }

    // -----

  } // Machine

} // ns

The Demo Program

To create the demo program, I launched Visual Studio and selected the C# console application template. I named the project MultiBandit. I used Visual Studio 2015, but the demo has no significant .NET version dependencies so any version of Visual Studio will work.

After the template code loaded, in the Solution Explorer window I right-clicked on file Program.cs and renamed it to the more descriptive MultiBanditProgram.cs, then allowed Visual Studio to automatically rename class MultiBandit. At the top of the code in the editor window, I deleted all unnecessary using statements, leaving just the one reference to the top-level System namespace.

All the control logic is in the Main method, which calls method ExploreExploit. The demo has a program-defined Machine class, which in turn has a program-defined nested class named Gaussian.

After some introductory WriteLine statements, the demo creates three machines:

int nMachines = 3;

Machine[] machines = new Machine[nMachines];

machines[0] = new Machine(0.0, 1.0, 0);

machines[1] = new Machine(-0.5, 2.0, 1);

machines[2] = new Machine(0.1, 0.5, 2); 

The Machine class constructor accepts three arguments: the mean payout, the standard deviation of the payouts and a seed for random number generation. So, machine [1] will pay out -0.5 units per pull on average, where most of the payouts (roughly 68 percent) will be between -0.5 - 2.0 = -1.5 units and -0.5 + 2.0 = +1.5 units. Notice that unlike real slot machines, which pay out either zero or a positive amount, the demo machines can pay out a negative amount.

The statements that perform the explore-exploit algorithm on the three machines are:

int nPulls = 20;

double pctExplore = 0.40;

double avgPay = ExploreExploit(machines, pctExplore, nPulls);

double totPay = avgPay * nPulls;

Method ExploreExploit returns the average gain (or loss if negative) per pull after nPulls random events. Therefore, the total pay from the session is the number of pulls times the average pay per pull. An alternative design is for ExploreExploit to return the total pay instead of the average pay.

The regret is calculated like so:

double avgBase = machines[2].mean;

double totBase = avgBase * nPulls;

double regret = totBase - totPay;

Variable avgBase is the average payout per pull of the best machine, machine [2] = 0.1 units. So the total average expected payout over two pulls is 20 * 0.10 = 2.0 units.

Generating Gaussian Random Values

As I mentioned, each machine in the demo program pays out a value that follows a Gaussian (also called normal, or bell-shaped) distribution. For example, machine [0] has a mean payout of 0.0 with a standard deviation of 1.0 units. Using the code from the demo that generates Gaussian values, I wrote a short program to produce 100 random payouts from machine [0]. The results are shown in the graph in Figure 3.

Figure 3 100 Random Gaussian Values

Notice that the majority of generated values are close to the mean. The variability of the generated values is controlled by the value of the standard deviation. A larger standard deviation produces a larger spread of values. In a multi-armed bandit problem, one of the most important factors for all algorithms is the variability of machine payouts. If a machine has highly variable payouts, it becomes very difficult to evaluate the machine’s true average payout.

There are several algorithms that can be used to generate Gaussian distributed random values with a specified mean and standard deviation. My preferred method is called the Box-Muller algorithm. The Box-Muller algorithm first generates a uniformly distributed value (the kind produced by the .NET Math.Random class) and then uses some very clever mathematics to transform the uniform value into one that’s Gaussian distributed. There are several variations of Box-Muller. The demo program uses a variation that’s somewhat inefficient compared to other variations, but very simple.

In the demo program, class Gaussian is defined inside class Machine. In the Microsoft .NET Framework, nested class definitions are mainly a convenience for situations where the nested class is a utility class used by the outer containing class. If you’re porting this demo code to a non-.NET language, I recommend refactoring class Gaussian to a standalone class. The Gaussian class has a single constructor that accepts a mean payout, a standard deviation for the payout and a seed value for the underlying uniform random number generator.

The Machine Class

The demo program defines class Machine in a very simple way. There are three class fields:

public class Machine

{

  public double mean; // Avg payout per pull

  public double sd; // Variability about the mean

  private Gaussian g; // Payout generator 

...

The Machine class is primarily a wrapper around a Gaussian random number generator. There are many possible design alternatives but in general I prefer to keep my class definitions as simple as possible. Instead of using the standard deviation, as I’ve done here, some research articles use the mathematical variance. Standard deviation and variance are equivalent because the variance is just the standard deviation squared.

The Machine class has a single constructor that sets up the Gaussian generator:

public Machine(double mean, double sd, int seed)

{

  this.mean = mean;

  this.sd = sd;

  this.g = new Gaussian(mean, sd, seed);

The Machine class has a single public method that returns a Gaussian distributed random payout:

public double Pay()

{

  return this.g.Next();

}

An alternative to returning a Gaussian distributed payout is to return a uniformly distributed value between specified endpoints. For example, a machine could return a random value between -2.0 and + 3.0, where the average payout would be (-2 + 3) / 2 = +0.5 units.

The Explore-Exploit Implementation

The definition of method ExploreExploit begins with:

static double ExploreExploit(Machine[] machines, double pctExplore,

  int nPulls)

{

  int nMachines = machines.Length;

  Random r = new Random(2); // Which machine

  double[] explorePays = new double[nMachines];

  double totPay = 0.0;

...

The Random object r is used to select a machine at random during the explore phase. The array named explorePays holds the cumulative payouts for each machine during the explore phase. There’s only a need for a single variable, totPay, to hold the cumulative payout of the exploit phase because only a single machine is used.

Next, the number of pulls for the explore and exploit phases is calculated:

int nExplore = (int)(nPulls * pctExplore);

int nExploit = nPulls - nExplore;

It would be a mistake to calculate the number of exploit pulls using the term (1.0 - pctExplore) because of possible round off in the calculation.

The explore phase, without WriteLine statements, is:

for (int pull = 0; pull < nExplore; ++pull)

{

  int m = r.Next(0, nMachines); // Pick a machine

  double pay = machines[m].Pay(); // Play

  explorePays[m] += pay; // Update

  totPay += pay;

}

The Random.Next(int minVal, int maxVal) returns an integer value between minVal (inclusive) and maxVal (exclusive), so if nMachines = 3, r.Next(0, nMachines) returns a random integer value of 0, 1 or 2.

Next, the best machine found during the explore phase is determined, and used in the exploit phase:

int bestMach = BestIdx(explorePays);

for (int pull = 0; pull < nExploit; ++pull)

{

  double pay = machines[bestMach].Pay();

  totPay += pay; // Accumulate

}

Program-defined helper method BestIdx returns the index of the cell of its array argument that holds the largest value. There are dozens of variations of the multi-armed bandit problem. For example, some variations define the best machine found during the explore phase in a different way. In my cranky opinion, many of these variations are nothing more than solutions in search of a research problem.

Method ExploreExploit finishes by calculating and returning the average payout per pull over all nPulls plays:

  . . .

  return totPay / nPulls;

}

Alternative designs could be to return the total payout instead of the average payout, or return the total regret value, or return both the total payout and the average payout values in a two-cell array or as two out-parameters.

Other Algorithms

Research suggests that there’s no single algorithm that works best for all types of multi-armed bandit problems. Different algorithms have different strengths and weaknesses, depending mostly on the number of machines in the problem, the number of pulls available, and the variability of the payout distribution functions.

Dr. James McCaffrey works for Microsoft Research in Redmond, Wash. He has worked on several Microsoft products including Internet Explorer and Bing. Dr. McCaffrey can be reached at jammc@microsoft.com.

Thanks to the following Microsoft technical experts who reviewed this article: Miro Dudik and Kirk Olynyk


About the Author