/*
 * Decompiled with CFR 0.152.
 */
package aima.learning.reinforcement;

import aima.probability.decision.MDPPolicy;
import aima.util.Pair;
import aima.util.Util;
import java.util.ArrayList;
import java.util.Hashtable;
import java.util.List;
import java.util.Set;

public class QTable<STATE_TYPE, ACTION_TYPE> {
    Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double> table = new Hashtable();
    private List<ACTION_TYPE> allPossibleActions;

    public QTable(List<ACTION_TYPE> list) {
        this.allPossibleActions = list;
    }

    public Double getQValue(STATE_TYPE STATE_TYPE, ACTION_TYPE ACTION_TYPE) {
        Pair<STATE_TYPE, ACTION_TYPE> pair = new Pair<STATE_TYPE, ACTION_TYPE>(STATE_TYPE, ACTION_TYPE);
        if (!this.table.keySet().contains(pair)) {
            return 0.0;
        }
        return this.table.get(pair);
    }

    public Pair<ACTION_TYPE, Double> maxDiff(STATE_TYPE STATE_TYPE, ACTION_TYPE ACTION_TYPE, STATE_TYPE STATE_TYPE2) {
        Double d = 0.0;
        Object x = null;
        x = Util.selectRandomlyFromList(this.allPossibleActions);
        d = this.getQValue(STATE_TYPE2, x) - this.getQValue(STATE_TYPE, ACTION_TYPE);
        for (ACTION_TYPE ACTION_TYPE2 : this.allPossibleActions) {
            Double d2 = this.getQValue(STATE_TYPE2, ACTION_TYPE2) - this.getQValue(STATE_TYPE, ACTION_TYPE);
            if (!(d2 > d)) continue;
            x = ACTION_TYPE2;
            d = d2;
        }
        return new Pair<Object, Double>(x, d);
    }

    public void setQValue(STATE_TYPE STATE_TYPE, ACTION_TYPE ACTION_TYPE, Double d) {
        Pair<STATE_TYPE, ACTION_TYPE> pair = new Pair<STATE_TYPE, ACTION_TYPE>(STATE_TYPE, ACTION_TYPE);
        this.table.put(pair, d);
    }

    public ACTION_TYPE upDateQ(STATE_TYPE STATE_TYPE, ACTION_TYPE ACTION_TYPE, STATE_TYPE STATE_TYPE2, double d, double d2, double d3) {
        double d4 = this.getQValue(STATE_TYPE, ACTION_TYPE);
        Pair<ACTION_TYPE, Double> pair = this.maxDiff(STATE_TYPE, ACTION_TYPE, STATE_TYPE2);
        double d5 = d * (d2 + d3 * pair.getSecond());
        this.setQValue(STATE_TYPE, ACTION_TYPE, d4 + d5);
        return pair.getFirst();
    }

    public void normalize() {
        Double d = this.findMaximumValue();
        if (d != 0.0) {
            for (Pair<STATE_TYPE, ACTION_TYPE> pair : this.table.keySet()) {
                Double d2 = this.table.get(pair);
                this.table.put(pair, d2 / d);
            }
        }
    }

    private Double findMaximumValue() {
        Set<Pair<STATE_TYPE, ACTION_TYPE>> set = this.table.keySet();
        if (set.size() > 0) {
            Double d = this.table.get(set.toArray()[0]);
            for (Pair<STATE_TYPE, ACTION_TYPE> pair : set) {
                Double d2 = this.table.get(pair);
                if (!(d2 > d)) continue;
                d = d2;
            }
            return d;
        }
        return 0.0;
    }

    public MDPPolicy<STATE_TYPE, ACTION_TYPE> getPolicy() {
        MDPPolicy<STATE_TYPE, ACTION_TYPE> mDPPolicy = new MDPPolicy<STATE_TYPE, ACTION_TYPE>();
        List<STATE_TYPE> list = this.getAllStartingStates();
        for (STATE_TYPE STATE_TYPE : list) {
            ACTION_TYPE ACTION_TYPE = this.getRecordedActionWithMaximumQValue(STATE_TYPE);
            mDPPolicy.setAction(STATE_TYPE, ACTION_TYPE);
        }
        return mDPPolicy;
    }

    private ACTION_TYPE getRecordedActionWithMaximumQValue(STATE_TYPE STATE_TYPE) {
        Double d = Double.NEGATIVE_INFINITY;
        ACTION_TYPE ACTION_TYPE = null;
        for (Pair<STATE_TYPE, ACTION_TYPE> pair : this.table.keySet()) {
            if (!pair.getFirst().equals(STATE_TYPE)) continue;
            ACTION_TYPE ACTION_TYPE2 = pair.getSecond();
            Double d2 = this.table.get(pair);
            if (!(d2 > d)) continue;
            d = d2;
            ACTION_TYPE = ACTION_TYPE2;
        }
        return ACTION_TYPE;
    }

    private List<STATE_TYPE> getAllStartingStates() {
        ArrayList<STATE_TYPE> arrayList = new ArrayList<STATE_TYPE>();
        for (Pair<STATE_TYPE, ACTION_TYPE> pair : this.table.keySet()) {
            STATE_TYPE STATE_TYPE = pair.getFirst();
            if (arrayList.contains(STATE_TYPE)) continue;
            arrayList.add(STATE_TYPE);
        }
        return arrayList;
    }

    public String toString() {
        return this.table.toString();
    }
}

