Example 3: Naive Bayes Classifier Using User Supplied Probability Function

This example is the same as Example 1, using Fisher's (1936) Iris data to train a Naive Bayes classifier using 140 of the 150 continuous patterns, then classifies ten unknown plants using their sepal and petal measurements.

Instead of using the NormalDistribution class from the Imsl.Stat namespace, a user supplied normal (Gaussian) distribution is used. Rather than calculating the means and standard deviations from the data, as is done by the NormalDistribution 's Eval(double[]) method, the user supplied class requires the means and standard deviations in the class constructor. The output is the same as in Example 1, since the means and standard deviations in this example are simply rounded means and standard deviations of the actual data subset by target classifications.

using System;
using Imsl.DataMining;
using IProbabilityDistribution = Imsl.Stat.IProbabilityDistribution;

public class NaiveBayesClassifierEx3
{

   private static double[][] irisFisherData = new double[][]{
        new double[]{1.0, 5.1, 3.5, 1.4, .2}, 
        new double[]{1.0, 4.9, 3.0, 1.4, .2}, 
        new double[]{1.0, 4.7, 3.2, 1.3, .2}, 
        new double[]{1.0, 4.6, 3.1, 1.5, .2}, 
        new double[]{1.0, 5.0, 3.6, 1.4, .2}, 
        new double[]{1.0, 5.4, 3.9, 1.7, .4}, 
        new double[]{1.0, 4.6, 3.4, 1.4, .3}, 
        new double[]{1.0, 5.0, 3.4, 1.5, .2}, 
        new double[]{1.0, 4.4, 2.9, 1.4, .2}, 
        new double[]{1.0, 4.9, 3.1, 1.5, .1}, 
        new double[]{1.0, 5.4, 3.7, 1.5, .2}, 
        new double[]{1.0, 4.8, 3.4, 1.6, .2}, 
        new double[]{1.0, 4.8, 3.0, 1.4, .1}, 
        new double[]{1.0, 4.3, 3.0, 1.1, .1}, 
        new double[]{1.0, 5.8, 4.0, 1.2, .2}, 
        new double[]{1.0, 5.7, 4.4, 1.5, .4}, 
        new double[]{1.0, 5.4, 3.9, 1.3, .4}, 
        new double[]{1.0, 5.1, 3.5, 1.4, .3}, 
        new double[]{1.0, 5.7, 3.8, 1.7, .3}, 
        new double[]{1.0, 5.1, 3.8, 1.5, .3}, 
        new double[]{1.0, 5.4, 3.4, 1.7, .2}, 
        new double[]{1.0, 5.1, 3.7, 1.5, .4}, 
        new double[]{1.0, 4.6, 3.6, 1.0, .2}, 
        new double[]{1.0, 5.1, 3.3, 1.7, .5}, 
        new double[]{1.0, 4.8, 3.4, 1.9, .2}, 
        new double[]{1.0, 5.0, 3.0, 1.6, .2}, 
        new double[]{1.0, 5.0, 3.4, 1.6, .4}, 
        new double[]{1.0, 5.2, 3.5, 1.5, .2}, 
        new double[]{1.0, 5.2, 3.4, 1.4, .2}, 
        new double[]{1.0, 4.7, 3.2, 1.6, .2}, 
        new double[]{1.0, 4.8, 3.1, 1.6, .2}, 
        new double[]{1.0, 5.4, 3.4, 1.5, .4}, 
        new double[]{1.0, 5.2, 4.1, 1.5, .1}, 
        new double[]{1.0, 5.5, 4.2, 1.4, .2}, 
        new double[]{1.0, 4.9, 3.1, 1.5, .2}, 
        new double[]{1.0, 5.0, 3.2, 1.2, .2}, 
        new double[]{1.0, 5.5, 3.5, 1.3, .2}, 
        new double[]{1.0, 4.9, 3.6, 1.4, .1}, 
        new double[]{1.0, 4.4, 3.0, 1.3, .2}, 
        new double[]{1.0, 5.1, 3.4, 1.5, .2}, 
        new double[]{1.0, 5.0, 3.5, 1.3, .3}, 
        new double[]{1.0, 4.5, 2.3, 1.3, .3}, 
        new double[]{1.0, 4.4, 3.2, 1.3, .2}, 
        new double[]{1.0, 5.0, 3.5, 1.6, .6}, 
        new double[]{1.0, 5.1, 3.8, 1.9, .4}, 
        new double[]{1.0, 4.8, 3.0, 1.4, .3}, 
        new double[]{1.0, 5.1, 3.8, 1.6, .2}, 
        new double[]{1.0, 4.6, 3.2, 1.4, .2}, 
        new double[]{1.0, 5.3, 3.7, 1.5, .2}, 
        new double[]{1.0, 5.0, 3.3, 1.4, .2}, 
        new double[]{2.0, 7.0, 3.2, 4.7, 1.4}, 
        new double[]{2.0, 6.4, 3.2, 4.5, 1.5}, 
        new double[]{2.0, 6.9, 3.1, 4.9, 1.5}, 
        new double[]{2.0, 5.5, 2.3, 4.0, 1.3}, 
        new double[]{2.0, 6.5, 2.8, 4.6, 1.5}, 
        new double[]{2.0, 5.7, 2.8, 4.5, 1.3}, 
        new double[]{2.0, 6.3, 3.3, 4.7, 1.6}, 
        new double[]{2.0, 4.9, 2.4, 3.3, 1.0}, 
        new double[]{2.0, 6.6, 2.9, 4.6, 1.3}, 
        new double[]{2.0, 5.2, 2.7, 3.9, 1.4}, 
        new double[]{2.0, 5.0, 2.0, 3.5, 1.0}, 
        new double[]{2.0, 5.9, 3.0, 4.2, 1.5}, 
        new double[]{2.0, 6.0, 2.2, 4.0, 1.0}, 
        new double[]{2.0, 6.1, 2.9, 4.7, 1.4}, 
        new double[]{2.0, 5.6, 2.9, 3.6, 1.3}, 
        new double[]{2.0, 6.7, 3.1, 4.4, 1.4}, 
        new double[]{2.0, 5.6, 3.0, 4.5, 1.5}, 
        new double[]{2.0, 5.8, 2.7, 4.1, 1.0}, 
        new double[]{2.0, 6.2, 2.2, 4.5, 1.5}, 
        new double[]{2.0, 5.6, 2.5, 3.9, 1.1}, 
        new double[]{2.0, 5.9, 3.2, 4.8, 1.8}, 
        new double[]{2.0, 6.1, 2.8, 4.0, 1.3}, 
        new double[]{2.0, 6.3, 2.5, 4.9, 1.5}, 
        new double[]{2.0, 6.1, 2.8, 4.7, 1.2}, 
        new double[]{2.0, 6.4, 2.9, 4.3, 1.3}, 
        new double[]{2.0, 6.6, 3.0, 4.4, 1.4}, 
        new double[]{2.0, 6.8, 2.8, 4.8, 1.4}, 
        new double[]{2.0, 6.7, 3.0, 5.0, 1.7}, 
        new double[]{2.0, 6.0, 2.9, 4.5, 1.5}, 
        new double[]{2.0, 5.7, 2.6, 3.5, 1.0}, 
        new double[]{2.0, 5.5, 2.4, 3.8, 1.1}, 
        new double[]{2.0, 5.5, 2.4, 3.7, 1.0}, 
        new double[]{2.0, 5.8, 2.7, 3.9, 1.2}, 
        new double[]{2.0, 6.0, 2.7, 5.1, 1.6}, 
        new double[]{2.0, 5.4, 3.0, 4.5, 1.5}, 
        new double[]{2.0, 6.0, 3.4, 4.5, 1.6}, 
        new double[]{2.0, 6.7, 3.1, 4.7, 1.5}, 
        new double[]{2.0, 6.3, 2.3, 4.4, 1.3}, 
        new double[]{2.0, 5.6, 3.0, 4.1, 1.3}, 
        new double[]{2.0, 5.5, 2.5, 4.0, 1.3}, 
        new double[]{2.0, 5.5, 2.6, 4.4, 1.2}, 
        new double[]{2.0, 6.1, 3.0, 4.6, 1.4}, 
        new double[]{2.0, 5.8, 2.6, 4.0, 1.2}, 
        new double[]{2.0, 5.0, 2.3, 3.3, 1.0}, 
        new double[]{2.0, 5.6, 2.7, 4.2, 1.3}, 
        new double[]{2.0, 5.7, 3.0, 4.2, 1.2}, 
        new double[]{2.0, 5.7, 2.9, 4.2, 1.3}, 
        new double[]{2.0, 6.2, 2.9, 4.3, 1.3}, 
        new double[]{2.0, 5.1, 2.5, 3.0, 1.1}, 
        new double[]{2.0, 5.7, 2.8, 4.1, 1.3}, 
        new double[]{3.0, 6.3, 3.3, 6.0, 2.5}, 
        new double[]{3.0, 5.8, 2.7, 5.1, 1.9}, 
        new double[]{3.0, 7.1, 3.0, 5.9, 2.1}, 
        new double[]{3.0, 6.3, 2.9, 5.6, 1.8}, 
        new double[]{3.0, 6.5, 3.0, 5.8, 2.2}, 
        new double[]{3.0, 7.6, 3.0, 6.6, 2.1}, 
        new double[]{3.0, 4.9, 2.5, 4.5, 1.7}, 
        new double[]{3.0, 7.3, 2.9, 6.3, 1.8}, 
        new double[]{3.0, 6.7, 2.5, 5.8, 1.8}, 
        new double[]{3.0, 7.2, 3.6, 6.1, 2.5}, 
        new double[]{3.0, 6.5, 3.2, 5.1, 2.0}, 
        new double[]{3.0, 6.4, 2.7, 5.3, 1.9}, 
        new double[]{3.0, 6.8, 3.0, 5.5, 2.1}, 
        new double[]{3.0, 5.7, 2.5, 5.0, 2.0}, 
        new double[]{3.0, 5.8, 2.8, 5.1, 2.4}, 
        new double[]{3.0, 6.4, 3.2, 5.3, 2.3}, 
        new double[]{3.0, 6.5, 3.0, 5.5, 1.8}, 
        new double[]{3.0, 7.7, 3.8, 6.7, 2.2}, 
        new double[]{3.0, 7.7, 2.6, 6.9, 2.3}, 
        new double[]{3.0, 6.0, 2.2, 5.0, 1.5}, 
        new double[]{3.0, 6.9, 3.2, 5.7, 2.3}, 
        new double[]{3.0, 5.6, 2.8, 4.9, 2.0}, 
        new double[]{3.0, 7.7, 2.8, 6.7, 2.0}, 
        new double[]{3.0, 6.3, 2.7, 4.9, 1.8}, 
        new double[]{3.0, 6.7, 3.3, 5.7, 2.1}, 
        new double[]{3.0, 7.2, 3.2, 6.0, 1.8}, 
        new double[]{3.0, 6.2, 2.8, 4.8, 1.8}, 
        new double[]{3.0, 6.1, 3.0, 4.9, 1.8}, 
        new double[]{3.0, 6.4, 2.8, 5.6, 2.1}, 
        new double[]{3.0, 7.2, 3.0, 5.8, 1.6}, 
        new double[]{3.0, 7.4, 2.8, 6.1, 1.9}, 
        new double[]{3.0, 7.9, 3.8, 6.4, 2.0}, 
        new double[]{3.0, 6.4, 2.8, 5.6, 2.2}, 
        new double[]{3.0, 6.3, 2.8, 5.1, 1.5},
        new double[]{3.0, 6.1, 2.6, 5.6, 1.4}, 
        new double[]{3.0, 7.7, 3.0, 6.1, 2.3}, 
        new double[]{3.0, 6.3, 3.4, 5.6, 2.4}, 
        new double[]{3.0, 6.4, 3.1, 5.5, 1.8}, 
        new double[]{3.0, 6.0, 3.0, 4.8, 1.8}, 
        new double[]{3.0, 6.9, 3.1, 5.4, 2.1}, 
        new double[]{3.0, 6.7, 3.1, 5.6, 2.4}, 
        new double[]{3.0, 6.9, 3.1, 5.1, 2.3}, 
        new double[]{3.0, 5.8, 2.7, 5.1, 1.9}, 
        new double[]{3.0, 6.8, 3.2, 5.9, 2.3}, 
        new double[]{3.0, 6.7, 3.3, 5.7, 2.5}, 
        new double[]{3.0, 6.7, 3.0, 5.2, 2.3}, 
        new double[]{3.0, 6.3, 2.5, 5.0, 1.9}, 
        new double[]{3.0, 6.5, 3.0, 5.2, 2.0}, 
        new double[]{3.0, 6.2, 3.4, 5.4, 2.3}, 
        new double[]{3.0, 5.9, 3.0, 5.1, 1.8}};

    public static void  Main(System.String[] args)
    {
        
        /* Data corrections described in the KDD data mining archive */
        irisFisherData[34][4] = 0.1;
        irisFisherData[37][2] = 3.1;
        irisFisherData[37][3] = 1.5;
        
        /*  Train first 140 patterns of the iris Fisher Data */
        int[] irisClassificationData = 
            new int[irisFisherData.Length - 10];
        double[][] irisContinuousData = 
            new double[irisFisherData.Length - 10][];
        for (int i = 0; i < irisFisherData.Length - 10; i++)
        {
            irisContinuousData[i] = 
                new double[irisFisherData[0].Length - 1];
        }
        
        for (int i = 0; i < irisFisherData.Length - 10; i++)
        {
            irisClassificationData[i] = (int) irisFisherData[i][0] - 1;
            Array.Copy(irisFisherData[i], 1, irisContinuousData[i], 0, 
                irisFisherData[0].Length - 1);
        }
        
        int nNominal = 0; /* no nominal input attributes      */
        int nContinuous = 4; /* four continuous input attributes */
        int nClasses = 3; /* three classification categories  */
        
        NaiveBayesClassifier nbTrainer = 
            new NaiveBayesClassifier(nContinuous, nNominal, nClasses);
        
        double[][] means = new double[][]{
            new double[]{5.06, 5.94, 6.58}, 
            new double[]{3.42, 2.8, 2.97}, 
            new double[]{1.5, 4.3, 5.6}, 
            new double[]{0.25, 1.33, 2.1}
        };
        double[][] stdev = new double[][]{
            new double[]{0.35, 0.52, 0.64}, 
            new double[]{0.38, 0.3, 0.32}, 
            new double[]{0.17, 0.47, 0.55}, 
            new double[]{0.12, 0.198, 0.275}
        };
        
        for (int i = 0; i < nContinuous; i++)
        {
            IProbabilityDistribution[] pdf = 
                new IProbabilityDistribution[nClasses];
            for (int j = 0; j < nClasses; j++)
            {
                pdf[j] = new TestGaussFcn1(means[i][j], stdev[i][j]);
            }
            nbTrainer.CreateContinuousAttribute(pdf);
        }
        nbTrainer.Train(irisContinuousData, null, 
         irisClassificationData);
        
        int[][] classErrors = nbTrainer.GetTrainingErrors();
        
        System.Console.Out.WriteLine(
            "     Iris Classification Error Rates");
        System.Console.Out.WriteLine(
            "------------------------------------------------");
        System.Console.Out.WriteLine(
            "  Setosa   Versicolour   Virginica    |   Total");
        System.Console.Out.WriteLine("  " + classErrors[0][0] + "/" + 
            classErrors[0][1] + "         " + classErrors[1][0] + "/" +
            classErrors[1][1] + "        " + classErrors[2][0] + "/" + 
            classErrors[2][1] + "       |   " + classErrors[3][0] + 
            "/" + classErrors[3][1]);
        System.Console.Out.WriteLine(
            "------------------------------------------------\n\n\n");
        
        
        /*  Classify last 10 iris data patterns 
       * with the trained classifier 
       */
        double[] continuousInput = 
         new double[(irisFisherData[0].Length - 1)];
        double[] classifiedProbabilities = new double[nClasses];
        
        System.Console.Out.WriteLine(
            "Probabilities for Incorrect Classifications");
        System.Console.Out.WriteLine(" Predicted   ");
        System.Console.Out.WriteLine(
            "   Class     |  Class       |   P(0)     P(1)     P(2) ");
        System.Console.Out.WriteLine(
            "-------------------------------------------------------");
        for (int i = 0; i < 10; i++)
        {
            int targetClassification = (int) 
                irisFisherData[(irisFisherData.Length - 10) + i][0] - 1;
            Array.Copy(irisFisherData[(irisFisherData.Length - 10) + i], 
                1, continuousInput, 0, (irisFisherData[0].Length - 1));
            
            classifiedProbabilities = 
                nbTrainer.Probabilities(continuousInput, null);
            int classification = 
                nbTrainer.PredictClass(continuousInput, null);
            if (classification == 0)
            {
                System.Console.Out.Write("Setosa       |");
            }
            else if (classification == 1)
            {
                System.Console.Out.Write("Versicolour  |");
            }
            else if (classification == 2)
            {
                System.Console.Out.Write("Virginica    |");
            }
            else
            {
                System.Console.Out.Write("Missing      |");
            }
            if (targetClassification == 0)
            {
                System.Console.Out.Write(" Setosa       |");
            }
            else if (targetClassification == 1)
            {
                System.Console.Out.Write(" Versicolour  |");
            }
            else if (targetClassification == 2)
            {
                System.Console.Out.Write(" Virginica    |");
            }
            else
            {
                System.Console.Out.Write(" Missing      |");
            }
            for (int j = 0; j < nClasses; j++)
            {
                System.Object[] pArgs = new System.Object[] { 
                   (double)classifiedProbabilities[j] };
                System.Console.Out.Write("   {0, 2:f3} ", pArgs);
            }
            System.Console.Out.WriteLine("");
        }
    }
    
    public class TestGaussFcn1 : IProbabilityDistribution
    {
        virtual public System.Object[] GetParameters()
        {
            System.Object[] parms = new System.Object[2];
            parms[0] = this.mean;
            parms[1] = this.stdev;
            return parms;
        }
        
        private double mean;
        private double stdev;
        
        public TestGaussFcn1(double mean, double stdev)
        {
            this.mean = mean;
            this.stdev = stdev;
        }
        
        public virtual double[] Eval(double[] xData)
        {
            double[] pdf = new double[xData.Length];
            for (int i = 0; i < xData.Length; i++)
            {
                pdf[i] = Eval(xData[i], null);
            }
            return pdf;
        }
        
        public virtual double[] Eval(double[] xData, 
         System.Object[] Params)
        {
            double[] pdf = new double[xData.Length];
            for (int i = 0; i < xData.Length; i++)
            {
                pdf[i] = Eval(xData[i], Params);
            }
            return pdf;
        }
        
        public virtual double Eval(double xData, System.Object[] Params)
        {
            return GaussianPdf(xData, mean, stdev);
        }
        
        private double GaussianPdf(double x, double mean, double stdev)
        {
            double e, phi2, z, s;
            double sqrt_pi2 = 2.506628274631; /* sqrt(2*pi) */
            if (System.Double.IsNaN(x))
            {
                return System.Double.NaN;
            }
            if (System.Double.IsNaN(mean) || System.Double.IsNaN(stdev))
            {
                return System.Double.NaN;
            }
            else
            {
                z = x;
                z -= mean;
                s = stdev;
                phi2 = sqrt_pi2 * s;
                e = (- 0.5) * (z * z) / (s * s);
                return System.Math.Exp(e) / phi2;
            }
        }
    }
}

Output

The Naive Bayes classifier incorrectly classifies 6 of the 150 training patterns.

     Iris Classification Error Rates
------------------------------------------------
  Setosa   Versicolour   Virginica    |   Total
  0/50         2/50        4/40       |   6/140
------------------------------------------------



Probabilities for Incorrect Classifications
 Predicted   
   Class     |  Class       |   P(0)     P(1)     P(2) 
-------------------------------------------------------
Virginica    | Virginica    |   0.000    0.000    1.000 
Virginica    | Virginica    |   0.000    0.000    1.000 
Virginica    | Virginica    |   0.000    0.051    0.949 
Virginica    | Virginica    |   0.000    0.000    1.000 
Virginica    | Virginica    |   0.000    0.000    1.000 
Virginica    | Virginica    |   0.000    0.000    1.000 
Virginica    | Virginica    |   0.000    0.048    0.952 
Virginica    | Virginica    |   0.000    0.001    0.999 
Virginica    | Virginica    |   0.000    0.000    1.000 
Virginica    | Virginica    |   0.000    0.126    0.874 

Link to C# source.