Classification tasks can be performed by using any of the supervised neural networks that this book has covered so far. However, it is recommended to use more complex architectures, such as MLPs. In this chapter, we are going to use the NeuralNet
class to build an MLP with one hidden layer and the sigmoid function at the output. Every output neuron denotes a class.
We've added to framework a special class called Classification
in order to handle concepts such as confusion matrix, sensitivity, and specificity. The following table shows a list of the methods and parameters contained in this class:
Class name: Classification | |
Methods | |
public double[][] calculateConfusionMatrix( double marginError, double[][] matrix )
|
Method to calculate confusion matrix |
Parameters: Margin error and matrix with real output and estimated output | |
Returns: Confusion matrix | |
public void printConfusionMatrix( double[][] matrix )
|
Method to print confusion matrix |
Parameters: Confusion matrix | |
Returns: - | |
public double calculateSensitivity( double[][] matrix )
|
Method to calculate sensitivity of classification |
Parameters: Matrix with real output and estimated output | |
Returns: Sensitivity value | |
public double calculateSpecificity( double[][] matrix )
|
Method to calculate specificity of classification |
Parameters: Matrix with real output and estimated output | |
Returns: Specificity value | |
public double calculateAccuracy( double[][] matrix )
|
Method to calculate accuracy of classification |
Parameters: matrix with real output and estimated output | |
Returns: specificity value | |
public double[][] convertToOneColumn( double[][] matrix )
|
Method to convert a matrix with more than one column to one column. It has been used when neural net has more than one neuron in output layer |
Parameters: Matrix with more than one column | |
Returns: Matrix with one column | |
Class implementation with Java: file Classification.java |
The implementation of a neural network for classification would follow the following steps:
First, let's load the data and normalize it:
//Training data Data dataInput = new Data("data", "inputs_training.csv"); Data dataOutput = new Data("data", "output_training.csv"); // test data Data dataInputTestRNA = new Data("data", "inputs_test.csv"); Data dataOutputTestRNA = new Data("data", "output_test.csv"); // normalization NormalizationTypesENUM NORMALIZATION_TYPE = Data.NormalizationTypesENUM.MAX_MIN_EQUALIZED;
It is important to convert the data to the matrix format so that it can be fed into the neural network:
//convert the raw data to matrix double[][] matrixInput = dataInput.rawData2Matrix( diseaseDataInput ); double[][] matrixOutput = dataOutput.rawData2Matrix( diseaseDataOutput ); //Normalize the data. Normalization code for test data is suppressed. double[][] matrixInputNorm = dataInput.normalize(matrixInput, NORMALIZATION_TYPE);
Now, let's create the neural network here with 8
inputs, 3
hidden neurons, and 2
outputs:
NeuralNet n1 = new NeuralNet(); n1 = n1.initNet(8, 1, 3, 2);
Next, we perform the training. Since we've already seen how this can be set up in Chapter 3, Handling Perceptrons, we're leaving this out here to save space. Then, we create a new network to receive the trained network:
//Create a new network to receive the trained network NeuralNet n1Trained = new NeuralNet(); n1Trained = n1.trainNet(n1); //Plot the error: Chart c1 = new Chart(); c1.plotXYData(n1.getListOfMSE().toArray(), "MSE Error", "Epochs", "MSE Value");
After the training has been finished, we instantiate a classification object to carry out some analyses on the results:
Classification classif = new Classification(); //Load the test data: n1Trained.setTrainSet( matrixInputTestRNANorm ); n1Trained.setRealMatrixOutputSet( matrixOutputTestRNA ); double[][] matrixOutputRNATest = n1Trained.getNetOutputValues(n1Trained); //Check the number of outputs to adapt the test data to the neural multiple outputs if(n1Trained.getOutputLayer().getNumberOfNeuronsInLayer() > 1) { matrixOutputTestRNA = classif.convertToOneColumn(matrixOutputTestRNA); matrixOutputRNATest = classif.convertToOneColumn(matrixOutputRNATest); }
Finally, we apply some processing for exhibiting the charts and the confusion matrix:
ArrayList<double[][]> listOfArraysToJoinTest = new ArrayList<double[][]>(); listOfArraysToJoinTest.add( matrixOutputTestRNA ); listOfArraysToJoinTest.add( matrixOutputRNATest ); double[][] matrixOutputsJoinedTest = new Data().joinArrays(listOfArraysToJoinTest); //Plot a bar chart Chart c3 = new Chart(); c3.plotBarChart(matrixOutputsJoinedTest, "Real x Estimated - Test Data", " Data", "Result (0: NO / 1: YES)"); //plots the confusion matrix and the sensitivity and specificity indexes double[][] confusionMatrix = classif.calculateConfusionMatrix(0.6, matrixOutputsJoinedTest); classif.printConfusionMatrix(confusionMatrix); System.out.println("SENSITIVITY = " + classif.calculateSensitivity(confusionMatrix)); System.out.println("SPECIFICITY = " + classif.calculateSpecificity(confusionMatrix)); //Finally the final accuracy of classification System.out.println("ACCURACY = " + classif.calculateAccuracy(confusionMatrix));