package training.expert_iteration.gradients;

import features.FeatureVector;
import gnu.trove.list.array.TIntArrayList;
import java.util.HashMap;
import java.util.List;
import main.collections.FVector;
import metadata.ai.heuristics.Heuristics;
import optimisers.Optimiser;
import policies.softmax.SoftmaxPolicyLinear;
import training.expert_iteration.ExItExperience;

/* loaded from: input_file:training/expert_iteration/gradients/Gradients.class */
public class Gradients {
    private Gradients() {
    }

    public static FVector computeDistributionErrors(FVector fVector, FVector fVector2) {
        FVector copy = fVector.copy();
        copy.subtract(fVector2);
        return copy;
    }

    public static FVector computeCrossEntropyErrors(SoftmaxPolicyLinear softmaxPolicyLinear, FVector fVector, FeatureVector[] featureVectorArr, int i, boolean z) {
        FVector fVector2;
        FVector computeDistribution = softmaxPolicyLinear.computeDistribution(featureVectorArr, i);
        if (z) {
            HashMap hashMap = new HashMap();
            for (int i2 = 0; i2 < featureVectorArr.length; i2++) {
                FeatureVector featureVector = featureVectorArr[i2];
                if (!hashMap.containsKey(featureVector)) {
                    hashMap.put(featureVector, new TIntArrayList());
                }
                ((TIntArrayList) hashMap.get(featureVector)).add(i2);
            }
            fVector2 = fVector.copy();
            boolean[] zArr = new boolean[fVector2.dim()];
            for (int i3 = 0; i3 < fVector2.dim(); i3++) {
                if (!zArr[i3]) {
                    TIntArrayList tIntArrayList = (TIntArrayList) hashMap.get(featureVectorArr[i3]);
                    if (tIntArrayList.size() > 1) {
                        float f = 0.0f;
                        for (int i4 = 0; i4 < tIntArrayList.size(); i4++) {
                            float f2 = fVector2.get(tIntArrayList.getQuick(i4));
                            if (f2 > f) {
                                f = f2;
                            }
                        }
                        for (int i5 = 0; i5 < tIntArrayList.size(); i5++) {
                            fVector2.set(tIntArrayList.getQuick(i5), f);
                            zArr[tIntArrayList.getQuick(i5)] = true;
                        }
                    }
                }
            }
            fVector2.normalise();
        } else {
            fVector2 = fVector;
        }
        return computeDistributionErrors(computeDistribution, fVector2);
    }

    public static FVector computeValueGradients(Heuristics heuristics, int i, ExItExperience exItExperience) {
        if (heuristics == null || i <= 0) {
            return null;
        }
        FVector paramsVector = heuristics.paramsVector();
        float tanh = (float) Math.tanh(paramsVector.dot(exItExperience.stateFeatureVector()));
        float f = tanh - ((float) exItExperience.playerOutcomes()[exItExperience.state().state().mover()]);
        FVector fVector = new FVector(paramsVector.dim());
        float f2 = 2.0f * f * (1.0f - (tanh * tanh));
        for (int i2 = 0; i2 < fVector.dim(); i2++) {
            fVector.set(i2, f2 * exItExperience.stateFeatureVector().get(i2));
        }
        return null;
    }

    public static FVector meanGradients(List<FVector> list) {
        if (list.isEmpty()) {
            return null;
        }
        return FVector.mean(list);
    }

    public static FVector wisGradients(List<FVector> list, float f) {
        if (list.isEmpty()) {
            return null;
        }
        FVector copy = list.get(0).copy();
        for (int i = 1; i < list.size(); i++) {
            copy.add(list.get(i));
        }
        if (f > 0.0d) {
            copy.div(f);
        }
        return copy;
    }

    public static void minimise(Optimiser optimiser, FVector fVector, FVector fVector2, float f) {
        FVector fVector3 = new FVector(fVector);
        fVector3.mult(f);
        optimiser.minimiseObjective(fVector, fVector2);
        fVector.subtract(fVector3);
    }

    public static void maximise(Optimiser optimiser, FVector fVector, FVector fVector2, float f) {
        FVector fVector3 = new FVector(fVector);
        fVector3.mult(f);
        optimiser.maximiseObjective(fVector, fVector2);
        fVector.subtract(fVector3);
    }
}
