package search.mcts.selection;

import java.util.concurrent.ThreadLocalRandom;
import other.state.State;
import search.mcts.MCTS;
import search.mcts.nodes.BaseNode;

/* loaded from: input_file:search/mcts/selection/ProgressiveBias.class */
public final class ProgressiveBias implements SelectionStrategy {
    protected double explorationConstant;
    static final /* synthetic */ boolean $assertionsDisabled;

    public ProgressiveBias() {
        this(Math.sqrt(2.0d));
    }

    public ProgressiveBias(double d) {
        this.explorationConstant = d;
    }

    @Override // search.mcts.selection.SelectionStrategy
    public int select(MCTS mcts, BaseNode baseNode) {
        double exploitationScore;
        double sqrt;
        double d;
        if (!$assertionsDisabled && mcts.heuristics() == null) {
            throw new AssertionError();
        }
        int i = -1;
        double d2 = Double.NEGATIVE_INFINITY;
        int i2 = 0;
        double log = Math.log(Math.max(1, baseNode.sumLegalChildVisits()));
        int numLegalMoves = baseNode.numLegalMoves();
        State state = baseNode.contextRef().state();
        int playerToAgent = state.playerToAgent(state.mover());
        double valueEstimateUnvisitedChildren = baseNode.valueEstimateUnvisitedChildren(playerToAgent);
        for (int i3 = 0; i3 < numLegalMoves; i3++) {
            BaseNode childForNthLegalMove = baseNode.childForNthLegalMove(i3);
            if (childForNthLegalMove == null) {
                exploitationScore = valueEstimateUnvisitedChildren;
                sqrt = Math.sqrt(log);
                d = valueEstimateUnvisitedChildren;
            } else {
                exploitationScore = childForNthLegalMove.exploitationScore(playerToAgent);
                int numVisits = childForNthLegalMove.numVisits() + childForNthLegalMove.numVirtualVisits();
                sqrt = Math.sqrt(log / numVisits);
                d = (10.0d * childForNthLegalMove.heuristicValueEstimates()[playerToAgent]) / numVisits;
            }
            double d3 = exploitationScore + (this.explorationConstant * sqrt) + d;
            if (d3 > d2) {
                d2 = d3;
                i = i3;
                i2 = 1;
            } else if (d3 == d2) {
                i2++;
                if (ThreadLocalRandom.current().nextInt() % i2 == 0) {
                    i = i3;
                }
            }
        }
        return i;
    }

    @Override // search.mcts.selection.SelectionStrategy
    public int backpropFlags() {
        return 0;
    }

    @Override // search.mcts.selection.SelectionStrategy
    public int expansionFlags() {
        return 1;
    }

    @Override // search.mcts.selection.SelectionStrategy
    public void customise(String[] strArr) {
        if (strArr.length > 1) {
            for (int i = 1; i < strArr.length; i++) {
                String str = strArr[i];
                if (str.startsWith("explorationconstant=")) {
                    this.explorationConstant = Double.parseDouble(str.substring("explorationconstant=".length()));
                } else {
                    System.err.println("Progressive Bias ignores unknown customisation: " + str);
                }
            }
        }
    }

    static {
        $assertionsDisabled = !ProgressiveBias.class.desiredAssertionStatus();
    }
}
