/*
 * Decompiled with CFR 0.152.
 */
package org.apache.asterix.optimizer.rules;

import java.io.Serializable;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.apache.asterix.common.exceptions.CompilationException;
import org.apache.asterix.common.exceptions.ErrorCode;
import org.apache.asterix.om.base.AInt32;
import org.apache.asterix.om.base.IAObject;
import org.apache.asterix.om.constants.AsterixConstantValue;
import org.apache.asterix.om.functions.BuiltinFunctions;
import org.apache.asterix.om.types.ARecordType;
import org.apache.asterix.om.utils.ConstantExpressionUtil;
import org.apache.commons.lang3.mutable.Mutable;
import org.apache.hyracks.algebricks.common.exceptions.AlgebricksException;
import org.apache.hyracks.algebricks.common.utils.Pair;
import org.apache.hyracks.algebricks.common.utils.Triple;
import org.apache.hyracks.algebricks.core.algebra.base.ILogicalExpression;
import org.apache.hyracks.algebricks.core.algebra.base.ILogicalOperator;
import org.apache.hyracks.algebricks.core.algebra.base.IOptimizationContext;
import org.apache.hyracks.algebricks.core.algebra.base.LogicalExpressionTag;
import org.apache.hyracks.algebricks.core.algebra.base.LogicalOperatorTag;
import org.apache.hyracks.algebricks.core.algebra.base.LogicalVariable;
import org.apache.hyracks.algebricks.core.algebra.expressions.AbstractFunctionCallExpression;
import org.apache.hyracks.algebricks.core.algebra.expressions.ConstantExpression;
import org.apache.hyracks.algebricks.core.algebra.expressions.IAlgebricksConstantValue;
import org.apache.hyracks.algebricks.core.algebra.expressions.IVariableTypeEnvironment;
import org.apache.hyracks.algebricks.core.algebra.expressions.VariableReferenceExpression;
import org.apache.hyracks.algebricks.core.algebra.operators.logical.AbstractLogicalOperator;
import org.apache.hyracks.algebricks.core.algebra.operators.logical.UnionAllOperator;
import org.apache.hyracks.algebricks.core.algebra.visitors.ILogicalExpressionReferenceTransform;
import org.apache.hyracks.algebricks.rewriter.rules.PushMapOperatorThroughUnionRule;

public class AsterixPushMapOperatorThroughUnionRule
extends PushMapOperatorThroughUnionRule {
    private final FieldAccessByIndexCollector fieldAccessByIndexCollector = new FieldAccessByIndexCollector();
    private final FieldAccessByIndexTransformer fieldAccessByIndexTransformer = new FieldAccessByIndexTransformer();
    private final Set<LogicalOperatorTag> allowedKinds;

    public AsterixPushMapOperatorThroughUnionRule(LogicalOperatorTag ... allowedKinds) {
        if (allowedKinds.length == 0) {
            throw new IllegalArgumentException();
        }
        this.allowedKinds = EnumSet.noneOf(LogicalOperatorTag.class);
        Collections.addAll(this.allowedKinds, allowedKinds);
    }

    protected boolean isOperatorKindPushableThroughUnion(ILogicalOperator op) {
        return this.allowedKinds.contains(op.getOperatorTag()) && super.isOperatorKindPushableThroughUnion(op);
    }

    protected Pair<ILogicalOperator, Map<LogicalVariable, LogicalVariable>> deepCopyForBranch(ILogicalOperator op, Set<LogicalVariable> opUsedVars, UnionAllOperator unionAllOp, int branchIdx, IOptimizationContext context) throws AlgebricksException {
        if (((AbstractLogicalOperator)op).hasNestedPlans()) {
            throw new CompilationException(ErrorCode.COMPILATION_ILLEGAL_STATE, op.getSourceLocation(), new Serializable[]{op.getOperatorTag().toString()});
        }
        this.fieldAccessByIndexCollector.reset(unionAllOp, branchIdx, context);
        op.acceptExpressionTransform((ILogicalExpressionReferenceTransform)this.fieldAccessByIndexCollector);
        if (this.fieldAccessByIndexCollector.failed) {
            this.fieldAccessByIndexCollector.clear();
            return null;
        }
        Pair newOpPair = super.deepCopyForBranch(op, opUsedVars, unionAllOp, branchIdx, context);
        if (this.fieldAccessByIndexCollector.hasFieldAccessMappings()) {
            this.fieldAccessByIndexTransformer.reset(unionAllOp, branchIdx, context);
            ((ILogicalOperator)newOpPair.first).acceptExpressionTransform((ILogicalExpressionReferenceTransform)this.fieldAccessByIndexTransformer);
            this.fieldAccessByIndexTransformer.clear();
        }
        this.fieldAccessByIndexCollector.clear();
        return newOpPair;
    }

    private static final class FieldAccessByIndexCollector
    extends AbstractFieldAccessByIndexTransformer {
        private final Map<Pair<LogicalVariable, Integer>, Integer> fieldIndexMap = new HashMap<Pair<LogicalVariable, Integer>, Integer>();
        private boolean failed;

        private FieldAccessByIndexCollector() {
        }

        @Override
        void reset(UnionAllOperator unionAllOp, int branchIdx, IOptimizationContext context) {
            super.reset(unionAllOp, branchIdx, context);
            this.fieldIndexMap.clear();
            this.failed = false;
        }

        @Override
        void clear() {
            super.clear();
            this.fieldIndexMap.clear();
        }

        boolean hasFieldAccessMappings() {
            return !this.fieldIndexMap.isEmpty();
        }

        public boolean transform(Mutable<ILogicalExpression> exprRef) throws AlgebricksException {
            this.visit(exprRef);
            return false;
        }

        private void visit(Mutable<ILogicalExpression> exprRef) throws AlgebricksException {
            boolean mapped;
            ILogicalExpression expr = (ILogicalExpression)exprRef.getValue();
            if (expr.getExpressionTag() != LogicalExpressionTag.FUNCTION_CALL) {
                return;
            }
            AbstractFunctionCallExpression callExpr = (AbstractFunctionCallExpression)expr;
            for (Mutable argExpr : callExpr.getArguments()) {
                this.visit((Mutable<ILogicalExpression>)argExpr);
                if (!this.failed) continue;
                return;
            }
            if (callExpr.getFunctionIdentifier().equals((Object)BuiltinFunctions.FIELD_ACCESS_BY_INDEX) && !(mapped = this.mapFieldIndex(callExpr))) {
                this.failed = true;
            }
        }

        private boolean mapFieldIndex(AbstractFunctionCallExpression callExpr) throws AlgebricksException {
            ILogicalExpression recordExpr = (ILogicalExpression)((Mutable)callExpr.getArguments().get(0)).getValue();
            if (recordExpr.getExpressionTag() != LogicalExpressionTag.VARIABLE) {
                return false;
            }
            Integer fieldIndexPostUnion = ConstantExpressionUtil.getIntArgument((AbstractFunctionCallExpression)callExpr, (int)1);
            if (fieldIndexPostUnion == null) {
                return false;
            }
            LogicalVariable recordVarPostUnion = ((VariableReferenceExpression)recordExpr).getVariableReference();
            for (Triple varMap : this.unionAllOp.getVariableMappings()) {
                if (!((LogicalVariable)varMap.third).equals((Object)recordVarPostUnion)) continue;
                LogicalVariable recordVarPreUnion = this.branchIdx == 0 ? (LogicalVariable)varMap.first : (LogicalVariable)varMap.second;
                IVariableTypeEnvironment typeEnvPostUnion = this.context.getOutputTypeEnvironment((ILogicalOperator)this.unionAllOp);
                ARecordType recordTypePostUnion = (ARecordType)typeEnvPostUnion.getVarType(recordVarPostUnion);
                String fieldName = recordTypePostUnion.getFieldNames()[fieldIndexPostUnion];
                ILogicalOperator inputOpToUnion = (ILogicalOperator)((Mutable)this.unionAllOp.getInputs().get(this.branchIdx)).getValue();
                IVariableTypeEnvironment typeEnvPreUnion = this.context.getOutputTypeEnvironment(inputOpToUnion);
                ARecordType recordTypePreUnion = (ARecordType)typeEnvPreUnion.getVarType(recordVarPreUnion);
                int fieldIndexPreUnion = recordTypePreUnion.getFieldIndex(fieldName);
                if (fieldIndexPreUnion >= 0) {
                    this.fieldIndexMap.put((Pair<LogicalVariable, Integer>)new Pair((Object)recordVarPreUnion, (Object)fieldIndexPostUnion), fieldIndexPreUnion);
                    return true;
                }
                return false;
            }
            return false;
        }
    }

    private final class FieldAccessByIndexTransformer
    extends AbstractFieldAccessByIndexTransformer {
        private FieldAccessByIndexTransformer() {
        }

        public boolean transform(Mutable<ILogicalExpression> exprRef) throws AlgebricksException {
            ILogicalExpression expr = (ILogicalExpression)exprRef.getValue();
            if (expr.getExpressionTag() != LogicalExpressionTag.FUNCTION_CALL) {
                return false;
            }
            boolean applied = false;
            AbstractFunctionCallExpression callExpr = (AbstractFunctionCallExpression)expr;
            for (Mutable argument : callExpr.getArguments()) {
                applied |= this.transform((Mutable<ILogicalExpression>)argument);
            }
            if (callExpr.getFunctionIdentifier().equals((Object)BuiltinFunctions.FIELD_ACCESS_BY_INDEX)) {
                this.transformFieldIndex(callExpr);
                applied = true;
            }
            return applied;
        }

        private void transformFieldIndex(AbstractFunctionCallExpression callExpr) throws AlgebricksException {
            ILogicalExpression recordExpr = (ILogicalExpression)((Mutable)callExpr.getArguments().get(0)).getValue();
            if (recordExpr.getExpressionTag() != LogicalExpressionTag.VARIABLE) {
                throw new CompilationException(ErrorCode.COMPILATION_ILLEGAL_STATE, callExpr.getSourceLocation(), new Serializable[]{recordExpr.getExpressionTag().toString()});
            }
            Integer fieldIndexPostUnion = ConstantExpressionUtil.getIntArgument((AbstractFunctionCallExpression)callExpr, (int)1);
            if (fieldIndexPostUnion == null) {
                throw new CompilationException(ErrorCode.COMPILATION_ILLEGAL_STATE, callExpr.getSourceLocation(), new Serializable[]{""});
            }
            LogicalVariable recordVarPreUnion = ((VariableReferenceExpression)recordExpr).getVariableReference();
            Integer fieldIndexPreUnion = AsterixPushMapOperatorThroughUnionRule.this.fieldAccessByIndexCollector.fieldIndexMap.get(new Pair((Object)recordVarPreUnion, (Object)fieldIndexPostUnion));
            if (fieldIndexPreUnion == null) {
                throw new CompilationException(ErrorCode.COMPILATION_ILLEGAL_STATE, callExpr.getSourceLocation(), new Serializable[]{recordVarPreUnion.toString()});
            }
            ((Mutable)callExpr.getArguments().get(1)).setValue((Object)new ConstantExpression((IAlgebricksConstantValue)new AsterixConstantValue((IAObject)new AInt32(fieldIndexPreUnion.intValue()))));
        }
    }

    private static abstract class AbstractFieldAccessByIndexTransformer
    implements ILogicalExpressionReferenceTransform {
        protected UnionAllOperator unionAllOp;
        protected int branchIdx;
        protected IOptimizationContext context;

        private AbstractFieldAccessByIndexTransformer() {
        }

        void reset(UnionAllOperator unionAllOp, int branchIdx, IOptimizationContext context) {
            this.unionAllOp = unionAllOp;
            this.branchIdx = branchIdx;
            this.context = context;
        }

        void clear() {
            this.unionAllOp = null;
            this.context = null;
        }
    }
}

