Articles

AI-Class: Implementation of MDP Grid World from Week 5, Unit 9

In AI-Class on 09/11/2011 by Michael Madden

Overview

This is a very basic implementation of the 3×4 grid world as used in AI-Class Week 5, Unit 9.

It uses the version of the Value Iteration equation that is given at the end of Unit 9.15, with minor modifications to conform to the algorithm as specified in Russell & Norvig, “Artificial Intelligence a Modern Approach”, 3ed Figure 17.4 p653.

With it, you can test the various scenarios outlined by Sebastian Thrun in the class. At the top of the code, you will see the following variables you can adjust:

  • Ra: reward in non-terminal states (used to initialise the eward matrix R)
  • gamma: discount factor
  • pGood: probability of taking intended action (1-pGood is split equally between the two orthogonal actions).

Implementation

It works as follows:

  • Initialise R to -3 for non-terminal states, -100/+100 for the two terminal states, and 0 for the blocked state
  • Initialise V’ for all states to 0
  • Repeat until convergence:
    • V = V’
    • Loop over all states:
      • In each state s:
        • V'(s) := R(s) if a terminal state
        • V'(s) := The Bellman equation otherwise (computed using V(s), not V'(s))

Convergence is achieved when the max difference between V’ and V is less than the specified tolerance. I also set a maximum number of iterations.

Testing

This gives the same values for all scenarios shown by Sebastian Thrun, to the number of significant figures he shows. It also gives the same results as reported in the Russell & Norvig textbook for the settings they use, which are a little different.

The Code

You can copy this code here and paste it to a file called GridWorld2.java.

This code is may be used freely without restriction, though attribution of my authorship would be appreciated.

/**
 * AI-Class Unit 9 simple grid world Value Iteration.
 * By MichaelFromGalway, Nov 2011. 
 * Further details: see http://MichaelFromGalway.wordpress.com.
 *
 * This version (GridWorld2) uses simultaneous updates,
 * as shown in AIMA 3ed Figure 17.4 p653. 
 * 
 * This code is may be used freely without restriction,
 * though attribution of my authorship would be appreciated.
 */
public class GridWorld2 
{
	// General settings
	private static double Ra = -3;            // reward in non-terminal states (used to initialise r[][])
	private static double gamma = 1;          // discount factor
	private static double pGood = 0.8;        // probability of taking intended action
	private static double pBad = (1-pGood)/2; // 2 bad actions, split prob between them
	private static int N = 10000;             // max number of iterations of Value Iteration
	private static double deltaMin = 1e-9;    // convergence criterion for iteration

	// Main data structures
	private static double U[][];  // long-term utility
	private static double Up[][]; // UPrime, used in updates
	private static double R[][];  // instantaneous reward
	private static char  Pi[][];  // policy
	
	private static int rMax = 3, cMax = 4;
	
	public static void main(String[] args)
	{
		int r,c;
		double delta = 0;

		// policy: initially null
		Pi = new char[rMax][cMax]; 
		
		// initialise U'
		Up = new double[rMax][cMax]; // row, col
		for (r=0; r<rMax; r++) {
			for (c=0; c<cMax; c++) {
				Up[r][c] = 0;
			}
		}
		// Don't initialise U: will set U=Uprime in iterations
		U = new double[rMax][cMax];
		
		// initialise R: set everything to Ra and then override the terminal states
		R = new double[rMax][cMax]; // row, col
		for (r=0; r<rMax; r++) {
			for (c=0; c<cMax; c++) {
				R[r][c] = Ra;
			}
		}
		R[0][3] =  100;  // positive sink state
		R[1][3] = -100;  // negative sink state
		R[1][1] =    0;  // unreachable state
		
		
		// Now perform Value Iteration.
		int n = 0;
		do 
		{
			// Simultaneous updates: set U = Up, then compute changes in Up using prev value of U.
			duplicate(Up, U); // src, dest
			n++;
			delta = 0;
			for (r=0; r<rMax; r++) {
				for (c=0; c<cMax; c++) {
					updateUPrime(r, c);
					double diff = Math.abs(Up[r][c] - U[r][c]);
					if (diff > delta)
						delta = diff;
				}
			}
		} while (delta > deltaMin && n < N);
		
		// Display final matrix
		System.out.println("After " + n + " iterations:\n");
		for (r=0; r<rMax; r++) {
			for (c=0; c<cMax; c++) {
				System.out.printf("% 6.1f\t", U[r][c]);
			}
			System.out.print("\n");
		}

		// Before displaying the best policy, insert chars in the sinks and the non-moving block
		Pi[0][3] = '+'; Pi[1][3] = '-'; Pi[1][1] = '#';
		
		System.out.println("\nBest policy:\n");
		for (r=0; r<rMax; r++) {
			for (c=0; c<cMax; c++) {
				System.out.print(Pi[r][c] + "   ");
			}
			System.out.print("\n");
		}
	}
	
	public static void updateUPrime(int r, int c)
	{
		// IMPORTANT: this modifies the value of Up, using values in U.
		
		double a[] = new double[4]; // 4 actions
	
		// If at a sink state or unreachable state, use that value
		if ((r==0 && c==3) || (r==1 && c==3) || (r==1 && c==1)) {
			Up[r][c] = R[r][c];
		}
		else {
			a[0] = aNorth(r,c)*pGood + aWest(r,c)*pBad + aEast(r,c)*pBad;
			a[1] = aSouth(r,c)*pGood + aWest(r,c)*pBad + aEast(r,c)*pBad;
			a[2] = aWest(r,c)*pGood + aSouth(r,c)*pBad + aNorth(r,c)*pBad;
			a[3] = aEast(r,c)*pGood + aSouth(r,c)*pBad + aNorth(r,c)*pBad;
			
			int best = maxindex(a);
			
			Up[r][c] = R[r][c] + gamma * a[best];
			
			// update policy
			Pi[r][c] = (best==0 ? 'N' : (best==1 ? 'S' : (best==2 ? 'W': 'E')));
		}
	}
	
	public static int maxindex(double a[]) 
	{
		int b=0;
		for (int i=1; i<a.length; i++)
			b = (a[b] > a[i]) ? b : i;
		return b;
	}
	
	public static double aNorth(int r, int c)
	{
		// can't go north if at row 0 or if in cell (2,1)
		if ((r==0) || (r==2 && c==1))
			return U[r][c];
		return U[r-1][c];
	}

	public static double aSouth(int r, int c)
	{
		// can't go south if at row 2 or if in cell (0,1)
		if ((r==rMax-1) || (r==0 && c==1))
			return U[r][c];
		return U[r+1][c];
	}

	public static double aWest(int r, int c)
	{
		// can't go west if at col 0 or if in cell (1,2)
		if ((c==0) || (r==1 && c==2))
			return U[r][c];
		return U[r][c-1];
	}

	public static double aEast(int r, int c)
	{
		// can't go east if at col 3 or if in cell (1,0)
		if ((c==cMax-1) || (r==1 && c==0))
			return U[r][c];
		return U[r][c+1];
	}
	
	public static void duplicate(double[][]src, double[][]dst)
	{
		// Copy data from src to dst
		for (int x=0; x<src.length; x++) {
			for (int y=0; y<src[x].length; y++) {
				dst[x][y] = src[x][y];
			}
		}
	}
}

Articles

AI-Class: Worked Example of Cheapest First Search

In AI-Class on 17/10/2011 by Michael Madden

While the lectures are great in general, I thought that the section on Uniform Cost Search (2.12-2.17) became slightly confusing, with the corrections and everything. Therefore, I have made my own video (12 minutes) walking through the example.

I also draw the search tree that arises from searching the graph, because I think it helps to clarify what the frontier is. At the end I address the issue of counting the number of nodes expanded, including the start and goal nodes.

I hope some of you find it helpful. Comments welcome!

Articles

AI-Class: Implementation of Gradient Descent from Week 3, Unit 5.31

In AI-Class on 15/10/2011 by Michael Madden

Overview

This is a very simple but correct implementation of the Gradient Descent algorithm described bySebastian Thrun in Week 3, Unit 5.31.

The only difference between my implementation and the formulas he presents is that I add a factor of 1/m to the formula for the gradients:

  • ∂L/∂w1 = -(2/m) Σ (yj – w1 xj – w0) xj
  • ∂L/∂w0 = -(2/m) Σ (yj – w1 xj – w0)

Here, m is the number of samples. This is a common form of the formula. Since m is a constant and these terms are multiplied by another constant, the learning rate α, it could be ignored. The benefit of including this term explicitly is that if you use a larger training set of similar data, you don’t have to adjust the learning rate to get similar behaviour.

Implementation Notes

  • The training data is specified in the arrays x and y at the top of the code.
  • The algorithm settings can be found at the start of the main method.
  • This program uses JMathPlot, a simple package for producing Matlab-style graphs:
    You can get it at http://code.google.com/p/jmathplot or just delete the blocks of code that use it.

Testing

I have verified that this program gives the correct answers for the examples shown in AI-Class Unit 5, and for some other test cases of my own.

The Code

You can copy this code here and paste it to a file called GradDescent.java.

This code is may be used freely without restriction, though attribution of my authorship would be appreciated.

/**
 * Implementation of gradient descent alg.
 * Ref: AI-Class.com, Week 3, Topic 5.31.
 *
 * By MichaelFromGalway, Oct 2011.
 * Further details: see http://MichaelFromGalway.wordpress.com.
 *
 * This code is may be used freely without restriction,
 * though attribution of my authorship would be appreciated.
 *
 */

import javax.swing.JFrame;

// This program uses JMathPlot, a package for producing Matlab-style graphs
//   Get it at http://code.google.com/p/jmathplot/
//   or just delete the blocks of code that use it.
import org.math.plot.*;

public class GradDescent
{
   // data taken from one of the worked examples
   // Note: for this data, correct answers are w0=0.5, w1=0.9.
   static double[] x = {2, 4, 6, 8};
   static double[] y = {2, 5, 5, 8};

   static int trendline; // handle used for adding/removing trendline

   // parameters
   static double w0;
   static double w1;

   public static void main(String[] args)
   {
      // Algorithm settings
      double alpha = 0.01;  // learning rate
      double tol = 1e-11;   // tolerance to determine convergence
      int maxiter = 9000;   // maximum number of iterations (in case convergence is not reached)
      int dispiter = 100;   // interval for displaying results during iterations

      // Other parameters
      double delta0, delta1;
      int iters = 0;

      // initial guesses for parameters
      w0 = 0;
      w1 = 0;

      // keep track of results so I can plot convergence
      double[] w0plot = new double[maxiter+1];
      double[] w1plot = new double[maxiter+1];
      double[] tplot = new double[maxiter+1];

      // plot the data
      // create a PlotPanel
      Plot2DPanel plot = new Plot2DPanel();

      // add a line plot to the PlotPanel
      plot.addLinePlot("X-Y", x, y);

      // show the trendline
      addTrendline(plot, false);

      // put the PlotPanel in a JFrame, as a JPanel
      JFrame frame = new JFrame("Original X-Y Data");
      frame.setContentPane(plot);
      frame.setSize(600, 600);
      frame.setVisible(true);

      do {
         delta1 = alpha * dLdw1();
         delta0 = alpha * dLdw0();

         // Store data for plotting
         tplot[iters] = iters;
         w0plot[iters] = w0;
         w1plot[iters] = w1;

         iters++;
         w1 -= delta1;
         w0 -= delta0;

         // display progress
         if (iters % dispiter == 0) {
            System.out.println("Iteration " + iters + ": w0=" + w0 + " - " + delta0 + ", w1=" + w1 + " - "+ delta1);
            addTrendline(plot, true);
         }

         if (iters > maxiter) break;
      } while (Math.abs(delta1) > tol || Math.abs(delta0) > tol);

      System.out.println("\nConvergence after " + iters + " iterations: w0=" + w0 + ", w1=" + w1);

      addTrendline(plot, false);

      // Before plotting the data, extract an array of the right size from it
      double[] w0plot2 = new double[iters];
      double[] w1plot2 = new double[iters];
      double[] tplot2 = new double[iters];
      System.arraycopy(w0plot, 0, w0plot2, 0, iters);
      System.arraycopy(w1plot, 0, w1plot2, 0, iters);
      System.arraycopy(tplot, 0, tplot2, 0, iters);

      // Plot the convergence of data
      Plot2DPanel convPlot = new Plot2DPanel();

      // add a line plot to the PlotPanel
      convPlot.addLinePlot("w0", tplot2, w0plot2);
      convPlot.addLinePlot("w1", tplot2, w1plot2);

      // put the PlotPanel in a JFrame, as a JPanel
      JFrame frame2 = new JFrame("Convergence of parameters over time");
      frame2.setContentPane(convPlot);
      frame2.setSize(600, 600);
      frame2.setVisible(true);

      // Commented out System.exit() so that plots don't vanish
      // System.exit(0);
   }

   public static double dLdw1()
   {
      double sum = 0;

      for (int j=0; j<x.length; j++) {
         sum += (y[j] - f(x[j])) * x[j];
      }
      return -2 * sum / x.length;
   }

   public static double dLdw0()
   {
      double sum = 0;

      for (int j=0; j<x.length; j++) {
         sum += y[j] - f(x[j]);
      }
      return -2 * sum / x.length;
   }

   public static double f(double x)
   {
      return w1*x + w0;
   }

   public static void addTrendline(Plot2DPanel plot, boolean removePrev)
   {
      if (removePrev)
      plot.removePlot(trendline);

      double[] yEnd = new double[x.length];
      for (int i=0; i<x.length; i++)
      yEnd[i] = f(x[i]);
      trendline = plot.addLinePlot("final", x, yEnd);
   }
}