Nessun risultato. Prova con un altro termine.
Guide
Notizie
Software
Tutorial

Una rete neurale con Deeplearning4j

Realizzare una rete neurale artificiale con la deep-learning library Deeplearning4j per classificare oggetti in base a particolari caratteristiche
Realizzare una rete neurale artificiale con la deep-learning library Deeplearning4j per classificare oggetti in base a particolari caratteristiche
Link copiato negli appunti

scopriamo ora come realizzare una rete neurale per la classificazione di piante in base a particolari caratteristiche. Utilizziamo il set di dati IRIS. Questo set definisce quattro attributi di input e uno di output:

  • Sepal length
  • Sepal width
  • Petal length
  • Petal width

Classificazione:

  • Iris Setosa;
  • oppure Iris Versicolour;
  • oppure Iris Virginica.
  • Introduzione della classe Java

    Definiamo il seguente scheletro di classe all'interno del quale andremo a definire l'implementazione della rete:

    import org.datavec.api.records.reader.RecordReader;
    import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
    import org.datavec.api.split.FileSplit;
    import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
    import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
    import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
    import org.deeplearning4j.nn.conf.layers.DenseLayer;
    import org.deeplearning4j.nn.conf.layers.OutputLayer;
    import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
    import org.deeplearning4j.nn.weights.WeightInit;
    import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
    import org.nd4j.evaluation.classification.Evaluation;
    import org.nd4j.linalg.activations.Activation;
    import org.nd4j.linalg.api.ndarray.INDArray;
    import org.nd4j.linalg.dataset.DataSet;
    import org.nd4j.linalg.dataset.SplitTestAndTrain;
    import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
    import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
    import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
    import org.nd4j.linalg.learning.config.Sgd;
    import org.nd4j.linalg.lossfunctions.LossFunctions;
    import org.apache.commons.io.FilenameUtils;
    import org.nd4j.common.resources.Downloader;
    import java.io.File;
    import java.net.URL;
    
    public class Neural {
        public static void main(String[] args) throws  Exception {
    	}

    Il primo step è recuperare il dataset utilizzando un reader.

    int numLinesToSkip = 0;
       char delimiter = ',';
       RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
       System.out.println(downloadIris());
       recordReader.initialize(new FileSplit(new File(downloadIris(),"iris.txt")));
    
       int labelIndex = 4;
       int numClasses = 3;
       int batchSize = 150;
       DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
       DataSet allData = iterator.next();
       allData.shuffle();
       SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
    
       DataSet trainingData = testAndTrain.getTrain();
       DataSet testData = testAndTrain.getTest();

    Il passo successivo è normalizzare i dati e costruire la rete:

    DataNormalization normalizer = new NormalizerStandardize();
        normalizer.fit(trainingData);
        normalizer.transform(trainingData);
        normalizer.transform(testData);
        final int numInputs = 4;
        final int outputNum = 3;
    
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .seed(5)
                    .activation(Activation.TANH)
                    .weightInit(WeightInit.XAVIER)
                    .updater(new Sgd(0.1))
                    .l2(1e-4)
                    .list()
                    .layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)
                            .build())
                    .layer(new DenseLayer.Builder().nIn(3).nOut(3)
                            .build())
                    .layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                            .activation(Activation.SOFTMAX)
                            .nIn(3).nOut(outputNum).build())
                    .build();
         MultiLayerNetwork model = new MultiLayerNetwork(conf);
         model.init();

    La rete neurale è costituita da un input layer, due hidden layer e un output layer. In particolare il layer di output utilizza la funzione softmax

    Training e testing

    La funzione di errore da minimizzare è la negative log likelihood, particolarmente adatta per problemi di classificazione. Concludiamo quinti con il training e testing:

    model.setListeners(new ScoreIterationListener(100));
    
        for(int i=0; i<1000; i++ ) {
            model.fit(trainingData);
        }
        Evaluation eval = new Evaluation(3);
        INDArray output = model.output(testData.getFeatures());
        eval.eval(testData.getLabels(), output);
        System.out.println(eval.stats());

    Il codice realizzato fa uso di un metodo ausiliario per il recupero del dataset IRIS:

    public static String downloadIris() throws Exception {
            String dataURL = "https://dl4jdata.blob.core.windows.net/dl4j-examples/datavec-examples/IrisData.zip";
            String downloadPath = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "IrisData.zip");
            String extractDir = FilenameUtils.concat(System.getProperty("user.home"), "dl4j-examples-data/" + "datavec-examples/IrisData");
            if (!new File(extractDir).exists())
                new File(extractDir).mkdirs();
            String dataPathLocal = extractDir;
    
            int downloadRetries = 10;
            if (!new File(dataPathLocal).exists() || new File(dataPathLocal).list().length == 0) {
                System.out.println("_______________________________________________________________________");
                System.out.println("Downloading data (1KB) and extracting to \n\t" + dataPathLocal);
                System.out.println("_______________________________________________________________________");
                Downloader.downloadAndExtract("files",
                        new URL(dataURL),
                        new File(downloadPath),
                        new File(extractDir),
                        "bb49e38bb91089634d7ef37ad8e430b8",
                        downloadRetries);
            } else {
                System.out.println("_______________________________________________________________________");
                System.out.println("Example data present in \n\t" + dataPathLocal);
                System.out.println("_______________________________________________________________________");
            }
            return dataPathLocal;
        }

Ti consigliamo anche