package com.xiaomi.ai.nlp.ml.base;

/* loaded from: classes17.dex */
public class MLMath {
    public static int argmax(double[] dArr) {
        if (dArr == null || dArr.length == 0) {
            throw new IllegalArgumentException("Vector x is null or empty");
        }
        int i = 0;
        for (int i2 = 1; i2 < dArr.length; i2++) {
            if (dArr[i] < dArr[i2]) {
                i = i2;
            }
        }
        return i;
    }

    public static double dotProd(double[] dArr, double[] dArr2) {
        if (dArr == null || dArr2 == null) {
            throw new IllegalArgumentException("dotProd has null input vector");
        }
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException("dotProd's input vector size isn't equal");
        }
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i] * dArr2[i];
        }
        return d;
    }

    public static double l1norm(double[] dArr) {
        if (dArr == null) {
            throw new IllegalArgumentException("vector v is null");
        }
        double d = 0.0d;
        for (double d2 : dArr) {
            d += Math.abs(d2);
        }
        return d;
    }

    public static double logSumExp(double[] dArr) {
        double d = dArr[0];
        double d2 = 0.0d;
        for (int i = 1; i < dArr.length; i++) {
            if (d < dArr[i]) {
                d = dArr[i];
                d2 = i;
            }
        }
        double d3 = d - 30.0d;
        double d4 = 0.0d;
        boolean z = false;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (i2 != d2 && dArr[i2] > d3) {
                d4 += Math.exp(dArr[i2] - d);
                z = true;
            }
        }
        return z ? d + Math.log(d4 + 1.0d) : d;
    }

    public static void plusTo(double[] dArr, double d, double[] dArr2, double d2, double[] dArr3) {
        if (dArr == null || dArr2 == null || dArr3 == null) {
            throw new IllegalArgumentException("plusTo has null input vector");
        }
        if (dArr.length != dArr2.length || dArr.length != dArr3.length) {
            throw new IllegalArgumentException("plusTo's input vector size isn't equal");
        }
        for (int i = 0; i < dArr.length; i++) {
            dArr3[i] = (dArr[i] * d) + (dArr2[i] * d2);
        }
    }

    public static void transformTo(double[] dArr, double d, double[] dArr2) {
        if (dArr == null || dArr2 == null) {
            throw new IllegalArgumentException("transformTo has null input vector");
        }
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException("transformTo's input vector size isn't equal");
        }
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = dArr[i] * d;
        }
    }
}
