Completiamo la classe Regression
che abbiamo iniziato a sviluppare nella lezione precedente, dedicata alla costruzione di un modello lineare, aggiungendo i metodi per il training ed il testing.
I metodi della classe
Per la fase di training, dobbiamo inizializzare il modello ed inserire un listener per stampare una valutazione del comportamento. In genere per questa fase si impostano un certo numero di "epoche". Un'epoca è costituita dall'insieme di iterazioni necessarie al passaggio di tutte le istanze di dato presenti nel batch corrente.
Il metodo non fa altro che addestrare il modello sottoponendo batch di un certa dimensione per un certo numero di epoche. Il tipo di apprendimento è quindi supervisionato.
La fase di testing è simile, si recuperano batch di una certa dimensione, si sottopongono al modello e si calcola l'errore rispetto ai dati attesi. La valutazione è compiuta attraverso un oggetto RegressionEvaluation
.
public void training(MultiLayerNetwork model, DataSet[] dataset, int epochs) {
model.init();
model.setListeners(new ScoreIterationListener(5));
for (int i = 0; i < epochs; i++) {
System.out.println("Epoch:"+i);
for(DataSet data: dataset[0].dataSetBatches(BATCH_SIZE)) {
model.fit(data);
}
}
}
public void testing(MultiLayerNetwork model, DataSet[] dataset, int epochs) {
RegressionEvaluation eval = new RegressionEvaluation(1);
for (int i = 0; i < epochs; i++) {
for(DataSet data: dataset[1].dataSetBatches(BATCH_SIZE)) {
INDArray output = model.output(data.getFeatures());
eval.eval(data.getLabels(), output);
System.out.println(eval.stats());
}
}
}
Prova finale
Introduciamo la classe RegressionTest
come classe di prova:
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.DataSet;
public class RegressionTest {
public static void main(String[] args){
Regression logistic = new Regression(32,8, 0.70, 12345);
DataSet[] dataset = logistic.readingData("Concrete_Data.csv");
MultiLayerNetwork model = logistic.getModel(0.001);
logistic.training(model, dataset, 7);
logistic.testing(model, dataset, 7);
}
}
ed eseguiamo l'applicazione. Il risultato che si dovrebbe ottenere se tutti i passaggi sono stati eseguiti correttamente viene mostrato nelle immagini seguenti:
Come è possibile notare, l'errore scende fino ad un valore di circa 0.3, quest'ultimo viene generato approssimativamente anche nella fase di testing, non abbiamo overfitting e nel contempo otteniamo un buon modello di regressione.