package supplementary.experiments.scripts;

import game.Game;
import gnu.trove.list.array.TIntArrayList;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.regex.Pattern;
import main.CommandLineArgParse;
import main.StringRoutines;
import main.UnixPrintWriter;
import main.collections.ArrayUtils;
import main.collections.ListUtils;
import org.apache.batik.constants.XMLConstants;
import org.apache.batik.svggen.SVGSyntax;
import org.apache.batik.util.SVGConstants;
import other.GameLoader;
import supplementary.experiments.analysis.RulesetConceptsUCT;
import utils.RulesetNames;

/* loaded from: input_file:supplementary/experiments/scripts/EvalTrainedFeaturesSnellius2.class */
public class EvalTrainedFeaturesSnellius2 {
    private static final int MAX_JOBS_PER_BATCH = 800;
    private static final String JVM_MEM = "3072";
    private static final int MEM_PER_PROCESS = 4;
    private static final int MEM_PER_NODE = 256;
    private static final int MAX_REQUEST_MEM = 234;
    private static final int NUM_TRIALS = 100;
    private static final int MAX_WALL_TIME = 2880;
    private static final int CORES_PER_NODE = 128;
    private static final int CORES_PER_PROCESS = 2;
    private static final int EXCLUSIVE_CORES_THRESHOLD = 96;
    private static final int EXCLUSIVE_PROCESSES_THRESHOLD = 48;
    private static final int PROCESSES_PER_JOB = 64;
    private static final String[] GAMES = {"Alquerque.lud", "Amazons.lud", "ArdRi.lud", "Arimaa.lud", "Ataxx.lud", "Bao Ki Arabu (Zanzibar 1).lud", "Bizingo.lud", "Breakthrough.lud", "Chess.lud", "English Draughts.lud", "Fanorona.lud", "Fox and Geese.lud", "Go.lud", "Gomoku.lud", "Gonnect.lud", "Havannah.lud", "Hex.lud", "Knightthrough.lud", "Konane.lud", "Lines of Action.lud", "Pentalath.lud", "Pretwa.lud", "Reversi.lud", "Royal Game of Ur.lud", "Surakarta.lud", "Shobu.lud", "Tablut.lud", "XII Scripta.lud", "Yavalath.lud"};
    private static final String[] VARIANTS = {"Baseline", "ReinforceGamma1", "ReinforceGamma099", "ReinforceGamma09", "SpecialMovesExpander", "SpecialMovesExpanderSplit", "SignCorrelationExpander", "RandomExpander"};

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:supplementary/experiments/scripts/EvalTrainedFeaturesSnellius2$ProcessData.class */
    public static class ProcessData {
        public final String gameName;
        public final int numPlayers;
        public final Object[] matchup;

        public ProcessData(String str, int i, Object[] objArr) {
            this.gameName = str;
            this.numPlayers = i;
            this.matchup = objArr;
        }
    }

    private EvalTrainedFeaturesSnellius2() {
    }

    private static void generateScripts(CommandLineArgParse commandLineArgParse) {
        UnixPrintWriter unixPrintWriter;
        Throwable th;
        ArrayList arrayList = new ArrayList();
        String replaceAll = commandLineArgParse.getValueString("--scripts-dir").replaceAll(Pattern.quote("\\"), "/");
        if (!replaceAll.endsWith("/")) {
            replaceAll = replaceAll + "/";
        }
        String valueString = commandLineArgParse.getValueString("--user-name");
        Game[] gameArr = new Game[GAMES.length];
        final double[] dArr = new double[GAMES.length];
        for (int i = 0; i < gameArr.length; i++) {
            Game loadGameFromName = GameLoader.loadGameFromName(GAMES[i]);
            if (loadGameFromName == null) {
                throw new IllegalArgumentException("Cannot load game: " + GAMES[i]);
            }
            gameArr[i] = loadGameFromName;
            dArr[i] = RulesetConceptsUCT.getValue(RulesetNames.gameRulesetName(loadGameFromName), "DurationMoves");
            System.out.println("expected duration per trial for " + GAMES[i] + " = " + dArr[i]);
        }
        List<Integer> sortedIndices = ArrayUtils.sortedIndices(GAMES.length, new Comparator<Integer>() { // from class: supplementary.experiments.scripts.EvalTrainedFeaturesSnellius2.1
            @Override // java.util.Comparator
            public int compare(Integer num, Integer num2) {
                double d = dArr[num2.intValue()] - dArr[num.intValue()];
                if (d < 0.0d) {
                    return -1;
                }
                return d > 0.0d ? 1 : 0;
            }
        });
        ArrayList arrayList2 = new ArrayList();
        int numCombinationsWithReplacement = ListUtils.numCombinationsWithReplacement(VARIANTS.length, 3);
        ArrayList arrayList3 = new ArrayList();
        Iterator<Integer> it = sortedIndices.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            Game game2 = gameArr[intValue];
            String str = GAMES[intValue];
            int count = game2.players().count();
            while (arrayList2.size() <= count) {
                arrayList2.add(null);
            }
            if (arrayList2.get(count) == null) {
                arrayList2.set(count, ListUtils.generateCombinationsWithReplacement(VARIANTS, count));
            }
            if (((Object[][]) arrayList2.get(count)).length > numCombinationsWithReplacement) {
                TIntArrayList tIntArrayList = new TIntArrayList(((Object[][]) arrayList2.get(count)).length);
                for (int i2 = 0; i2 < ((Object[][]) arrayList2.get(count)).length; i2++) {
                    tIntArrayList.add(i2);
                }
                while (tIntArrayList.size() > numCombinationsWithReplacement) {
                    ListUtils.removeSwap(tIntArrayList, ThreadLocalRandom.current().nextInt(tIntArrayList.size()));
                }
                Object[][] objArr = new Object[numCombinationsWithReplacement][count];
                for (int i3 = 0; i3 < objArr.length; i3++) {
                    objArr[i3] = ((Object[][]) arrayList2.get(count))[tIntArrayList.getQuick(i3)];
                }
                arrayList2.set(count, objArr);
            }
            for (int i4 = 0; i4 < ((Object[][]) arrayList2.get(count)).length; i4++) {
                arrayList3.add(new ProcessData(str, count, ((Object[][]) arrayList2.get(count))[i4]));
            }
        }
        long j = 0;
        int i5 = 0;
        while (i5 < arrayList3.size()) {
            String str2 = "EvalFeatures_" + arrayList.size() + ".sh";
            try {
                unixPrintWriter = new UnixPrintWriter(new File(replaceAll + str2), "UTF-8");
                th = null;
            } catch (FileNotFoundException | UnsupportedEncodingException e) {
                e.printStackTrace();
            }
            try {
                try {
                    unixPrintWriter.println("#!/bin/bash");
                    unixPrintWriter.println("#SBATCH -J EvalFeatures");
                    unixPrintWriter.println("#SBATCH -p thin");
                    unixPrintWriter.println("#SBATCH -o /home/" + valueString + "/EvalFeaturesSnellius/Out/Out_%J.out");
                    unixPrintWriter.println("#SBATCH -e /home/" + valueString + "/EvalFeaturesSnellius/Out/Err_%J.err");
                    unixPrintWriter.println("#SBATCH -t 2880");
                    unixPrintWriter.println("#SBATCH -N 1");
                    int min = Math.min(arrayList3.size() - i5, 64);
                    boolean z = min > 48;
                    int min2 = z ? Math.min(256, MAX_REQUEST_MEM) : Math.min(min * 4, MAX_REQUEST_MEM);
                    unixPrintWriter.println("#SBATCH --cpus-per-task=" + (min * 2));
                    unixPrintWriter.println("#SBATCH --mem=" + min2 + SVGConstants.SVG_G_VALUE);
                    j += 6144;
                    if (z) {
                        unixPrintWriter.println("#SBATCH --exclusive");
                    } else {
                        unixPrintWriter.println("#SBATCH --exclusive");
                    }
                    unixPrintWriter.println("module load 2021");
                    unixPrintWriter.println("module load Java/11.0.2");
                    for (int i6 = 0; i6 < 64 && i5 < arrayList3.size(); i6++) {
                        ProcessData processData = (ProcessData) arrayList3.get(i5);
                        ArrayList arrayList4 = new ArrayList();
                        for (Object obj : processData.matchup) {
                            ArrayList arrayList5 = new ArrayList();
                            arrayList5.add("playout=softmax");
                            for (int i7 = 1; i7 <= processData.numPlayers; i7++) {
                                arrayList5.add("policyweights" + i7 + "=/home/" + valueString + "/TrainFeaturesSnellius/Out/" + StringRoutines.cleanGameName(processData.gameName.replaceAll(Pattern.quote(".lud"), "")) + "_" + ((String) obj) + "/PolicyWeightsPlayout_P" + i7 + "_00201.txt");
                            }
                            ArrayList arrayList6 = new ArrayList();
                            arrayList6.add("learned_selection_policy=softmax");
                            for (int i8 = 1; i8 <= processData.numPlayers; i8++) {
                                arrayList6.add("policyweights" + i8 + "=/home/" + valueString + "/TrainFeaturesSnellius/Out/" + StringRoutines.cleanGameName(processData.gameName.replaceAll(Pattern.quote(".lud"), "")) + "_" + ((String) obj) + "/PolicyWeightsSelection_P" + i8 + "_00201.txt");
                            }
                            arrayList4.add(StringRoutines.quote(StringRoutines.join(XMLConstants.XML_CHAR_REF_SUFFIX, "algorithm=MCTS", "selection=noisyag0selection", StringRoutines.join(SVGSyntax.COMMA, arrayList5), "tree_reuse=true", "use_score_bounds=true", "num_threads=2", "final_move=robustchild", StringRoutines.join(SVGSyntax.COMMA, arrayList6), "friendly_name=" + ((String) obj))));
                        }
                        unixPrintWriter.println(StringRoutines.join(" ", "taskset", "-c", StringRoutines.join(SVGSyntax.COMMA, String.valueOf(i6 * 2), String.valueOf((i6 * 2) + 1)), "java", "-Xms3072M", "-Xmx3072M", "-XX:+HeapDumpOnOutOfMemoryError", "-da", "-dsa", "-XX:+UseStringDeduplication", "-jar", StringRoutines.quote("/home/" + valueString + "/EvalFeaturesSnellius/Ludii.jar"), "--eval-agents", "--game", StringRoutines.quote("/" + processData.gameName), "-n 100", "--thinking-time 1", "--agents", StringRoutines.join(" ", arrayList4), "--out-dir", StringRoutines.quote("/home/" + valueString + "/EvalFeaturesSnellius/Out/" + StringRoutines.cleanGameName(processData.gameName.replaceAll(Pattern.quote(".lud"), "")) + "/" + StringRoutines.join("_", processData.matchup)), "--output-summary", "--output-alpha-rank-data", "--max-wall-time", String.valueOf(MAX_WALL_TIME), XMLConstants.XML_CLOSE_TAG_END, "/home/" + valueString + "/EvalFeaturesSnellius/Out/Out_${SLURM_JOB_ID}_" + i6 + ".out", "&"));
                        i5++;
                    }
                    unixPrintWriter.println("wait");
                    arrayList.add(str2);
                    if (unixPrintWriter != null) {
                        if (0 != 0) {
                            try {
                                unixPrintWriter.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            unixPrintWriter.close();
                        }
                    }
                } catch (Throwable th3) {
                    if (unixPrintWriter != null) {
                        if (th != null) {
                            try {
                                unixPrintWriter.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            unixPrintWriter.close();
                        }
                    }
                    throw th3;
                    break;
                }
            } catch (Throwable th5) {
                th = th5;
                throw th5;
                break;
            }
        }
        System.out.println("Total requested core hours = " + j);
        ArrayList arrayList7 = new ArrayList();
        List list = arrayList;
        while (true) {
            List list2 = list;
            if (list2.size() <= 0) {
                break;
            }
            if (list2.size() > MAX_JOBS_PER_BATCH) {
                ArrayList arrayList8 = new ArrayList();
                for (int i9 = 0; i9 < MAX_JOBS_PER_BATCH; i9++) {
                    arrayList8.add(list2.get(i9));
                }
                arrayList7.add(arrayList8);
                list = list2.subList(MAX_JOBS_PER_BATCH, list2.size());
            } else {
                arrayList7.add(list2);
                list = new ArrayList();
            }
        }
        for (int i10 = 0; i10 < arrayList7.size(); i10++) {
            try {
                UnixPrintWriter unixPrintWriter2 = new UnixPrintWriter(new File(replaceAll + "SubmitJobs_Part" + i10 + ".sh"), "UTF-8");
                Throwable th6 = null;
                try {
                    try {
                        Iterator it2 = ((List) arrayList7.get(i10)).iterator();
                        while (it2.hasNext()) {
                            unixPrintWriter2.println("sbatch " + ((String) it2.next()));
                        }
                        if (unixPrintWriter2 != null) {
                            if (0 != 0) {
                                try {
                                    unixPrintWriter2.close();
                                } catch (Throwable th7) {
                                    th6.addSuppressed(th7);
                                }
                            } else {
                                unixPrintWriter2.close();
                            }
                        }
                    } catch (Throwable th8) {
                        th6 = th8;
                        throw th8;
                        break;
                    }
                } catch (Throwable th9) {
                    if (unixPrintWriter2 != null) {
                        if (th6 != null) {
                            try {
                                unixPrintWriter2.close();
                            } catch (Throwable th10) {
                                th6.addSuppressed(th10);
                            }
                        } else {
                            unixPrintWriter2.close();
                        }
                    }
                    throw th9;
                    break;
                }
            } catch (FileNotFoundException | UnsupportedEncodingException e2) {
                e2.printStackTrace();
            }
        }
    }

    public static void main(String[] strArr) {
        CommandLineArgParse commandLineArgParse = new CommandLineArgParse(true, "Creating eval job scripts.");
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--user-name").help("Username on the cluster.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).setRequired());
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--scripts-dir").help("Directory in which to store generated scripts.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).setRequired());
        if (commandLineArgParse.parseArguments(strArr)) {
            generateScripts(commandLineArgParse);
        }
    }
}
