scopriamo ora come realizzare una rete neurale per la classificazione di piante in base a particolari caratteristiche. Utilizziamo il set di dati IRIS. Questo set definisce quattro attributi di input e uno di output:
- Sepal length in cm;
- Sepal width in cm;
- Petal length in cm;
- Petal width in cm.
Classificazione:
- Iris Setosa;
- oppure Iris Versicolour;
- oppure Iris Virginica.
Introduzione della classe Java
Definiamo il seguente scheletro di classe all'interno del quale andremo a definire l'implementazione della rete:
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.apache.commons.io.FilenameUtils;
import org.nd4j.common.resources.Downloader;
import java.io.File;
import java.net.URL;
public class Neural {
public static void main(String[] args) throws Exception {
}
Il primo step è recuperare il dataset utilizzando un reader.
int numLinesToSkip = 0;
char delimiter = ',';
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
System.out.println(downloadIris());
recordReader.initialize(new FileSplit(new File(downloadIris(),"iris.txt")));
int labelIndex = 4;
int numClasses = 3;
int batchSize = 150;
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();
Il passo successivo è normalizzare i dati e costruire la rete:
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData);
normalizer.transform(trainingData);
normalizer.transform(testData);
final int numInputs = 4;
final int outputNum = 3;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(5)
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.updater(new Sgd(0.1))
.l2(1e-4)
.list()
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)
.build())
.layer(new DenseLayer.Builder().nIn(3).nOut(3)
.build())
.layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(3).nOut(outputNum).build())
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
La rete neurale è costituita da un input layer, due hidden layer e un output layer. In particolare il layer di output utilizza la funzione softmax
come funzione di attivazione per fornire in output le probabilità di appartenza ad una delle possibili categorie.
Training e testing
La funzione di errore da minimizzare è la negative log likelihood
, particolarmente adatta per problemi di classificazione. Concludiamo quinti con il training e testing:
model.setListeners(new ScoreIterationListener(100));
for(int i=0; i<1000; i++ ) {
model.fit(trainingData);
}
Evaluation eval = new Evaluation(3);
INDArray output = model.output(testData.getFeatures());
eval.eval(testData.getLabels(), output);
System.out.println(eval.stats());
Il codice realizzato fa uso di un metodo ausiliario per il recupero del dataset IRIS:
public static String downloadIris() throws Exception {
String dataURL = "https://dl4jdata.blob.core.windows.net/dl4j-examples/datavec-examples/IrisData.zip";
String downloadPath = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "IrisData.zip");
String extractDir = FilenameUtils.concat(System.getProperty("user.home"), "dl4j-examples-data/" + "datavec-examples/IrisData");
if (!new File(extractDir).exists())
new File(extractDir).mkdirs();
String dataPathLocal = extractDir;
int downloadRetries = 10;
if (!new File(dataPathLocal).exists() || new File(dataPathLocal).list().length == 0) {
System.out.println("_______________________________________________________________________");
System.out.println("Downloading data (1KB) and extracting to \n\t" + dataPathLocal);
System.out.println("_______________________________________________________________________");
Downloader.downloadAndExtract("files",
new URL(dataURL),
new File(downloadPath),
new File(extractDir),
"bb49e38bb91089634d7ef37ad8e430b8",
downloadRetries);
} else {
System.out.println("_______________________________________________________________________");
System.out.println("Example data present in \n\t" + dataPathLocal);
System.out.println("_______________________________________________________________________");
}
return dataPathLocal;
}