/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.evaluator;

import ai.djl.modality.cv.MultiBoxTarget;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.training.evaluator.AbstractAccuracy;
import ai.djl.util.Pair;

public class SingleShotDetectionAccuracy
extends AbstractAccuracy {
    private MultiBoxTarget multiBoxTarget = MultiBoxTarget.builder().build();

    public SingleShotDetectionAccuracy(String name) {
        super(name, 0);
    }

    @Override
    protected Pair<Long, NDArray> accuracyHelper(NDList labels, NDList predictions) {
        NDArray anchors = (NDArray)predictions.get(0);
        NDArray classPredictions = (NDArray)predictions.get(1);
        NDList targets = this.multiBoxTarget.target(new NDList(anchors, labels.head(), classPredictions.transpose(0, 2, 1)));
        NDArray classLabels = (NDArray)targets.get(2);
        this.checkLabelShapes(classLabels, classPredictions);
        NDArray predictionReduced = classPredictions.argMax(-1);
        long total = classLabels.size();
        NDArray numCorrect = classLabels.toType(DataType.INT64, false).eq(predictionReduced).countNonzero();
        return new Pair<Long, NDArray>(total, numCorrect);
    }
}

