Example 5: DecisionTree

This example uses the dataset Kyphosis. The 81 cases represent 81 children who have undergone surgery to correct a type of spinal deformity known as Kyphosis. The response variable is the presence or absence of Kyphosis after the surgery. Three predictors are Age of the patient in months, Start, the number of the vertebra where the surgery started, and Number, the number of vertebra involved in the surgery. This example uses the method QUEST to produce a maximal tree. It also requests predictions for a test-data set consisting of 10 "new" cases.


import com.imsl.datamining.decisionTree.*;

public class DecisionTreeEx5 {

    public static void main(String[] args) throws Exception {

        double[][] xy = {
            {0, 71, 3, 5},
            {0, 158, 3, 14},
            {1, 128, 4, 5},
            {0, 2, 5, 1},
            {0, 1, 4, 15},
            {0, 1, 2, 16},
            {0, 61, 2, 17},
            {0, 37, 3, 16},
            {0, 113, 2, 16},
            {1, 59, 6, 12},
            {1, 82, 5, 14},
            {0, 148, 3, 16},
            {0, 18, 5, 2},
            {0, 1, 4, 12},
            {0, 168, 3, 18},
            {0, 1, 3, 16},
            {0, 78, 6, 15},
            {0, 175, 5, 13},
            {0, 80, 5, 16},
            {0, 27, 4, 9},
            {0, 22, 2, 16},
            {1, 105, 6, 5},
            {1, 96, 3, 12},
            {0, 131, 2, 3},
            {1, 15, 7, 2},
            {0, 9, 5, 13},
            {0, 8, 3, 6},
            {0, 100, 3, 14},
            {0, 4, 3, 16},
            {0, 151, 2, 16},
            {0, 31, 3, 16},
            {0, 125, 2, 11},
            {0, 130, 5, 13},
            {0, 112, 3, 16},
            {0, 140, 5, 11},
            {0, 93, 3, 16},
            {0, 1, 3, 9},
            {1, 52, 5, 6},
            {0, 20, 6, 9},
            {1, 91, 5, 12},
            {1, 73, 5, 1},
            {0, 35, 3, 13},
            {0, 143, 9, 3},
            {0, 61, 4, 1},
            {0, 97, 3, 16},
            {1, 139, 3, 10},
            {0, 136, 4, 15},
            {0, 131, 5, 13},
            {1, 121, 3, 3},
            {0, 177, 2, 14},
            {0, 68, 5, 10},
            {0, 9, 2, 17},
            {1, 139, 10, 6},
            {0, 2, 2, 17},
            {0, 140, 4, 15},
            {0, 72, 5, 15},
            {0, 2, 3, 13},
            {1, 120, 5, 8},
            {0, 51, 7, 9},
            {0, 102, 3, 13},
            {1, 130, 4, 1},
            {1, 114, 7, 8},
            {0, 81, 4, 1},
            {0, 118, 3, 16},
            {0, 118, 4, 16},
            {0, 17, 4, 10},
            {0, 195, 2, 17},
            {0, 159, 4, 13},
            {0, 18, 4, 11},
            {0, 15, 5, 16},
            {0, 158, 5, 14},
            {0, 127, 4, 12},
            {0, 87, 4, 16},
            {0, 206, 4, 10},
            {0, 11, 3, 15},
            {0, 178, 4, 15},
            {1, 157, 3, 13},
            {0, 26, 7, 13},
            {0, 120, 2, 13},
            {1, 42, 7, 6},
            {0, 36, 4, 13}
        };

        double[][] xyTest = {
            {0, 71, 3, 5},
            {1, 128, 4, 5},
            {0, 1, 4, 15},
            {0, 61, 6, 10},
            {0, 113, 2, 16},
            {1, 82, 5, 14},
            {0, 148, 3, 16},
            {0, 1, 4, 12},
            {0, 1, 3, 16},
            {0, 175, 5, 13}
        };

        DecisionTree.VariableType[] varType = {
            DecisionTree.VariableType.CATEGORICAL,
            DecisionTree.VariableType.QUANTITATIVE_CONTINUOUS,
            DecisionTree.VariableType.QUANTITATIVE_CONTINUOUS,
            DecisionTree.VariableType.QUANTITATIVE_CONTINUOUS
        };

        String[] names = {"Age", "Number", "Start"};
        String[] classNames = {"Absent", "Present"};
        String responseName = "Kyphosis";

        QUEST dt = new QUEST(xy, 0, varType);
        dt.setMinObsPerChildNode(5);
        dt.setMinObsPerNode(10);
        dt.setMaxNodes(50);
        dt.setPrintLevel(2);
        dt.fitModel();

        double[] predictions = dt.predict(xyTest);
        double predErrSS = dt.getMeanSquaredPredictionError();

        dt.printDecisionTree(responseName, names,
                classNames, null, true);

        System.out.println("\nPredictions for test data:");
        System.out.printf("%5s%8s%7s%10s\n", names[0], names[1], names[2],
                responseName);

        for (int i = 0; i < xyTest.length; i++) {
            System.out.printf("%5.0f%8.0f%7.0f", xyTest[i][1], xyTest[i][2],
                    xyTest[i][3]);
            int idx = (int) predictions[i];
            System.out.printf("%10s\n", classNames[idx]);
        }
        System.out.printf("\nMean squared prediction error: %f\n", predErrSS);
    }
}

Output

Growing the maximal tree using method QUEST:

Node 2 is a terminal node. It has 7.0 cases--too few cases to split.
Node 3 is a terminal node. It has 6.0 cases--too few cases to split.
Node 5 is a terminal node. It has 6.0 cases--too few cases to split.
Node 8 is a terminal node. The split is too thin having count 2.
Node 10 is a terminal node. It has 6.0 cases--too few cases to split.
Node 11 is a terminal node, because it is pure.
Node 11 is a terminal node. It has 7.0 cases--too few cases to split.
Node 13 is a terminal node. It has 5.0 cases--too few cases to split.
Node 14 is a terminal node, because it is pure.

Decision Tree:


Node 0: Cost = 0.210, N= 81, Level = 0, Child nodes:  1  4 
P(Y=0)= 0.790
P(Y=1)= 0.210
Predicted Kyphosis:  Absent 

Node 1: Cost = 0.074, N= 13, Level = 1, Child nodes:  2  3 
Rule:  Start <= 5.155
P(Y=0)= 0.538
P(Y=1)= 0.462
Predicted Kyphosis:  Absent 

Node 2: Cost = 0.025, N= 7, Level = 2
Rule:  Age <= 84.030
P(Y=0)= 0.714
P(Y=1)= 0.286
Predicted Kyphosis:  Absent 

Node 3: Cost = 0.025, N= 6, Level = 2
Rule:  Age > 84.030
P(Y=0)= 0.333
P(Y=1)= 0.667
Predicted Kyphosis:  Present 

Node 4: Cost = 0.136, N= 68, Level = 1, Child nodes:  5  6 
Rule:  Start > 5.155
P(Y=0)= 0.838
P(Y=1)= 0.162
Predicted Kyphosis:  Absent 

Node 5: Cost = 0.012, N= 6, Level = 2
Rule:  Start <= 8.862
P(Y=0)= 0.167
P(Y=1)= 0.833
Predicted Kyphosis:  Present 

Node 6: Cost = 0.074, N= 62, Level = 2, Child nodes:  7  12 
Rule:  Start > 8.862
P(Y=0)= 0.903
P(Y=1)= 0.097
Predicted Kyphosis:  Absent 

Node 7: Cost = 0.062, N= 28, Level = 3, Child nodes:  8  9 
Rule:  Start <= 13.092
P(Y=0)= 0.821
P(Y=1)= 0.179
Predicted Kyphosis:  Absent 

Node 8: Cost = 0.025, N= 15, Level = 4
Rule:  Age <= 91.722
P(Y=0)= 0.867
P(Y=1)= 0.133
Predicted Kyphosis:  Absent 

Node 9: Cost = 0.037, N= 13, Level = 4, Child nodes:  10  11 
Rule:  Age > 91.722
P(Y=0)= 0.769
P(Y=1)= 0.231
Predicted Kyphosis:  Absent 

Node 10: Cost = 0.037, N= 6, Level = 5
Rule:  Number <= 3.450
P(Y=0)= 0.500
P(Y=1)= 0.500
Predicted Kyphosis:  Absent 

Node 11: Cost = 0.000, N= 7, Level = 5
Rule:  Number > 3.450
P(Y=0)= 1.000
P(Y=1)= 0.000
Predicted Kyphosis:  Absent 

Node 12: Cost = 0.012, N= 34, Level = 3, Child nodes:  13  14 
Rule:  Start > 13.092
P(Y=0)= 0.971
P(Y=1)= 0.029
Predicted Kyphosis:  Absent 

Node 13: Cost = 0.012, N= 5, Level = 4
Rule:  Start <= 14.864
P(Y=0)= 0.800
P(Y=1)= 0.200
Predicted Kyphosis:  Absent 

Node 14: Cost = 0.000, N= 29, Level = 4
Rule:  Start > 14.864
P(Y=0)= 1.000
P(Y=1)= 0.000
Predicted Kyphosis:  Absent 

Predictions for test data:
  Age  Number  Start  Kyphosis
   71       3      5    Absent
  128       4      5   Present
    1       4     15    Absent
   61       6     10    Absent
  113       2     16    Absent
   82       5     14    Absent
  148       3     16    Absent
    1       4     12    Absent
    1       3     16    Absent
  175       5     13    Absent

Mean squared prediction error: 0.100000
Link to Java source.