package com.xiaomi.ai.nlp.optimization;

import com.xiaomi.ai.nlp.loss.DiffFunction;
import com.xiaomi.ai.nlp.loss.L1RegOwlDiffFunction;
import com.xiaomi.ai.nlp.utils.MLMath;

/* loaded from: classes17.dex */
public class OWLQNMinimizer implements Miniminizer {
    private LineSearch lineSearch;
    private NewtonDirection newtonDirection;

    public OWLQNMinimizer(NewtonDirection newtonDirection, LineSearch lineSearch) {
        if (newtonDirection == null) {
            throw new IllegalArgumentException("newton direction is null");
        }
        if (lineSearch == null) {
            throw new IllegalArgumentException("line search is null");
        }
        this.newtonDirection = newtonDirection;
        this.lineSearch = lineSearch;
    }

    private void projectDirection(double[] dArr, double[] dArr2) {
        for (int i = 0; i < dArr2.length; i++) {
            if (Math.signum(dArr2[i]) != Math.signum(dArr[i])) {
                dArr2[i] = 0.0d;
            }
        }
    }

    @Override // com.xiaomi.ai.nlp.optimization.Miniminizer
    public double[] minimize(DiffFunction diffFunction, double d, double[] dArr, int i) {
        if (diffFunction == null) {
            throw new IllegalArgumentException("diff function is null");
        }
        if (!(diffFunction instanceof L1RegOwlDiffFunction)) {
            throw new IllegalArgumentException("diff function type is not L1RegOwlDiffFunction");
        }
        if (dArr == null) {
            throw new IllegalArgumentException("initial point is null");
        }
        if (dArr.length != this.newtonDirection.getDimension()) {
            throw new IllegalArgumentException("initial point dimension isn't equal to newton direction dimension");
        }
        int length = dArr.length;
        double[] dArr2 = new double[length];
        double[] dArr3 = new double[length];
        double[] dArr4 = new double[length];
        L1RegOwlDiffFunction l1RegOwlDiffFunction = (L1RegOwlDiffFunction) diffFunction;
        System.arraycopy(dArr, 0, dArr2, 0, dArr.length);
        int i2 = 0;
        while (true) {
            int i3 = i2 + 1;
            if (i2 >= i) {
                break;
            }
            double[] derivativeAt = diffFunction.derivativeAt(dArr2);
            MLMath.transformTo(derivativeAt, -1.0d, derivativeAt);
            double valueAt = diffFunction.valueAt(dArr2);
            double[] lossDerivativeAt = l1RegOwlDiffFunction.lossDerivativeAt(dArr2);
            double[] computeDirection = this.newtonDirection.computeDirection(derivativeAt);
            projectDirection(derivativeAt, computeDirection);
            double[] nextX = this.lineSearch.nextX(dArr2, computeDirection, diffFunction, true);
            double valueAt2 = diffFunction.valueAt(nextX);
            if (valueAt2 + d > valueAt) {
                break;
            }
            double[] lossDerivativeAt2 = l1RegOwlDiffFunction.lossDerivativeAt(nextX);
            MLMath.plusTo(nextX, 1.0d, dArr2, -1.0d, dArr3);
            MLMath.plusTo(lossDerivativeAt2, 1.0d, lossDerivativeAt, -1.0d, dArr4);
            this.newtonDirection.updateSYRho(dArr3, dArr4);
            System.arraycopy(nextX, 0, dArr2, 0, length);
            System.out.println("iter: " + i3 + " loss: " + valueAt2);
            i2 = i3;
        }
        return dArr2;
    }
}
