package policies;

import compiler.Compiler;
import decision_trees.classifiers.DecisionTreeNode;
import features.Feature;
import features.FeatureVector;
import features.aspatial.AspatialFeature;
import features.feature_sets.BaseFeatureSet;
import features.feature_sets.network.JITSPatterNetFeatureSet;
import features.spatial.SpatialFeature;
import game.Game;
import game.rules.play.moves.Moves;
import game.types.play.RoleType;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import main.FileHandling;
import main.collections.FVector;
import main.collections.FastArrayList;
import main.grammar.Report;
import metadata.ai.features.trees.FeatureTrees;
import metadata.ai.features.trees.classifiers.DecisionTree;
import other.context.Context;
import other.move.Move;
import other.trial.Trial;
import playout_move_selectors.DecisionTreeMoveSelector;
import playout_move_selectors.EpsilonGreedyWrapper;
import search.mcts.MCTS;

/* loaded from: input_file:policies/ProportionalPolicyClassificationTree.class */
public class ProportionalPolicyClassificationTree extends Policy {
    protected DecisionTreeNode[] decisionTreeRoots;
    protected BaseFeatureSet[] featureSets;
    protected int playoutActionLimit;
    protected int playoutTurnLimit;
    protected double epsilon;
    protected boolean greedy;

    public ProportionalPolicyClassificationTree() {
        this.playoutActionLimit = -1;
        this.playoutTurnLimit = -1;
        this.epsilon = 0.0d;
        this.greedy = false;
        this.decisionTreeRoots = null;
        this.featureSets = null;
    }

    public ProportionalPolicyClassificationTree(DecisionTreeNode[] decisionTreeNodeArr, BaseFeatureSet[] baseFeatureSetArr) {
        this.playoutActionLimit = -1;
        this.playoutTurnLimit = -1;
        this.epsilon = 0.0d;
        this.greedy = false;
        this.decisionTreeRoots = decisionTreeNodeArr;
        this.featureSets = (BaseFeatureSet[]) Arrays.copyOf(baseFeatureSetArr, baseFeatureSetArr.length);
    }

    public ProportionalPolicyClassificationTree(DecisionTreeNode[] decisionTreeNodeArr, BaseFeatureSet[] baseFeatureSetArr, int i) {
        this.playoutActionLimit = -1;
        this.playoutTurnLimit = -1;
        this.epsilon = 0.0d;
        this.greedy = false;
        this.decisionTreeRoots = decisionTreeNodeArr;
        this.featureSets = (BaseFeatureSet[]) Arrays.copyOf(baseFeatureSetArr, baseFeatureSetArr.length);
        this.playoutActionLimit = i;
    }

    public static ProportionalPolicyClassificationTree constructPolicy(FeatureTrees featureTrees, double d) {
        ProportionalPolicyClassificationTree proportionalPolicyClassificationTree = new ProportionalPolicyClassificationTree();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (DecisionTree decisionTree : featureTrees.decisionTrees()) {
            if (decisionTree.role() == RoleType.Shared || decisionTree.role() == RoleType.Neutral) {
                addFeatureSetRoot(0, decisionTree.root(), arrayList, arrayList2);
            } else {
                addFeatureSetRoot(decisionTree.role().owner(), decisionTree.root(), arrayList, arrayList2);
            }
        }
        proportionalPolicyClassificationTree.featureSets = (BaseFeatureSet[]) arrayList.toArray(new BaseFeatureSet[arrayList.size()]);
        proportionalPolicyClassificationTree.decisionTreeRoots = (DecisionTreeNode[]) arrayList2.toArray(new DecisionTreeNode[arrayList2.size()]);
        proportionalPolicyClassificationTree.epsilon = d;
        return proportionalPolicyClassificationTree;
    }

    @Override // policies.Policy
    public FVector computeDistribution(Context context, FastArrayList<Move> fastArrayList, boolean z) {
        return computeDistribution((this.featureSets.length == 1 ? this.featureSets[0] : this.featureSets[context.state().mover()]).computeFeatureVectors(context, fastArrayList, z), context.state().mover());
    }

    public FVector computeDistribution(FeatureVector[] featureVectorArr, int i) {
        float[] fArr = new float[featureVectorArr.length];
        DecisionTreeNode decisionTreeNode = this.decisionTreeRoots.length == 1 ? this.decisionTreeRoots[0] : this.decisionTreeRoots[i];
        for (int i2 = 0; i2 < featureVectorArr.length; i2++) {
            fArr[i2] = decisionTreeNode.predict(featureVectorArr[i2]);
        }
        FVector wrap = FVector.wrap(fArr);
        wrap.normalise();
        return wrap;
    }

    @Override // search.mcts.playout.PlayoutStrategy
    public Trial runPlayout(MCTS mcts, Context context) {
        return context.game().playout(context, null, 1.0d, this.epsilon < 1.0d ? this.epsilon <= 0.0d ? new DecisionTreeMoveSelector(this.featureSets, this.decisionTreeRoots, this.greedy) : new EpsilonGreedyWrapper(new DecisionTreeMoveSelector(this.featureSets, this.decisionTreeRoots, this.greedy), this.epsilon) : null, this.playoutActionLimit, this.playoutTurnLimit, ThreadLocalRandom.current());
    }

    @Override // search.mcts.playout.PlayoutStrategy
    public boolean playoutSupportsGame(Game game2) {
        return supportsGame(game2);
    }

    @Override // search.mcts.playout.PlayoutStrategy
    public int backpropFlags() {
        return 0;
    }

    @Override // search.mcts.playout.PlayoutStrategy
    public void customise(String[] strArr) {
        String str = null;
        for (int i = 1; i < strArr.length; i++) {
            String str2 = strArr[i];
            if (str2.toLowerCase().startsWith("policytrees=")) {
                str = str2.substring("policytrees=".length());
            } else if (str2.toLowerCase().startsWith("playoutactionlimit=")) {
                this.playoutActionLimit = Integer.parseInt(str2.substring("playoutactionlimit=".length()));
            } else if (str2.toLowerCase().startsWith("playoutturnlimit=")) {
                this.playoutTurnLimit = Integer.parseInt(str2.substring("playoutturnlimit=".length()));
            } else if (str2.toLowerCase().startsWith("friendly_name=")) {
                this.friendlyName = str2.substring("friendly_name=".length());
            } else if (str2.toLowerCase().startsWith("epsilon=")) {
                this.epsilon = Double.parseDouble(str2.substring("epsilon=".length()));
            } else if (str2.toLowerCase().startsWith("greedy=")) {
                this.greedy = Boolean.parseBoolean(str2.substring("greedy=".length()));
            }
        }
        if (str == null) {
            System.err.println("Cannot construct Proportional Policy Classification Tree from: " + Arrays.toString(strArr));
            return;
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        try {
            for (DecisionTree decisionTree : ((FeatureTrees) Compiler.compileObject(FileHandling.loadTextContentsFromFile(str), "metadata.ai.features.trees.FeatureTrees", new Report())).decisionTrees()) {
                if (decisionTree.role() == RoleType.Shared || decisionTree.role() == RoleType.Neutral) {
                    addFeatureSetRoot(0, decisionTree.root(), arrayList, arrayList2);
                } else {
                    addFeatureSetRoot(decisionTree.role().owner(), decisionTree.root(), arrayList, arrayList2);
                }
            }
            this.featureSets = (BaseFeatureSet[]) arrayList.toArray(new BaseFeatureSet[arrayList.size()]);
            this.decisionTreeRoots = (DecisionTreeNode[]) arrayList2.toArray(new DecisionTreeNode[arrayList2.size()]);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override // other.AI
    public Move selectAction(Game game2, Context context, double d, int i, int i2) {
        Moves moves = game2.moves(context);
        FVector computeDistribution = computeDistribution((this.featureSets.length == 1 ? this.featureSets[0] : this.featureSets[context.state().mover()]).computeFeatureVectors(context, moves.moves(), true), context.state().mover());
        return this.greedy ? moves.moves().get(computeDistribution.argMaxRand()) : moves.moves().get(computeDistribution.sampleFromDistribution());
    }

    @Override // other.AI
    public void initAI(Game game2, int i) {
        if (this.featureSets.length != 1) {
            for (int i2 = 1; i2 < this.featureSets.length; i2++) {
                this.featureSets[i2].init(game2, new int[]{i2}, null);
            }
            return;
        }
        int[] iArr = new int[game2.players().count()];
        for (int i3 = 0; i3 < iArr.length; i3++) {
            iArr[i3] = i3 + 1;
        }
        this.featureSets[0].init(game2, iArr, null);
    }

    @Override // other.AI
    public void closeAI() {
        if (this.featureSets == null) {
            return;
        }
        if (this.featureSets.length == 1) {
            this.featureSets[0].closeCache();
            return;
        }
        for (int i = 1; i < this.featureSets.length; i++) {
            this.featureSets[i].closeCache();
        }
    }

    public BaseFeatureSet[] featureSets() {
        return this.featureSets;
    }

    public static ProportionalPolicyClassificationTree fromLines(String[] strArr) {
        ProportionalPolicyClassificationTree proportionalPolicyClassificationTree = new ProportionalPolicyClassificationTree();
        proportionalPolicyClassificationTree.customise(strArr);
        return proportionalPolicyClassificationTree;
    }

    protected static void addFeatureSetRoot(int i, metadata.ai.features.trees.classifiers.DecisionTreeNode decisionTreeNode, List<BaseFeatureSet> list, List<DecisionTreeNode> list2) {
        while (list.size() <= i) {
            list.add(null);
        }
        while (list2.size() <= i) {
            list2.add(null);
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        HashSet hashSet = new HashSet();
        decisionTreeNode.collectFeatureStrings(hashSet);
        Iterator<String> it = hashSet.iterator();
        while (it.hasNext()) {
            Feature fromString = Feature.fromString(it.next());
            if (fromString instanceof AspatialFeature) {
                arrayList.add((AspatialFeature) fromString);
            } else {
                arrayList2.add((SpatialFeature) fromString);
            }
        }
        JITSPatterNetFeatureSet construct = JITSPatterNetFeatureSet.construct(arrayList, arrayList2);
        list.set(i, construct);
        list2.set(i, DecisionTreeNode.fromMetadataNode(decisionTreeNode, construct));
    }
}
