/*
 * Decompiled with CFR 0.152.
 */
package aima.learning.neural;

import aima.learning.neural.FunctionApproximator;
import aima.learning.neural.HardLimitActivationFunction;
import aima.learning.neural.Layer;
import aima.learning.neural.NNDataSet;
import aima.learning.neural.NNExample;
import aima.learning.neural.Vector;
import aima.util.Matrix;

public class Perceptron
implements FunctionApproximator {
    private final Layer layer;
    private Vector lastInput;

    public Perceptron(int numberOfNeurons, int numberOfInputs) {
        this.layer = new Layer(numberOfNeurons, numberOfInputs, 2.0, -2.0, new HardLimitActivationFunction());
    }

    @Override
    public Vector processInput(Vector input) {
        this.lastInput = input;
        return this.layer.feedForward(input);
    }

    @Override
    public void processError(Vector error) {
        Matrix weightUpdate = error.times(this.lastInput.transpose());
        this.layer.acceptNewWeightUpdate(weightUpdate);
        Vector biasUpdate = this.layer.getBiasVector().plus(error);
        this.layer.acceptNewBiasUpdate(biasUpdate);
    }

    public void trainOn(NNDataSet innds, int numberofEpochs) {
        for (int i = 0; i < numberofEpochs; ++i) {
            innds.refreshDataset();
            while (innds.hasMoreExamples()) {
                NNExample nne = innds.getExampleAtRandom();
                this.processInput(nne.getInput());
                Vector error = this.layer.errorVectorFrom(nne.getTarget());
                this.processError(error);
            }
        }
    }

    public Vector predict(NNExample nne) {
        return this.processInput(nne.getInput());
    }

    public int[] testOnDataSet(NNDataSet nnds) {
        int[] result = new int[]{0, 0};
        nnds.refreshDataset();
        while (nnds.hasMoreExamples()) {
            Vector prediction;
            NNExample nne = nnds.getExampleAtRandom();
            if (nne.isCorrect(prediction = this.predict(nne))) {
                result[0] = result[0] + 1;
                continue;
            }
            result[1] = result[1] + 1;
        }
        return result;
    }
}

