/*
 * Decompiled with CFR 0.152.
 */
package aima.learning.reinforcement;

import aima.learning.reinforcement.MDPAgent;
import aima.learning.reinforcement.QTable;
import aima.probability.decision.MDP;
import aima.probability.decision.MDPPerception;
import aima.util.FrequencyCounter;
import aima.util.Pair;
import java.util.Hashtable;
import java.util.List;

public class QLearningAgent<STATE_TYPE, ACTION_TYPE>
extends MDPAgent<STATE_TYPE, ACTION_TYPE> {
    private Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double> Q = new Hashtable();
    private FrequencyCounter<Pair<STATE_TYPE, ACTION_TYPE>> stateActionCount;
    private Double previousReward;
    private QTable<STATE_TYPE, ACTION_TYPE> qTable;
    private int actionCounter;

    public QLearningAgent(MDP<STATE_TYPE, ACTION_TYPE> mdp) {
        super(mdp);
        this.qTable = new QTable(mdp.getAllActions());
        this.stateActionCount = new FrequencyCounter();
        this.actionCounter = 0;
    }

    @Override
    public ACTION_TYPE decideAction(MDPPerception<STATE_TYPE> perception) {
        this.currentState = perception.getState();
        this.currentReward = perception.getReward();
        if (this.startingTrial()) {
            ACTION_TYPE chosenAction = this.selectRandomAction();
            this.updateLearnerState(chosenAction);
            return (ACTION_TYPE)this.previousAction;
        }
        if (this.mdp.isTerminalState(this.currentState)) {
            this.incrementStateActionCount(this.previousState, this.previousAction);
            this.updateQ(0.8);
            this.previousAction = null;
            this.previousState = null;
            this.previousReward = null;
            return (ACTION_TYPE)this.previousAction;
        }
        this.incrementStateActionCount(this.previousState, this.previousAction);
        ACTION_TYPE chosenAction = this.updateQ(0.8);
        this.updateLearnerState(chosenAction);
        return (ACTION_TYPE)this.previousAction;
    }

    private void updateLearnerState(ACTION_TYPE chosenAction) {
        this.previousAction = chosenAction;
        this.previousAction = chosenAction;
        this.previousState = this.currentState;
        this.previousReward = this.currentReward;
    }

    private ACTION_TYPE updateQ(double gamma) {
        ++this.actionCounter;
        double alpha = this.calculateProbabilityOf(this.previousState, this.previousAction);
        Object ac = this.qTable.upDateQ(this.previousState, this.previousAction, this.currentState, alpha, this.currentReward, 0.8);
        return (ACTION_TYPE)ac;
    }

    private double calculateProbabilityOf(STATE_TYPE state, ACTION_TYPE action) {
        Double den = 0.0;
        Double num = 0.0;
        for (Pair<STATE_TYPE, ACTION_TYPE> stateActionPair : this.stateActionCount.getStates()) {
            if (!stateActionPair.getFirst().equals(state)) continue;
            den = den + 1.0;
            if (!stateActionPair.getSecond().equals(action)) continue;
            num = num + 1.0;
        }
        return num / den;
    }

    private ACTION_TYPE actionMaximizingLearningFunction() {
        ACTION_TYPE maxAct = null;
        Double maxValue = Double.NEGATIVE_INFINITY;
        for (Object action : this.mdp.getAllActions()) {
            Double qValue = this.qTable.getQValue(this.currentState, action);
            Double lfv = this.learningFunction(qValue);
            if (!(lfv > maxValue)) continue;
            maxValue = lfv;
            maxAct = action;
        }
        return maxAct;
    }

    private Double learningFunction(Double utility) {
        if (this.actionCounter > 3) {
            this.actionCounter = 0;
            return 1.0;
        }
        return utility;
    }

    private ACTION_TYPE selectRandomAction() {
        List allActions = this.mdp.getAllActions();
        return allActions.get(0);
    }

    private boolean startingTrial() {
        return this.previousAction == null && this.previousState == null && this.previousReward == null && this.currentState.equals(this.mdp.getInitialState());
    }

    private void incrementStateActionCount(STATE_TYPE state, ACTION_TYPE action) {
        Pair<STATE_TYPE, ACTION_TYPE> stateActionPair = new Pair<STATE_TYPE, ACTION_TYPE>(state, action);
        this.stateActionCount.incrementFor(stateActionPair);
    }

    public Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double> getQ() {
        return this.Q;
    }

    public QTable<STATE_TYPE, ACTION_TYPE> getQTable() {
        return this.qTable;
    }
}

