package experiments.testUCThs;

import experiments.fastGameLengths.TrialRecord;
import game.Game;
import gnu.trove.list.array.TIntArrayList;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadLocalRandom;
import main.math.statistics.Stats;
import org.apache.batik.dom.events.DOMKeyEvent;
import other.AI;
import other.GameLoader;
import other.context.Context;
import other.model.Model;
import other.trial.Trial;
import search.mcts.MCTS;
import search.mcts.backpropagation.MonteCarloBackprop;
import search.mcts.finalmoveselection.RobustChild;
import search.mcts.playout.HeuristicPlayout;
import search.mcts.selection.UCB1;
import search.minimax.AlphaBetaSearch;

/* loaded from: input_file:experiments/testUCThs/TestUCThs.class */
public class TestUCThs {
    private final List<String> output = new ArrayList();
    private static final DecimalFormat df = new DecimalFormat("#.###");

    /* loaded from: input_file:experiments/testUCThs/TestUCThs$GameName.class */
    public enum GameName {
        Breakthrough(2, -1),
        Tablut(4, -1),
        Yavalath(4, -1),
        Clobber(3, -1),
        NineMensMorris(3, 50),
        TicTacToe(3, 9),
        ConnectFour(3, 36),
        EnglishDraughts(3, 70),
        GoMoku(3, 30),
        LinesOfAction(3, 44),
        Halma(3, -1),
        Chess(3, 70),
        Shogi(3, DOMKeyEvent.DOM_VK_F4);

        private int depth;
        private int expected;

        GameName(int i, int i2) {
            this.depth = 0;
            this.expected = -1;
            this.depth = i;
            this.expected = i2;
        }

        public int depth() {
            return this.depth;
        }

        public int expected() {
            return this.expected;
        }
    }

    void test() {
        test(GameName.ConnectFour);
    }

    void test(GameName gameName) {
        Game game2 = null;
        switch (gameName) {
            case Tablut:
                game2 = GameLoader.loadGameFromName("Tablut.lud");
                break;
            case Yavalath:
                game2 = GameLoader.loadGameFromName("Yavalath.lud");
                break;
            case Clobber:
                game2 = GameLoader.loadGameFromName("Clobber.lud", (List<String>) Arrays.asList("Rows/6", "Columns/6"));
                break;
            case NineMensMorris:
                game2 = GameLoader.loadGameFromName("Nine Men's Morris.lud");
                break;
            case Chess:
                game2 = GameLoader.loadGameFromName("Chess.lud");
                break;
            case ConnectFour:
                game2 = GameLoader.loadGameFromName("Connect Four.lud");
                break;
            case EnglishDraughts:
                game2 = GameLoader.loadGameFromName("English Draughts.lud");
                break;
            case GoMoku:
                game2 = GameLoader.loadGameFromName("GoMoku.lud");
                break;
            case Halma:
                game2 = GameLoader.loadGameFromName("Halma.lud", (List<String>) Arrays.asList("Board Size/6x6"));
                break;
            case Breakthrough:
                game2 = GameLoader.loadGameFromName("Breakthrough.lud", (List<String>) Arrays.asList("Board Size/6x6"));
                break;
            case LinesOfAction:
                game2 = GameLoader.loadGameFromName("Lines of Action.lud");
                break;
            case Shogi:
                game2 = GameLoader.loadGameFromName("Shogi.lud");
                break;
            case TicTacToe:
                game2 = GameLoader.loadGameFromName("Tic-Tac-Toe.lud");
                break;
        }
        System.out.println("==================================================");
        System.out.println("Loaded game " + game2.name() + ".");
        this.output.clear();
        this.output.add("   [");
        this.output.add("      [ (" + game2.name() + ") ]");
        try {
            double branchingFactorParallel = branchingFactorParallel(game2, 10);
            int sqrt = (int) (Math.sqrt((int) (Math.pow(branchingFactorParallel, 4.0d) + 0.5d)) + 0.5d);
            System.out.println("depth=4, BF=" + branchingFactorParallel + ", iterations=" + sqrt + ".");
            compareUCThs(gameName, game2, sqrt, 4);
        } catch (Exception e) {
            e.printStackTrace();
        }
        this.output.add("   ]");
        Iterator<String> it = this.output.iterator();
        while (it.hasNext()) {
            System.out.println(it.next());
        }
    }

    public int gameLength(Trial trial, Game game2) {
        return trial.numTurns() - trial.numForcedPasses();
    }

    void compareUCThs(GameName gameName, Game game2, int i, int i2) throws Exception {
        long nanoTime = System.nanoTime();
        AlphaBetaSearch alphaBetaSearch = null;
        MCTS mcts = null;
        System.out.println("\nUCT (" + i + " iterations).");
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(1000);
        ArrayList arrayList = new ArrayList(1000);
        CountDownLatch countDownLatch = new CountDownLatch(1000);
        for (int i3 = 0; i3 < 1000; i3++) {
            int i4 = i3 % 2;
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add(null);
            String str = "src/experiments/fastGameLengths/Heuristics_" + gameName + "_Good.txt";
            try {
                alphaBetaSearch = new AlphaBetaSearch(str);
                mcts = new MCTS(new UCB1(), new HeuristicPlayout(str), new MonteCarloBackprop(), new RobustChild());
                mcts.setFriendlyName("UCThs1/1");
            } catch (Exception e) {
                e.printStackTrace();
            }
            if (i4 == 0) {
                arrayList2.add(alphaBetaSearch);
                arrayList2.add(mcts);
            } else {
                arrayList2.add(mcts);
                arrayList2.add(alphaBetaSearch);
            }
            arrayList.add(newFixedThreadPool.submit(() -> {
                Trial trial = new Trial(game2);
                Context context = new Context(game2, trial);
                game2.start(context);
                for (int i5 = 1; i5 <= game2.players().count(); i5++) {
                    ((AI) arrayList2.get(i5)).initAI(game2, i5);
                }
                Model model = context.model();
                while (!trial.over()) {
                    model.startNewStep(context, (List<AI>) arrayList2, -1.0d, i, i2, 0.0d);
                }
                System.out.print(context.trial().status().winner());
                countDownLatch.countDown();
                return new TrialRecord(i4, trial);
            }));
        }
        countDownLatch.await();
        double nanoTime2 = (System.nanoTime() - nanoTime) / 1.0E9d;
        System.out.println("UCT (" + i + ") " + nanoTime2 + "s (" + (nanoTime2 / 1000.0d) + "s per game).");
        showResults(game2, "UCThs Results", 1000, arrayList, nanoTime2, alphaBetaSearch, mcts);
        newFixedThreadPool.shutdown();
    }

    void showResults(Game game2, String str, int i, List<Future<TrialRecord>> list, double d, AI ai, AI ai2) throws Exception {
        Stats stats = new Stats(str);
        Stats stats2 = new Stats(str);
        double[] dArr = new double[17];
        for (int i2 = 0; i2 < i; i2++) {
            TrialRecord trialRecord = list.get(i2).get();
            int winner = trialRecord.trial().status().winner();
            if (winner == 0) {
                dArr[0] = dArr[0] + 0.5d;
                dArr[1] = dArr[1] + 0.5d;
            } else if (trialRecord.starter() == 0) {
                if (winner == 1) {
                    dArr[0] = dArr[0] + 1.0d;
                } else {
                    dArr[1] = dArr[1] + 1.0d;
                }
            } else if (winner == 1) {
                dArr[1] = dArr[1] + 1.0d;
            } else {
                dArr[0] = dArr[0] + 1.0d;
            }
            double d2 = 0.0d;
            double d3 = 0.0d;
            if (winner == 0) {
                d2 = 0.5d;
                d3 = 0.5d;
            } else if (trialRecord.starter() == 0) {
                if (winner == 1) {
                    d2 = 1.0d;
                } else {
                    d3 = 1.0d;
                }
            } else if (winner == 1) {
                d3 = 1.0d;
            } else {
                d2 = 1.0d;
            }
            stats.addSample(d2);
            stats2.addSample(d3);
        }
        System.out.println(ai.friendlyName() + " success rate " + ((dArr[0] * 100.0d) / i) + "%.");
        System.out.println(ai2.friendlyName() + " success rate " + ((dArr[1] * 100.0d) / i) + "%.");
        stats.measure();
        System.out.print(ai.friendlyName());
        stats.showFull();
        stats2.measure();
        System.out.print(ai2.friendlyName());
        stats2.showFull();
    }

    double lengthRandomSerial(Game game2, int i) {
        long nanoTime = System.nanoTime();
        Context context = new Context(game2, new Trial(game2));
        Stats stats = new Stats("Serial Random");
        for (int i2 = 0; i2 < i; i2++) {
            game2.start(context);
            stats.addSample(gameLength(game2.playout(context, null, 1.0d, null, -1, -1, ThreadLocalRandom.current()), game2));
        }
        stats.measure();
        stats.showFull();
        System.out.println("Serial in " + ((System.nanoTime() - nanoTime) / 1.0E9d) + "s.");
        return stats.mean();
    }

    double lengthRandomParallel(Game game2, int i) throws Exception {
        long nanoTime = System.nanoTime();
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(i);
        ArrayList arrayList = new ArrayList(i);
        CountDownLatch countDownLatch = new CountDownLatch(i);
        for (int i2 = 0; i2 < i; i2++) {
            Trial trial = new Trial(game2);
            Context context = new Context(game2, trial);
            arrayList.add(newFixedThreadPool.submit(() -> {
                game2.start(context);
                game2.playout(context, null, 1.0d, null, -1, -1, ThreadLocalRandom.current());
                countDownLatch.countDown();
                return trial;
            }));
        }
        countDownLatch.await();
        Stats stats = new Stats("Random");
        for (int i3 = 0; i3 < i; i3++) {
            stats.addSample(gameLength((Trial) ((Future) arrayList.get(i3)).get(), game2));
        }
        stats.measure();
        double nanoTime2 = (System.nanoTime() - nanoTime) / 1.0E9d;
        stats.showFull();
        System.out.println("Random concurrent in " + nanoTime2 + "s (" + (nanoTime2 / i) + "s per game).");
        formatOutput(stats, i, nanoTime2);
        newFixedThreadPool.shutdown();
        return stats.mean();
    }

    void formatOutput(Stats stats, int i, double d) {
        this.output.add("      [ (" + stats.label() + ") " + stats.n() + " " + df.format(stats.mean()) + " " + ((int) stats.min()) + " " + ((int) stats.max()) + " " + df.format(stats.sd()) + " " + df.format(stats.se()) + " " + df.format(stats.ci()) + " " + df.format((d / i) * 1000.0d) + " ]");
    }

    double branchingFactorParallel(Game game2, int i) throws Exception {
        game2.disableMemorylessPlayouts();
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(i);
        ArrayList arrayList = new ArrayList(i);
        CountDownLatch countDownLatch = new CountDownLatch(i);
        for (int i2 = 0; i2 < i; i2++) {
            Trial trial = new Trial(game2);
            Context context = new Context(game2, trial);
            trial.storeLegalMovesHistorySizes();
            arrayList.add(newFixedThreadPool.submit(() -> {
                game2.start(context);
                game2.playout(context, null, 1.0d, null, -1, -1, ThreadLocalRandom.current());
                countDownLatch.countDown();
                return trial;
            }));
        }
        countDownLatch.await();
        double d = 0.0d;
        for (int i3 = 0; i3 < i; i3++) {
            TIntArrayList legalMovesHistorySizes = ((Trial) ((Future) arrayList.get(i3)).get()).auxilTrialData().legalMovesHistorySizes();
            double d2 = 0.0d;
            if (legalMovesHistorySizes.size() > 0) {
                for (int i4 = 0; i4 < legalMovesHistorySizes.size(); i4++) {
                    d2 += legalMovesHistorySizes.getQuick(i4);
                }
                d2 /= legalMovesHistorySizes.size();
            }
            d += d2;
        }
        newFixedThreadPool.shutdown();
        return d / i;
    }

    public static void main(String[] strArr) {
        new TestUCThs().test();
    }
}
