This example trains a 3-layer network using Fisher's Iris data with four continuous input attributes and three output classifications. This is perhaps the best known database to be found in the pattern recognition literature. Fisher's paper is a classic in the field. The data set contains 3 classes of 50 instances each, where each class refers to a type of iris plant.
The structure of the network consists of four input nodes and three layers, with four perceptrons in the first hidden layer, three perceptrons in the second hidden layer and three in the output layer.
The four input attributes represent
The output attribute represents the class of the iris plant and are encoded using binary encoding.
There are a total of 46 weights in this network, including the bias weights. All hidden layers use the logistic activation function. Since the target output is multi-classification the softmax activation function is used in the output layer and the MultiClassification
error function class is used by the trainer. The error class MultiClassification
combines the cross-entropy error claculations and the softmax function.
import com.imsl.datamining.neural.*; import com.imsl.math.*; import java.io.*; import java.util.logging.*; //*************************************************************************** // Three Layer Feed-Forward Network with 4 inputs, all // continuous, and 3 classification categories. // // new classification training_ex5.c // // This is perhaps the best known database to be found in the pattern // recognition literature. Fisher's paper is a classic in the field. // The data set contains 3 classes of 50 instances each, // where each class refers to a type of iris plant. One class is // linearly separable from the other 2; the latter are NOT linearly // separable from each other. // // Predicted attribute: class of iris plant. // 1=Iris Setosa, 2=Iris Versicolour, and 3=Iris Virginica // // Input Attributes (4 Continuous Attributes) // X1: Sepal length, X2: Sepal width, X3: Petal length, // and X4: Petal width //*************************************************************************** public class MultiClassificationEx1 implements Serializable { private static int nObs = 150; // number of training patterns private static int nInputs = 4; // 9 nominal coded as 0=x, 1=o, 2=blank private static int nOutputs = 3; // one continuous output (nClasses=2) private static boolean trace = true; // Turns on/off training log // irisData[]: The raw data matrix. This is a 2-D matrix with 150 rows // and 5 columns. The first 4 columns are the continuous // input attributes and the 5th column is the // classification category (1-3). These data contain no // categorical input attributes. private static double[][] irisData = { {5.1, 3.5, 1.4, 0.2, 1}, {4.9, 3.0, 1.4, 0.2, 1}, {4.7, 3.2, 1.3, 0.2, 1}, {4.6, 3.1, 1.5, 0.2, 1}, {5.0, 3.6, 1.4, 0.2, 1}, {5.4, 3.9, 1.7, 0.4, 1}, {4.6, 3.4, 1.4, 0.3, 1}, {5.0, 3.4, 1.5, 0.2, 1}, {4.4, 2.9, 1.4, 0.2, 1}, {4.9, 3.1, 1.5, 0.1, 1}, {5.4, 3.7, 1.5, 0.2, 1}, {4.8, 3.4, 1.6, 0.2, 1}, {4.8, 3.0, 1.4, 0.1, 1}, {4.3, 3.0, 1.1, 0.1, 1}, {5.8, 4.0, 1.2, 0.2, 1}, {5.7, 4.4, 1.5, 0.4, 1}, {5.4, 3.9, 1.3, 0.4, 1}, {5.1, 3.5, 1.4, 0.3, 1}, {5.7, 3.8, 1.7, 0.3, 1}, {5.1, 3.8, 1.5, 0.3, 1}, {5.4, 3.4, 1.7, 0.2, 1}, {5.1, 3.7, 1.5, 0.4, 1}, {4.6, 3.6, 1.0, 0.2, 1}, {5.1, 3.3, 1.7, 0.5, 1}, {4.8, 3.4, 1.9, 0.2, 1}, {5.0, 3.0, 1.6, 0.2, 1}, {5.0, 3.4, 1.6, 0.4, 1}, {5.2, 3.5, 1.5, 0.2, 1}, {5.2, 3.4, 1.4, 0.2, 1}, {4.7, 3.2, 1.6, 0.2, 1}, {4.8, 3.1, 1.6, 0.2, 1}, {5.4, 3.4, 1.5, 0.4, 1}, {5.2, 4.1, 1.5, 0.1, 1}, {5.5, 4.2, 1.4, 0.2, 1}, {4.9, 3.1, 1.5, 0.1, 1}, {5.0, 3.2, 1.2, 0.2, 1}, {5.5, 3.5, 1.3, 0.2, 1}, {4.9, 3.1, 1.5, 0.1, 1}, {4.4, 3.0, 1.3, 0.2, 1}, {5.1, 3.4, 1.5, 0.2, 1}, {5.0, 3.5, 1.3, 0.3, 1}, {4.5, 2.3, 1.3, 0.3, 1}, {4.4, 3.2, 1.3, 0.2, 1}, {5.0, 3.5, 1.6, 0.6, 1}, {5.1, 3.8, 1.9, 0.4, 1}, {4.8, 3.0, 1.4, 0.3, 1}, {5.1, 3.8, 1.6, 0.2, 1}, {4.6, 3.2, 1.4, 0.2, 1}, {5.3, 3.7, 1.5, 0.2, 1}, {5.0, 3.3, 1.4, 0.2, 1}, {7.0, 3.2, 4.7, 1.4, 2}, {6.4, 3.2, 4.5, 1.5, 2}, {6.9, 3.1, 4.9, 1.5, 2}, {5.5, 2.3, 4.0, 1.3, 2}, {6.5, 2.8, 4.6, 1.5, 2}, {5.7, 2.8, 4.5, 1.3, 2}, {6.3, 3.3, 4.7, 1.6, 2}, {4.9, 2.4, 3.3, 1.0, 2}, {6.6, 2.9, 4.6, 1.3, 2}, {5.2, 2.7, 3.9, 1.4, 2}, {5.0, 2.0, 3.5, 1.0, 2}, {5.9, 3.0, 4.2, 1.5, 2}, {6.0, 2.2, 4.0, 1.0, 2}, {6.1, 2.9, 4.7, 1.4, 2}, {5.6, 2.9, 3.6, 1.3, 2}, {6.7, 3.1, 4.4, 1.4, 2}, {5.6, 3.0, 4.5, 1.5, 2}, {5.8, 2.7, 4.1, 1.0, 2}, {6.2, 2.2, 4.5, 1.5, 2}, {5.6, 2.5, 3.9, 1.1, 2}, {5.9, 3.2, 4.8, 1.8, 2}, {6.1, 2.8, 4.0, 1.3, 2}, {6.3, 2.5, 4.9, 1.5, 2}, {6.1, 2.8, 4.7, 1.2, 2}, {6.4, 2.9, 4.3, 1.3, 2}, {6.6, 3.0, 4.4, 1.4, 2}, {6.8, 2.8, 4.8, 1.4, 2}, {6.7, 3.0, 5.0, 1.7, 2}, {6.0, 2.9, 4.5, 1.5, 2}, {5.7, 2.6, 3.5, 1.0, 2}, {5.5, 2.4, 3.8, 1.1, 2}, {5.5, 2.4, 3.7, 1.0, 2}, {5.8, 2.7, 3.9, 1.2, 2}, {6.0, 2.7, 5.1, 1.6, 2}, {5.4, 3.0, 4.5, 1.5, 2}, {6.0, 3.4, 4.5, 1.6, 2}, {6.7, 3.1, 4.7, 1.5, 2}, {6.3, 2.3, 4.4, 1.3, 2}, {5.6, 3.0, 4.1, 1.3, 2}, {5.5, 2.5, 4.0, 1.3, 2}, {5.5, 2.6, 4.4, 1.2, 2}, {6.1, 3.0, 4.6, 1.4, 2}, {5.8, 2.6, 4.0, 1.2, 2}, {5.0, 2.3, 3.3, 1.0, 2}, {5.6, 2.7, 4.2, 1.3, 2}, {5.7, 3.0, 4.2, 1.2, 2}, {5.7, 2.9, 4.2, 1.3, 2}, {6.2, 2.9, 4.3, 1.3, 2}, {5.1, 2.5, 3.0, 1.1, 2}, {5.7, 2.8, 4.1, 1.3, 2}, {6.3, 3.3, 6.0, 2.5, 3}, {5.8, 2.7, 5.1, 1.9, 3}, {7.1, 3.0, 5.9, 2.1, 3}, {6.3, 2.9, 5.6, 1.8, 3}, {6.5, 3.0, 5.8, 2.2, 3}, {7.6, 3.0, 6.6, 2.1, 3}, {4.9, 2.5, 4.5, 1.7, 3}, {7.3, 2.9, 6.3, 1.8, 3}, {6.7, 2.5, 5.8, 1.8, 3}, {7.2, 3.6, 6.1, 2.5, 3}, {6.5, 3.2, 5.1, 2.0, 3}, {6.4, 2.7, 5.3, 1.9, 3}, {6.8, 3.0, 5.5, 2.1, 3}, {5.7, 2.5, 5.0, 2.0, 3}, {5.8, 2.8, 5.1, 2.4, 3}, {6.4, 3.2, 5.3, 2.3, 3}, {6.5, 3.0, 5.5, 1.8, 3}, {7.7, 3.8, 6.7, 2.2, 3}, {7.7, 2.6, 6.9, 2.3, 3}, {6.0, 2.2, 5.0, 1.5, 3}, {6.9, 3.2, 5.7, 2.3, 3}, {5.6, 2.8, 4.9, 2.0, 3}, {7.7, 2.8, 6.7, 2.0, 3}, {6.3, 2.7, 4.9, 1.8, 3}, {6.7, 3.3, 5.7, 2.1, 3}, {7.2, 3.2, 6.0, 1.8, 3}, {6.2, 2.8, 4.8, 1.8, 3}, {6.1, 3.0, 4.9, 1.8, 3}, {6.4, 2.8, 5.6, 2.1, 3}, {7.2, 3.0, 5.8, 1.6, 3}, {7.4, 2.8, 6.1, 1.9, 3}, {7.9, 3.8, 6.4, 2.0, 3}, {6.4, 2.8, 5.6, 2.2, 3}, {6.3, 2.8, 5.1, 1.5, 3}, {6.1, 2.6, 5.6, 1.4, 3}, {7.7, 3.0, 6.1, 2.3, 3}, {6.3, 3.4, 5.6, 2.4, 3}, {6.4, 3.1, 5.5, 1.8, 3}, {6.0, 3.0, 4.8, 1.8, 3}, {6.9, 3.1, 5.4, 2.1, 3}, {6.7, 3.1, 5.6, 2.4, 3}, {6.9, 3.1, 5.1, 2.3, 3}, {5.8, 2.7, 5.1, 1.9, 3}, {6.8, 3.2, 5.9, 2.3, 3}, {6.7, 3.3, 5.7, 2.5, 3}, {6.7, 3.0, 5.2, 2.3, 3}, {6.3, 2.5, 5.0, 1.9, 3}, {6.5, 3.0, 5.2, 2.0, 3}, {6.2, 3.4, 5.4, 2.3, 3}, {5.9, 3.0, 5.1, 1.8, 3} }; public static void main(String[] args) throws Exception { double xData[][] = new double[nObs][nInputs]; int yData[] = new int[nObs]; for (int i = 0; i < nObs; i++) { for (int j = 0; j < nInputs; j++) { xData[i][j] = irisData[i][j]; } yData[i] = (int) irisData[i][4]; } // Create network FeedForwardNetwork network = new FeedForwardNetwork(); network.getInputLayer().createInputs(nInputs); network.createHiddenLayer(). createPerceptrons(4, Activation.LOGISTIC, 0.0); network.createHiddenLayer(). createPerceptrons(3, Activation.LOGISTIC, 0.0); network.getOutputLayer(). createPerceptrons(nOutputs, Activation.SOFTMAX, 0.0); network.linkAll(); MultiClassification classification = new MultiClassification(network); // Create trainer QuasiNewtonTrainer trainer = new QuasiNewtonTrainer(); trainer.setError(classification.getError()); trainer.setMaximumTrainingIterations(1000); // If tracing is requested setup training logger if (trace) { Handler handler = new FileHandler("ClassificationNetworkTraining.log"); Logger logger = Logger.getLogger("com.imsl.datamining.neural"); logger.setLevel(Level.FINEST); logger.addHandler(handler); handler.setFormatter(QuasiNewtonTrainer.getFormatter()); } // Train Network long t0 = System.currentTimeMillis(); classification.train(trainer, xData, yData); // Display Network Errors double stats[] = classification.computeStatistics(xData, yData); System.out.println("***********************************************"); System.out.println("--> Cross-entropy error: " + (float) stats[0]); System.out.println("--> Classification error rate: " + (float) stats[1]); System.out.println("***********************************************"); System.out.println(""); double weight[] = network.getWeights(); double gradient[] = trainer.getErrorGradient(); double wg[][] = new double[weight.length][2]; for (int i = 0; i < weight.length; i++) { wg[i][0] = weight[i]; wg[i][1] = gradient[i]; } PrintMatrixFormat pmf = new PrintMatrixFormat(); pmf.setNumberFormat(new java.text.DecimalFormat("0.000000")); pmf.setColumnLabels(new String[]{"Weights", "Gradients"}); new PrintMatrix().print(pmf, wg); double report[][] = new double[nObs][nInputs + 2]; for (int i = 0; i < nObs; i++) { for (int j = 0; j < nInputs; j++) { report[i][j] = xData[i][j]; } report[i][nInputs] = irisData[i][4]; report[i][nInputs + 1] = classification.predictedClass(xData[i]); } pmf = new PrintMatrixFormat(); pmf.setColumnLabels(new String[]{ "Sepal Length", "Sepal Width", "Petal Length", "Petal Width", "Expected", "Predicted"} ); new PrintMatrix("Forecast").print(pmf, report); // ****************************************************************** // DISPLAY CLASSIFICATION STATISTICS // ****************************************************************** double statsClass[] = classification.computeStatistics(xData, yData); // Display Network Errors System.out.println("***********************************************"); System.out.println("--> Cross-Entropy Error: " + (float) statsClass[0]); System.out.println("--> Classification Error: " + (float) statsClass[1]); System.out.println("***********************************************"); System.out.println(""); long t1 = System.currentTimeMillis(); double time = t1 - t0; time = time / 1000; System.out.println("****************Time: " + time); System.out.println("Cross-Entropy Error Value = " + trainer.getErrorValue()); } }
*********************************************** --> Cross-entropy error: 4.653512 --> Classification error rate: 0.006666667 *********************************************** Weights Gradients 0 -42.381828 0.030801 1 193.055878 0.000000 2 -30.384656 0.000000 3 95.352605 0.000000 4 -33.692976 0.012782 5 -282.844912 0.000000 6 -422.581218 0.000000 7 -317.896968 0.000000 8 60.505928 0.023458 9 -94.286590 0.000000 10 109.828939 0.000000 11 -168.351914 0.000000 12 42.250439 0.008464 13 691.012987 0.000000 14 602.474794 0.000000 15 694.122349 0.000000 16 -3.036409 -1.035514 17 -151.673802 -5.466157 18 75.471899 0.000003 19 3.479083 -1.035303 20 46.614896 -5.466182 21 56.347637 0.032349 22 153.010906 -1.035303 23 204.597137 -5.466182 24 21.557752 0.032349 25 64.075577 -1.035303 26 67.599168 -5.466182 27 78.904088 0.032349 28 -5672.118029 0.000000 29 1244.905099 0.010077 30 4428.212930 -0.010077 31 -5600.671740 0.000000 32 1746.525173 0.004710 33 3855.146566 -0.004710 34 -5562.390356 0.000000 35 1230.760279 0.010431 36 4332.630076 -0.010431 37 -15.417798 0.004103 38 328.841061 0.000000 39 323.847338 0.000000 40 306.946067 0.000000 41 -214.124377 -1.035303 42 -167.320245 -5.466182 43 -156.514239 0.032349 44 13108.354735 0.000000 45 -2985.466557 0.010413 46 -10122.888178 -0.010413 Forecast Sepal Length Sepal Width Petal Length Petal Width Expected Predicted 0 5.1 3.5 1.4 0.2 1 1 1 4.9 3 1.4 0.2 1 1 2 4.7 3.2 1.3 0.2 1 1 3 4.6 3.1 1.5 0.2 1 1 4 5 3.6 1.4 0.2 1 1 5 5.4 3.9 1.7 0.4 1 1 6 4.6 3.4 1.4 0.3 1 1 7 5 3.4 1.5 0.2 1 1 8 4.4 2.9 1.4 0.2 1 1 9 4.9 3.1 1.5 0.1 1 1 10 5.4 3.7 1.5 0.2 1 1 11 4.8 3.4 1.6 0.2 1 1 12 4.8 3 1.4 0.1 1 1 13 4.3 3 1.1 0.1 1 1 14 5.8 4 1.2 0.2 1 1 15 5.7 4.4 1.5 0.4 1 1 16 5.4 3.9 1.3 0.4 1 1 17 5.1 3.5 1.4 0.3 1 1 18 5.7 3.8 1.7 0.3 1 1 19 5.1 3.8 1.5 0.3 1 1 20 5.4 3.4 1.7 0.2 1 1 21 5.1 3.7 1.5 0.4 1 1 22 4.6 3.6 1 0.2 1 1 23 5.1 3.3 1.7 0.5 1 1 24 4.8 3.4 1.9 0.2 1 1 25 5 3 1.6 0.2 1 1 26 5 3.4 1.6 0.4 1 1 27 5.2 3.5 1.5 0.2 1 1 28 5.2 3.4 1.4 0.2 1 1 29 4.7 3.2 1.6 0.2 1 1 30 4.8 3.1 1.6 0.2 1 1 31 5.4 3.4 1.5 0.4 1 1 32 5.2 4.1 1.5 0.1 1 1 33 5.5 4.2 1.4 0.2 1 1 34 4.9 3.1 1.5 0.1 1 1 35 5 3.2 1.2 0.2 1 1 36 5.5 3.5 1.3 0.2 1 1 37 4.9 3.1 1.5 0.1 1 1 38 4.4 3 1.3 0.2 1 1 39 5.1 3.4 1.5 0.2 1 1 40 5 3.5 1.3 0.3 1 1 41 4.5 2.3 1.3 0.3 1 1 42 4.4 3.2 1.3 0.2 1 1 43 5 3.5 1.6 0.6 1 1 44 5.1 3.8 1.9 0.4 1 1 45 4.8 3 1.4 0.3 1 1 46 5.1 3.8 1.6 0.2 1 1 47 4.6 3.2 1.4 0.2 1 1 48 5.3 3.7 1.5 0.2 1 1 49 5 3.3 1.4 0.2 1 1 50 7 3.2 4.7 1.4 2 2 51 6.4 3.2 4.5 1.5 2 2 52 6.9 3.1 4.9 1.5 2 2 53 5.5 2.3 4 1.3 2 2 54 6.5 2.8 4.6 1.5 2 2 55 5.7 2.8 4.5 1.3 2 2 56 6.3 3.3 4.7 1.6 2 2 57 4.9 2.4 3.3 1 2 2 58 6.6 2.9 4.6 1.3 2 2 59 5.2 2.7 3.9 1.4 2 2 60 5 2 3.5 1 2 2 61 5.9 3 4.2 1.5 2 2 62 6 2.2 4 1 2 2 63 6.1 2.9 4.7 1.4 2 2 64 5.6 2.9 3.6 1.3 2 2 65 6.7 3.1 4.4 1.4 2 2 66 5.6 3 4.5 1.5 2 2 67 5.8 2.7 4.1 1 2 2 68 6.2 2.2 4.5 1.5 2 2 69 5.6 2.5 3.9 1.1 2 2 70 5.9 3.2 4.8 1.8 2 2 71 6.1 2.8 4 1.3 2 2 72 6.3 2.5 4.9 1.5 2 2 73 6.1 2.8 4.7 1.2 2 2 74 6.4 2.9 4.3 1.3 2 2 75 6.6 3 4.4 1.4 2 2 76 6.8 2.8 4.8 1.4 2 2 77 6.7 3 5 1.7 2 2 78 6 2.9 4.5 1.5 2 2 79 5.7 2.6 3.5 1 2 2 80 5.5 2.4 3.8 1.1 2 2 81 5.5 2.4 3.7 1 2 2 82 5.8 2.7 3.9 1.2 2 2 83 6 2.7 5.1 1.6 2 3 84 5.4 3 4.5 1.5 2 2 85 6 3.4 4.5 1.6 2 2 86 6.7 3.1 4.7 1.5 2 2 87 6.3 2.3 4.4 1.3 2 2 88 5.6 3 4.1 1.3 2 2 89 5.5 2.5 4 1.3 2 2 90 5.5 2.6 4.4 1.2 2 2 91 6.1 3 4.6 1.4 2 2 92 5.8 2.6 4 1.2 2 2 93 5 2.3 3.3 1 2 2 94 5.6 2.7 4.2 1.3 2 2 95 5.7 3 4.2 1.2 2 2 96 5.7 2.9 4.2 1.3 2 2 97 6.2 2.9 4.3 1.3 2 2 98 5.1 2.5 3 1.1 2 2 99 5.7 2.8 4.1 1.3 2 2 100 6.3 3.3 6 2.5 3 3 101 5.8 2.7 5.1 1.9 3 3 102 7.1 3 5.9 2.1 3 3 103 6.3 2.9 5.6 1.8 3 3 104 6.5 3 5.8 2.2 3 3 105 7.6 3 6.6 2.1 3 3 106 4.9 2.5 4.5 1.7 3 3 107 7.3 2.9 6.3 1.8 3 3 108 6.7 2.5 5.8 1.8 3 3 109 7.2 3.6 6.1 2.5 3 3 110 6.5 3.2 5.1 2 3 3 111 6.4 2.7 5.3 1.9 3 3 112 6.8 3 5.5 2.1 3 3 113 5.7 2.5 5 2 3 3 114 5.8 2.8 5.1 2.4 3 3 115 6.4 3.2 5.3 2.3 3 3 116 6.5 3 5.5 1.8 3 3 117 7.7 3.8 6.7 2.2 3 3 118 7.7 2.6 6.9 2.3 3 3 119 6 2.2 5 1.5 3 3 120 6.9 3.2 5.7 2.3 3 3 121 5.6 2.8 4.9 2 3 3 122 7.7 2.8 6.7 2 3 3 123 6.3 2.7 4.9 1.8 3 3 124 6.7 3.3 5.7 2.1 3 3 125 7.2 3.2 6 1.8 3 3 126 6.2 2.8 4.8 1.8 3 3 127 6.1 3 4.9 1.8 3 3 128 6.4 2.8 5.6 2.1 3 3 129 7.2 3 5.8 1.6 3 3 130 7.4 2.8 6.1 1.9 3 3 131 7.9 3.8 6.4 2 3 3 132 6.4 2.8 5.6 2.2 3 3 133 6.3 2.8 5.1 1.5 3 3 134 6.1 2.6 5.6 1.4 3 3 135 7.7 3 6.1 2.3 3 3 136 6.3 3.4 5.6 2.4 3 3 137 6.4 3.1 5.5 1.8 3 3 138 6 3 4.8 1.8 3 3 139 6.9 3.1 5.4 2.1 3 3 140 6.7 3.1 5.6 2.4 3 3 141 6.9 3.1 5.1 2.3 3 3 142 5.8 2.7 5.1 1.9 3 3 143 6.8 3.2 5.9 2.3 3 3 144 6.7 3.3 5.7 2.5 3 3 145 6.7 3 5.2 2.3 3 3 146 6.3 2.5 5 1.9 3 3 147 6.5 3 5.2 2 3 3 148 6.2 3.4 5.4 2.3 3 3 149 5.9 3 5.1 1.8 3 3 *********************************************** --> Cross-Entropy Error: 4.653512 --> Classification Error: 0.006666667 *********************************************** ****************Time: 0.858 Cross-Entropy Error Value = 4.653511831588219Link to Java source.