package com.xiaomi.ai.nlp.ml.infer;

import com.xiaomi.ai.nlp.lm.util.Constant;
import com.xiaomi.ai.nlp.ml.base.MLMath;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: classes17.dex */
public class MultinomialLogisticRegression {
    private Map<String, LabelWeights> featureWeights = new HashMap();
    private Map<String, Integer> labelToIndex = new HashMap();
    private Map<Integer, String> indexToLabel = new HashMap();

    /* loaded from: classes17.dex */
    public class LabelWeights {
        private Map<String, Double> weights = new HashMap();

        public LabelWeights() {
        }

        public void put(String str, double d) {
            this.weights.put(str, Double.valueOf(d));
        }

        public Map<String, Double> weigths() {
            return this.weights;
        }
    }

    /* loaded from: classes17.dex */
    public static class ProbInfo implements Comparable<ProbInfo> {
        private String label;
        private double prob;

        @Override // java.lang.Comparable
        public int compareTo(ProbInfo probInfo) {
            return -Double.compare(this.prob, probInfo.prob);
        }

        public String getLabel() {
            return this.label;
        }

        public double getProb() {
            return this.prob;
        }

        public void setLabel(String str) {
            this.label = str;
        }

        public void setProb(double d) {
            this.prob = d;
        }
    }

    public List<ProbInfo> infer(Map<String, Double> map) {
        int size = this.labelToIndex.size();
        double[] dArr = new double[size];
        for (Map.Entry<String, Double> entry : map.entrySet()) {
            if (this.featureWeights.containsKey(entry.getKey())) {
                for (Map.Entry entry2 : this.featureWeights.get(entry.getKey()).weights.entrySet()) {
                    if (!this.labelToIndex.containsKey(entry2.getKey())) {
                        throw new IllegalArgumentException("label set error, new label find: " + ((String) entry2.getKey()));
                    }
                    int intValue = this.labelToIndex.get(entry2.getKey()).intValue();
                    dArr[intValue] = dArr[intValue] + (((Double) entry2.getValue()).doubleValue() * entry.getValue().doubleValue());
                }
            }
        }
        double logSumExp = MLMath.logSumExp(dArr);
        for (int i = 0; i < size; i++) {
            dArr[i] = Math.exp(dArr[i] - logSumExp);
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < size; i2++) {
            ProbInfo probInfo = new ProbInfo();
            probInfo.label = this.indexToLabel.get(Integer.valueOf(i2));
            probInfo.prob = dArr[i2];
            arrayList.add(probInfo);
        }
        Collections.sort(arrayList);
        return arrayList;
    }

    public void load(InputStream inputStream, Set<String> set) {
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
        while (true) {
            String readLine = bufferedReader.readLine();
            int i = 0;
            if (readLine == null) {
                bufferedReader.close();
                for (String str : set) {
                    this.labelToIndex.put(str, Integer.valueOf(i));
                    this.indexToLabel.put(Integer.valueOf(i), str);
                    i++;
                }
                return;
            }
            String trim = readLine.trim();
            if (!trim.isEmpty()) {
                String[] split = trim.split(Constant.BLANK);
                if (split.length < 2) {
                    throw new IllegalArgumentException("feature weight not found: " + trim);
                }
                String str2 = split[0];
                LabelWeights labelWeights = new LabelWeights();
                for (int i2 = 1; i2 < split.length; i2++) {
                    String[] split2 = split[i2].split(":");
                    if (split2.length != 2) {
                        throw new IllegalArgumentException("label weight format error: " + split[i2]);
                    }
                    String str3 = split2[0];
                    if (!set.contains(str3)) {
                        throw new IllegalArgumentException("label wasn't in label set: " + str3);
                    }
                    try {
                        labelWeights.put(str3, Double.parseDouble(split2[1]));
                    } catch (NumberFormatException unused) {
                        throw new NumberFormatException("feature weight parse error: " + split2[1]);
                    }
                }
                this.featureWeights.put(str2, labelWeights);
            }
        }
    }
}
