/*
 * Decompiled with CFR 0.152.
 */
package com.metamatrix.query.optimizer.relational.rules;

import com.metamatrix.api.exception.query.QueryPlannerException;
import com.metamatrix.common.types.DataTypeManager;
import com.metamatrix.query.analysis.AnalysisRecord;
import com.metamatrix.query.metadata.QueryMetadataInterface;
import com.metamatrix.query.optimizer.capabilities.CapabilitiesFinder;
import com.metamatrix.query.optimizer.relational.OptimizerRule;
import com.metamatrix.query.optimizer.relational.RuleStack;
import com.metamatrix.query.optimizer.relational.plantree.NodeConstants;
import com.metamatrix.query.optimizer.relational.plantree.NodeEditor;
import com.metamatrix.query.optimizer.relational.plantree.NodeFactory;
import com.metamatrix.query.optimizer.relational.plantree.PlanNode;
import com.metamatrix.query.optimizer.relational.rules.FrameUtil;
import com.metamatrix.query.optimizer.relational.rules.RuleConstants;
import com.metamatrix.query.resolver.util.ResolverVisitorUtil;
import com.metamatrix.query.sql.LanguageObject;
import com.metamatrix.query.sql.LanguageVisitor;
import com.metamatrix.query.sql.lang.CompareCriteria;
import com.metamatrix.query.sql.lang.Criteria;
import com.metamatrix.query.sql.lang.JoinType;
import com.metamatrix.query.sql.navigator.AggregateStopNavigator;
import com.metamatrix.query.sql.navigator.DeepPreOrderNavigator;
import com.metamatrix.query.sql.symbol.AggregateSymbol;
import com.metamatrix.query.sql.symbol.AliasSymbol;
import com.metamatrix.query.sql.symbol.Constant;
import com.metamatrix.query.sql.symbol.ElementSymbol;
import com.metamatrix.query.sql.symbol.Expression;
import com.metamatrix.query.sql.symbol.ExpressionSymbol;
import com.metamatrix.query.sql.symbol.Function;
import com.metamatrix.query.sql.symbol.SingleElementSymbol;
import com.metamatrix.query.sql.visitor.AggregateSymbolCollectorVisitor;
import com.metamatrix.query.sql.visitor.ElementCollectorVisitor;
import com.metamatrix.query.util.CommandContext;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;

public class RulePushAggregates
implements OptimizerRule {
    public PlanNode execute(PlanNode plan, QueryMetadataInterface metadata, CapabilitiesFinder capFinder, RuleStack rules, AnalysisRecord analysisRecord, CommandContext context) throws QueryPlannerException {
        List groupNodes = NodeEditor.findAllNodes((PlanNode)plan, (int)23, (int)3);
        if (groupNodes.isEmpty()) {
            return plan;
        }
        boolean pushed = false;
        Iterator iter = groupNodes.iterator();
        while (iter.hasNext()) {
            PlanNode groupNode = (PlanNode)iter.next();
            List groupingExpressions = (List)groupNode.getProperty((Object)NodeConstants.Info.GROUP_COLS);
            if (groupingExpressions == null) continue;
            ArrayList aggregates = new ArrayList();
            ArrayList elements = new ArrayList();
            this.collectAggregates(groupNode, aggregates, elements);
            if (!this.possibleToPush(elements, aggregates) || !this.pushGroupNode(groupNode, groupingExpressions, aggregates, elements, metadata)) continue;
            pushed = true;
        }
        if (pushed) {
            rules.push(RuleConstants.RAISE_ACCESS);
        }
        return plan;
    }

    private boolean possibleToPush(List elements, List aggregates) {
        if (elements.size() == 0) {
            return true;
        }
        if (aggregates.size() == 0) {
            return false;
        }
        Iterator iter = aggregates.iterator();
        while (iter.hasNext()) {
            AggregateSymbol obj = (AggregateSymbol)iter.next();
            if (obj.getExpression() != null) continue;
            return false;
        }
        final ArrayList aggregateExprs = new ArrayList();
        LanguageVisitor visitor = new LanguageVisitor(){

            public void visit(AggregateSymbol obj) {
                aggregateExprs.add(obj);
            }
        };
        iter = elements.iterator();
        while (iter.hasNext()) {
            LanguageObject obj = (LanguageObject)iter.next();
            if (obj instanceof AggregateSymbol) continue;
            DeepPreOrderNavigator.doVisit((LanguageObject)obj, (LanguageVisitor)visitor);
        }
        return aggregateExprs.isEmpty();
    }

    private List collectAggregates(PlanNode groupNode, List aggregates, List elements) {
        for (PlanNode currentNode = groupNode.getParent(); currentNode != null && (currentNode.getType() == 13 || currentNode.getType() == 11); currentNode = currentNode.getParent()) {
            if (currentNode.getType() == 11) {
                List projectedSymbols = (List)currentNode.getProperty((Object)NodeConstants.Info.PROJECT_COLS);
                Iterator iter = projectedSymbols.iterator();
                while (iter.hasNext()) {
                    SingleElementSymbol symbol = (SingleElementSymbol)iter.next();
                    this.collectAggregates((LanguageObject)symbol, aggregates, elements);
                }
                continue;
            }
            Criteria crit = (Criteria)currentNode.getProperty((Object)NodeConstants.Info.SELECT_CRITERIA);
            this.collectAggregates((LanguageObject)crit, aggregates, elements);
        }
        return aggregates;
    }

    private void collectAggregates(LanguageObject obj, List aggregates, List elements) {
        AggregateSymbolCollectorVisitor visitor = new AggregateSymbolCollectorVisitor((Collection)aggregates, (Collection)elements);
        AggregateStopNavigator nav = new AggregateStopNavigator((LanguageVisitor)visitor);
        obj.acceptVisitor((LanguageVisitor)nav);
    }

    private boolean pushGroupNode(PlanNode groupNode, List groupingExpressions, List aggregateExpressions, List elements, QueryMetadataInterface metadata) throws QueryPlannerException {
        boolean pushed = false;
        PlanNode node = groupNode.getFirstChild();
        boolean usesComposable = this.usesComposableAggregates(aggregateExpressions);
        if (node.getType() == 7) {
            boolean pushingFullAggregate = true;
            while (node.getType() == 7) {
                HashSet aggGroups = new HashSet();
                if (aggregateExpressions.size() > 0) {
                    this.collectSymbolGroups(aggregateExpressions, aggGroups);
                } else {
                    this.collectSymbolGroups(groupingExpressions, aggGroups);
                }
                PlanNode targetChild = this.findTargetChild(node, aggGroups);
                if (targetChild == null || !this.joinNodeIsInnerEquijoin(node)) break;
                if (pushingFullAggregate && usesComposable && this.joinNodeIsGroupInvariant(node, groupingExpressions, elements, aggGroups, targetChild, metadata)) {
                    PlanNode groupParent = groupNode.getParent();
                    NodeEditor.removeChildNode((PlanNode)groupParent, (PlanNode)groupNode);
                    NodeEditor.insertNode((PlanNode)node, (PlanNode)targetChild, (PlanNode)groupNode);
                    node = node.getFirstChild();
                    pushed = true;
                    pushingFullAggregate = false;
                    continue;
                }
                pushingFullAggregate = false;
                node = this.insertStagingGroup(groupNode, targetChild, aggGroups, aggregateExpressions, metadata);
                if (node == null) break;
                groupNode = node.getParent();
                pushed = true;
            }
        }
        return pushed;
    }

    private PlanNode findTargetChild(PlanNode node, Collection aggGroups) {
        PlanNode targetChild = null;
        Iterator childIter = node.getChildren().iterator();
        while (childIter.hasNext()) {
            PlanNode child = (PlanNode)childIter.next();
            if (!child.getGroups().containsAll(aggGroups)) continue;
            if (targetChild != null) {
                return null;
            }
            targetChild = child;
        }
        return targetChild;
    }

    private boolean usesComposableAggregates(List aggregateExpressions) {
        Iterator iter = aggregateExpressions.iterator();
        while (iter.hasNext()) {
            AggregateSymbol agg = (AggregateSymbol)iter.next();
            if (!agg.isDistinct() && !agg.getAggregateFunction().equals("COUNT") && !agg.getAggregateFunction().equals("AVG")) continue;
            return false;
        }
        return true;
    }

    private boolean joinNodeIsInnerEquijoin(PlanNode joinNode) {
        JoinType type = (JoinType)joinNode.getProperty((Object)NodeConstants.Info.JOIN_TYPE);
        if (type.equals((Object)JoinType.JOIN_CROSS)) {
            return true;
        }
        if (!type.equals((Object)JoinType.JOIN_INNER)) {
            return false;
        }
        List crits = (List)joinNode.getProperty((Object)NodeConstants.Info.JOIN_CRITERIA);
        Iterator joinCrit = crits.iterator();
        while (joinCrit.hasNext()) {
            Criteria crit = (Criteria)joinCrit.next();
            if (!(crit instanceof CompareCriteria)) {
                return false;
            }
            CompareCriteria comp = (CompareCriteria)crit;
            if (comp.getLeftExpression() instanceof ElementSymbol && comp.getRightExpression() instanceof ElementSymbol) continue;
            return false;
        }
        return true;
    }

    private boolean joinNodeIsGroupInvariant(PlanNode joinNode, List groupingExpressions, List elements, Set aggGroups, PlanNode targetChild, QueryMetadataInterface metadata) {
        HashSet groupingElementGroups = new HashSet();
        this.collectSymbolGroups(elements, groupingElementGroups);
        if (!aggGroups.containsAll(groupingElementGroups)) {
            return false;
        }
        List joinCrits = (List)joinNode.getProperty((Object)NodeConstants.Info.JOIN_CRITERIA);
        if (joinCrits == null || joinCrits.size() == 0) {
            return false;
        }
        Iterator joinCrit = joinCrits.iterator();
        while (joinCrit.hasNext()) {
            Criteria crit = (Criteria)joinCrit.next();
            Collection joinElem = ElementCollectorVisitor.getElements((LanguageObject)crit, (boolean)true);
            Iterator iter = joinElem.iterator();
            while (iter.hasNext()) {
                ElementSymbol element = (ElementSymbol)iter.next();
                if (!targetChild.getGroups().contains(element.getGroupSymbol()) || groupingExpressions.contains(element)) continue;
                return false;
            }
        }
        return true;
    }

    private void collectSymbolGroups(List symbols, Set groups) {
        LinkedList working = new LinkedList(symbols);
        while (working.size() > 0) {
            Object symbol = working.removeFirst();
            if (symbol instanceof ElementSymbol) {
                groups.add(((ElementSymbol)symbol).getGroupSymbol());
                continue;
            }
            if (!(symbol instanceof ExpressionSymbol)) continue;
            ElementCollectorVisitor.getElements((LanguageObject)((ExpressionSymbol)symbol), working);
        }
    }

    private PlanNode insertStagingGroup(PlanNode groupNode, PlanNode targetChild, Set aggGroups, List aggregateExpressions, QueryMetadataInterface metadata) throws QueryPlannerException {
        HashSet<Object> joinElements = new HashSet<Object>();
        List groupCols = (List)groupNode.getProperty((Object)NodeConstants.Info.GROUP_COLS);
        Iterator groupColIter = groupCols.iterator();
        while (groupColIter.hasNext()) {
            SingleElementSymbol symbol = (SingleElementSymbol)groupColIter.next();
            if (joinElements.contains(symbol)) continue;
            if (symbol instanceof ElementSymbol) {
                if (!targetChild.getGroups().contains(((ElementSymbol)symbol).getGroupSymbol())) continue;
                joinElements.add(symbol);
                continue;
            }
            ExpressionSymbol expr = (ExpressionSymbol)symbol;
            Collection elements = ElementCollectorVisitor.getElements((LanguageObject)expr, (boolean)true);
            Iterator elemIter = elements.iterator();
            boolean includeElement = true;
            while (elemIter.hasNext()) {
                ElementSymbol elem = (ElementSymbol)elemIter.next();
                if (aggGroups.contains(elem.getGroupSymbol())) continue;
                includeElement = false;
                break;
            }
            if (!includeElement) continue;
            joinElements.add(symbol);
        }
        PlanNode joinNode = groupNode.getFirstChild();
        List crits = (List)joinNode.getProperty((Object)NodeConstants.Info.JOIN_CRITERIA);
        if (crits != null) {
            Iterator critIter = crits.iterator();
            while (critIter.hasNext()) {
                Criteria crit = (Criteria)critIter.next();
                Collection elements = ElementCollectorVisitor.getElements((LanguageObject)crit, (boolean)true);
                Iterator elemIter = elements.iterator();
                while (elemIter.hasNext()) {
                    ElementSymbol elem = (ElementSymbol)elemIter.next();
                    if (!targetChild.getGroups().contains(elem.getGroupSymbol())) continue;
                    joinElements.add(elem);
                }
            }
        }
        PlanNode stageGroup = NodeFactory.getNewNode((int)23);
        stageGroup.addGroups((Collection)targetChild.getGroups());
        stageGroup.setProperty((Object)NodeConstants.Info.GROUP_COLS, new ArrayList(joinElements));
        Iterator iter = aggregateExpressions.iterator();
        HashMap<SingleElementSymbol, SingleElementSymbol> aggMap = new HashMap<SingleElementSymbol, SingleElementSymbol>();
        while (iter.hasNext()) {
            SingleElementSymbol partitionAgg = (SingleElementSymbol)iter.next();
            SingleElementSymbol combineAgg = this.createStagedAggregate(partitionAgg, aggregateExpressions, metadata);
            if (combineAgg == null) {
                return null;
            }
            aggMap.put(partitionAgg, combineAgg);
        }
        this.swapAggregates(groupNode, aggMap);
        ListIterator aggIter = aggregateExpressions.listIterator();
        while (aggIter.hasNext()) {
            Object agg = aggIter.next();
            Object replacement = aggMap.get(agg);
            if (replacement == null) continue;
            aggIter.set(replacement);
        }
        NodeEditor.insertNode((PlanNode)joinNode, (PlanNode)targetChild, (PlanNode)stageGroup);
        if (joinElements.isEmpty() && ((JoinType)joinNode.getProperty((Object)NodeConstants.Info.JOIN_TYPE)).equals((Object)JoinType.JOIN_CROSS)) {
            PlanNode selectNode = NodeFactory.getNewNode((int)13);
            AggregateSymbol count = new AggregateSymbol("count", "COUNT", false, null);
            selectNode.setProperty((Object)NodeConstants.Info.SELECT_CRITERIA, (Object)new CompareCriteria((Expression)count, 4, (Expression)new Constant((Object)new Integer(0))));
            selectNode.setProperty((Object)NodeConstants.Info.IS_HAVING, (Object)Boolean.TRUE);
            NodeEditor.insertNode((PlanNode)joinNode, (PlanNode)stageGroup, (PlanNode)selectNode);
        }
        return stageGroup.getFirstChild();
    }

    private SingleElementSymbol createStagedAggregate(SingleElementSymbol parentSymbol, List aggregateExpressions, QueryMetadataInterface metadata) {
        if (parentSymbol instanceof AggregateSymbol) {
            AggregateSymbol parentAgg = (AggregateSymbol)parentSymbol;
            if (parentAgg.isDistinct()) {
                return null;
            }
            if (parentAgg.getExpression() instanceof Constant) {
                return null;
            }
            String aggFunction = parentAgg.getAggregateFunction();
            AggregateSymbol innerExpr = parentAgg;
            if (aggFunction.equals("COUNT")) {
                aggFunction = "SUM";
                AggregateSymbol newAgg = new AggregateSymbol(this.getNewAggregateName(parentAgg, aggregateExpressions), aggFunction, false, (Expression)innerExpr);
                ArrayList<AggregateSymbol> aggExprUpdated = new ArrayList<AggregateSymbol>(aggregateExpressions);
                aggExprUpdated.add(newAgg);
                Class outputType = parentAgg.getType();
                Constant convertTargetType = new Constant((Object)DataTypeManager.getDataTypeName((Class)outputType), DataTypeManager.DefaultDataClasses.STRING);
                Function convertFunc = new Function("convert", new Expression[]{newAgg, convertTargetType});
                try {
                    ResolverVisitorUtil.resolveFunction(convertFunc, metadata);
                }
                catch (Exception e) {
                    return null;
                }
                ExpressionSymbol exprSymbol = new ExpressionSymbol(this.getNewAggregateName(newAgg, aggExprUpdated), (Expression)convertFunc);
                return exprSymbol;
            }
            if (aggFunction.equals("AVG")) {
                return null;
            }
            return new AggregateSymbol(this.getNewAggregateName(parentAgg, aggregateExpressions), aggFunction, false, (Expression)innerExpr);
        }
        ExpressionSymbol exprSymbol = (ExpressionSymbol)parentSymbol;
        Function func = (Function)exprSymbol.getExpression();
        AggregateSymbol innerAgg = (AggregateSymbol)func.getArg(0);
        AggregateSymbol newAgg = new AggregateSymbol(this.getNewAggregateName(innerAgg, aggregateExpressions), "SUM", false, (Expression)innerAgg);
        ArrayList<AggregateSymbol> aggExprUpdated = new ArrayList<AggregateSymbol>(aggregateExpressions);
        aggExprUpdated.add(newAgg);
        Function copyFunc = (Function)func.clone();
        copyFunc.getArgs()[0] = newAgg;
        ExpressionSymbol newSymbol = new ExpressionSymbol(this.getNewAggregateName(newAgg, aggExprUpdated), (Expression)copyFunc);
        return newSymbol;
    }

    String getNewAggregateName(AggregateSymbol agg, Collection aggregates) {
        String oldName = agg.getShortName();
        HashSet<String> aggNames = new HashSet<String>();
        Iterator iter = aggregates.iterator();
        while (iter.hasNext()) {
            aggNames.add(((SingleElementSymbol)iter.next()).getCanonicalName());
        }
        int nameCounter = 1;
        String nameRoot = oldName.toUpperCase() + "_";
        String proposedName = nameRoot + nameCounter;
        while (aggNames.contains(proposedName)) {
            proposedName = nameRoot + ++nameCounter;
        }
        return proposedName;
    }

    private void swapAggregates(PlanNode node, Map aggMap) throws QueryPlannerException {
        block6: for (PlanNode current = node; current != null; current = current.getParent()) {
            switch (current.getType()) {
                case 13: {
                    Criteria crit = (Criteria)current.getProperty((Object)NodeConstants.Info.SELECT_CRITERIA);
                    FrameUtil.convertCriteria((Criteria)crit, (Map)aggMap);
                    continue block6;
                }
                case 11: {
                    List projectedSymbols = (List)current.getProperty((Object)NodeConstants.Info.PROJECT_COLS);
                    List converted = this.convertSymbols(projectedSymbols, aggMap);
                    current.setProperty((Object)NodeConstants.Info.PROJECT_COLS, (Object)converted);
                    continue block6;
                }
                case 19: {
                    Map symbolMap = (Map)current.getProperty((Object)NodeConstants.Info.SYMBOL_MAP);
                    Iterator iter = symbolMap.entrySet().iterator();
                    while (iter.hasNext()) {
                        Map.Entry entry = iter.next();
                        Object replacement = aggMap.get(entry.getValue());
                        if (replacement == null) continue;
                        entry.setValue(replacement);
                    }
                    return;
                }
                case 17: {
                    List sortCols = (List)current.getProperty((Object)NodeConstants.Info.SORT_ORDER);
                    ArrayList<Object> copyCols = new ArrayList<Object>(sortCols.size());
                    Iterator iter = sortCols.iterator();
                    while (iter.hasNext()) {
                        SingleElementSymbol symbol = (SingleElementSymbol)iter.next();
                        if (symbol instanceof AliasSymbol) {
                            AliasSymbol alias = (AliasSymbol)symbol;
                            Object replacement = aggMap.get(symbol = alias.getSymbol());
                            if (replacement != null) {
                                alias.setSymbol((SingleElementSymbol)replacement);
                                copyCols.add(alias);
                                continue;
                            }
                            copyCols.add(alias);
                            continue;
                        }
                        Object replacement = aggMap.get(symbol);
                        if (replacement != null) {
                            copyCols.add(replacement);
                            continue;
                        }
                        copyCols.add(symbol);
                    }
                    current.setProperty((Object)NodeConstants.Info.SORT_ORDER, copyCols);
                    continue block6;
                }
            }
        }
    }

    private List convertSymbols(List objs, Map aggMap) {
        if (objs == null) {
            return objs;
        }
        ArrayList<Object> converted = new ArrayList<Object>(objs.size());
        Iterator iter = objs.iterator();
        while (iter.hasNext()) {
            SingleElementSymbol replacement;
            SingleElementSymbol symbol = (SingleElementSymbol)iter.next();
            AliasSymbol alias = null;
            if (symbol instanceof AliasSymbol) {
                alias = (AliasSymbol)symbol;
                symbol = alias.getSymbol();
            }
            if ((replacement = (SingleElementSymbol)aggMap.get(symbol)) != null) {
                if (alias == null) {
                    converted.add(replacement);
                    continue;
                }
                alias.setSymbol(replacement);
                converted.add(alias);
                continue;
            }
            if (alias == null) {
                converted.add(symbol);
                continue;
            }
            converted.add(alias);
        }
        return converted;
    }

    public String toString() {
        return "PushAggregates";
    }
}

