/*
 * Decompiled with CFR 0.152.
 */
package aima.learning.reinforcement;

import aima.probability.decision.MDPPolicy;
import aima.util.Pair;
import aima.util.Util;
import java.util.ArrayList;
import java.util.Hashtable;
import java.util.List;
import java.util.Set;

public class QTable<STATE_TYPE, ACTION_TYPE> {
    Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double> table = new Hashtable();
    private List<ACTION_TYPE> allPossibleActions;

    public QTable(List<ACTION_TYPE> allPossibleActions) {
        this.allPossibleActions = allPossibleActions;
    }

    public Double getQValue(STATE_TYPE state, ACTION_TYPE action) {
        Pair<STATE_TYPE, ACTION_TYPE> stateActionPair = new Pair<STATE_TYPE, ACTION_TYPE>(state, action);
        if (!this.table.keySet().contains(stateActionPair)) {
            return 0.0;
        }
        return this.table.get(stateActionPair);
    }

    public Pair<ACTION_TYPE, Double> maxDiff(STATE_TYPE startState, ACTION_TYPE action, STATE_TYPE endState) {
        Double maxDiff = 0.0;
        Object maxAction = null;
        maxAction = Util.selectRandomlyFromList(this.allPossibleActions);
        maxDiff = this.getQValue(endState, maxAction) - this.getQValue(startState, action);
        for (ACTION_TYPE anAction : this.allPossibleActions) {
            Double diff = this.getQValue(endState, anAction) - this.getQValue(startState, action);
            if (!(diff > maxDiff)) continue;
            maxAction = anAction;
            maxDiff = diff;
        }
        return new Pair<Object, Double>(maxAction, maxDiff);
    }

    public void setQValue(STATE_TYPE state, ACTION_TYPE action, Double d) {
        Pair<STATE_TYPE, ACTION_TYPE> stateActionPair = new Pair<STATE_TYPE, ACTION_TYPE>(state, action);
        this.table.put(stateActionPair, d);
    }

    public ACTION_TYPE upDateQ(STATE_TYPE startState, ACTION_TYPE action, STATE_TYPE endState, double alpha, double reward, double phi) {
        double oldQValue = this.getQValue(startState, action);
        Pair<ACTION_TYPE, Double> actionAndMaxDiffValue = this.maxDiff(startState, action, endState);
        double addedValue = alpha * (reward + phi * actionAndMaxDiffValue.getSecond());
        this.setQValue(startState, action, oldQValue + addedValue);
        return actionAndMaxDiffValue.getFirst();
    }

    public void normalize() {
        Double maxValue = this.findMaximumValue();
        if (maxValue != 0.0) {
            for (Pair<STATE_TYPE, ACTION_TYPE> key : this.table.keySet()) {
                Double presentValue = this.table.get(key);
                this.table.put(key, presentValue / maxValue);
            }
        }
    }

    private Double findMaximumValue() {
        Set<Pair<STATE_TYPE, ACTION_TYPE>> keys = this.table.keySet();
        if (keys.size() > 0) {
            Double maxValue = this.table.get(keys.toArray()[0]);
            for (Pair<STATE_TYPE, ACTION_TYPE> key : keys) {
                Double v = this.table.get(key);
                if (!(v > maxValue)) continue;
                maxValue = v;
            }
            return maxValue;
        }
        return 0.0;
    }

    public MDPPolicy<STATE_TYPE, ACTION_TYPE> getPolicy() {
        MDPPolicy<STATE_TYPE, ACTION_TYPE> policy = new MDPPolicy<STATE_TYPE, ACTION_TYPE>();
        List<STATE_TYPE> startingStatesRecorded = this.getAllStartingStates();
        for (STATE_TYPE state : startingStatesRecorded) {
            ACTION_TYPE action = this.getRecordedActionWithMaximumQValue(state);
            policy.setAction(state, action);
        }
        return policy;
    }

    private ACTION_TYPE getRecordedActionWithMaximumQValue(STATE_TYPE state) {
        Double maxValue = Double.NEGATIVE_INFINITY;
        ACTION_TYPE action = null;
        for (Pair<STATE_TYPE, ACTION_TYPE> stateActionPair : this.table.keySet()) {
            if (!stateActionPair.getFirst().equals(state)) continue;
            ACTION_TYPE ac = stateActionPair.getSecond();
            Double value = this.table.get(stateActionPair);
            if (!(value > maxValue)) continue;
            maxValue = value;
            action = ac;
        }
        return action;
    }

    private List<STATE_TYPE> getAllStartingStates() {
        ArrayList<STATE_TYPE> states = new ArrayList<STATE_TYPE>();
        for (Pair<STATE_TYPE, ACTION_TYPE> stateActionPair : this.table.keySet()) {
            STATE_TYPE state = stateActionPair.getFirst();
            if (states.contains(state)) continue;
            states.add(state);
        }
        return states;
    }

    public String toString() {
        return this.table.toString();
    }
}

