/*
 * Decompiled with CFR 0.152.
 */
package aima.probability.decision;

import aima.probability.decision.MDPPolicy;
import aima.probability.decision.MDPTransition;
import aima.probability.decision.MDPUtilityFunction;
import aima.util.Pair;
import aima.util.Util;
import java.util.ArrayList;
import java.util.Hashtable;
import java.util.List;

public class MDPTransitionModel<STATE_TYPE, ACTION_TYPE> {
    private Hashtable<MDPTransition<STATE_TYPE, ACTION_TYPE>, Double> transitionToProbability = new Hashtable();
    private List<STATE_TYPE> terminalStates;

    public MDPTransitionModel(List<STATE_TYPE> terminalStates) {
        this.terminalStates = terminalStates;
    }

    public void setTransitionProbability(STATE_TYPE initialState, ACTION_TYPE action, STATE_TYPE finalState, double probability) {
        if (!this.isTerminal(initialState)) {
            MDPTransition<STATE_TYPE, ACTION_TYPE> t = new MDPTransition<STATE_TYPE, ACTION_TYPE>(initialState, action, finalState);
            this.transitionToProbability.put(t, probability);
        }
    }

    public double getTransitionProbability(STATE_TYPE initialState, ACTION_TYPE action, STATE_TYPE finalState) {
        MDPTransition<STATE_TYPE, ACTION_TYPE> key = new MDPTransition<STATE_TYPE, ACTION_TYPE>(initialState, action, finalState);
        if (this.transitionToProbability.keySet().contains(key)) {
            return this.transitionToProbability.get(key);
        }
        return 0.0;
    }

    public String toString() {
        StringBuffer buf = new StringBuffer();
        for (MDPTransition<STATE_TYPE, ACTION_TYPE> transition : this.transitionToProbability.keySet()) {
            buf.append(transition.toString() + " -> " + this.transitionToProbability.get(transition) + " \n");
        }
        return buf.toString();
    }

    public Pair<ACTION_TYPE, Double> getTransitionWithMaximumExpectedUtility(STATE_TYPE s, MDPUtilityFunction<STATE_TYPE> uf) {
        if (this.isTerminal(s)) {
            return new Pair<Object, Double>(null, 0.0);
        }
        List<MDPTransition<STATE_TYPE, ACTION_TYPE>> transitionsStartingWithS = this.getTransitionsStartingWith(s);
        Hashtable<ACTION_TYPE, Double> actionsToUtilities = this.getExpectedUtilityForSelectedTransitions(transitionsStartingWithS, uf);
        return this.getActionWithMaximumUtility(actionsToUtilities);
    }

    public Pair<ACTION_TYPE, Double> getTransitionWithMaximumExpectedUtilityUsingPolicy(MDPPolicy<STATE_TYPE, ACTION_TYPE> policy, STATE_TYPE s, MDPUtilityFunction<STATE_TYPE> uf) {
        if (this.isTerminal(s)) {
            return new Pair<Object, Double>(null, 0.0);
        }
        List<MDPTransition<STATE_TYPE, ACTION_TYPE>> transitionsWithStartingStateSAndActionFromPolicy = this.getTransitionsWithStartingStateAndAction(s, policy.getAction(s));
        Hashtable<ACTION_TYPE, Double> actionsToUtilities = this.getExpectedUtilityForSelectedTransitions(transitionsWithStartingStateSAndActionFromPolicy, uf);
        return this.getActionWithMaximumUtility(actionsToUtilities);
    }

    private boolean isTerminal(STATE_TYPE s) {
        return this.terminalStates.contains(s);
    }

    private Pair<ACTION_TYPE, Double> getActionWithMaximumUtility(Hashtable<ACTION_TYPE, Double> actionsToUtilities) {
        Pair<Object, Double> highest = new Pair<Object, Double>(null, (Double)Double.MIN_VALUE);
        for (ACTION_TYPE key : actionsToUtilities.keySet()) {
            Double value = actionsToUtilities.get(key);
            if (!(value > highest.getSecond())) continue;
            highest = new Pair<ACTION_TYPE, Double>(key, value);
        }
        return highest;
    }

    private Hashtable<ACTION_TYPE, Double> getExpectedUtilityForSelectedTransitions(List<MDPTransition<STATE_TYPE, ACTION_TYPE>> transitions, MDPUtilityFunction<STATE_TYPE> uf) {
        Hashtable<ACTION_TYPE, Double> actionsToUtilities = new Hashtable<ACTION_TYPE, Double>();
        for (MDPTransition<STATE_TYPE, ACTION_TYPE> triplet : transitions) {
            STATE_TYPE s = triplet.getInitialState();
            ACTION_TYPE action = triplet.getAction();
            STATE_TYPE destinationState = triplet.getDestinationState();
            double probabilityOfTransition = this.getTransitionProbability(s, action, destinationState);
            double expectedUtility = probabilityOfTransition * uf.getUtility(destinationState);
            Double presentValue = (Double)actionsToUtilities.get(action);
            if (presentValue == null) {
                actionsToUtilities.put(action, expectedUtility);
                continue;
            }
            actionsToUtilities.put(action, presentValue + expectedUtility);
        }
        return actionsToUtilities;
    }

    private List<MDPTransition<STATE_TYPE, ACTION_TYPE>> getTransitionsStartingWith(STATE_TYPE s) {
        ArrayList<MDPTransition<STATE_TYPE, ACTION_TYPE>> result = new ArrayList<MDPTransition<STATE_TYPE, ACTION_TYPE>>();
        for (MDPTransition<STATE_TYPE, ACTION_TYPE> transition : this.transitionToProbability.keySet()) {
            if (!transition.getInitialState().equals(s)) continue;
            result.add(transition);
        }
        return result;
    }

    public List<MDPTransition<STATE_TYPE, ACTION_TYPE>> getTransitionsWithStartingStateAndAction(STATE_TYPE s, ACTION_TYPE a) {
        ArrayList<MDPTransition<STATE_TYPE, ACTION_TYPE>> result = new ArrayList<MDPTransition<STATE_TYPE, ACTION_TYPE>>();
        for (MDPTransition<STATE_TYPE, ACTION_TYPE> transition : this.transitionToProbability.keySet()) {
            if (!transition.getInitialState().equals(s) || !transition.getAction().equals(a)) continue;
            result.add(transition);
        }
        return result;
    }

    public ACTION_TYPE randomActionFor(STATE_TYPE s) {
        List<MDPTransition<STATE_TYPE, ACTION_TYPE>> transitions = this.getTransitionsStartingWith(s);
        MDPTransition<STATE_TYPE, ACTION_TYPE> randomTransition = Util.selectRandomlyFromList(transitions);
        return transitions.get(0).getAction();
    }
}

