/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer;

import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Set;
import java.util.Stack;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.metadata.Partition;
import org.apache.hadoop.hive.ql.metadata.Table;
import org.apache.hadoop.hive.ql.optimizer.GroupByOptimizer;
import org.apache.hadoop.hive.ql.optimizer.Transform;
import org.apache.hadoop.hive.ql.optimizer.ppr.PartitionPruner;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.PrunedPartitionList;
import org.apache.hadoop.hive.ql.parse.QBJoinTree;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.util.StringUtils;

public class BucketMapJoinOptimizer
implements Transform {
    private static final Log LOG = LogFactory.getLog((String)GroupByOptimizer.class.getName());

    @Override
    public ParseContext transform(ParseContext pctx) throws SemanticException {
        LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
        BucketMapjoinOptProcCtx bucketMapJoinOptimizeCtx = new BucketMapjoinOptProcCtx();
        opRules.put(new RuleRegExp("R1", "MAPJOIN%"), this.getBucketMapjoinProc(pctx));
        opRules.put(new RuleRegExp("R2", "RS%.*MAPJOIN"), this.getBucketMapjoinRejectProc(pctx));
        opRules.put(new RuleRegExp(new String("R3"), "UNION%.*MAPJOIN%"), this.getBucketMapjoinRejectProc(pctx));
        opRules.put(new RuleRegExp(new String("R4"), "MAPJOIN%.*MAPJOIN%"), this.getBucketMapjoinRejectProc(pctx));
        DefaultRuleDispatcher disp = new DefaultRuleDispatcher(this.getDefaultProc(), opRules, bucketMapJoinOptimizeCtx);
        DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getTopOps().values());
        ogw.startWalking(topNodes, null);
        return pctx;
    }

    private NodeProcessor getBucketMapjoinRejectProc(ParseContext pctx) {
        return new NodeProcessor(){

            @Override
            public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
                MapJoinOperator mapJoinOp = (MapJoinOperator)nd;
                BucketMapjoinOptProcCtx context = (BucketMapjoinOptProcCtx)procCtx;
                context.listOfRejectedMapjoins.add(mapJoinOp);
                return null;
            }
        };
    }

    private NodeProcessor getBucketMapjoinProc(ParseContext pctx) {
        return new BucketMapjoinOptProc(pctx);
    }

    private NodeProcessor getDefaultProc() {
        return new NodeProcessor(){

            @Override
            public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
                return null;
            }
        };
    }

    class BucketMapjoinOptProcCtx
    implements NodeProcessorCtx {
        Set<MapJoinOperator> listOfRejectedMapjoins = new HashSet<MapJoinOperator>();

        BucketMapjoinOptProcCtx() {
        }

        public Set<MapJoinOperator> getListOfRejectedMapjoins() {
            return this.listOfRejectedMapjoins;
        }
    }

    class BucketMapjoinOptProc
    implements NodeProcessor {
        protected ParseContext pGraphContext;

        public BucketMapjoinOptProc(ParseContext pGraphContext) {
            this.pGraphContext = pGraphContext;
        }

        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            MapJoinOperator mapJoinOp = (MapJoinOperator)nd;
            BucketMapjoinOptProcCtx context = (BucketMapjoinOptProcCtx)procCtx;
            if (context.getListOfRejectedMapjoins().contains(mapJoinOp)) {
                return null;
            }
            QBJoinTree joinCxt = this.pGraphContext.getMapJoinContext().get(mapJoinOp);
            if (joinCxt == null) {
                return null;
            }
            ArrayList<String> joinAliases = new ArrayList<String>();
            String[] srcs = joinCxt.getBaseSrc();
            String[] left = joinCxt.getLeftAliases();
            List<String> mapAlias = joinCxt.getMapAliases();
            String baseBigAlias = null;
            for (String s : left) {
                if (s == null || joinAliases.contains(s)) continue;
                joinAliases.add(s);
                if (mapAlias.contains(s)) continue;
                baseBigAlias = s;
            }
            for (String s : srcs) {
                if (s == null || joinAliases.contains(s)) continue;
                joinAliases.add(s);
                if (mapAlias.contains(s)) continue;
                baseBigAlias = s;
            }
            MapJoinDesc mjDecs = (MapJoinDesc)mapJoinOp.getConf();
            LinkedHashMap<String, Integer> aliasToBucketNumberMapping = new LinkedHashMap<String, Integer>();
            LinkedHashMap<String, List<String>> aliasToBucketFileNamesMapping = new LinkedHashMap<String, List<String>>();
            HashMap<String, Operator<? extends Serializable>> topOps = this.pGraphContext.getTopOps();
            HashMap<TableScanOperator, Table> topToTable = this.pGraphContext.getTopToTable();
            LinkedHashMap<Partition, List<String>> bigTblPartsToBucketFileNames = new LinkedHashMap<Partition, List<String>>();
            LinkedHashMap<Partition, Integer> bigTblPartsToBucketNumber = new LinkedHashMap<Partition, Integer>();
            for (int index = 0; index < joinAliases.size(); ++index) {
                String alias = (String)joinAliases.get(index);
                TableScanOperator tso = (TableScanOperator)topOps.get(alias);
                if (tso == null) {
                    return null;
                }
                Table tbl = (Table)topToTable.get(tso);
                if (tbl.isPartitioned()) {
                    PrunedPartitionList prunedParts = null;
                    try {
                        prunedParts = this.pGraphContext.getOpToPartList().get(tso);
                        if (prunedParts == null) {
                            prunedParts = PartitionPruner.prune(tbl, this.pGraphContext.getOpToPartPruner().get(tso), this.pGraphContext.getConf(), alias, this.pGraphContext.getPrunedPartitions());
                            this.pGraphContext.getOpToPartList().put(tso, prunedParts);
                        }
                    }
                    catch (HiveException e) {
                        LOG.error((Object)StringUtils.stringifyException((Throwable)e));
                        throw new SemanticException(e.getMessage(), e);
                    }
                    int partNumber = prunedParts.getConfirmedPartns().size() + prunedParts.getUnknownPartns().size();
                    if (partNumber > 1) {
                        List<String> fileNames;
                        if (alias != baseBigAlias) {
                            return null;
                        }
                        for (Partition p : prunedParts.getConfirmedPartns()) {
                            if (!this.checkBucketColumns(p.getBucketCols(), mjDecs, index)) {
                                return null;
                            }
                            fileNames = this.getOnePartitionBucketFileNames(p);
                            bigTblPartsToBucketFileNames.put(p, fileNames);
                            bigTblPartsToBucketNumber.put(p, p.getBucketCount());
                        }
                        for (Partition p : prunedParts.getUnknownPartns()) {
                            if (!this.checkBucketColumns(p.getBucketCols(), mjDecs, index)) {
                                return null;
                            }
                            fileNames = this.getOnePartitionBucketFileNames(p);
                            bigTblPartsToBucketFileNames.put(p, fileNames);
                            bigTblPartsToBucketNumber.put(p, p.getBucketCount());
                        }
                        continue;
                    }
                    Partition part = null;
                    Iterator<Partition> iter = prunedParts.getConfirmedPartns().iterator();
                    if (iter.hasNext()) {
                        part = iter.next();
                    }
                    if (part == null && (iter = prunedParts.getUnknownPartns().iterator()).hasNext()) {
                        part = iter.next();
                    }
                    assert (part != null);
                    Integer num = new Integer(part.getBucketCount());
                    aliasToBucketNumberMapping.put(alias, num);
                    if (!this.checkBucketColumns(part.getBucketCols(), mjDecs, index)) {
                        return null;
                    }
                    List<String> fileNames = this.getOnePartitionBucketFileNames(part);
                    aliasToBucketFileNamesMapping.put(alias, fileNames);
                    if (alias != baseBigAlias) continue;
                    bigTblPartsToBucketFileNames.put(part, fileNames);
                    bigTblPartsToBucketNumber.put(part, num);
                    continue;
                }
                if (!this.checkBucketColumns(tbl.getBucketCols(), mjDecs, index)) {
                    return null;
                }
                Integer num = new Integer(tbl.getNumBuckets());
                aliasToBucketNumberMapping.put(alias, num);
                ArrayList<String> fileNames = new ArrayList<String>();
                try {
                    FileSystem fs = FileSystem.get((URI)tbl.getDataLocation(), (Configuration)this.pGraphContext.getConf());
                    FileStatus[] files = fs.listStatus(new Path(tbl.getDataLocation().toString()));
                    if (files != null) {
                        for (FileStatus file : files) {
                            fileNames.add(file.getPath().toString());
                        }
                    }
                }
                catch (IOException e) {
                    throw new SemanticException(e);
                }
                aliasToBucketFileNamesMapping.put(alias, fileNames);
            }
            if (bigTblPartsToBucketNumber.size() > 0) {
                Iterator bigTblPartToBucketNumber = bigTblPartsToBucketNumber.entrySet().iterator();
                while (bigTblPartToBucketNumber.hasNext()) {
                    int bucketNumberInPart = (Integer)bigTblPartToBucketNumber.next().getValue();
                    if (this.checkBucketNumberAgainstBigTable(aliasToBucketNumberMapping, bucketNumberInPart)) continue;
                    return null;
                }
            } else {
                int bucketNoInBigTbl = (Integer)aliasToBucketNumberMapping.get(baseBigAlias);
                if (!this.checkBucketNumberAgainstBigTable(aliasToBucketNumberMapping, bucketNoInBigTbl)) {
                    return null;
                }
            }
            MapJoinDesc desc = (MapJoinDesc)mapJoinOp.getConf();
            LinkedHashMap<String, LinkedHashMap<String, ArrayList<String>>> aliasBucketFileNameMapping = new LinkedHashMap<String, LinkedHashMap<String, ArrayList<String>>>();
            if (bigTblPartsToBucketNumber.size() > 0) {
                Collection bucketNamesAllParts = bigTblPartsToBucketFileNames.values();
                for (List partBucketNames : bucketNamesAllParts) {
                    Collections.sort(partBucketNames);
                }
            } else {
                Collections.sort((List)aliasToBucketFileNamesMapping.get(baseBigAlias));
            }
            for (int j = 0; j < joinAliases.size(); ++j) {
                String alias = (String)joinAliases.get(j);
                if (alias.equals(baseBigAlias)) continue;
                Collections.sort((List)aliasToBucketFileNamesMapping.get(alias));
                LinkedHashMap<String, ArrayList<String>> mapping = new LinkedHashMap<String, ArrayList<String>>();
                aliasBucketFileNameMapping.put(alias, mapping);
                if (bigTblPartsToBucketNumber.size() > 0) {
                    Iterator bigTblPartToBucketNames = bigTblPartsToBucketFileNames.entrySet().iterator();
                    Iterator bigTblPartToBucketNum = bigTblPartsToBucketNumber.entrySet().iterator();
                    while (bigTblPartToBucketNames.hasNext()) {
                        assert (bigTblPartToBucketNum.hasNext());
                        int bigTblBucketNum = (Integer)bigTblPartToBucketNum.next().getValue();
                        List bigTblBucketNameList = (List)bigTblPartToBucketNames.next().getValue();
                        this.fillMapping(baseBigAlias, aliasToBucketNumberMapping, aliasToBucketFileNamesMapping, alias, mapping, bigTblBucketNum, bigTblBucketNameList, desc.getBucketFileNameMapping());
                    }
                    continue;
                }
                List bigTblBucketNameList = (List)aliasToBucketFileNamesMapping.get(baseBigAlias);
                int bigTblBucketNum = aliasToBucketNumberMapping.get(baseBigAlias);
                this.fillMapping(baseBigAlias, aliasToBucketNumberMapping, aliasToBucketFileNamesMapping, alias, mapping, bigTblBucketNum, bigTblBucketNameList, desc.getBucketFileNameMapping());
            }
            desc.setAliasBucketFileNameMapping(aliasBucketFileNameMapping);
            desc.setBigTableAlias(baseBigAlias);
            return null;
        }

        private void fillMapping(String baseBigAlias, LinkedHashMap<String, Integer> aliasToBucketNumberMapping, LinkedHashMap<String, List<String>> aliasToBucketFileNamesMapping, String alias, LinkedHashMap<String, ArrayList<String>> mapping, int bigTblBucketNum, List<String> bigTblBucketNameList, LinkedHashMap<String, Integer> bucketFileNameMapping) {
            for (int index = 0; index < bigTblBucketNameList.size(); ++index) {
                String inputBigTBLBucket = bigTblBucketNameList.get(index);
                int smallTblBucketNum = aliasToBucketNumberMapping.get(alias);
                ArrayList<String> resultFileNames = new ArrayList<String>();
                if (bigTblBucketNum >= smallTblBucketNum) {
                    int toAddSmallIndex = index % smallTblBucketNum;
                    if (toAddSmallIndex < aliasToBucketFileNamesMapping.get(alias).size()) {
                        resultFileNames.add(aliasToBucketFileNamesMapping.get(alias).get(toAddSmallIndex));
                    }
                } else {
                    int jump = smallTblBucketNum / bigTblBucketNum;
                    List<String> bucketNames = aliasToBucketFileNamesMapping.get(alias);
                    for (int i = index; i < aliasToBucketFileNamesMapping.get(alias).size(); i += jump) {
                        if (i > aliasToBucketFileNamesMapping.get(alias).size()) continue;
                        resultFileNames.add(bucketNames.get(i));
                    }
                }
                mapping.put(inputBigTBLBucket, resultFileNames);
                bucketFileNameMapping.put(inputBigTBLBucket, index);
            }
        }

        private boolean checkBucketNumberAgainstBigTable(LinkedHashMap<String, Integer> aliasToBucketNumber, int bucketNumberInPart) {
            for (int nxt : aliasToBucketNumber.values()) {
                boolean ok = nxt >= bucketNumberInPart ? nxt % bucketNumberInPart == 0 : bucketNumberInPart % nxt == 0;
                if (ok) continue;
                return false;
            }
            return true;
        }

        private List<String> getOnePartitionBucketFileNames(Partition part) throws SemanticException {
            ArrayList<String> fileNames = new ArrayList<String>();
            try {
                FileSystem fs = FileSystem.get((URI)part.getDataLocation(), (Configuration)this.pGraphContext.getConf());
                FileStatus[] files = fs.listStatus(new Path(part.getDataLocation().toString()));
                if (files != null) {
                    for (FileStatus file : files) {
                        fileNames.add(file.getPath().toString());
                    }
                }
            }
            catch (IOException e) {
                throw new SemanticException(e);
            }
            return fileNames;
        }

        private boolean checkBucketColumns(List<String> bucketColumns, MapJoinDesc mjDesc, int index) {
            List<ExprNodeDesc> keys = mjDesc.getKeys().get((byte)index);
            if (keys == null || bucketColumns == null || bucketColumns.size() == 0) {
                return false;
            }
            ArrayList<String> joinCols = new ArrayList<String>();
            ArrayList<ExprNodeDesc> joinKeys = new ArrayList<ExprNodeDesc>();
            joinKeys.addAll(keys);
            while (joinKeys.size() > 0) {
                ExprNodeDesc node = (ExprNodeDesc)joinKeys.remove(0);
                if (node instanceof ExprNodeColumnDesc) {
                    joinCols.addAll(node.getCols());
                    continue;
                }
                if (node instanceof ExprNodeGenericFuncDesc) {
                    ExprNodeGenericFuncDesc udfNode = (ExprNodeGenericFuncDesc)node;
                    GenericUDF udf = udfNode.getGenericUDF();
                    if (!FunctionRegistry.isDeterministic(udf)) {
                        return false;
                    }
                    joinKeys.addAll(0, udfNode.getChildExprs());
                    continue;
                }
                return false;
            }
            return joinCols.size() != 0 && joinCols.containsAll(bucketColumns);
        }
    }
}

