package policies.softmax;

import features.Feature;
import features.FeatureVector;
import features.WeightVector;
import features.aspatial.AspatialFeature;
import features.feature_sets.BaseFeatureSet;
import features.feature_sets.network.JITSPatterNetFeatureSet;
import features.spatial.SpatialFeature;
import function_approx.BoostedLinearFunction;
import function_approx.LinearFunction;
import game.Game;
import game.rules.play.moves.Moves;
import game.types.play.RoleType;
import gnu.trove.list.array.TFloatArrayList;
import gnu.trove.list.array.TIntArrayList;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import main.collections.FVector;
import main.collections.FastArrayList;
import metadata.ai.features.FeatureSet;
import metadata.ai.features.Features;
import org.apache.batik.constants.XMLConstants;
import org.apache.batik.util.SVGConstants;
import other.context.Context;
import other.move.Move;
import other.trial.Trial;
import playout_move_selectors.EpsilonGreedyWrapper;
import playout_move_selectors.FeaturesSoftmaxMoveSelector;
import search.mcts.MCTS;
import utils.ExperimentFileUtils;

/* loaded from: input_file:policies/softmax/SoftmaxPolicyLinear.class */
public class SoftmaxPolicyLinear extends SoftmaxPolicy {
    protected LinearFunction[] linearFunctions;
    protected BaseFeatureSet[] featureSets;
    protected int playoutActionLimit;
    protected int playoutTurnLimit;
    protected double epsilon;

    public SoftmaxPolicyLinear() {
        this.playoutActionLimit = -1;
        this.playoutTurnLimit = -1;
        this.epsilon = 0.0d;
        this.linearFunctions = null;
        this.featureSets = null;
    }

    public SoftmaxPolicyLinear(LinearFunction[] linearFunctionArr, BaseFeatureSet[] baseFeatureSetArr) {
        this.playoutActionLimit = -1;
        this.playoutTurnLimit = -1;
        this.epsilon = 0.0d;
        this.linearFunctions = linearFunctionArr;
        this.featureSets = (BaseFeatureSet[]) Arrays.copyOf(baseFeatureSetArr, baseFeatureSetArr.length);
    }

    public SoftmaxPolicyLinear(LinearFunction[] linearFunctionArr, BaseFeatureSet[] baseFeatureSetArr, int i) {
        this.playoutActionLimit = -1;
        this.playoutTurnLimit = -1;
        this.epsilon = 0.0d;
        this.linearFunctions = linearFunctionArr;
        this.featureSets = (BaseFeatureSet[]) Arrays.copyOf(baseFeatureSetArr, baseFeatureSetArr.length);
        this.playoutActionLimit = i;
    }

    public static SoftmaxPolicyLinear constructSelectionPolicy(Features features2, double d) {
        SoftmaxPolicyLinear softmaxPolicyLinear = new SoftmaxPolicyLinear();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (FeatureSet featureSet : features2.featureSets()) {
            if (featureSet.role() == RoleType.Shared || featureSet.role() == RoleType.Neutral) {
                addFeatureSetWeights(0, featureSet.featureStrings(), featureSet.selectionWeights(), arrayList, arrayList2);
            } else {
                addFeatureSetWeights(featureSet.role().owner(), featureSet.featureStrings(), featureSet.selectionWeights(), arrayList, arrayList2);
            }
        }
        softmaxPolicyLinear.featureSets = (BaseFeatureSet[]) arrayList.toArray(new BaseFeatureSet[arrayList.size()]);
        softmaxPolicyLinear.linearFunctions = (LinearFunction[]) arrayList2.toArray(new LinearFunction[arrayList2.size()]);
        softmaxPolicyLinear.epsilon = d;
        return softmaxPolicyLinear;
    }

    public static SoftmaxPolicyLinear constructPlayoutPolicy(Features features2, double d) {
        SoftmaxPolicyLinear softmaxPolicyLinear = new SoftmaxPolicyLinear();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (FeatureSet featureSet : features2.featureSets()) {
            if (featureSet.role() == RoleType.Shared || featureSet.role() == RoleType.Neutral) {
                addFeatureSetWeights(0, featureSet.featureStrings(), featureSet.playoutWeights(), arrayList, arrayList2);
            } else {
                addFeatureSetWeights(featureSet.role().owner(), featureSet.featureStrings(), featureSet.playoutWeights(), arrayList, arrayList2);
            }
        }
        softmaxPolicyLinear.featureSets = (BaseFeatureSet[]) arrayList.toArray(new BaseFeatureSet[arrayList.size()]);
        softmaxPolicyLinear.linearFunctions = (LinearFunction[]) arrayList2.toArray(new LinearFunction[arrayList2.size()]);
        softmaxPolicyLinear.epsilon = d;
        return softmaxPolicyLinear;
    }

    @Override // policies.Policy
    public FVector computeDistribution(Context context, FastArrayList<Move> fastArrayList, boolean z) {
        return computeDistribution((this.featureSets.length == 1 ? this.featureSets[0] : this.featureSets[context.state().mover()]).computeFeatureVectors(context, fastArrayList, z), context.state().mover());
    }

    @Override // policies.softmax.SoftmaxPolicy
    public float computeLogit(Context context, Move move) {
        return (this.linearFunctions.length == 1 ? this.linearFunctions[0] : this.linearFunctions[context.state().mover()]).predict((this.featureSets.length == 1 ? this.featureSets[0] : this.featureSets[context.state().mover()]).computeFeatureVector(context, move, true));
    }

    public FVector computeDistribution(FeatureVector[] featureVectorArr, int i) {
        float[] fArr = new float[featureVectorArr.length];
        LinearFunction linearFunction = this.linearFunctions.length == 1 ? this.linearFunctions[0] : this.linearFunctions[i];
        for (int i2 = 0; i2 < featureVectorArr.length; i2++) {
            fArr[i2] = linearFunction.predict(featureVectorArr[i2]);
        }
        FVector wrap = FVector.wrap(fArr);
        wrap.softmax();
        return wrap;
    }

    public FVector computeParamGradients(FVector fVector, FeatureVector[] featureVectorArr, int i) {
        FVector fVector2 = new FVector((this.linearFunctions.length == 1 ? this.linearFunctions[0] : this.linearFunctions[i]).trainableParams().allWeights().dim());
        int dim = fVector.dim();
        for (int i2 = 0; i2 < dim; i2++) {
            float f = fVector.get(i2);
            FeatureVector featureVector = featureVectorArr[i2];
            int dim2 = featureVector.aspatialFeatureValues().dim();
            for (int i3 = 0; i3 < dim2; i3++) {
                fVector2.addToEntry(i3, f * featureVector.aspatialFeatureValues().get(i3));
            }
            TIntArrayList activeSpatialFeatureIndices = featureVector.activeSpatialFeatureIndices();
            for (int i4 = 0; i4 < activeSpatialFeatureIndices.size(); i4++) {
                fVector2.addToEntry(activeSpatialFeatureIndices.getQuick(i4) + dim2, f);
            }
        }
        return fVector2;
    }

    public int selectActionFromDistribution(FVector fVector) {
        return fVector.sampleFromDistribution();
    }

    public void updateFeatureSets(BaseFeatureSet[] baseFeatureSetArr) {
        for (int i = 0; i < this.linearFunctions.length; i++) {
            if (baseFeatureSetArr[i] != null) {
                int numSpatialFeatures = baseFeatureSetArr[i].getNumSpatialFeatures() - this.featureSets[i].getNumSpatialFeatures();
                for (int i2 = 0; i2 < numSpatialFeatures; i2++) {
                    this.linearFunctions[i].setTheta(new WeightVector(this.linearFunctions[i].trainableParams().allWeights().append(0.0f)));
                }
                this.featureSets[i] = baseFeatureSetArr[i];
            } else if (baseFeatureSetArr[0] != null) {
                int numSpatialFeatures2 = baseFeatureSetArr[0].getNumSpatialFeatures() - this.featureSets[0].getNumSpatialFeatures();
                for (int i3 = 0; i3 < numSpatialFeatures2; i3++) {
                    this.linearFunctions[i].setTheta(new WeightVector(this.linearFunctions[i].trainableParams().allWeights().append(0.0f)));
                }
            }
        }
    }

    @Override // search.mcts.playout.PlayoutStrategy
    public Trial runPlayout(MCTS mcts, Context context) {
        WeightVector[] weightVectorArr = new WeightVector[this.linearFunctions.length];
        for (int i = 0; i < this.linearFunctions.length; i++) {
            if (this.linearFunctions[i] == null) {
                weightVectorArr[i] = null;
            } else {
                weightVectorArr[i] = this.linearFunctions[i].effectiveParams();
            }
        }
        return context.game().playout(context, null, 1.0d, this.epsilon < 1.0d ? this.epsilon <= 0.0d ? new FeaturesSoftmaxMoveSelector(this.featureSets, weightVectorArr, true) : new EpsilonGreedyWrapper(new FeaturesSoftmaxMoveSelector(this.featureSets, weightVectorArr, true), this.epsilon) : null, this.playoutActionLimit, this.playoutTurnLimit, ThreadLocalRandom.current());
    }

    @Override // search.mcts.playout.PlayoutStrategy
    public boolean playoutSupportsGame(Game game2) {
        return supportsGame(game2);
    }

    @Override // search.mcts.playout.PlayoutStrategy
    public int backpropFlags() {
        return 0;
    }

    @Override // search.mcts.playout.PlayoutStrategy
    public void customise(String[] strArr) {
        ArrayList arrayList = new ArrayList();
        boolean z = false;
        for (int i = 1; i < strArr.length; i++) {
            String str = strArr[i];
            if (str.toLowerCase().startsWith("policyweights=")) {
                if (arrayList.size() > 0) {
                    arrayList.clear();
                }
                arrayList.add(str.substring("policyweights=".length()));
            } else if (str.toLowerCase().startsWith("policyweights")) {
                for (int i2 = 1; i2 <= 16; i2++) {
                    if (str.toLowerCase().startsWith("policyweights" + i2 + XMLConstants.XML_EQUAL_SIGN)) {
                        while (arrayList.size() <= i2) {
                            arrayList.add(null);
                        }
                        if (i2 < 10) {
                            arrayList.set(i2, str.substring("policyweightsX=".length()));
                        } else {
                            arrayList.set(i2, str.substring("policyweightsXX=".length()));
                        }
                    }
                }
            } else if (str.toLowerCase().startsWith("playoutactionlimit=")) {
                this.playoutActionLimit = Integer.parseInt(str.substring("playoutactionlimit=".length()));
            } else if (str.toLowerCase().startsWith("playoutturnlimit=")) {
                this.playoutTurnLimit = Integer.parseInt(str.substring("playoutturnlimit=".length()));
            } else if (str.toLowerCase().startsWith("friendly_name=")) {
                this.friendlyName = str.substring("friendly_name=".length());
            } else if (str.toLowerCase().startsWith("boosted=")) {
                if (str.toLowerCase().endsWith(SVGConstants.SVG_TRUE_VALUE)) {
                    z = true;
                }
            } else if (str.toLowerCase().startsWith("epsilon=")) {
                this.epsilon = Double.parseDouble(str.substring("epsilon=".length()));
            }
        }
        if (arrayList.isEmpty()) {
            System.err.println("Cannot construct linear Softmax Policy from: " + Arrays.toString(strArr));
            return;
        }
        this.linearFunctions = new LinearFunction[arrayList.size()];
        this.featureSets = new BaseFeatureSet[this.linearFunctions.length];
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            String str2 = (String) arrayList.get(i3);
            if (str2 != null) {
                String parent = new File(str2).getParent();
                if (!new File(str2).exists()) {
                    str2 = str2.contains("Selection") ? ExperimentFileUtils.getLastFilepath(parent + "/PolicyWeightsSelection_P" + i3, "txt") : str2.contains("Playout") ? ExperimentFileUtils.getLastFilepath(parent + "/PolicyWeightsPlayout_P" + i3, "txt") : str2.contains("TSPG") ? ExperimentFileUtils.getLastFilepath(parent + "/PolicyWeightsTSPG_P" + i3, "txt") : null;
                }
                if (z) {
                    this.linearFunctions[i3] = BoostedLinearFunction.boostedFromFile(str2, null);
                } else {
                    this.linearFunctions[i3] = LinearFunction.fromFile(str2);
                }
                this.featureSets[i3] = JITSPatterNetFeatureSet.construct(parent + File.separator + this.linearFunctions[i3].featureSetFile());
            }
        }
    }

    @Override // other.AI
    public Move selectAction(Game game2, Context context, double d, int i, int i2) {
        Moves moves = game2.moves(context);
        return moves.moves().get(selectActionFromDistribution(computeDistribution((this.featureSets.length == 1 ? this.featureSets[0] : this.featureSets[context.state().mover()]).computeFeatureVectors(context, moves.moves(), true), context.state().mover())));
    }

    @Override // other.AI
    public void initAI(Game game2, int i) {
        if (this.featureSets.length != 1) {
            for (int i2 = 1; i2 < this.featureSets.length; i2++) {
                this.featureSets[i2].init(game2, new int[]{i2}, this.linearFunctions[i2].effectiveParams());
            }
            return;
        }
        int[] iArr = new int[game2.players().count()];
        for (int i3 = 0; i3 < iArr.length; i3++) {
            iArr[i3] = i3 + 1;
        }
        this.featureSets[0].init(game2, iArr, this.linearFunctions[0].effectiveParams());
    }

    @Override // other.AI
    public void closeAI() {
        if (this.featureSets == null) {
            return;
        }
        if (this.featureSets.length == 1) {
            this.featureSets[0].closeCache();
            return;
        }
        for (int i = 1; i < this.featureSets.length; i++) {
            this.featureSets[i].closeCache();
        }
    }

    public LinearFunction linearFunction(int i) {
        return this.linearFunctions.length == 1 ? this.linearFunctions[0] : this.linearFunctions[i];
    }

    public LinearFunction[] linearFunctions() {
        return this.linearFunctions;
    }

    public BaseFeatureSet[] featureSets() {
        return this.featureSets;
    }

    public static SoftmaxPolicyLinear fromLines(String[] strArr) {
        SoftmaxPolicyLinear softmaxPolicyLinear = null;
        int length = strArr.length;
        int i = 0;
        while (true) {
            if (i >= length) {
                break;
            }
            if (strArr[i].equalsIgnoreCase("features=from_metadata")) {
                softmaxPolicyLinear = new SoftmaxFromMetadataSelection(0.0d);
                break;
            }
            i++;
        }
        if (softmaxPolicyLinear == null) {
            softmaxPolicyLinear = new SoftmaxPolicyLinear();
        }
        softmaxPolicyLinear.customise(strArr);
        return softmaxPolicyLinear;
    }

    public static SoftmaxPolicyLinear fromFile(File file) {
        BufferedReader bufferedReader;
        Throwable th;
        SoftmaxPolicyLinear softmaxPolicyLinear = new SoftmaxPolicyLinear();
        try {
            bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(file.getAbsolutePath()), "UTF-8"));
            th = null;
        } catch (IOException e) {
            e.printStackTrace();
        }
        try {
            try {
                String str = null;
                for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
                    str = readLine;
                }
                r10 = str.startsWith("FeatureSet=") ? false : true;
                if (bufferedReader != null) {
                    if (0 != 0) {
                        try {
                            bufferedReader.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        bufferedReader.close();
                    }
                }
                softmaxPolicyLinear.customise(new String[]{"softmax", "policyweights=" + file.getAbsolutePath(), "boosted=" + r10});
                return softmaxPolicyLinear;
            } finally {
            }
        } finally {
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void addFeatureSetWeights(int i, String[] strArr, float[] fArr, List<BaseFeatureSet> list, List<LinearFunction> list2) {
        while (list.size() <= i) {
            list.add(null);
        }
        while (list2.size() <= i) {
            list2.add(null);
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        TFloatArrayList tFloatArrayList = new TFloatArrayList();
        for (int i2 = 0; i2 < strArr.length; i2++) {
            Feature fromString = Feature.fromString(strArr[i2]);
            if (fromString instanceof AspatialFeature) {
                arrayList.add((AspatialFeature) fromString);
            } else {
                arrayList2.add((SpatialFeature) fromString);
            }
            tFloatArrayList.add(fArr[i2]);
        }
        list.set(i, JITSPatterNetFeatureSet.construct(arrayList, arrayList2));
        list2.set(i, new LinearFunction(new WeightVector(new FVector(tFloatArrayList.toArray()))));
    }
}
