package decision_trees.classifiers;

import features.FeatureVector;
import features.WeightVector;
import features.aspatial.AspatialFeature;
import features.feature_sets.BaseFeatureSet;
import function_approx.LinearFunction;
import gnu.trove.list.array.TFloatArrayList;
import gnu.trove.list.array.TIntArrayList;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import main.collections.ArrayUtils;
import main.collections.FVector;
import training.expert_iteration.ExItExperience;
import utils.data_structures.experience_buffers.ExperienceBuffer;

/* loaded from: input_file:decision_trees/classifiers/ExperienceImbalancedBinaryClassificationTree2Learner.class */
public class ExperienceImbalancedBinaryClassificationTree2Learner {
    public static DecisionTreeNode buildTree(BaseFeatureSet baseFeatureSet, LinearFunction linearFunction, ExperienceBuffer experienceBuffer, int i, int i2) {
        WeightVector effectiveParams = linearFunction.effectiveParams();
        ExItExperience[] allExperience = experienceBuffer.allExperience();
        ArrayList arrayList = new ArrayList();
        TFloatArrayList tFloatArrayList = new TFloatArrayList();
        for (ExItExperience exItExperience : allExperience) {
            if (exItExperience != null && exItExperience.moves().size() > 1) {
                FeatureVector[] generateFeatureVectors = exItExperience.generateFeatureVectors(baseFeatureSet);
                float[] fArr = new float[generateFeatureVectors.length];
                for (int i3 = 0; i3 < generateFeatureVectors.length; i3++) {
                    fArr[i3] = effectiveParams.dot(generateFeatureVectors[i3]);
                }
                float max = ArrayUtils.max(fArr);
                float min = ArrayUtils.min(fArr);
                if (max != min) {
                    for (FeatureVector featureVector : generateFeatureVectors) {
                        arrayList.add(featureVector);
                    }
                    int nextSetBit = exItExperience.winningMoves().nextSetBit(0);
                    while (true) {
                        int i4 = nextSetBit;
                        if (i4 < 0) {
                            break;
                        }
                        fArr[i4] = max;
                        nextSetBit = exItExperience.winningMoves().nextSetBit(i4 + 1);
                    }
                    int nextSetBit2 = exItExperience.losingMoves().nextSetBit(0);
                    while (true) {
                        int i5 = nextSetBit2;
                        if (i5 < 0) {
                            break;
                        }
                        fArr[i5] = min;
                        nextSetBit2 = exItExperience.losingMoves().nextSetBit(i5 + 1);
                    }
                    FVector fVector = new FVector(fArr);
                    fVector.softmax();
                    float max2 = fVector.max();
                    float[] fArr2 = new float[fArr.length];
                    for (int i6 = 0; i6 < fArr2.length; i6++) {
                        fArr2[i6] = fVector.get(i6) / max2;
                    }
                    for (float f : fArr2) {
                        tFloatArrayList.add(f);
                    }
                }
            }
        }
        return buildNode(baseFeatureSet, arrayList, tFloatArrayList, new BitSet(), new BitSet(), baseFeatureSet.getNumAspatialFeatures(), baseFeatureSet.getNumSpatialFeatures(), i, i2);
    }

    private static DecisionTreeNode buildNode(BaseFeatureSet baseFeatureSet, List<FeatureVector> list, TFloatArrayList tFloatArrayList, BitSet bitSet, BitSet bitSet2, int i, int i2, int i3, int i4) {
        BitSet bitSet3;
        BitSet bitSet4;
        double d;
        double d2;
        double d3;
        double d4;
        if (i4 <= 0) {
            throw new IllegalArgumentException("minSamplesPerLeaf must be greater than 0");
        }
        if (list.isEmpty()) {
            return new BinaryLeafNode(0.5f);
        }
        if (i3 == 0) {
            return new BinaryLeafNode(tFloatArrayList.sum() / tFloatArrayList.size());
        }
        double[] dArr = new double[i];
        int[] iArr = new int[i];
        double[] dArr2 = new double[i];
        int[] iArr2 = new int[i];
        for (int i5 = 0; i5 < i; i5++) {
            if (!bitSet.get(i5)) {
                for (int i6 = 0; i6 < list.size(); i6++) {
                    FeatureVector featureVector = list.get(i6);
                    float quick = tFloatArrayList.getQuick(i6);
                    if (featureVector.aspatialFeatureValues().get(i5) != 0.0f) {
                        int i7 = i5;
                        dArr2[i7] = dArr2[i7] + quick;
                        int i8 = i5;
                        iArr2[i8] = iArr2[i8] + 1;
                    } else {
                        int i9 = i5;
                        dArr[i9] = dArr[i9] + quick;
                        int i10 = i5;
                        iArr[i10] = iArr[i10] + 1;
                    }
                }
            }
        }
        double[] dArr3 = new double[i2];
        int[] iArr3 = new int[i2];
        double[] dArr4 = new double[i2];
        int[] iArr4 = new int[i2];
        for (int i11 = 0; i11 < list.size(); i11++) {
            FeatureVector featureVector2 = list.get(i11);
            float quick2 = tFloatArrayList.getQuick(i11);
            boolean[] zArr = new boolean[i2];
            TIntArrayList activeSpatialFeatureIndices = featureVector2.activeSpatialFeatureIndices();
            for (int i12 = 0; i12 < activeSpatialFeatureIndices.size(); i12++) {
                zArr[activeSpatialFeatureIndices.getQuick(i12)] = true;
            }
            for (int i13 = 0; i13 < zArr.length; i13++) {
                if (!bitSet2.get(i13)) {
                    if (zArr[i13]) {
                        int i14 = i13;
                        dArr4[i14] = dArr4[i14] + quick2;
                        int i15 = i13;
                        iArr4[i15] = iArr4[i15] + 1;
                    } else {
                        int i16 = i13;
                        dArr3[i16] = dArr3[i16] + quick2;
                        int i17 = i13;
                        iArr3[i17] = iArr3[i17] + 1;
                    }
                }
            }
        }
        double[] dArr5 = new double[i];
        double[] dArr6 = new double[i];
        double[] dArr7 = new double[i2];
        double[] dArr8 = new double[i2];
        for (int i18 = 0; i18 < i; i18++) {
            if (iArr[i18] > 0) {
                dArr5[i18] = dArr[i18] / iArr[i18];
            }
            if (iArr2[i18] > 0) {
                dArr6[i18] = dArr2[i18] / iArr2[i18];
            }
        }
        for (int i19 = 0; i19 < i2; i19++) {
            if (iArr3[i19] > 0) {
                dArr7[i19] = dArr3[i19] / iArr3[i19];
            }
            if (iArr4[i19] > 0) {
                dArr8[i19] = dArr4[i19] / iArr4[i19];
            }
        }
        double d5 = Double.POSITIVE_INFINITY;
        double d6 = Double.NEGATIVE_INFINITY;
        int i20 = -1;
        int i21 = -1;
        boolean z = true;
        for (int i22 = 0; i22 < i; i22++) {
            if (iArr[i22] >= i4 && iArr2[i22] >= i4) {
                double d7 = 0.0d;
                for (int i23 = 0; i23 < list.size(); i23++) {
                    FeatureVector featureVector3 = list.get(i23);
                    float quick3 = tFloatArrayList.getQuick(i23);
                    if (featureVector3.aspatialFeatureValues().get(i22) != 0.0f) {
                        d3 = quick3;
                        d4 = dArr6[i22];
                    } else {
                        d3 = quick3;
                        d4 = dArr5[i22];
                    }
                    double d8 = d3 - d4;
                    d7 += d8 * d8;
                }
                if (d7 < d5) {
                    d5 = d7;
                    i21 = i22;
                    i20 = iArr2[i22];
                } else if (d7 == d5 && iArr2[i22] > i20) {
                    i21 = i22;
                    i20 = iArr2[i22];
                }
                if (d7 > d6) {
                    d6 = d7;
                }
            }
        }
        for (int i24 = 0; i24 < i2; i24++) {
            if (iArr3[i24] >= i4 && iArr4[i24] >= i4) {
                double d9 = 0.0d;
                for (int i25 = 0; i25 < list.size(); i25++) {
                    FeatureVector featureVector4 = list.get(i25);
                    float quick4 = tFloatArrayList.getQuick(i25);
                    if (featureVector4.activeSpatialFeatureIndices().contains(i24)) {
                        d = quick4;
                        d2 = dArr8[i24];
                    } else {
                        d = quick4;
                        d2 = dArr7[i24];
                    }
                    double d10 = d - d2;
                    d9 += d10 * d10;
                }
                if (d9 < d5) {
                    d5 = d9;
                    i21 = i24;
                    i20 = iArr4[i24];
                    z = false;
                } else if (d9 == d5 && iArr4[i24] > i20) {
                    i21 = i24;
                    i20 = iArr4[i24];
                    z = false;
                }
                if (d9 > d6) {
                    d6 = d9;
                }
            }
        }
        if (i21 == -1 || d5 == d6) {
            return new BinaryLeafNode(tFloatArrayList.sum() / tFloatArrayList.size());
        }
        AspatialFeature aspatialFeature = z ? baseFeatureSet.aspatialFeatures()[i21] : baseFeatureSet.spatialFeatures()[i21];
        if (z) {
            bitSet4 = (BitSet) bitSet.clone();
            bitSet4.set(i21);
            bitSet3 = bitSet2;
        } else {
            bitSet3 = (BitSet) bitSet2.clone();
            bitSet3.set(i21);
            bitSet4 = bitSet;
        }
        ArrayList arrayList = new ArrayList();
        TFloatArrayList tFloatArrayList2 = new TFloatArrayList();
        ArrayList arrayList2 = new ArrayList();
        TFloatArrayList tFloatArrayList3 = new TFloatArrayList();
        if (z) {
            for (int i26 = 0; i26 < list.size(); i26++) {
                if (list.get(i26).aspatialFeatureValues().get(i21) != 0.0f) {
                    arrayList.add(list.get(i26));
                    tFloatArrayList2.add(tFloatArrayList.getQuick(i26));
                } else {
                    arrayList2.add(list.get(i26));
                    tFloatArrayList3.add(tFloatArrayList.getQuick(i26));
                }
            }
        } else {
            for (int i27 = 0; i27 < list.size(); i27++) {
                if (list.get(i27).activeSpatialFeatureIndices().contains(i21)) {
                    arrayList.add(list.get(i27));
                    tFloatArrayList2.add(tFloatArrayList.getQuick(i27));
                } else {
                    arrayList2.add(list.get(i27));
                    tFloatArrayList3.add(tFloatArrayList.getQuick(i27));
                }
            }
        }
        return new DecisionConditionNode(aspatialFeature, buildNode(baseFeatureSet, arrayList, tFloatArrayList2, bitSet4, bitSet3, i, i2, 0, i4), buildNode(baseFeatureSet, arrayList2, tFloatArrayList3, bitSet4, bitSet3, i, i2, i3 - 1, i4));
    }
}
