Example 2: SVClassification

This example uses stratified cross-validation to select parameter settings for C and \gamma using the minimum CV error criterion. Then, the fitted model using the "best" settings is used to classify the entire dataset. The classification errors are shown.


import com.imsl.datamining.*;
import com.imsl.datamining.supportvectormachine.*;
import com.imsl.datamining.neural.*;
import com.imsl.stat.*;

public class SupportVectorMachineEx2 {

    public static void main(String[] args) throws Exception {

        SVClassification.VariableType[] irisVarType = {
            SVClassification.VariableType.CATEGORICAL,
            SVClassification.VariableType.QUANTITATIVE_CONTINUOUS,
            SVClassification.VariableType.QUANTITATIVE_CONTINUOUS,
            SVClassification.VariableType.QUANTITATIVE_CONTINUOUS,
            SVClassification.VariableType.QUANTITATIVE_CONTINUOUS
        };

        String dashes
                = "--------------------------------------------------------------";

        double[][] irisFisherData = {
            {1.0, 5.1, 3.5, 1.4, .2}, {1.0, 4.9, 3.0, 1.4, .2},
            {1.0, 4.7, 3.2, 1.3, .2}, {1.0, 4.6, 3.1, 1.5, .2},
            {1.0, 5.0, 3.6, 1.4, .2}, {1.0, 5.4, 3.9, 1.7, .4},
            {1.0, 4.6, 3.4, 1.4, .3}, {1.0, 5.0, 3.4, 1.5, .2},
            {1.0, 4.4, 2.9, 1.4, .2}, {1.0, 4.9, 3.1, 1.5, .1},
            {1.0, 5.4, 3.7, 1.5, .2}, {1.0, 4.8, 3.4, 1.6, .2},
            {1.0, 4.8, 3.0, 1.4, .1}, {1.0, 4.3, 3.0, 1.1, .1},
            {1.0, 5.8, 4.0, 1.2, .2}, {1.0, 5.7, 4.4, 1.5, .4},
            {1.0, 5.4, 3.9, 1.3, .4}, {1.0, 5.1, 3.5, 1.4, .3},
            {1.0, 5.7, 3.8, 1.7, .3}, {1.0, 5.1, 3.8, 1.5, .3},
            {1.0, 5.4, 3.4, 1.7, .2}, {1.0, 5.1, 3.7, 1.5, .4},
            {1.0, 4.6, 3.6, 1.0, .2}, {1.0, 5.1, 3.3, 1.7, .5},
            {1.0, 4.8, 3.4, 1.9, .2}, {1.0, 5.0, 3.0, 1.6, .2},
            {1.0, 5.0, 3.4, 1.6, .4}, {1.0, 5.2, 3.5, 1.5, .2},
            {1.0, 5.2, 3.4, 1.4, .2}, {1.0, 4.7, 3.2, 1.6, .2},
            {1.0, 4.8, 3.1, 1.6, .2}, {1.0, 5.4, 3.4, 1.5, .4},
            {1.0, 5.2, 4.1, 1.5, .1}, {1.0, 5.5, 4.2, 1.4, .2},
            {1.0, 4.9, 3.1, 1.5, .1}, {1.0, 5.0, 3.2, 1.2, .2},
            {1.0, 5.5, 3.5, 1.3, .2}, {1.0, 4.9, 3.1, 1.5, .1},
            {1.0, 4.4, 3.0, 1.3, .2}, {1.0, 5.1, 3.4, 1.5, .2},
            {1.0, 5.0, 3.5, 1.3, .3}, {1.0, 4.5, 2.3, 1.3, .3},
            {1.0, 4.4, 3.2, 1.3, .2}, {1.0, 5.0, 3.5, 1.6, .6},
            {1.0, 5.1, 3.8, 1.9, .4}, {1.0, 4.8, 3.0, 1.4, .3},
            {1.0, 5.1, 3.8, 1.6, .2}, {1.0, 4.6, 3.2, 1.4, .2},
            {1.0, 5.3, 3.7, 1.5, .2}, {1.0, 5.0, 3.3, 1.4, .2},
            {2.0, 7.0, 3.2, 4.7, 1.4}, {2.0, 6.4, 3.2, 4.5, 1.5},
            {2.0, 6.9, 3.1, 4.9, 1.5}, {2.0, 5.5, 2.3, 4.0, 1.3},
            {2.0, 6.5, 2.8, 4.6, 1.5}, {2.0, 5.7, 2.8, 4.5, 1.3},
            {2.0, 6.3, 3.3, 4.7, 1.6}, {2.0, 4.9, 2.4, 3.3, 1.0},
            {2.0, 6.6, 2.9, 4.6, 1.3}, {2.0, 5.2, 2.7, 3.9, 1.4},
            {2.0, 5.0, 2.0, 3.5, 1.0}, {2.0, 5.9, 3.0, 4.2, 1.5},
            {2.0, 6.0, 2.2, 4.0, 1.0}, {2.0, 6.1, 2.9, 4.7, 1.4},
            {2.0, 5.6, 2.9, 3.6, 1.3}, {2.0, 6.7, 3.1, 4.4, 1.4},
            {2.0, 5.6, 3.0, 4.5, 1.5}, {2.0, 5.8, 2.7, 4.1, 1.0},
            {2.0, 6.2, 2.2, 4.5, 1.5}, {2.0, 5.6, 2.5, 3.9, 1.1},
            {2.0, 5.9, 3.2, 4.8, 1.8}, {2.0, 6.1, 2.8, 4.0, 1.3},
            {2.0, 6.3, 2.5, 4.9, 1.5}, {2.0, 6.1, 2.8, 4.7, 1.2},
            {2.0, 6.4, 2.9, 4.3, 1.3}, {2.0, 6.6, 3.0, 4.4, 1.4},
            {2.0, 6.8, 2.8, 4.8, 1.4}, {2.0, 6.7, 3.0, 5.0, 1.7},
            {2.0, 6.0, 2.9, 4.5, 1.5}, {2.0, 5.7, 2.6, 3.5, 1.0},
            {2.0, 5.5, 2.4, 3.8, 1.1}, {2.0, 5.5, 2.4, 3.7, 1.0},
            {2.0, 5.8, 2.7, 3.9, 1.2}, {2.0, 6.0, 2.7, 5.1, 1.6},
            {2.0, 5.4, 3.0, 4.5, 1.5}, {2.0, 6.0, 3.4, 4.5, 1.6},
            {2.0, 6.7, 3.1, 4.7, 1.5}, {2.0, 6.3, 2.3, 4.4, 1.3},
            {2.0, 5.6, 3.0, 4.1, 1.3}, {2.0, 5.5, 2.5, 4.0, 1.3},
            {2.0, 5.5, 2.6, 4.4, 1.2}, {2.0, 6.1, 3.0, 4.6, 1.4},
            {2.0, 5.8, 2.6, 4.0, 1.2}, {2.0, 5.0, 2.3, 3.3, 1.0},
            {2.0, 5.6, 2.7, 4.2, 1.3}, {2.0, 5.7, 3.0, 4.2, 1.2},
            {2.0, 5.7, 2.9, 4.2, 1.3}, {2.0, 6.2, 2.9, 4.3, 1.3},
            {2.0, 5.1, 2.5, 3.0, 1.1}, {2.0, 5.7, 2.8, 4.1, 1.3},
            {3.0, 6.3, 3.3, 6.0, 2.5}, {3.0, 5.8, 2.7, 5.1, 1.9},
            {3.0, 7.1, 3.0, 5.9, 2.1}, {3.0, 6.3, 2.9, 5.6, 1.8},
            {3.0, 6.5, 3.0, 5.8, 2.2}, {3.0, 7.6, 3.0, 6.6, 2.1},
            {3.0, 4.9, 2.5, 4.5, 1.7}, {3.0, 7.3, 2.9, 6.3, 1.8},
            {3.0, 6.7, 2.5, 5.8, 1.8}, {3.0, 7.2, 3.6, 6.1, 2.5},
            {3.0, 6.5, 3.2, 5.1, 2.0}, {3.0, 6.4, 2.7, 5.3, 1.9},
            {3.0, 6.8, 3.0, 5.5, 2.1}, {3.0, 5.7, 2.5, 5.0, 2.0},
            {3.0, 5.8, 2.8, 5.1, 2.4}, {3.0, 6.4, 3.2, 5.3, 2.3},
            {3.0, 6.5, 3.0, 5.5, 1.8}, {3.0, 7.7, 3.8, 6.7, 2.2},
            {3.0, 7.7, 2.6, 6.9, 2.3}, {3.0, 6.0, 2.2, 5.0, 1.5},
            {3.0, 6.9, 3.2, 5.7, 2.3}, {3.0, 5.6, 2.8, 4.9, 2.0},
            {3.0, 7.7, 2.8, 6.7, 2.0}, {3.0, 6.3, 2.7, 4.9, 1.8},
            {3.0, 6.7, 3.3, 5.7, 2.1}, {3.0, 7.2, 3.2, 6.0, 1.8},
            {3.0, 6.2, 2.8, 4.8, 1.8}, {3.0, 6.1, 3.0, 4.9, 1.8},
            {3.0, 6.4, 2.8, 5.6, 2.1}, {3.0, 7.2, 3.0, 5.8, 1.6},
            {3.0, 7.4, 2.8, 6.1, 1.9}, {3.0, 7.9, 3.8, 6.4, 2.0},
            {3.0, 6.4, 2.8, 5.6, 2.2}, {3.0, 6.3, 2.8, 5.1, 1.5},
            {3.0, 6.1, 2.6, 5.6, 1.4}, {3.0, 7.7, 3.0, 6.1, 2.3},
            {3.0, 6.3, 3.4, 5.6, 2.4}, {3.0, 6.4, 3.1, 5.5, 1.8},
            {3.0, 6.0, 3.0, 4.8, 1.8}, {3.0, 6.9, 3.1, 5.4, 2.1},
            {3.0, 6.7, 3.1, 5.6, 2.4}, {3.0, 6.9, 3.1, 5.1, 2.3},
            {3.0, 5.8, 2.7, 5.1, 1.9}, {3.0, 6.8, 3.2, 5.9, 2.3},
            {3.0, 6.7, 3.3, 5.7, 2.5}, {3.0, 6.7, 3.0, 5.2, 2.3},
            {3.0, 6.3, 2.5, 5.0, 1.9}, {3.0, 6.5, 3.0, 5.2, 2.0},
            {3.0, 6.2, 3.4, 5.4, 2.3}, {3.0, 5.9, 3.0, 5.1, 1.8}
        };

        // Create a scaled version of the Iris attribute data.
        double[][] x = new double[150][4];
        double[][] xx = new double[150][4];

        // Get the data.
        for (int i = 0; i < 150; i++) {
            for (int j = 0; j < 4; j++) {
                x[i][j] = irisFisherData[i][j + 1];
            }
        }
        // Scale the data.
        double realMin = 0.0, realMax = 10.0, targetMin = 0.0, targetMax = 1.0;

        ScaleFilter scaleFilter = new ScaleFilter(ScaleFilter.BOUNDED_SCALING);
        scaleFilter.setBounds(realMin, realMax, targetMin, targetMax);

        for (int i = 0; i < 150; i++) {
            xx[i] = scaleFilter.encode(x[i]);
        }

        // Build a training data set.
        int nTrain = 30;
        double[][] xy = new double[nTrain][5];
        int ii = 0;
        // The response variable (Iris Species) is encoded starting in "1".
        // Here, subtract 1 from the response because the class assumes
        // 0 based categorical response variable.
        for (int i = 0; i < 3; i++) {
            for (int j = 0; j < 10; j++) {
                xy[ii][0] = irisFisherData[(i * 50) + j][0] - 1;
                System.arraycopy(xx[(i * 50 + j)], 0, xy[ii], 1, 4);
                ii++;
            }
        }

        // Construct a Support Vector Machine.
        SVClassification svm = new SVClassification(xy, 0, irisVarType);

        double[] gamma = {0.1};
        double C = 2.0;
        double result;
        double minResult = 10000.0;
        double bestGamma = 0.0;
        double bestC = 0.0;

        CrossValidation svmCV = new CrossValidation(svm);
        svmCV.setNumberOfSampleFolds(5);
        svmCV.setRandomObject(new Random(123457));
        svmCV.setStratifiedCrossValidation(true);

        for (int i = 0; i < 10; i++) {
            for (int j = 0; j < 5; j++) {
                svm.setRegularizationParameter(C);
                svm.setKernelParameters(gamma);
                svmCV.crossValidate();
                result = svmCV.getCrossValidatedError();
                if (result < minResult) {
                    minResult = result;
                    bestGamma = gamma[0];
                    bestC = C;
                }
                gamma[0] = gamma[0] * 2.0;
            }
            gamma[0] = 0.1;
            C = C * 2.0;
        }
        System.out.printf("Best C: %5.0f \n", bestC);
        System.out.printf("Best gamma: %5.3f \n", bestGamma);
        System.out.printf("Minimum CV error: %5.3f \n", minResult);

        svm.setRegularizationParameter(bestC);
        gamma[0] = bestGamma;
        svm.setKernelParameters(gamma);
        // Train the model on the training sample (30 observations).
        svm.fitModel();

        // Classify the entire data set with the fitted model
        //  using the "best" C and gamma parameter values.
        xy = new double[150][5];
        double[] knownClass = new double[150];
        for (int i = 0; i < 150; i++) {
            xy[i][0] = irisFisherData[i][0] - 1;
            knownClass[i] = xy[i][0];
            System.arraycopy(xx[i], 0, xy[i], 1, 4);
        }

        double[] predictedClass = svm.predict(xy);
        int[][] classErrors = svm.getClassErrors(knownClass, predictedClass);

        System.out.println("\n   Iris Classification Error Rates");
        System.out.println("\n" + dashes);
        System.out.println(" Setosa Versicolour Virginica | TOTAL");
        System.out.println("  " + classErrors[0][0] + "/" + classErrors[0][1]
                + "      " + classErrors[1][0] + "/" + classErrors[1][1]
                + "       " + classErrors[2][0] + "/" + classErrors[2][1]
                + "     " + classErrors[3][0] + "/" + classErrors[3][1]);

        System.out.println(dashes);
    }
}

Output

Best C:    64 
Best gamma: 1.600 
Minimum CV error: 0.000 

   Iris Classification Error Rates

--------------------------------------------------------------
 Setosa Versicolour Virginica | TOTAL
  0/50      1/50       3/50     4/150
--------------------------------------------------------------
Link to Java source.