/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.codegen.template;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataGenOp;
import org.apache.sysds.hops.DnnOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.NaryOp;
import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.hops.codegen.cplan.CNodeNary;
import org.apache.sysds.hops.codegen.cplan.CNodeRow;
import org.apache.sysds.hops.codegen.cplan.CNodeTernary;
import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
import org.apache.sysds.hops.codegen.template.CPlanMemoTable;
import org.apache.sysds.hops.codegen.template.TemplateBase;
import org.apache.sysds.hops.codegen.template.TemplateCell;
import org.apache.sysds.hops.codegen.template.TemplateUtils;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.Pair;

public class TemplateRow
extends TemplateBase {
    private static final Types.AggOp[] SUPPORTED_ROW_AGG = new Types.AggOp[]{Types.AggOp.SUM, Types.AggOp.MIN, Types.AggOp.MAX, Types.AggOp.MEAN};
    private static final Types.OpOp1[] SUPPORTED_VECT_UNARY = new Types.OpOp1[]{Types.OpOp1.EXP, Types.OpOp1.SQRT, Types.OpOp1.LOG, Types.OpOp1.ABS, Types.OpOp1.ROUND, Types.OpOp1.CEIL, Types.OpOp1.FLOOR, Types.OpOp1.SIGN, Types.OpOp1.SIN, Types.OpOp1.COS, Types.OpOp1.TAN, Types.OpOp1.ASIN, Types.OpOp1.ACOS, Types.OpOp1.ATAN, Types.OpOp1.SINH, Types.OpOp1.COSH, Types.OpOp1.TANH, Types.OpOp1.CUMSUM, Types.OpOp1.CUMMIN, Types.OpOp1.CUMMAX, Types.OpOp1.SPROP, Types.OpOp1.SIGMOID};
    private static final Types.OpOp2[] SUPPORTED_VECT_BINARY = new Types.OpOp2[]{Types.OpOp2.MULT, Types.OpOp2.DIV, Types.OpOp2.MINUS, Types.OpOp2.PLUS, Types.OpOp2.POW, Types.OpOp2.MIN, Types.OpOp2.MAX, Types.OpOp2.XOR, Types.OpOp2.EQUAL, Types.OpOp2.NOTEQUAL, Types.OpOp2.LESS, Types.OpOp2.LESSEQUAL, Types.OpOp2.GREATER, Types.OpOp2.GREATEREQUAL, Types.OpOp2.BITWAND};

    public TemplateRow() {
        super(TemplateBase.TemplateType.ROW);
    }

    public TemplateRow(TemplateBase.CloseType ctype) {
        super(TemplateBase.TemplateType.ROW, ctype);
    }

    @Override
    public boolean open(Hop hop) {
        return hop instanceof BinaryOp && hop.dimsKnown() && TemplateRow.isValidBinaryOperation(hop) && hop.getInput().get(0).getDim1() > 1L && hop.getInput().get(0).getDim2() > 1L || (hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp) && TemplateCell.isValidOperation(hop) && hop.getDim1() > 1L || HopRewriteUtils.isTernary(hop, Types.OpOp3.PLUS_MULT, Types.OpOp3.MINUS_MULT) || TemplateRow.isValidBinaryNaryCBind(hop) || HopRewriteUtils.isNary(hop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) && hop.isMatrix() || hop instanceof AggBinaryOp && hop.dimsKnown() && hop.getDim2() == 1L && hop.getInput().get(0).getDim1() > 1L && hop.getInput().get(0).getDim2() > 1L || hop instanceof AggBinaryOp && hop.dimsKnown() && LibMatrixMult.isSkinnyRightHandSide(hop.getInput().get(0).getDim1(), hop.getInput().get(0).getDim2(), hop.getInput().get(1).getDim1(), hop.getInput().get(1).getDim2(), false) && hop.getInput().get(0).getDim1() > 1L && hop.getInput().get(0).getDim2() > 1L && !HopRewriteUtils.isOuterProductLikeMM(hop) || HopRewriteUtils.isTransposeOperation(hop) && hop.getParent().size() == 1 && hop.getParent().get(0) instanceof AggBinaryOp && hop.getParent().get(0).dimsKnown() && hop.getParent().get(0).getInput().indexOf(hop) == 0 && TemplateRow.isFuseSkinnyMatrixMult(hop.getParent().get(0)) || hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection() != Types.Direction.RowCol && hop.getInput().get(0).getDim1() > 1L && hop.getInput().get(0).getDim2() > 1L && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG) || hop instanceof IndexingOp && hop.getInput().get(0).getDim1() > 1L && hop.getInput().get(0).getDim2() >= 0L && HopRewriteUtils.isColumnRangeIndexing((IndexingOp)hop) || HopRewriteUtils.isDnn(hop, Types.OpOpDnn.BIASADD, Types.OpOpDnn.BIASMULT) && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown() && hop.getInput().get(0).getDim2() > 1L || HopRewriteUtils.isDnn(hop, Types.OpOpDnn.MAX_POOL, Types.OpOpDnn.AVG_POOL, Types.OpOpDnn.CONV2D) && hop.getInput().get(0).dimsKnown() && ((DnnOp)hop).isStride1Pad0() && hop.getInput().get(1).dimsKnown();
    }

    @Override
    public boolean fuse(Hop hop, Hop input) {
        return !this.isClosed() && (hop instanceof BinaryOp && TemplateRow.isValidBinaryOperation(hop) || TemplateRow.isValidBinaryNaryCBind(hop) || HopRewriteUtils.isNary(hop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) && hop.isMatrix() || (hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp) && TemplateCell.isValidOperation(hop) || HopRewriteUtils.isTernary(hop, Types.OpOp3.PLUS_MULT, Types.OpOp3.MINUS_MULT) || hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection() != Types.Direction.RowCol && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG) || hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection() == Types.Direction.RowCol && ((AggUnaryOp)hop).getOp() == Types.AggOp.SUM || hop instanceof AggBinaryOp && hop.getDim1() > 1L && hop.getDim2() == 1L && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)) || hop instanceof AggBinaryOp && hop.dimsKnown() && TemplateRow.isFuseSkinnyMatrixMult(hop) && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)) && hop.getInput().get(0).getDim1() > 1L && hop.getInput().get(0).getDim2() > 1L || HopRewriteUtils.isDnn(hop, Types.OpOpDnn.BIASADD, Types.OpOpDnn.BIASMULT) && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown() && hop.getInput().get(0).getDim2() > 1L || HopRewriteUtils.isDnn(hop, Types.OpOpDnn.MAX_POOL, Types.OpOpDnn.AVG_POOL, Types.OpOpDnn.CONV2D) && hop.getInput().get(0).dimsKnown() && ((DnnOp)hop).isStride1Pad0() && hop.getInput().get(1).dimsKnown() && hop.getInput().get(1) != input || TemplateRow.isPartOfValidCumAggChain(hop) || TemplateRow.isPartOfValidTransposeMMChain(hop));
    }

    @Override
    public boolean merge(Hop hop, Hop input) {
        return !this.isClosed() && (hop instanceof BinaryOp && TemplateRow.isValidBinaryOperation(hop) && hop.getDim1() > 1L && input.getDim1() > 1L || TemplateRow.isValidBinaryNaryCBind(hop) || HopRewriteUtils.isNary(hop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) && hop.isMatrix() || HopRewriteUtils.isDnn(hop, Types.OpOpDnn.BIASADD, Types.OpOpDnn.BIASMULT) && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown() && hop.getInput().get(0).getDim2() > 1L || HopRewriteUtils.isDnn(hop, Types.OpOpDnn.MAX_POOL, Types.OpOpDnn.AVG_POOL, Types.OpOpDnn.CONV2D) && hop.getInput().get(0).dimsKnown() && ((DnnOp)hop).isStride1Pad0() && hop.getInput().get(1).dimsKnown() && hop.getInput().get(1) != input || HopRewriteUtils.isDataGenOpWithLiteralInputs(input, Types.OpOpDG.SEQ) && HopRewriteUtils.hasOnlyUnaryBinaryParents(input, false) || hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)) && (input.getDim2() == 1L || input == hop.getInput().get(1) && HopRewriteUtils.containsInput(input, hop.getInput().get(0).getInput().get(0))));
    }

    @Override
    public TemplateBase.CloseType close(Hop hop) {
        if (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection() != Types.Direction.Row || hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) {
            return TemplateBase.CloseType.CLOSED_VALID;
        }
        if (HopRewriteUtils.isTransposeOperation(hop)) {
            return TemplateBase.CloseType.OPEN_INVALID;
        }
        return TemplateBase.CloseType.OPEN_VALID;
    }

    private static boolean isValidBinaryOperation(Hop hop) {
        return TemplateUtils.isOperationSupported(hop);
    }

    private static boolean isValidBinaryNaryCBind(Hop hop) {
        return (HopRewriteUtils.isBinary(hop, Types.OpOp2.CBIND) || HopRewriteUtils.isNary(hop, Types.OpOpN.CBIND)) && hop.getInput().get(0).isMatrix() && hop.dimsKnown() && hop.getInput().get(0).getDim1() > 1L;
    }

    private static boolean isFuseSkinnyMatrixMult(Hop hop) {
        Hop in1 = hop.getInput().get(0);
        Hop in2 = hop.getInput().get(1);
        return LibMatrixMult.isSkinnyRightHandSide(in1.getDim2(), in1.getDim1(), hop.getDim1(), hop.getDim2(), false) || LibMatrixMult.isSkinnyRightHandSide(in2.getDim1(), in2.getDim2(), hop.getDim2(), hop.getDim1(), false);
    }

    private static boolean isPartOfValidCumAggChain(Hop hop) {
        if (HopRewriteUtils.isTransposeOperation(hop)) {
            return HopRewriteUtils.isUnary(hop.getInput().get(0), Types.OpOp1.CUMSUM, Types.OpOp1.CUMMIN, Types.OpOp1.CUMMAX) && hop.getParent().size() == 1 && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0).getInput().get(0)) && hop.getInput().get(0).getInput().get(0).getParent().size() == 1 || HopRewriteUtils.isUnary(hop.getParent().get(0), Types.OpOp1.CUMSUM, Types.OpOp1.CUMMIN, Types.OpOp1.CUMMAX) && hop.getParent().size() == 1 && HopRewriteUtils.isTransposeOperation(hop.getParent().get(0).getParent().get(0)) && hop.getParent().get(0).getParent().size() == 1;
        }
        return HopRewriteUtils.isUnary(hop, Types.OpOp1.CUMSUM, Types.OpOp1.CUMMIN, Types.OpOp1.CUMMAX) && hop.getParent().size() == 1 && HopRewriteUtils.isTransposeOperation(hop.getParent().get(0)) && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)) && hop.getInput().get(0).getParent().size() == 1;
    }

    private static boolean isPartOfValidTransposeMMChain(Hop hop) {
        return HopRewriteUtils.isTransposeOperation(hop) && hop.getParent().size() == 1 && hop.dimsKnown() && hop.getParent().get(0).dimsKnown() && hop.getDim2() > 128L * hop.getParent().get(0).getDim1() && hop.getDim2() > 128L * hop.getParent().get(0).getDim2() && HopRewriteUtils.isMatrixMultiply(hop.getParent().get(0)) && TemplateRow.isFuseSkinnyMatrixMult(hop.getParent().get(0)) && (hop.getParent().get(0).getInput().get(0) == hop && HopRewriteUtils.containsInput(hop, hop.getParent().get(0).getInput().get(1)) || hop.getParent().get(0).getInput().get(1) == hop && HopRewriteUtils.containsInput(hop, hop.getParent().get(0).getInput().get(0)));
    }

    @Override
    public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
        long n2;
        HashSet<Hop> inHops = new HashSet<Hop>();
        HashMap<String, Hop> inHops2 = new HashMap<String, Hop>();
        HashMap<Long, CNode> tmp = new HashMap<Long, CNode>();
        hop.resetVisitStatus();
        this.rConstructCplan(hop, memo, tmp, inHops, inHops2, compileLiterals);
        hop.resetVisitStatus();
        Hop[] sinHops = (Hop[])inHops.stream().filter(h -> !h.getDataType().isScalar() || !((CNode)tmp.get(h.getHopID())).isLiteral()).sorted(new HopInputComparator(inHops2.get("X"), inHops2.get("B1"))).toArray(Hop[]::new);
        inHops2.putIfAbsent("X", sinHops[0]);
        ArrayList<CNode> inputs = new ArrayList<CNode>();
        for (Hop in : sinHops) {
            inputs.add(tmp.get(in.getHopID()));
        }
        CNode output = tmp.get(hop.getHopID());
        CNodeRow tpl = new CNodeRow(inputs, output);
        tpl.setRowType(TemplateUtils.getRowType(hop, inHops2.get("X"), inHops2.get("B1")));
        long l = n2 = tpl.getRowType() == SpoofRowwise.RowType.COL_AGG_B1 ? hop.getDim1() : hop.getDim2();
        if (tpl.getRowType().isConstDim2(n2)) {
            tpl.setConstDim2(n2);
        }
        tpl.setNumVectorIntermediates(TemplateUtils.determineMinVectorIntermediates(output, inputs.isEmpty() ? null : inputs.get(0)));
        tpl.getOutput().resetVisitStatus();
        tpl.rReorderCommutativeBinaryOps(tpl.getOutput(), sinHops[0].getHopID());
        tpl.setBeginLine(hop.getBeginLine());
        return new Pair<Hop[], CNodeTpl>(sinHops, tpl);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
        CNode cdata2;
        CNode cdata1;
        if (tmp.containsKey(hop.getHopID())) {
            return;
        }
        CPlanMemoTable.MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateBase.TemplateType.ROW, TemplateBase.TemplateType.CELL);
        for (int i = 0; i < hop.getInput().size(); ++i) {
            Hop c = hop.getInput().get(i);
            if (me != null && me.isPlanRef(i)) {
                this.rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
                continue;
            }
            CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
            tmp.put(c.getHopID(), cdata);
            inHops.add(c);
        }
        CNode out = null;
        if (hop instanceof AggUnaryOp) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            if (((AggUnaryOp)hop).getDirection().isRow() && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG)) {
                if (hop.getInput().get(0).getDim2() == 1L) {
                    out = cdata1.getDataType() == Types.DataType.SCALAR ? cdata1 : new CNodeUnary(cdata1, CNodeUnary.UnaryType.LOOKUP_R);
                } else {
                    String opcode = "ROW_" + ((AggUnaryOp)hop).getOp().name().toUpperCase() + "S";
                    out = new CNodeUnary(cdata1, CNodeUnary.UnaryType.valueOf(opcode));
                    if (cdata1 instanceof CNodeData && !inHops2.containsKey("X")) {
                        inHops2.put("X", hop.getInput().get(0));
                    }
                }
            } else if (HopRewriteUtils.isAggUnaryOp(hop, Types.AggOp.SUM, Types.AggOp.MEAN) && ((AggUnaryOp)hop).getDirection().isCol()) {
                out = cdata1 instanceof CNodeBinary && ((CNodeBinary)cdata1).getType().isVectorScalarPrimitive() ? new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), ((CNodeBinary)cdata1).getType().getVectorAddPrimitive()) : cdata1;
            } else if (((AggUnaryOp)hop).getDirection() == Types.Direction.RowCol && ((AggUnaryOp)hop).getOp() == Types.AggOp.SUM) {
                out = cdata1.getDataType().isMatrix() ? new CNodeUnary(cdata1, CNodeUnary.UnaryType.ROW_SUMS) : cdata1;
            }
        } else if (hop instanceof AggBinaryOp) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) {
                cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
                inHops.remove(hop.getInput().get(0));
                if (cdata1 instanceof CNodeData) {
                    inHops.add(hop.getInput().get(0).getInput().get(0));
                }
                if (hop.getInput().get(1).getDim2() == 1L) {
                    out = new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.VECT_MULT_ADD);
                } else {
                    out = new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.VECT_OUTERMULT_ADD);
                    if (!inHops2.containsKey("B1")) {
                        if (cdata1 instanceof CNodeData) {
                            inHops2.put("X", hop.getInput().get(0).getInput().get(0));
                        }
                        inHops2.put("B1", hop.getInput().get(1));
                    }
                }
                if (!inHops2.containsKey("X")) {
                    inHops2.put("X", hop.getInput().get(0).getInput().get(0));
                }
            } else if (hop.getInput().get(0).getDim2() == 1L && hop.getInput().get(1).getDim2() == 1L) {
                out = new CNodeBinary(cdata1.getDataType() == Types.DataType.SCALAR ? cdata1 : new CNodeUnary(cdata1, CNodeUnary.UnaryType.LOOKUP0), cdata2.getDataType() == Types.DataType.SCALAR ? cdata2 : new CNodeUnary(cdata2, CNodeUnary.UnaryType.LOOKUP0), CNodeBinary.BinType.MULT);
            } else if (hop.getInput().get(1).getDim2() == 1L) {
                out = new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.DOT_PRODUCT);
                inHops2.put("X", hop.getInput().get(0));
            } else {
                out = new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.VECT_MATRIXMULT);
                inHops2.put("X", hop.getInput().get(0));
                inHops2.put("B1", hop.getInput().get(1));
            }
        } else if (HopRewriteUtils.isDataGenOp(hop, Types.OpOpDG.SEQ)) {
            CNodeData from = TemplateUtils.getLiteral(tmp.get(((DataGenOp)hop).getParam("from").getHopID()));
            CNodeData to = TemplateUtils.getLiteral(tmp.get(((DataGenOp)hop).getParam("to").getHopID()));
            CNodeData incr = TemplateUtils.getLiteral(tmp.get(((DataGenOp)hop).getParam("incr").getHopID()));
            if (Double.parseDouble(from.getVarname()) > Double.parseDouble(to.getVarname()) && Double.parseDouble(incr.getVarname()) > 0.0) {
                incr = TemplateUtils.createCNodeData(new LiteralOp("-" + incr.getVarname()), true);
            }
            out = new CNodeBinary(from, incr, CNodeBinary.BinType.SEQ_RIX);
        } else if (HopRewriteUtils.isTransposeOperation(hop)) {
            out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), hop, tmp, compileLiterals);
            if (out instanceof CNodeData && !inHops.contains(hop.getInput().get(0))) {
                inHops.add(hop.getInput().get(0));
            }
        } else if (hop instanceof UnaryOp) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            if (hop.getInput().get(0).getDim1() >= 1L && hop.getInput().get(0).getDim2() > 1L || !hop.dimsKnown() && cdata1.getDataType() == Types.DataType.MATRIX) {
                if (!HopRewriteUtils.isUnary(hop, SUPPORTED_VECT_UNARY)) throw new RuntimeException("Unsupported unary matrix operation: " + ((UnaryOp)hop).getOp().name());
                String opname = "VECT_" + ((UnaryOp)hop).getOp().name();
                out = new CNodeUnary(cdata1, CNodeUnary.UnaryType.valueOf(opname));
                if (cdata1 instanceof CNodeData && !inHops2.containsKey("X")) {
                    inHops2.put("X", hop.getInput().get(0));
                }
            } else {
                cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
                String primitiveOpName = ((UnaryOp)hop).getOp().name();
                out = new CNodeUnary(cdata1, CNodeUnary.UnaryType.valueOf(primitiveOpName));
            }
        } else if (HopRewriteUtils.isBinary(hop, Types.OpOp2.CBIND)) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            cdata2 = null;
            if (HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1))) {
                cdata2 = TemplateUtils.createCNodeData(HopRewriteUtils.getDataGenOpConstantValue(hop.getInput().get(1)), true);
                inHops.remove(hop.getInput().get(1));
            } else {
                cdata2 = tmp.get(hop.getInput().get(1).getHopID());
                cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1), true);
            }
            out = new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.VECT_CBIND);
            if (cdata1 instanceof CNodeData && !inHops2.containsKey("X")) {
                inHops2.put("X", hop.getInput().get(0));
            }
        } else if (hop instanceof BinaryOp) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            if (hop.getInput().get(0).getDim1() >= 1L && hop.getInput().get(0).getDim2() > 1L || hop.getInput().get(1).getDim1() >= 1L && hop.getInput().get(1).getDim2() > 1L || (!hop.dimsKnown() || !hop.getInput().get(0).dimsKnown() || !hop.getInput().get(1).dimsKnown()) && hop.getDim2() != 1L && (cdata1.getDataType().isMatrix() || cdata2.getDataType().isMatrix())) {
                if (!HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY)) throw new RuntimeException("Unsupported binary matrix operation: " + ((BinaryOp)hop).getOp().name());
                if (TemplateUtils.isColVector(cdata1)) {
                    cdata1 = new CNodeUnary(cdata1, CNodeUnary.UnaryType.LOOKUP_R);
                }
                if (TemplateUtils.isColVector(cdata2)) {
                    cdata2 = new CNodeUnary(cdata2, CNodeUnary.UnaryType.LOOKUP_R);
                }
                out = TemplateRow.getVectorBinary(cdata1, cdata2, ((BinaryOp)hop).getOp().name());
                if (cdata1 instanceof CNodeData && !inHops2.containsKey("X") && cdata1.getDataType() != Types.DataType.SCALAR) {
                    inHops2.put("X", hop.getInput().get(0));
                }
            } else {
                String primitiveOpName = ((BinaryOp)hop).getOp().name();
                if (TemplateUtils.isColVector(cdata1)) {
                    cdata1 = new CNodeUnary(cdata1, CNodeUnary.UnaryType.LOOKUP_R);
                }
                if (TemplateUtils.isColVector(cdata2) || TemplateUtils.isColVector(hop.getInput().get(0)) && cdata2 instanceof CNodeData && hop.getInput().get(1).getDataType().isMatrix()) {
                    cdata2 = new CNodeUnary(cdata2, CNodeUnary.UnaryType.LOOKUP_R);
                }
                out = new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.valueOf(primitiveOpName));
            }
        } else if (hop instanceof TernaryOp) {
            TernaryOp top = (TernaryOp)hop;
            CNode cdata12 = tmp.get(hop.getInput().get(0).getHopID());
            CNode cdata22 = tmp.get(hop.getInput().get(1).getHopID());
            CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
            if (hop.getDim2() >= 2L) {
                out = new CNodeBinary(cdata12, new CNodeBinary(cdata22, cdata3, CNodeBinary.BinType.VECT_MULT_SCALAR), top.getOp() == Types.OpOp3.PLUS_MULT ? CNodeBinary.BinType.VECT_PLUS : CNodeBinary.BinType.VECT_MINUS);
            } else {
                cdata12 = TemplateUtils.wrapLookupIfNecessary(cdata12, hop.getInput().get(0));
                cdata22 = TemplateUtils.wrapLookupIfNecessary(cdata22, hop.getInput().get(1));
                cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
                out = new CNodeTernary(cdata12, cdata22, cdata3, CNodeTernary.TernaryType.valueOf(top.getOp().name()));
            }
        } else if (HopRewriteUtils.isDnn(hop, Types.OpOpDnn.BIASADD, Types.OpOpDnn.BIASMULT)) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            out = new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.valueOf("VECT_" + ((DnnOp)hop).getOp().name()));
        } else if (HopRewriteUtils.isDnn(hop, Types.OpOpDnn.MAX_POOL, Types.OpOpDnn.AVG_POOL)) {
            CNode[] in = (CNode[])hop.getInput().stream().map(h -> (CNode)tmp.get(h.getHopID())).toArray(CNode[]::new);
            out = new CNodeNary(in, CNodeNary.NaryType.valueOf("VECT_" + ((DnnOp)hop).getOp().name()));
        } else if (HopRewriteUtils.isDnn(hop, Types.OpOpDnn.CONV2D)) {
            CNode[] in1 = (CNode[])hop.getInput().stream().filter(h -> h != hop.getInput().get(1)).map(h -> (CNode)tmp.get(h.getHopID())).toArray(CNode[]::new);
            CNodeNary im2col = new CNodeNary(in1, CNodeNary.NaryType.VECT_IM2COL);
            CNode[] in2 = (CNode[])hop.getInput().stream().map(h -> h == hop.getInput().get(0) ? im2col : (CNode)tmp.get(h.getHopID())).toArray(CNode[]::new);
            out = new CNodeNary(in2, CNodeNary.NaryType.VECT_CONV2DMM);
        } else if (hop instanceof NaryOp) {
            int i;
            CNode[] inputs = new CNode[hop.getInput().size()];
            for (i = 0; i < hop.getInput().size(); ++i) {
                Hop c = hop.getInput().get(i);
                CNode cdata = tmp.get(c.getHopID());
                if (TemplateUtils.isColVector(cdata) || TemplateUtils.isRowVector(cdata)) {
                    cdata = TemplateUtils.wrapLookupIfNecessary(cdata, c);
                }
                inputs[i] = cdata;
                if (i != 0 || !(cdata instanceof CNodeData) || inHops2.containsKey("X")) continue;
                inHops2.put("X", c);
            }
            if (HopRewriteUtils.isNary(hop, Types.OpOpN.CBIND)) {
                out = new CNodeNary(inputs, CNodeNary.NaryType.VECT_CBIND);
            } else if (HopRewriteUtils.isNary(hop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS)) {
                out = TemplateRow.getVectorOrScalarBinary(inputs[0], inputs[1], ((NaryOp)hop).getOp().name());
                for (i = 2; i < hop.getInput().size(); ++i) {
                    out = TemplateRow.getVectorOrScalarBinary(out, inputs[i], ((NaryOp)hop).getOp().name());
                }
            }
        } else if (hop instanceof ParameterizedBuiltinOp) {
            cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());
            cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
            cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID());
            CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID());
            CNodeTernary.TernaryType ttype = cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN") ? CNodeTernary.TernaryType.REPLACE_NAN : CNodeTernary.TernaryType.REPLACE;
            out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
        } else if (hop instanceof IndexingOp) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), hop.getDim2() != 1L ? CNodeTernary.TernaryType.LOOKUP_RVECT1 : CNodeTernary.TernaryType.LOOKUP_RC1);
        }
        if (out == null) {
            throw new HopsException(hop.getHopID() + " " + hop.getOpString());
        }
        if (out.getDataType().isMatrix()) {
            out.setNumRows(hop.getDim1());
            out.setNumCols(hop.getDim2());
        }
        tmp.put(hop.getHopID(), out);
    }

    private static CNodeBinary getVectorOrScalarBinary(CNode cdata1, CNode cdata2, String name) {
        if ((TemplateUtils.isColVector(cdata1) || cdata1.getDataType().isScalar()) && (TemplateUtils.isColVector(cdata2) || cdata2.getDataType().isScalar())) {
            return new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.valueOf(name));
        }
        return TemplateRow.getVectorBinary(cdata1, cdata2, name);
    }

    private static CNodeBinary getVectorBinary(CNode cdata1, CNode cdata2, String name) {
        if (TemplateUtils.isMatrix(cdata1) && (TemplateUtils.isMatrix(cdata2) || TemplateUtils.isRowVector(cdata2))) {
            return new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.valueOf("VECT_" + name));
        }
        return new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.valueOf("VECT_" + name + "_SCALAR"));
    }

    public static class HopInputComparator
    implements Comparator<Hop> {
        private final Hop _X;
        private final Hop _B1;

        public HopInputComparator(Hop X, Hop B1) {
            this._X = X;
            this._B1 = B1;
        }

        @Override
        public int compare(Hop h1, Hop h2) {
            long ncells2;
            long ncells1;
            long l = h1.isScalar() ? Long.MIN_VALUE : (h1 == this._X ? Long.MAX_VALUE : (h1 == this._B1 ? 0x7FFFFFFFFFFFFFFEL : (ncells1 = h1.dimsKnown() ? h1.getLength() : 0x7FFFFFFFFFFFFFFDL)));
            long l2 = h2.isScalar() ? Long.MIN_VALUE : (h2 == this._X ? Long.MAX_VALUE : (h2 == this._B1 ? 0x7FFFFFFFFFFFFFFEL : (ncells2 = h2.dimsKnown() ? h2.getLength() : 0x7FFFFFFFFFFFFFFDL)));
            return ncells1 > ncells2 ? -1 : (ncells1 < ncells2 ? 1 : 0);
        }
    }
}

