/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.estim.EstimatorMatrixHistogram;
import org.apache.sysds.hops.estim.MMNode;
import org.apache.sysds.hops.estim.SparsityEstimator;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysds.hops.rewrite.RewriteMatrixMultChainOptimization;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

public class RewriteMatrixMultChainOptimizationSparse
extends RewriteMatrixMultChainOptimization {
    @Override
    protected void optimizeMMChain(Hop hop, ArrayList<Hop> mmChain, ArrayList<Hop> mmOperators, ProgramRewriteStatus state) {
        double[] dimsArray = new double[mmChain.size() + 1];
        boolean dimsKnown = RewriteMatrixMultChainOptimizationSparse.getDimsArray(hop, mmChain, dimsArray);
        MMNode[] sketchArray = new MMNode[mmChain.size() + 1];
        boolean inputsAvail = RewriteMatrixMultChainOptimizationSparse.getInputMatrices(hop, mmChain, sketchArray, state);
        if (dimsKnown && inputsAvail) {
            RewriteMatrixMultChainOptimizationSparse.clearLinksWithinChain(hop, mmOperators);
            int size = mmChain.size();
            int[][] split = RewriteMatrixMultChainOptimizationSparse.mmChainDPSparse(dimsArray, sketchArray, mmChain.size());
            LOG.trace((Object)"Optimal MM Chain: ");
            this.mmChainRelinkHops(mmOperators.get(0), 0, size - 1, mmChain, mmOperators, new MutableInt(1), split, 1);
        }
    }

    private static int[][] mmChainDPSparse(double[] dimArray, MMNode[] sketchArray, int size) {
        double[][] dpMatrix = new double[size][size];
        MMNode[][] dpMatrixS = new MMNode[size][size];
        int[][] split = new int[size][size];
        for (int i = 0; i < size; ++i) {
            Arrays.fill(dpMatrix[i], 0.0);
            Arrays.fill(split[i], -1);
            dpMatrixS[i][i] = sketchArray[i];
        }
        EstimatorMatrixHistogram estim = new EstimatorMatrixHistogram(true);
        for (int l = 2; l <= size; ++l) {
            for (int i = 0; i < size - l + 1; ++i) {
                int j = i + l - 1;
                dpMatrix[i][j] = Double.MAX_VALUE;
                for (int k = i; k <= j - 1; ++k) {
                    MMNode tmp = new MMNode(dpMatrixS[i][k], dpMatrixS[k + 1][j], SparsityEstimator.OpCode.MM);
                    estim.estim(tmp, false);
                    EstimatorMatrixHistogram.MatrixHistogram lhs = (EstimatorMatrixHistogram.MatrixHistogram)dpMatrixS[i][k].getSynopsis();
                    EstimatorMatrixHistogram.MatrixHistogram rhs = (EstimatorMatrixHistogram.MatrixHistogram)dpMatrixS[k + 1][j].getSynopsis();
                    double cost = dpMatrix[i][k] + dpMatrix[k + 1][j] + RewriteMatrixMultChainOptimizationSparse.dotProduct(lhs.getColCounts(), rhs.getRowCounts());
                    if (!(cost < dpMatrix[i][j])) continue;
                    dpMatrix[i][j] = cost;
                    dpMatrixS[i][j] = tmp;
                    split[i][j] = k;
                }
                if (!LOG.isTraceEnabled()) continue;
                LOG.trace((Object)("mmchainopt [i=" + (i + 1) + ",j=" + (j + 1) + "]: costs = " + dpMatrix[i][j] + ", split = " + (split[i][j] + 1)));
            }
        }
        return split;
    }

    private static boolean getInputMatrices(Hop hop, ArrayList<Hop> chain, MMNode[] sketchArray, ProgramRewriteStatus state) {
        boolean inputsAvail = true;
        LocalVariableMap vars = state.getVariables();
        for (int i = 0; i < chain.size() && (inputsAvail &= HopRewriteUtils.isData(chain.get(0), Types.OpOpData.TRANSIENTREAD)); ++i) {
            sketchArray[i] = new MMNode(RewriteMatrixMultChainOptimizationSparse.getMatrix(chain.get(i).getName(), vars));
        }
        return inputsAvail;
    }

    private static MatrixBlock getMatrix(String name, LocalVariableMap vars) {
        Data dat = vars.get(name);
        if (!(dat instanceof MatrixObject)) {
            throw new HopsException("Input '" + name + "' not a matrix: " + dat.getDataType());
        }
        return (MatrixBlock)((MatrixObject)dat).acquireReadAndRelease();
    }

    private static double dotProduct(int[] h1cNnz, int[] h2rNnz) {
        long fp = 0L;
        for (int j = 0; j < h1cNnz.length; ++j) {
            fp += (long)h1cNnz[j] * (long)h2rNnz[j];
        }
        return fp;
    }
}

