package supplementary.experiments.feature_trees.small_games;

import decision_trees.classifiers.ExperienceImbalancedBinaryClassificationTree2Learner;
import features.feature_sets.BaseFeatureSet;
import function_approx.LinearFunction;
import game.Game;
import game.types.play.RoleType;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.regex.Pattern;
import main.StringRoutines;
import metadata.ai.features.trees.FeatureTrees;
import metadata.ai.features.trees.classifiers.DecisionTree;
import org.apache.batik.constants.XMLConstants;
import other.GameLoader;
import policies.softmax.SoftmaxPolicyLinear;
import search.mcts.MCTS;
import utils.AIFactory;
import utils.ExperimentFileUtils;
import utils.data_structures.experience_buffers.ExperienceBuffer;
import utils.data_structures.experience_buffers.PrioritizedReplayBuffer;
import utils.data_structures.experience_buffers.UniformExperienceBuffer;

/* loaded from: input_file:supplementary/experiments/feature_trees/small_games/TrainImbalancedBinaryDecisionTrees2SmallGames.class */
public class TrainImbalancedBinaryDecisionTrees2SmallGames {
    private static final String RESULTS_DIR = "D:/Downloads/results.tar/results/Out/";
    private static final String[] GAMES = {"Tic-Tac-Toe.lud", "Mu Torere.lud", "Mu Torere.lud", "Jeu Militaire.lud", "Pong Hau K'i.lud", "Akidada.lud", "Alquerque de Tres.lud", "Ho-Bag Gonu.lud", "Madelinette.lud", "Haretavl.lud", "Kaooa.lud", "Hat Diviyan Keliya.lud", "Three Men's Morris.lud"};
    private static final String[] RULESETS = {"", "Ruleset/Complete (Observed)", "Ruleset/Simple (Suggested)", "", "", "", "", "", "", "", "", "", ""};
    private static final String[] POLICY_WEIGHT_TYPES = {"Playout", "TSPG"};
    private static final boolean[] BOOSTED = {false, true};
    private static int[] TREE_DEPTHS = {1, 2, 3, 4, 5, 10};

    private TrainImbalancedBinaryDecisionTrees2SmallGames() {
    }

    public void run() {
        loop0: for (int i = 0; i < GAMES.length; i++) {
            Game loadGameFromName = GameLoader.loadGameFromName(GAMES[i], RULESETS[i]);
            if (loadGameFromName == null) {
                throw new IllegalArgumentException("Cannot load game: " + GAMES[i] + " " + RULESETS[i]);
            }
            String cleanGameName = StringRoutines.cleanGameName(GAMES[i].replaceAll(Pattern.quote(".lud"), ""));
            String replaceAll = StringRoutines.cleanRulesetName(RULESETS[i]).replaceAll(Pattern.quote("/"), "_");
            for (int i2 = 0; i2 < POLICY_WEIGHT_TYPES.length; i2++) {
                StringBuilder sb = new StringBuilder();
                sb.append("playout=softmax");
                for (int i3 = 1; i3 <= loadGameFromName.players().count(); i3++) {
                    sb.append(",policyweights" + i3 + XMLConstants.XML_EQUAL_SIGN + ExperimentFileUtils.getLastFilepath(RESULTS_DIR + cleanGameName + "_" + replaceAll + "/PolicyWeights" + POLICY_WEIGHT_TYPES[i2] + "_P" + i3, "txt"));
                }
                if (BOOSTED[i2]) {
                    sb.append(",boosted=true");
                }
                SoftmaxPolicyLinear softmaxPolicyLinear = (SoftmaxPolicyLinear) ((MCTS) AIFactory.createAI(StringRoutines.join(XMLConstants.XML_CHAR_REF_SUFFIX, "algorithm=MCTS", "selection=noisyag0selection", sb.toString(), "final_move=robustchild", "tree_reuse=true", "learned_selection_policy=playout", "friendly_name=BiasedMCTS"))).playoutStrategy();
                BaseFeatureSet[] featureSets = softmaxPolicyLinear.featureSets();
                LinearFunction[] linearFunctions = softmaxPolicyLinear.linearFunctions();
                softmaxPolicyLinear.initAI(loadGameFromName, -1);
                for (int i4 : TREE_DEPTHS) {
                    DecisionTree[] decisionTreeArr = new DecisionTree[featureSets.length - 1];
                    for (int i5 = 1; i5 < featureSets.length; i5++) {
                        String lastFilepath = ExperimentFileUtils.getLastFilepath(RESULTS_DIR + cleanGameName + "_" + replaceAll + "/ExperienceBuffer_P" + i5, "buf");
                        ExperienceBuffer experienceBuffer = null;
                        try {
                            experienceBuffer = PrioritizedReplayBuffer.fromFile(loadGameFromName, lastFilepath);
                        } catch (Exception e) {
                            if (experienceBuffer == null) {
                                try {
                                    experienceBuffer = UniformExperienceBuffer.fromFile(loadGameFromName, lastFilepath);
                                } catch (Exception e2) {
                                    e.printStackTrace();
                                    e2.printStackTrace();
                                }
                            }
                        }
                        decisionTreeArr[i5 - 1] = new DecisionTree(RoleType.roleForPlayerId(i5), ExperienceImbalancedBinaryClassificationTree2Learner.buildTree(featureSets[i5], linearFunctions[i5], experienceBuffer, i4, 5).toMetadataNode());
                    }
                    String str = "D:/Downloads/results.tar/results/Out/Trees/" + cleanGameName + "_" + replaceAll + "/ImbalancedBinaryClassificationTree2_" + POLICY_WEIGHT_TYPES[i2] + "_" + i4 + ".txt";
                    System.out.println("Writing Imbalanced Binary Classification (2) tree to: " + str);
                    new File(str).getParentFile().mkdirs();
                    try {
                        PrintWriter printWriter = new PrintWriter(str);
                        Throwable th = null;
                        try {
                            try {
                                printWriter.println(new FeatureTrees(null, decisionTreeArr));
                                if (printWriter != null) {
                                    if (0 != 0) {
                                        try {
                                            printWriter.close();
                                        } catch (Throwable th2) {
                                            th.addSuppressed(th2);
                                        }
                                    } else {
                                        printWriter.close();
                                    }
                                }
                            } catch (Throwable th3) {
                                th = th3;
                                throw th3;
                                break loop0;
                            }
                        } catch (Throwable th4) {
                            if (printWriter != null) {
                                if (th != null) {
                                    try {
                                        printWriter.close();
                                    } catch (Throwable th5) {
                                        th.addSuppressed(th5);
                                    }
                                } else {
                                    printWriter.close();
                                }
                            }
                            throw th4;
                            break loop0;
                        }
                    } catch (IOException e3) {
                        e3.printStackTrace();
                    }
                }
            }
        }
    }

    public static void main(String[] strArr) {
        new TrainImbalancedBinaryDecisionTrees2SmallGames().run();
    }
}
