package supplementary.experiments.feature_trees;

import decision_trees.logits.ExactLogitTreeLearner;
import features.feature_sets.BaseFeatureSet;
import function_approx.LinearFunction;
import game.types.play.RoleType;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.List;
import main.CommandLineArgParse;
import main.StringRoutines;
import metadata.ai.features.trees.FeatureTrees;
import metadata.ai.features.trees.logits.LogitTree;
import org.apache.batik.constants.XMLConstants;
import policies.softmax.SoftmaxPolicyLinear;
import search.mcts.MCTS;
import utils.AIFactory;

/* loaded from: input_file:supplementary/experiments/feature_trees/GenerateExactFeatureTree.class */
public class GenerateExactFeatureTree {
    protected List<String> featureWeightsFilepaths;
    protected File outFile;
    protected boolean boosted;
    protected String treeType;

    private GenerateExactFeatureTree() {
    }

    public void run() {
        StringBuilder sb = new StringBuilder();
        sb.append("playout=softmax");
        for (int i = 1; i <= this.featureWeightsFilepaths.size(); i++) {
            sb.append(",policyweights" + i + XMLConstants.XML_EQUAL_SIGN + this.featureWeightsFilepaths.get(i - 1));
        }
        if (this.boosted) {
            sb.append(",boosted=true");
        }
        SoftmaxPolicyLinear softmaxPolicyLinear = (SoftmaxPolicyLinear) ((MCTS) AIFactory.createAI(StringRoutines.join(XMLConstants.XML_CHAR_REF_SUFFIX, "algorithm=MCTS", "selection=noisyag0selection", sb.toString(), "final_move=robustchild", "tree_reuse=true", "learned_selection_policy=playout", "friendly_name=BiasedMCTS"))).playoutStrategy();
        BaseFeatureSet[] featureSets = softmaxPolicyLinear.featureSets();
        LinearFunction[] linearFunctions = softmaxPolicyLinear.linearFunctions();
        LogitTree[] logitTreeArr = new LogitTree[featureSets.length - 1];
        for (int i2 = 1; i2 < featureSets.length; i2++) {
            logitTreeArr[i2 - 1] = new LogitTree(RoleType.roleForPlayerId(i2), (this.treeType.equals("Exact") ? ExactLogitTreeLearner.buildTree(featureSets[i2], linearFunctions[i2], 10) : this.treeType.equals("NaiveMaxAbs") ? ExactLogitTreeLearner.buildTreeNaiveMaxAbs(featureSets[i2], linearFunctions[i2], 10) : null).toMetadataNode());
        }
        try {
            PrintWriter printWriter = new PrintWriter(this.outFile);
            Throwable th = null;
            try {
                try {
                    printWriter.println(new FeatureTrees(logitTreeArr, null));
                    if (printWriter != null) {
                        if (0 != 0) {
                            try {
                                printWriter.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            printWriter.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] strArr) {
        CommandLineArgParse commandLineArgParse = new CommandLineArgParse(true, "Write features to a file.");
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--feature-weights-filepaths").help("Filepaths for trained feature weights.").withNumVals("+").withType(CommandLineArgParse.OptionTypes.String));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--out-file").help("Filepath to write to.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).setRequired());
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--boosted").help("Indicates that the policy weight files are expected to be boosted.").withType(CommandLineArgParse.OptionTypes.Boolean));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--tree-type").help("Type of tree to build.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).withDefault("Exact").withLegalVals("Exact", "NaiveMaxAbs"));
        if (commandLineArgParse.parseArguments(strArr)) {
            GenerateExactFeatureTree generateExactFeatureTree = new GenerateExactFeatureTree();
            generateExactFeatureTree.featureWeightsFilepaths = (List) commandLineArgParse.getValue("--feature-weights-filepaths");
            generateExactFeatureTree.outFile = new File(commandLineArgParse.getValueString("--out-file"));
            generateExactFeatureTree.boosted = commandLineArgParse.getValueBool("--boosted");
            generateExactFeatureTree.treeType = commandLineArgParse.getValueString("--tree-type");
            generateExactFeatureTree.run();
        }
    }
}
