Applying neural networks for classification

Classification tasks can be performed by using any of the supervised neural networks that this book has covered so far. However, it is recommended to use more complex architectures, such as MLPs. In this chapter, we are going to use the NeuralNet class to build an MLP with one hidden layer and the sigmoid function at the output. Every output neuron denotes a class.

We've added to framework a special class called Classification in order to handle concepts such as confusion matrix, sensitivity, and specificity. The following table shows a list of the methods and parameters contained in this class:

Class name: Classification

Methods

public double[][] calculateConfusionMatrix( double marginError, double[][] matrix )

Method to calculate confusion matrix

Parameters: Margin error and matrix with real output and estimated output

Returns: Confusion matrix

public void printConfusionMatrix( double[][] matrix )

Method to print confusion matrix

Parameters: Confusion matrix

Returns: -

public double calculateSensitivity( double[][] matrix )

Method to calculate sensitivity of classification

Parameters: Matrix with real output and estimated output

Returns: Sensitivity value

public double calculateSpecificity( double[][] matrix ) 

Method to calculate specificity of classification

Parameters: Matrix with real output and estimated output

Returns: Specificity value

public double calculateAccuracy( double[][] matrix )

Method to calculate accuracy of classification

Parameters: matrix with real output and estimated output

Returns: specificity value

public double[][] convertToOneColumn( double[][] matrix )

Method to convert a matrix with more than one column to one column. It has been used when neural net has more than one neuron in output layer

Parameters: Matrix with more than one column

Returns: Matrix with one column

Class implementation with Java: file Classification.java

The implementation of a neural network for classification would follow the following steps:

  1. Data loading (training and test data)
  2. Data normalization
  3. Creating neural network
  4. Training neural network
  5. Analyze and take conclusions from the classifier via a classification object

First, let's load the data and normalize it:

    //Training data
  Data dataInput  = new Data("data", "inputs_training.csv");
    Data dataOutput = new Data("data", "output_training.csv");
    // test data
    Data dataInputTestRNA  = new Data("data", "inputs_test.csv");
    Data dataOutputTestRNA = new Data("data", "output_test.csv");

    // normalization
    NormalizationTypesENUM NORMALIZATION_TYPE = Data.NormalizationTypesENUM.MAX_MIN_EQUALIZED;

It is important to convert the data to the matrix format so that it can be fed into the neural network:

      //convert the raw data to matrix
      double[][] matrixInput  = dataInput.rawData2Matrix( diseaseDataInput );
      double[][] matrixOutput = dataOutput.rawData2Matrix( diseaseDataOutput );
      
      //Normalize the data. Normalization code for test data is suppressed.
      double[][] matrixInputNorm = dataInput.normalize(matrixInput, NORMALIZATION_TYPE);

Now, let's create the neural network here with 8 inputs, 3 hidden neurons, and 2 outputs:

      NeuralNet n1 = new NeuralNet();
      n1 = n1.initNet(8, 1, 3, 2);

Next, we perform the training. Since we've already seen how this can be set up in Chapter 3, Handling Perceptrons, we're leaving this out here to save space. Then, we create a new network to receive the trained network:

      //Create a new network to receive the trained network
      NeuralNet n1Trained = new NeuralNet();
      n1Trained = n1.trainNet(n1);
      
      //Plot the error:
      Chart c1 = new Chart();
      c1.plotXYData(n1.getListOfMSE().toArray(), "MSE Error", "Epochs", "MSE Value");

After the training has been finished, we instantiate a classification object to carry out some analyses on the results:

      Classification classif = new Classification();
      
      //Load the test data:
      n1Trained.setTrainSet( matrixInputTestRNANorm );
      n1Trained.setRealMatrixOutputSet( matrixOutputTestRNA );
      
      double[][] matrixOutputRNATest = n1Trained.getNetOutputValues(n1Trained);
      
      //Check the number of outputs to adapt the test data to the neural multiple outputs
      if(n1Trained.getOutputLayer().getNumberOfNeuronsInLayer() > 1) {
      
        matrixOutputTestRNA = classif.convertToOneColumn(matrixOutputTestRNA);
        matrixOutputRNATest = classif.convertToOneColumn(matrixOutputRNATest);
        

      }

Finally, we apply some processing for exhibiting the charts and the confusion matrix:

      ArrayList<double[][]> listOfArraysToJoinTest = new ArrayList<double[][]>();
      listOfArraysToJoinTest.add( matrixOutputTestRNA );
      listOfArraysToJoinTest.add( matrixOutputRNATest );
      double[][] matrixOutputsJoinedTest = new Data().joinArrays(listOfArraysToJoinTest);

      //Plot a bar chart      
      Chart c3 = new Chart();
      c3.plotBarChart(matrixOutputsJoinedTest, "Real x Estimated - Test Data", " Data", "Result (0: NO / 1: YES)");

      
      //plots the confusion matrix and the sensitivity and specificity indexes
      double[][] confusionMatrix = classif.calculateConfusionMatrix(0.6, matrixOutputsJoinedTest);
      classif.printConfusionMatrix(confusionMatrix);
      System.out.println("SENSITIVITY = " + classif.calculateSensitivity(confusionMatrix));
      System.out.println("SPECIFICITY = " + classif.calculateSpecificity(confusionMatrix));

      //Finally the final accuracy of classification
      System.out.println("ACCURACY    = " + classif.calculateAccuracy(confusionMatrix));
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset