/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.metadata;

import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.SetOp;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.metadata.BuiltInMetadata;
import org.apache.calcite.rel.metadata.MetadataDef;
import org.apache.calcite.rel.metadata.MetadataHandler;
import org.apache.calcite.rel.metadata.ReflectiveRelMetadataProvider;
import org.apache.calcite.rel.metadata.RelMetadataProvider;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.util.Arrow;
import org.apache.calcite.util.ArrowSet;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.Mappings;
import org.checkerframework.checker.nullness.qual.Nullable;

public class RelMdFunctionalDependency
implements MetadataHandler<BuiltInMetadata.FunctionalDependency> {
    public static final RelMetadataProvider SOURCE = ReflectiveRelMetadataProvider.reflectiveSource(new RelMdFunctionalDependency(), BuiltInMetadata.FunctionalDependency.Handler.class);

    protected RelMdFunctionalDependency() {
    }

    @Override
    public MetadataDef<BuiltInMetadata.FunctionalDependency> getDef() {
        return BuiltInMetadata.FunctionalDependency.DEF;
    }

    public @Nullable Boolean determines(RelNode rel, RelMetadataQuery mq, int determinant, int dependent) {
        return this.determinesSet(rel, mq, ImmutableBitSet.of(determinant), ImmutableBitSet.of(dependent));
    }

    public Boolean determinesSet(RelNode rel, RelMetadataQuery mq, ImmutableBitSet determinants, ImmutableBitSet dependents) {
        ArrowSet fdSet = mq.getFDs(rel);
        return fdSet.implies(determinants, dependents);
    }

    public ImmutableBitSet dependents(RelNode rel, RelMetadataQuery mq, ImmutableBitSet ordinals) {
        ArrowSet fdSet = mq.getFDs(rel);
        return fdSet.dependents(ordinals);
    }

    public Set<ImmutableBitSet> determinants(RelNode rel, RelMetadataQuery mq, ImmutableBitSet ordinals) {
        ArrowSet fdSet = mq.getFDs(rel);
        return fdSet.determinants(ordinals);
    }

    public ArrowSet getFDs(RelNode rel, RelMetadataQuery mq) {
        if ((rel = rel.stripped()) instanceof TableScan) {
            return RelMdFunctionalDependency.getTableScanFD((TableScan)rel);
        }
        if (rel instanceof Project) {
            return this.getProjectFD((Project)rel, mq);
        }
        if (rel instanceof Aggregate) {
            return this.getAggregateFD((Aggregate)rel, mq);
        }
        if (rel instanceof Join) {
            return this.getJoinFD((Join)rel, mq);
        }
        if (rel instanceof Calc) {
            return this.getCalcFD((Calc)rel, mq);
        }
        if (rel instanceof Filter) {
            return this.getFilterFD((Filter)rel, mq);
        }
        if (rel instanceof SetOp) {
            return ArrowSet.EMPTY;
        }
        if (rel instanceof Correlate) {
            return ArrowSet.EMPTY;
        }
        return this.getFD(rel.getInputs(), mq);
    }

    private ArrowSet getFD(List<RelNode> inputs, RelMetadataQuery mq) {
        if (inputs.size() != 1) {
            return ArrowSet.EMPTY;
        }
        return mq.getFDs(inputs.get(0));
    }

    private static ArrowSet getTableScanFD(TableScan rel) {
        ArrowSet.Builder fdBuilder = new ArrowSet.Builder();
        RelOptTable table = rel.getTable();
        List<ImmutableBitSet> keys = table.getKeys();
        if (keys == null || keys.isEmpty()) {
            return fdBuilder.build();
        }
        for (ImmutableBitSet key : keys) {
            ImmutableBitSet allColumns = ImmutableBitSet.range(rel.getRowType().getFieldCount());
            ImmutableBitSet dependents = allColumns.except(key);
            if (dependents.isEmpty()) continue;
            fdBuilder.addArrow(key, dependents);
        }
        return fdBuilder.build();
    }

    private ArrowSet getProjectFD(Project rel, RelMetadataQuery mq) {
        return this.getProjectionFD(rel.getInput(), rel.getProjects(), mq);
    }

    private ArrowSet getProjectionFD(RelNode input, List<RexNode> projections, RelMetadataQuery mq) {
        ArrowSet inputFdSet = mq.getFDs(input);
        ArrowSet.Builder fdBuilder = new ArrowSet.Builder();
        Mapping inputToOutputMap = RelOptUtil.permutation(projections, input.getRowType()).inverse();
        RelMdFunctionalDependency.mapInputFDs(inputFdSet, inputToOutputMap, fdBuilder);
        int fieldCount = projections.size();
        ImmutableBitSet[] inputBits = new ImmutableBitSet[fieldCount];
        HashMap<RexNode, Integer> uniqueExprToIndex = new HashMap<RexNode, Integer>();
        HashMap<RexLiteral, Integer> literalToIndex = new HashMap<RexLiteral, Integer>();
        HashMap<RexInputRef, Integer> refToIndex = new HashMap<RexInputRef, Integer>();
        for (int i = 0; i < fieldCount; ++i) {
            RexNode expr = projections.get(i);
            if (!RexUtil.isDeterministic(expr)) continue;
            Integer prev = uniqueExprToIndex.putIfAbsent(expr, i);
            if (prev != null) {
                fdBuilder.addBidirectionalArrow(prev, i);
            }
            if (expr instanceof RexLiteral) {
                literalToIndex.put((RexLiteral)expr, i);
                continue;
            }
            if (expr instanceof RexInputRef) {
                refToIndex.put((RexInputRef)expr, i);
                inputBits[i] = ImmutableBitSet.of(((RexInputRef)expr).getIndex());
            }
            inputBits[i] = RelOptUtil.InputFinder.bits(expr);
        }
        uniqueExprToIndex.keySet().removeIf(key -> key instanceof RexLiteral);
        for (Map.Entry entry : uniqueExprToIndex.entrySet()) {
            RexNode expr = (RexNode)entry.getKey();
            Integer i = (Integer)entry.getValue();
            literalToIndex.values().forEach(l -> fdBuilder.addArrow(i, (int)l));
            refToIndex.forEach((k, v) -> {
                ImmutableBitSet bitSet;
                ImmutableBitSet refIndex = ImmutableBitSet.of(k.getIndex());
                ImmutableBitSet immutableBitSet = bitSet = expr instanceof RexInputRef ? ImmutableBitSet.of(((RexInputRef)expr).getIndex()) : inputBits[i];
                if (inputFdSet.implies(refIndex, bitSet)) {
                    fdBuilder.addArrow((int)v, i);
                }
            });
        }
        return fdBuilder.build();
    }

    private static void mapInputFDs(ArrowSet inputFdSet, Mappings.TargetMapping mapping, ArrowSet.Builder outputFdBuilder) {
        for (Arrow inputFd : inputFdSet.getArrows()) {
            ImmutableBitSet mappedDependents;
            ImmutableBitSet determinants = inputFd.getDeterminants();
            ImmutableBitSet dependents = inputFd.getDependents();
            ImmutableBitSet mappedDeterminants = RelMdFunctionalDependency.mapAllCols(determinants, mapping);
            if (mappedDeterminants.isEmpty() || (mappedDependents = RelMdFunctionalDependency.mapAvailableCols(dependents, mapping)).isEmpty()) continue;
            outputFdBuilder.addArrow(mappedDeterminants, mappedDependents);
        }
    }

    private static ImmutableBitSet mapAllCols(ImmutableBitSet ordinals, Mappings.TargetMapping mapping) {
        ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
        for (int ord : ordinals) {
            int mappedOrd = mapping.getTargetOpt(ord);
            if (mappedOrd < 0) {
                return ImmutableBitSet.of();
            }
            builder.set(mappedOrd);
        }
        return builder.build();
    }

    private static ImmutableBitSet mapAvailableCols(ImmutableBitSet ordinals, Mappings.TargetMapping mapping) {
        ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
        for (int ord : ordinals) {
            int mappedOrd = mapping.getTargetOpt(ord);
            if (mappedOrd < 0) continue;
            builder.set(mappedOrd);
        }
        return builder.build();
    }

    private ArrowSet getAggregateFD(Aggregate rel, RelMetadataQuery mq) {
        ArrowSet.Builder fdBuilder = new ArrowSet.Builder();
        ArrowSet inputFdSet = mq.getFDs(rel.getInput());
        ImmutableBitSet groupSet = rel.getGroupSet();
        if (Aggregate.isSimple(rel)) {
            for (Arrow inputFd : inputFdSet.getArrows()) {
                ImmutableBitSet determinants = inputFd.getDeterminants();
                ImmutableBitSet dependents = inputFd.getDependents();
                if (!groupSet.contains(determinants) || !groupSet.contains(dependents)) continue;
                fdBuilder.addArrow(determinants, dependents);
            }
            Iterator<Integer> iterator = groupSet.iterator();
            while (iterator.hasNext()) {
                int groupCol = (Integer)iterator.next();
                ImmutableBitSet singleton = ImmutableBitSet.of(groupCol);
                ImmutableBitSet closure = inputFdSet.dependents(singleton);
                ImmutableBitSet groupDependents = closure.intersect(groupSet).except(singleton);
                if (groupDependents.isEmpty()) continue;
                fdBuilder.addArrow(singleton, groupDependents);
            }
        }
        if (!groupSet.isEmpty() && !rel.getAggCallList().isEmpty()) {
            ImmutableBitSet aggCols = ImmutableBitSet.range(rel.getGroupCount(), rel.getRowType().getFieldCount());
            fdBuilder.addArrow(groupSet, aggCols);
        }
        return fdBuilder.build();
    }

    private ArrowSet getFilterFD(Filter rel, RelMetadataQuery mq) {
        ArrowSet inputSet = mq.getFDs(rel.getInput());
        ArrowSet.Builder fdBuilder = new ArrowSet.Builder();
        RelMdFunctionalDependency.addFDsFromEqualityCondition(rel.getCondition(), fdBuilder);
        return fdBuilder.build().union(inputSet);
    }

    private ArrowSet getJoinFD(Join rel, RelMetadataQuery mq) {
        ArrowSet leftFdSet = mq.getFDs(rel.getLeft());
        ArrowSet rightFdSet = mq.getFDs(rel.getRight());
        int leftFieldCount = rel.getLeft().getRowType().getFieldCount();
        JoinRelType joinType = rel.getJoinType();
        switch (joinType) {
            case INNER: 
            case LEFT: 
            case RIGHT: {
                ArrowSet.Builder joinFdBuilder = new ArrowSet.Builder().addArrowSet(leftFdSet.union(this.shiftFdSet(rightFdSet, leftFieldCount)));
                RelMdFunctionalDependency.addFDsFromEqualityCondition(rel.getCondition(), joinFdBuilder);
                return joinFdBuilder.build();
            }
            case SEMI: 
            case ANTI: {
                return leftFdSet.clone();
            }
        }
        return ArrowSet.EMPTY;
    }

    private ArrowSet getCalcFD(Calc rel, RelMetadataQuery mq) {
        List<RexNode> projections = rel.getProgram().expandList(rel.getProgram().getProjectList());
        return this.getProjectionFD(rel.getInput(), projections, mq);
    }

    private ArrowSet shiftFdSet(ArrowSet fdSet, int offset) {
        ArrowSet.Builder shiftedFdSetBuilder = new ArrowSet.Builder();
        for (Arrow fd : fdSet.getArrows()) {
            ImmutableBitSet shiftedDeterminants = fd.getDeterminants().shift(offset);
            ImmutableBitSet shiftedDependents = fd.getDependents().shift(offset);
            shiftedFdSetBuilder.addArrow(shiftedDeterminants, shiftedDependents);
        }
        return shiftedFdSetBuilder.build();
    }

    private static void addFDsFromEqualityCondition(RexNode condition, ArrowSet.Builder builder) {
        block4: {
            RexCall call;
            block3: {
                if (!(condition instanceof RexCall)) {
                    return;
                }
                call = (RexCall)condition;
                if (call.getOperator().getKind() != SqlKind.EQUALS && call.getOperator().getKind() != SqlKind.IS_NOT_DISTINCT_FROM) break block3;
                List<RexNode> operands = call.getOperands();
                if (operands.size() != 2) break block4;
                RexNode left = operands.get(0);
                RexNode right = operands.get(1);
                if (!(left instanceof RexInputRef) || !(right instanceof RexInputRef)) break block4;
                int leftRef = ((RexInputRef)left).getIndex();
                int rightRef = ((RexInputRef)right).getIndex();
                builder.addBidirectionalArrow(leftRef, rightRef);
                break block4;
            }
            if (call.getOperator().getKind() == SqlKind.AND) {
                for (RexNode operand : call.getOperands()) {
                    RelMdFunctionalDependency.addFDsFromEqualityCondition(operand, builder);
                }
            }
        }
    }
}

