/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.parfor;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.io.Writable;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.util.LongAccumulator;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.DataPartitionerRemoteSparkMapper;
import org.apache.sysds.runtime.controlprogram.parfor.RemoteDPParForSparkWorker;
import org.apache.sysds.runtime.controlprogram.parfor.RemoteParForJobReturn;
import org.apache.sysds.runtime.controlprogram.parfor.RemoteParForUtils;
import org.apache.sysds.runtime.controlprogram.parfor.util.PairWritableBlock;
import org.apache.sysds.runtime.instructions.spark.data.DatasetObject;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.instructions.spark.utils.RDDConverterUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.Statistics;
import scala.Tuple2;

public class RemoteDPParForSpark {
    protected static final Log LOG = LogFactory.getLog((String)RemoteDPParForSpark.class.getName());

    public static RemoteParForJobReturn runJob(long pfid, String itervar, String matrixvar, String program, HashMap<String, byte[]> clsMap, String resultFile, MatrixObject input, ExecutionContext ec, ParForProgramBlock.PartitionFormat dpf, Types.FileFormat fmt, boolean tSparseCol, boolean enableCPCaching, int numReducers) {
        String jobname = "ParFor-DPESP";
        long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        JavaSparkContext sc = sec.getSparkContext();
        boolean isLocal = sc.isLocal();
        MatrixObject mo = sec.getMatrixObject(matrixvar);
        DataCharacteristics mc = mo.getDataCharacteristics();
        JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable(matrixvar);
        LongAccumulator aTasks = sc.sc().longAccumulator("tasks");
        LongAccumulator aIters = sc.sc().longAccumulator("iterations");
        int numParts = SparkUtils.getNumPreferredPartitions(mc, in);
        int numReducers2 = Math.max(numReducers, Math.min(numParts, (int)dpf.getNumParts(mc)));
        RemoteDPParForSparkWorker efun = new RemoteDPParForSparkWorker(program, isLocal, clsMap, matrixvar, itervar, enableCPCaching, mc, tSparseCol, dpf, fmt, aTasks, aIters);
        JavaPairRDD<Long, Writable> tmp = RemoteDPParForSpark.getPartitionedInput(sec, matrixvar, fmt, dpf);
        List out = (RemoteDPParForSpark.requiresGrouping(dpf, mo) ? tmp.groupByKey(numReducers2) : tmp.map((Function)new PseudoGrouping())).mapPartitionsToPair((PairFlatMapFunction)efun).collect();
        LocalVariableMap[] results = RemoteParForUtils.getResults(out, LOG);
        int numTasks = aTasks.value().intValue();
        int numIters = aIters.value().intValue();
        RemoteParForJobReturn ret = new RemoteParForJobReturn(true, numTasks, numIters, results);
        Statistics.incrementNoOfCompiledSPInst();
        Statistics.incrementNoOfExecutedSPInst();
        if (DMLScript.STATISTICS) {
            Statistics.maintainCPHeavyHitters(jobname, System.nanoTime() - t0);
        }
        return ret;
    }

    private static JavaPairRDD<Long, Writable> getPartitionedInput(SparkExecutionContext sec, String matrixvar, Types.FileFormat fmt, ParForProgramBlock.PartitionFormat dpf) {
        MatrixObject mo = sec.getMatrixObject(matrixvar);
        DataCharacteristics mc = mo.getDataCharacteristics();
        if (RemoteDPParForSpark.hasInputDataSet(dpf, mo)) {
            DatasetObject dsObj = (DatasetObject)mo.getRDDHandle().getLineageChilds().get(0).getLineageChilds().get(0);
            Dataset<Row> in = dsObj.getDataset();
            JavaPairRDD prepinput = dsObj.containsID() ? in.javaRDD().mapToPair((PairFunction)new RDDConverterUtils.DataFrameExtractIDFunction(in.schema().fieldIndex("__INDEX"))) : in.javaRDD().zipWithIndex();
            return prepinput.mapToPair((PairFunction)new DataFrameToRowBinaryBlockFunction(mc.getCols(), dsObj.isVectorBased(), dsObj.containsID()));
        }
        if (!RemoteDPParForSpark.requiresGrouping(dpf, mo)) {
            JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable(matrixvar);
            DataPartitionerRemoteSparkMapper dpfun = new DataPartitionerRemoteSparkMapper(mc, fmt, dpf._dpf, dpf._N);
            return in.flatMapToPair((PairFlatMapFunction)dpfun);
        }
        Object in = sec.getBinaryMatrixBlockRDDHandleForVariable(matrixvar);
        if (mo.getRDDHandle().isCheckpointRDD() && !sec.isRDDCached(in.id())) {
            in = ((RDDObject)mo.getRDDHandle().getLineageChilds().get(0)).getRDD();
        }
        DataPartitionerRemoteSparkMapper dpfun = new DataPartitionerRemoteSparkMapper(mc, fmt, dpf._dpf, dpf._N);
        return in.flatMapToPair((PairFlatMapFunction)dpfun);
    }

    private static boolean requiresGrouping(ParForProgramBlock.PartitionFormat dpf, MatrixObject mo) {
        DataCharacteristics mc = mo.getDataCharacteristics();
        return (dpf == ParForProgramBlock.PartitionFormat.ROW_WISE && mc.getNumColBlocks() > 1L || dpf == ParForProgramBlock.PartitionFormat.COLUMN_WISE && mc.getNumRowBlocks() > 1L || dpf._dpf == ParForProgramBlock.PDataPartitionFormat.ROW_BLOCK_WISE_N && mc.getNumColBlocks() > 1L || dpf._dpf == ParForProgramBlock.PDataPartitionFormat.COLUMN_BLOCK_WISE_N && mc.getNumRowBlocks() > 1L) && !RemoteDPParForSpark.hasInputDataSet(dpf, mo);
    }

    private static boolean hasInputDataSet(ParForProgramBlock.PartitionFormat dpf, MatrixObject mo) {
        return dpf == ParForProgramBlock.PartitionFormat.ROW_WISE && mo.getRDDHandle().isCheckpointRDD() && mo.getRDDHandle().getLineageChilds().size() == 1 && mo.getRDDHandle().getLineageChilds().get(0).getLineageChilds().size() == 1 && mo.getRDDHandle().getLineageChilds().get(0).getLineageChilds().get(0) instanceof DatasetObject;
    }

    private static class DataFrameToRowBinaryBlockFunction
    implements PairFunction<Tuple2<Row, Long>, Long, Writable> {
        private static final long serialVersionUID = -3162404379379461523L;
        private final long _clen;
        private final boolean _containsID;
        private final boolean _isVector;

        public DataFrameToRowBinaryBlockFunction(long clen, boolean containsID, boolean isVector) {
            this._clen = clen;
            this._containsID = containsID;
            this._isVector = isVector;
        }

        public Tuple2<Long, Writable> call(Tuple2<Row, Long> arg0) throws Exception {
            long rowix = (Long)arg0._2() + 1L;
            int off = this._containsID ? 1 : 0;
            Object obj = this._isVector ? ((Row)arg0._1()).get(off) : arg0._1();
            boolean sparse = obj instanceof SparseVector;
            MatrixBlock mb = new MatrixBlock(1, (int)this._clen, sparse);
            if (this._isVector) {
                Vector vect = (Vector)obj;
                if (vect instanceof SparseVector) {
                    SparseVector svect = (SparseVector)vect;
                    int lnnz = svect.numNonzeros();
                    for (int k = 0; k < lnnz; ++k) {
                        mb.appendValue(0, svect.indices()[k], svect.values()[k]);
                    }
                } else {
                    int j = 0;
                    while ((long)j < this._clen) {
                        mb.appendValue(0, j, vect.apply(j));
                        ++j;
                    }
                }
            } else {
                Row row = (Row)obj;
                int j = off;
                while ((long)j < (long)off + this._clen) {
                    mb.appendValue(0, j - off, UtilFunctions.getDouble(row.get(j)));
                    ++j;
                }
            }
            mb.examSparsity();
            return new Tuple2((Object)rowix, (Object)new PairWritableBlock(new MatrixIndexes(1L, 1L), mb));
        }
    }

    private static class PseudoGrouping
    implements Function<Tuple2<Long, Writable>, Tuple2<Long, Iterable<Writable>>> {
        private static final long serialVersionUID = 2016614593596923995L;

        private PseudoGrouping() {
        }

        public Tuple2<Long, Iterable<Writable>> call(Tuple2<Long, Writable> arg0) {
            return new Tuple2((Object)((Long)arg0._1()), Collections.singletonList((Writable)arg0._2()));
        }
    }
}

