/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNodeId;
import org.apache.iotdb.db.queryengine.plan.relational.planner.DataOrganizationSpecification;
import org.apache.iotdb.db.queryengine.plan.relational.planner.OrderingScheme;
import org.apache.iotdb.db.queryengine.plan.relational.planner.SortOrder;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
import org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolAllocator;
import org.apache.iotdb.db.queryengine.plan.relational.planner.ir.ExpressionRewriter;
import org.apache.iotdb.db.queryengine.plan.relational.planner.ir.ExpressionTreeRewriter;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ApplyNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.LimitNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.Measure;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.PatternRecognitionNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TopKNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.WindowNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.AggregationValuePointer;
import org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.ClassifierValuePointer;
import org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.ExpressionAndValuePointers;
import org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.IrLabel;
import org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.MatchNumberValuePointer;
import org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.ScalarValuePointer;
import org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.ValuePointer;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;

public class SymbolMapper {
    private final Function<Symbol, Symbol> mappingFunction;

    public SymbolMapper(Function<Symbol, Symbol> mappingFunction) {
        this.mappingFunction = Objects.requireNonNull(mappingFunction, "mappingFunction is null");
    }

    public static SymbolMapper symbolMapper(Map<Symbol, Symbol> mapping) {
        return new SymbolMapper(symbol -> {
            while (mapping.containsKey(symbol) && !((Symbol)mapping.get(symbol)).equals(symbol)) {
                symbol = (Symbol)mapping.get(symbol);
            }
            return symbol;
        });
    }

    public static SymbolMapper symbolReallocator(Map<Symbol, Symbol> mapping, SymbolAllocator symbolAllocator) {
        return new SymbolMapper(symbol -> {
            if (mapping.containsKey(symbol)) {
                while (mapping.containsKey(symbol) && !((Symbol)mapping.get(symbol)).equals(symbol)) {
                    symbol = (Symbol)mapping.get(symbol);
                }
                mapping.put((Symbol)symbol, (Symbol)symbol);
                return symbol;
            }
            Symbol newSymbol = symbolAllocator.newSymbol((Symbol)symbol);
            mapping.put((Symbol)symbol, newSymbol);
            mapping.put(newSymbol, newSymbol);
            return newSymbol;
        });
    }

    public Symbol map(Symbol symbol) {
        return this.mappingFunction.apply(symbol);
    }

    public ApplyNode.SetExpression map(ApplyNode.SetExpression expression) {
        if (expression instanceof ApplyNode.Exists) {
            return expression;
        }
        if (expression instanceof ApplyNode.In) {
            ApplyNode.In in = (ApplyNode.In)expression;
            return new ApplyNode.In(this.map(in.getValue()), this.map(in.getReference()));
        }
        if (expression instanceof ApplyNode.QuantifiedComparison) {
            ApplyNode.QuantifiedComparison comparison = (ApplyNode.QuantifiedComparison)expression;
            return new ApplyNode.QuantifiedComparison(comparison.getOperator(), comparison.getQuantifier(), this.map(comparison.getValue()), this.map(comparison.getReference()));
        }
        throw new IllegalArgumentException("Unexpected value: " + expression);
    }

    public List<Symbol> map(List<Symbol> symbols) {
        return (List)symbols.stream().map(this::map).collect(ImmutableList.toImmutableList());
    }

    public List<Symbol> mapAndDistinct(List<Symbol> symbols) {
        return (List)symbols.stream().map(this::map).distinct().collect(ImmutableList.toImmutableList());
    }

    public DataOrganizationSpecification mapAndDistinct(DataOrganizationSpecification specification) {
        return new DataOrganizationSpecification(this.mapAndDistinct(specification.getPartitionBy()), specification.getOrderingScheme().map(this::map));
    }

    public Expression map(Expression expression) {
        return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>(){

            @Override
            public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
                Symbol canonical = SymbolMapper.this.map(Symbol.from(node));
                return canonical.toSymbolReference();
            }
        }, expression);
    }

    public AggregationNode map(AggregationNode node, PlanNode source) {
        return this.map(node, source, node.getPlanNodeId());
    }

    public AggregationNode map(AggregationNode node, PlanNode source, PlanNodeId newNodeId) {
        ImmutableMap.Builder aggregations = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
            aggregations.put((Object)this.map(entry.getKey()), (Object)this.map(entry.getValue()));
        }
        return new AggregationNode(newNodeId, source, (Map<Symbol, AggregationNode.Aggregation>)aggregations.buildOrThrow(), AggregationNode.groupingSets(this.mapAndDistinct(node.getGroupingKeys()), node.getGroupingSetCount(), node.getGlobalGroupingSets()), (List<Symbol>)ImmutableList.of(), node.getStep(), node.getHashSymbol().map(this::map), node.getGroupIdSymbol().map(this::map));
    }

    public AggregationNode.Aggregation map(AggregationNode.Aggregation aggregation) {
        return new AggregationNode.Aggregation(aggregation.getResolvedFunction(), (List)aggregation.getArguments().stream().map(this::map).collect(ImmutableList.toImmutableList()), aggregation.isDistinct(), aggregation.getFilter().map(this::map), aggregation.getOrderingScheme().map(this::map), aggregation.getMask().map(this::map));
    }

    public LimitNode map(LimitNode node, PlanNode source) {
        return new LimitNode(node.getPlanNodeId(), source, node.getCount(), node.getTiesResolvingScheme().map(this::map));
    }

    public OrderingScheme map(OrderingScheme orderingScheme) {
        ImmutableList.Builder newSymbols = ImmutableList.builder();
        ImmutableMap.Builder newOrderings = ImmutableMap.builder();
        HashSet<Symbol> added = new HashSet<Symbol>(orderingScheme.getOrderBy().size());
        for (Symbol symbol : orderingScheme.getOrderBy()) {
            Symbol canonical = this.map(symbol);
            if (!added.add(canonical)) continue;
            newSymbols.add((Object)canonical);
            newOrderings.put((Object)canonical, (Object)orderingScheme.getOrdering(symbol));
        }
        return new OrderingScheme((List<Symbol>)newSymbols.build(), (Map<Symbol, SortOrder>)newOrderings.buildOrThrow());
    }

    public WindowNode map(WindowNode node, PlanNode source) {
        ImmutableMap.Builder newFunctions = ImmutableMap.builder();
        node.getWindowFunctions().forEach((symbol, function) -> {
            List newArguments = (List)function.getArguments().stream().map(this::map).collect(ImmutableList.toImmutableList());
            WindowNode.Frame newFrame = this.map(function.getFrame());
            newFunctions.put((Object)this.map((Symbol)symbol), (Object)new WindowNode.Function(function.getResolvedFunction(), newArguments, newFrame, function.isIgnoreNulls()));
        });
        ImmutableList newPartitionBy = (ImmutableList)node.getSpecification().getPartitionBy().stream().map(this::map).collect(ImmutableList.toImmutableList());
        Optional<OrderingScheme> newOrderingScheme = node.getSpecification().getOrderingScheme().map(this::map);
        DataOrganizationSpecification newSpecification = new DataOrganizationSpecification((List<Symbol>)newPartitionBy, newOrderingScheme);
        return new WindowNode(node.getPlanNodeId(), source, newSpecification, (Map<Symbol, WindowNode.Function>)newFunctions.buildOrThrow(), node.getHashSymbol().map(this::map), (Set)node.getPrePartitionedInputs().stream().map(this::map).collect(ImmutableSet.toImmutableSet()), node.getPreSortedOrderPrefix());
    }

    private WindowNode.Frame map(WindowNode.Frame frame) {
        return new WindowNode.Frame(frame.getType(), frame.getStartType(), frame.getStartValue().map(this::map), frame.getSortKeyCoercedForFrameStartComparison().map(this::map), frame.getEndType(), frame.getEndValue().map(this::map), frame.getSortKeyCoercedForFrameEndComparison().map(this::map), frame.getOriginalStartValue(), frame.getOriginalEndValue());
    }

    public TopKNode map(TopKNode node, List<PlanNode> source) {
        return this.map(node, source, node.getPlanNodeId());
    }

    public TopKNode map(TopKNode node, List<PlanNode> source, PlanNodeId nodeId) {
        return new TopKNode(nodeId, source, this.map(node.getOrderingScheme()), node.getCount(), node.getOutputSymbols().stream().map(this::map).collect(Collectors.toList()), node.isChildrenDataInOrder());
    }

    public PatternRecognitionNode map(PatternRecognitionNode node, PlanNode source) {
        ImmutableMap.Builder newMeasures = ImmutableMap.builder();
        node.getMeasures().forEach((symbol, measure) -> {
            ExpressionAndValuePointers newExpression = this.map(measure.getExpressionAndValuePointers());
            newMeasures.put((Object)this.map((Symbol)symbol), (Object)new Measure(newExpression, measure.getType()));
        });
        ImmutableMap.Builder newVariableDefinitions = ImmutableMap.builder();
        node.getVariableDefinitions().forEach((label, expression) -> newVariableDefinitions.put(label, (Object)this.map((ExpressionAndValuePointers)expression)));
        return new PatternRecognitionNode(node.getPlanNodeId(), source, this.mapAndDistinct(node.getPartitionBy()), node.getOrderingScheme(), node.getHashSymbol().map(this::map), (Map<Symbol, Measure>)newMeasures.buildOrThrow(), node.getRowsPerMatch(), node.getSkipToLabels(), node.getSkipToPosition(), node.getPattern(), (Map<IrLabel, ExpressionAndValuePointers>)newVariableDefinitions.buildOrThrow());
    }

    private ExpressionAndValuePointers map(ExpressionAndValuePointers expressionAndValuePointers) {
        ImmutableList.Builder newAssignments = ImmutableList.builder();
        for (ExpressionAndValuePointers.Assignment assignment : expressionAndValuePointers.getAssignments()) {
            ValuePointer pointer;
            ValuePointer newPointer;
            if (assignment.getValuePointer() instanceof ClassifierValuePointer) {
                newPointer = assignment.getValuePointer();
            } else if (assignment.getValuePointer() instanceof MatchNumberValuePointer) {
                newPointer = assignment.getValuePointer();
            } else if (assignment.getValuePointer() instanceof ScalarValuePointer) {
                pointer = (ScalarValuePointer)assignment.getValuePointer();
                newPointer = new ScalarValuePointer(((ScalarValuePointer)pointer).getLogicalIndexPointer(), this.map(((ScalarValuePointer)pointer).getInputSymbol()));
            } else if (assignment.getValuePointer() instanceof AggregationValuePointer) {
                pointer = (AggregationValuePointer)assignment.getValuePointer();
                List newArguments = (List)((AggregationValuePointer)pointer).getArguments().stream().map(arg_0 -> this.lambda$map$5((AggregationValuePointer)pointer, arg_0)).collect(ImmutableList.toImmutableList());
                newPointer = new AggregationValuePointer(((AggregationValuePointer)pointer).getFunction(), ((AggregationValuePointer)pointer).getSetDescriptor(), newArguments, ((AggregationValuePointer)pointer).getClassifierSymbol(), ((AggregationValuePointer)pointer).getMatchNumberSymbol());
            } else {
                throw new IllegalArgumentException("Unsupported ValuePointer type: " + assignment.getValuePointer().getClass().getName());
            }
            newAssignments.add((Object)new ExpressionAndValuePointers.Assignment(assignment.getSymbol(), newPointer));
        }
        return new ExpressionAndValuePointers(expressionAndValuePointers.getExpression(), (List<ExpressionAndValuePointers.Assignment>)newAssignments.build());
    }

    public static Builder builder() {
        return new Builder();
    }

    private /* synthetic */ Expression lambda$map$5(final AggregationValuePointer pointer, Expression expression) {
        return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>(){

            @Override
            public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
                if (pointer.getClassifierSymbol().isPresent() && Symbol.from(node).equals(pointer.getClassifierSymbol().get()) || pointer.getMatchNumberSymbol().isPresent() && Symbol.from(node).equals(pointer.getMatchNumberSymbol().get())) {
                    return node;
                }
                return SymbolMapper.this.map(node);
            }
        }, expression);
    }

    public static class Builder {
        private final ImmutableMap.Builder<Symbol, Symbol> mappings = ImmutableMap.builder();

        public void put(Symbol from, Symbol to) {
            this.mappings.put((Object)from, (Object)to);
        }

        public SymbolMapper build() {
            return SymbolMapper.symbolMapper((Map<Symbol, Symbol>)this.mappings.buildOrThrow());
        }
    }
}

