Example 2: MultiClassification

This example trains a 2-layer network using three binary inputs (X0, X1, X2) and one three-level classification (Y). Where

Y = 0 if X1 = 1

Y = 1 if X2 = 1

Y = 2 if X3 = 1


import com.imsl.datamining.neural.*;
import com.imsl.math.*;
import java.io.*;
import java.util.logging.*;

//***************************************************************************
// Two-Layer FFN with 3 binary inputs (X0, X1, X2) and one three-level
// classification variable (Y)
// Y = 0 if X1 = 1
// Y = 1 if X2 = 1
// Y = 2 if X3 = 1
//  (training_ex6)
//***************************************************************************
public class MultiClassificationEx2 implements Serializable {

    private static int nObs = 6;     // number of training patterns
    private static int nInputs = 3;  // 3 inputs, all categorical
    private static int nOutputs = 3; // output
    private static boolean trace = true; // Turns on/off training log  
    private static double xData[][] = {
        {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1}, {0, 0, 1}
    };
    private static int yData[] = {1, 1, 2, 2, 3, 3};

    private static double weights[] = {
        1.29099444873580580000, -0.64549722436790280000, -0.64549722436790291000,
        0.00000000000000000000, 1.11803398874989490000, -1.11803398874989470000,
        0.57735026918962584000, 0.57735026918962584000, 0.57735026918962584000,
        0.33333333333333331000, 0.33333333333333331000, 0.33333333333333331000,
        0.33333333333333331000, 0.33333333333333331000, 0.33333333333333331000,
        0.33333333333333331000, 0.33333333333333331000, 0.33333333333333331000,
        -0.00000000000000005851, -0.00000000000000005851, -0.57735026918962573000,
        0.00000000000000000000, 0.00000000000000000000, 0.00000000000000000000
    };

    public static void main(String[] args) throws Exception {
        FeedForwardNetwork network = new FeedForwardNetwork();
        network.getInputLayer().createInputs(nInputs);
        network.createHiddenLayer().
                createPerceptrons(3, Activation.LINEAR, 0.0);
        network.getOutputLayer().
                createPerceptrons(nOutputs, Activation.SOFTMAX, 0.0);
        network.linkAll();
        network.setWeights(weights);

        MultiClassification classification = new MultiClassification(network);

        QuasiNewtonTrainer trainer = new QuasiNewtonTrainer();
        trainer.setError(classification.getError());
        trainer.setMaximumTrainingIterations(1000);
        trainer.setFalseConvergenceTolerance(1.0e-20);
        trainer.setGradientTolerance(1.0e-20);
        trainer.setRelativeTolerance(1.0e-20);
        trainer.setStepTolerance(1.0e-20);

        // If tracing is requested setup training logger
        if (trace) {
            Handler handler = new FileHandler("ClassificationNetworkEx2.log");
            Logger logger = Logger.getLogger("com.imsl.datamining.neural");
            logger.setLevel(Level.FINEST);
            logger.addHandler(handler);
            handler.setFormatter(QuasiNewtonTrainer.getFormatter());
        }
        // Train Network
        classification.train(trainer, xData, yData);

        // Display Network Errors
        double stats[] = classification.computeStatistics(xData, yData);
        System.out.println("***********************************************");
        System.out.println("--> Cross-Entropy Error:      "
                + (float) stats[0]);
        System.out.println("--> Classification Error:     "
                + (float) stats[1]);
        System.out.println("***********************************************");
        System.out.println();

        double weight[] = network.getWeights();
        double gradient[] = trainer.getErrorGradient();
        double wg[][] = new double[weight.length][2];
        for (int i = 0; i < weight.length; i++) {
            wg[i][0] = weight[i];
            wg[i][1] = gradient[i];
        }
        PrintMatrixFormat pmf = new PrintMatrixFormat();
        pmf.setNumberFormat(new java.text.DecimalFormat("0.000000"));
        pmf.setColumnLabels(new String[]{"Weights", "Gradients"});
        new PrintMatrix().print(pmf, wg);

        double report[][] = new double[nObs][nInputs + nOutputs + 2];
        for (int i = 0; i < nObs; i++) {
            for (int j = 0; j < nInputs; j++) {
                report[i][j] = xData[i][j];
            }
            report[i][nInputs] = yData[i];
            double p[] = classification.probabilities(xData[i]);
            for (int j = 0; j < nOutputs; j++) {
                report[i][nInputs + 1 + j] = p[j];
            }
            report[i][nInputs + nOutputs + 1]
                    = classification.predictedClass(xData[i]);
        }
        pmf = new PrintMatrixFormat();
        pmf.setColumnLabels(new String[]{"X1", "X2", "X3", "Y", "P(C1)",
            "P(C2)", "P(C3)", "Predicted"});
        new PrintMatrix("Forecast").print(pmf, report);
        System.out.println("Cross-Entropy Error Value = "
                + trainer.getErrorValue());

        // ******************************************************************
        // DISPLAY CLASSIFICATION STATISTICS
        // ******************************************************************
        double statsClass[] = classification.computeStatistics(xData, yData);
        // Display Network Errors
        System.out.println("***********************************************");
        System.out.println("--> Cross-Entropy Error:      "
                + (float) statsClass[0]);
        System.out.println("--> Classification Error:     "
                + (float) statsClass[1]);
        System.out.println("***********************************************");
        System.out.println("");
    }
}

Output

***********************************************
--> Cross-Entropy Error:      0.0
--> Classification Error:     0.0
***********************************************

     Weights   Gradients  
 0   3.401208  -0.000000  
 1  -4.126657   0.000000  
 2  -2.201606  -0.000000  
 3  -2.009527   0.000000  
 4   3.173323  -0.000000  
 5  -4.200377  -0.000000  
 6   0.028736  -0.000000  
 7   2.657051   0.000000  
 8   4.868134  -0.000000  
 9   3.711295  -0.000000  
10  -2.723536  -0.000000  
11   0.012241   0.000000  
12  -4.996359   0.000000  
13   4.296983   0.000000  
14   1.699376  -0.000000  
15  -1.993114   0.000000  
16  -4.048833   0.000000  
17   7.041948  -0.000000  
18  -0.447927  -0.000000  
19   0.653830   0.000000  
20  -0.925019  -0.000000  
21  -0.078963   0.000000  
22   0.247835   0.000000  
23  -0.168872  -0.000000  

                     Forecast
   X1  X2  X3  Y  P(C1)  P(C2)  P(C3)  Predicted  
0  1   0   0   1    1      0      0        1      
1  1   0   0   1    1      0      0        1      
2  0   1   0   2    0      1      0        2      
3  0   1   0   2    0      1      0        2      
4  0   0   1   3    0      0      1        3      
5  0   0   1   3    0      0      1        3      

Cross-Entropy Error Value = 0.0
***********************************************
--> Cross-Entropy Error:      0.0
--> Classification Error:     0.0
***********************************************

Link to Java source.