package supplementary.experiments.scripts;

import game.Game;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.regex.Pattern;
import main.CommandLineArgParse;
import main.StringRoutines;
import main.UnixPrintWriter;
import main.collections.ArrayUtils;
import org.apache.batik.constants.XMLConstants;
import org.apache.batik.svggen.SVGSyntax;
import org.apache.batik.util.SVGConstants;
import other.GameLoader;
import supplementary.experiments.analysis.RulesetConceptsUCT;
import utils.RulesetNames;

/* loaded from: input_file:supplementary/experiments/scripts/ExItTrainingScriptsGenSnellius2.class */
public class ExItTrainingScriptsGenSnellius2 {
    private static final int MAX_JOBS_PER_BATCH = 800;
    private static final String JVM_MEM = "5120";
    private static final int MEM_PER_PROCESS = 6;
    private static final int MEM_PER_NODE = 256;
    private static final int MAX_REQUEST_MEM = 234;
    private static final int MAX_SELFPLAY_TRIALS = 200;
    private static final int MAX_WALL_TIME = 2880;
    private static final int CORES_PER_NODE = 128;
    private static final int CORES_PER_PROCESS = 3;
    private static final int EXCLUSIVE_CORES_THRESHOLD = 96;
    private static final int EXCLUSIVE_PROCESSES_THRESHOLD = 32;
    private static final int PROCESSES_PER_JOB = 42;
    private static final String[] GAMES = {"Alquerque.lud", "Amazons.lud", "ArdRi.lud", "Arimaa.lud", "Ataxx.lud", "Bao Ki Arabu (Zanzibar 1).lud", "Bizingo.lud", "Breakthrough.lud", "Chess.lud", "English Draughts.lud", "Fanorona.lud", "Fox and Geese.lud", "Go.lud", "Gomoku.lud", "Gonnect.lud", "Havannah.lud", "Hex.lud", "Knightthrough.lud", "Konane.lud", "Lines of Action.lud", "Pentalath.lud", "Pretwa.lud", "Reversi.lud", "Royal Game of Ur.lud", "Surakarta.lud", "Shobu.lud", "Tablut.lud", "XII Scripta.lud", "Yavalath.lud"};
    private static final String[] VARIANTS = {"Baseline", "ReinforceGamma1", "ReinforceGamma099", "ReinforceGamma09", "SpecialMovesExpander", "SpecialMovesExpanderSplit", "SignCorrelationExpander", "RandomExpander"};

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:supplementary/experiments/scripts/ExItTrainingScriptsGenSnellius2$ProcessData.class */
    public static class ProcessData {
        public final String gameName;
        public final int numPlayers;
        public final String trainingVariant;

        public ProcessData(String str, int i, String str2) {
            this.gameName = str;
            this.numPlayers = i;
            this.trainingVariant = str2;
        }
    }

    private ExItTrainingScriptsGenSnellius2() {
    }

    private static void generateScripts(CommandLineArgParse commandLineArgParse) {
        UnixPrintWriter unixPrintWriter;
        Throwable th;
        UnixPrintWriter unixPrintWriter2;
        Throwable th2;
        ArrayList arrayList = new ArrayList();
        String replaceAll = commandLineArgParse.getValueString("--scripts-dir").replaceAll(Pattern.quote("\\"), "/");
        if (!replaceAll.endsWith("/")) {
            replaceAll = replaceAll + "/";
        }
        String valueString = commandLineArgParse.getValueString("--user-name");
        Game[] gameArr = new Game[GAMES.length];
        final double[] dArr = new double[GAMES.length];
        for (int i = 0; i < gameArr.length; i++) {
            Game loadGameFromName = GameLoader.loadGameFromName(GAMES[i]);
            if (loadGameFromName == null) {
                throw new IllegalArgumentException("Cannot load game: " + GAMES[i]);
            }
            gameArr[i] = loadGameFromName;
            dArr[i] = RulesetConceptsUCT.getValue(RulesetNames.gameRulesetName(loadGameFromName), "DurationMoves");
            System.out.println("expected duration per trial for " + GAMES[i] + " = " + dArr[i]);
        }
        List<Integer> sortedIndices = ArrayUtils.sortedIndices(GAMES.length, new Comparator<Integer>() { // from class: supplementary.experiments.scripts.ExItTrainingScriptsGenSnellius2.1
            @Override // java.util.Comparator
            public int compare(Integer num, Integer num2) {
                double d = dArr[num2.intValue()] - dArr[num.intValue()];
                if (d < 0.0d) {
                    return -1;
                }
                return d > 0.0d ? 1 : 0;
            }
        });
        ArrayList arrayList2 = new ArrayList();
        Iterator<Integer> it = sortedIndices.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            Game game2 = gameArr[intValue];
            String str = GAMES[intValue];
            for (String str2 : VARIANTS) {
                arrayList2.add(new ProcessData(str, game2.players().count(), str2));
            }
        }
        int i2 = 0;
        while (i2 < arrayList2.size()) {
            String str3 = "TrainFeatures_" + arrayList.size() + ".sh";
            try {
                unixPrintWriter2 = new UnixPrintWriter(new File(replaceAll + str3), "UTF-8");
                th2 = null;
            } catch (FileNotFoundException | UnsupportedEncodingException e) {
                e.printStackTrace();
            }
            try {
                try {
                    unixPrintWriter2.println("#!/bin/bash");
                    unixPrintWriter2.println("#SBATCH -J TrainFeatures");
                    unixPrintWriter2.println("#SBATCH -p thin");
                    unixPrintWriter2.println("#SBATCH -o /home/" + valueString + "/TrainFeaturesSnellius/Out/Out_%J.out");
                    unixPrintWriter2.println("#SBATCH -e /home/" + valueString + "/TrainFeaturesSnellius/Out/Err_%J.err");
                    unixPrintWriter2.println("#SBATCH -t 2880");
                    unixPrintWriter2.println("#SBATCH -N 1");
                    int min = Math.min(arrayList2.size() - i2, 42);
                    boolean z = min > 32;
                    int min2 = z ? Math.min(256, MAX_REQUEST_MEM) : Math.min(min * 6, MAX_REQUEST_MEM);
                    unixPrintWriter2.println("#SBATCH --cpus-per-task=" + (min * 3));
                    unixPrintWriter2.println("#SBATCH --mem=" + min2 + SVGConstants.SVG_G_VALUE);
                    if (z) {
                        unixPrintWriter2.println("#SBATCH --exclusive");
                    } else {
                        unixPrintWriter2.println("#SBATCH --exclusive");
                    }
                    unixPrintWriter2.println("module load 2021");
                    unixPrintWriter2.println("module load Java/11.0.2");
                    for (int i3 = 0; i3 < min; i3++) {
                        ProcessData processData = (ProcessData) arrayList2.get(i2);
                        String join = StringRoutines.join(" ", "taskset", "-c", StringRoutines.join(SVGSyntax.COMMA, String.valueOf(i3 * 3), String.valueOf((i3 * 3) + 1), String.valueOf((i3 * 3) + 2)), "java", "-Xms5120M", "-Xmx5120M", "-XX:+HeapDumpOnOutOfMemoryError", "-da", "-dsa", "-XX:+UseStringDeduplication", "-jar", StringRoutines.quote("/home/" + valueString + "/TrainFeaturesSnellius/Ludii.jar"), "--expert-iteration", "--game", StringRoutines.quote("/" + processData.gameName), "-n", String.valueOf(MAX_SELFPLAY_TRIALS), "--game-length-cap 1000", "--thinking-time 1", "--is-episode-durations", "--prioritized-experience-replay", "--wis", "--handle-aliasing", "--playout-features-epsilon 0.5", "--no-value-learning", "--train-tspg", "--checkpoint-freq 5", "--num-agent-threads", String.valueOf(3), "--num-policy-gradient-threads", String.valueOf(3), " --post-pg-weight-scalar 0.0", "--num-feature-discovery-threads", String.valueOf(Math.min(processData.numPlayers, 3)), "--out-dir", StringRoutines.quote("/home/" + valueString + "/TrainFeaturesSnellius/Out/" + StringRoutines.cleanGameName(processData.gameName.replaceAll(Pattern.quote(".lud"), "")) + "_" + processData.trainingVariant + "/"), "--no-logging", "--max-wall-time", String.valueOf(MAX_WALL_TIME));
                        if (processData.trainingVariant.contains("Reinforce")) {
                            join = join + " --num-policy-gradient-epochs 100";
                            if (processData.trainingVariant.equals("ReinforceGamma1")) {
                                join = join + " --pg-gamma 1";
                            } else if (processData.trainingVariant.equals("ReinforceGamma099")) {
                                join = join + " --pg-gamma 0.99";
                            } else if (processData.trainingVariant.equals("ReinforceGamma09")) {
                                join = join + " --pg-gamma 0.9";
                            }
                        }
                        if (processData.trainingVariant.equals("SpecialMovesExpander")) {
                            join = join + " --special-moves-expander";
                        } else if (processData.trainingVariant.equals("SpecialMovesExpanderSplit")) {
                            join = join + " --special-moves-expander-split";
                        } else if (processData.trainingVariant.equals("SignCorrelationExpander")) {
                            join = join + " --expander-type CorrelationErrorSignExpander";
                        } else if (processData.trainingVariant.equals("RandomExpander")) {
                            join = join + " --expander-type Random";
                        }
                        unixPrintWriter2.println(join + " " + StringRoutines.join(" ", XMLConstants.XML_CLOSE_TAG_END, "/home/" + valueString + "/TrainFeaturesSnellius/Out/Out_${SLURM_JOB_ID}_" + i3 + ".out", "&"));
                        i2++;
                    }
                    unixPrintWriter2.println("wait");
                    arrayList.add(str3);
                    if (unixPrintWriter2 != null) {
                        if (0 != 0) {
                            try {
                                unixPrintWriter2.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            unixPrintWriter2.close();
                        }
                    }
                } catch (Throwable th4) {
                    if (unixPrintWriter2 != null) {
                        if (th2 != null) {
                            try {
                                unixPrintWriter2.close();
                            } catch (Throwable th5) {
                                th2.addSuppressed(th5);
                            }
                        } else {
                            unixPrintWriter2.close();
                        }
                    }
                    throw th4;
                    break;
                }
            } catch (Throwable th6) {
                th2 = th6;
                throw th6;
                break;
            }
        }
        ArrayList arrayList3 = new ArrayList();
        List list = arrayList;
        while (true) {
            List list2 = list;
            if (list2.size() <= 0) {
                break;
            }
            if (list2.size() > MAX_JOBS_PER_BATCH) {
                ArrayList arrayList4 = new ArrayList();
                for (int i4 = 0; i4 < MAX_JOBS_PER_BATCH; i4++) {
                    arrayList4.add(list2.get(i4));
                }
                arrayList3.add(arrayList4);
                list = list2.subList(MAX_JOBS_PER_BATCH, list2.size());
            } else {
                arrayList3.add(list2);
                list = new ArrayList();
            }
        }
        for (int i5 = 0; i5 < arrayList3.size(); i5++) {
            try {
                unixPrintWriter = new UnixPrintWriter(new File(replaceAll + "SubmitJobs_Part" + i5 + ".sh"), "UTF-8");
                th = null;
            } catch (FileNotFoundException | UnsupportedEncodingException e2) {
                e2.printStackTrace();
            }
            try {
                try {
                    Iterator it2 = ((List) arrayList3.get(i5)).iterator();
                    while (it2.hasNext()) {
                        unixPrintWriter.println("sbatch " + ((String) it2.next()));
                    }
                    if (unixPrintWriter != null) {
                        if (0 != 0) {
                            try {
                                unixPrintWriter.close();
                            } catch (Throwable th7) {
                                th.addSuppressed(th7);
                            }
                        } else {
                            unixPrintWriter.close();
                        }
                    }
                } catch (Throwable th8) {
                    if (unixPrintWriter != null) {
                        if (th != null) {
                            try {
                                unixPrintWriter.close();
                            } catch (Throwable th9) {
                                th.addSuppressed(th9);
                            }
                        } else {
                            unixPrintWriter.close();
                        }
                    }
                    throw th8;
                    break;
                }
            } catch (Throwable th10) {
                th = th10;
                throw th10;
                break;
            }
        }
    }

    public static void main(String[] strArr) {
        CommandLineArgParse commandLineArgParse = new CommandLineArgParse(true, "Creating feature training job scripts for Snellius cluster.");
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--user-name").help("Username on the cluster.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).setRequired());
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--scripts-dir").help("Directory in which to store generated scripts.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).setRequired());
        if (commandLineArgParse.parseArguments(strArr)) {
            generateScripts(commandLineArgParse);
        }
    }
}
