/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.physical.batch;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.api.config.OptimizerConfigOptions;
import org.apache.flink.table.connector.source.DynamicTableSource;
import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.plan.abilities.source.AggregatePushDownSpec;
import org.apache.flink.table.planner.plan.abilities.source.ProjectPushDownSpec;
import org.apache.flink.table.planner.plan.abilities.source.SourceAbilityContext;
import org.apache.flink.table.planner.plan.abilities.source.SourceAbilitySpec;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalCalc;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan;
import org.apache.flink.table.planner.plan.schema.TableSourceTable;
import org.apache.flink.table.planner.plan.stats.FlinkStatistic;
import org.apache.flink.table.planner.plan.utils.RexNodeExtractor;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.flink.table.types.logical.RowType;

public abstract class PushLocalAggIntoScanRuleBase
extends RelOptRule {
    public PushLocalAggIntoScanRuleBase(RelOptRuleOperand operand, String description) {
        super(operand, description);
    }

    protected boolean canPushDown(RelOptRuleCall call, BatchPhysicalGroupAggregateBase aggregate, BatchPhysicalTableSourceScan tableSourceScan) {
        TableConfig tableConfig = ShortcutUtils.unwrapContext(call.getPlanner()).getTableConfig();
        if (!((Boolean)tableConfig.get(OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED)).booleanValue()) {
            return false;
        }
        if (aggregate.isFinal() || aggregate.getAggCallList().isEmpty()) {
            return false;
        }
        List<AggregateCall> aggCallList = JavaScalaConversionUtil.toJava(aggregate.getAggCallList());
        for (AggregateCall aggCall : aggCallList) {
            if (!aggCall.isDistinct() && !aggCall.isApproximate() && aggCall.getArgList().size() <= 1 && !aggCall.hasFilter() && aggCall.getCollation().getFieldCollations().isEmpty()) continue;
            return false;
        }
        TableSourceTable tableSourceTable = tableSourceScan.tableSourceTable();
        return tableSourceTable != null && tableSourceTable.tableSource() instanceof SupportsAggregatePushDown && Arrays.stream(tableSourceTable.abilitySpecs()).noneMatch(spec -> spec instanceof AggregatePushDownSpec);
    }

    protected void pushLocalAggregateIntoScan(RelOptRuleCall call, BatchPhysicalGroupAggregateBase localAgg, BatchPhysicalTableSourceScan oldScan) {
        this.pushLocalAggregateIntoScan(call, localAgg, oldScan, null);
    }

    protected void pushLocalAggregateIntoScan(RelOptRuleCall call, BatchPhysicalGroupAggregateBase localAgg, BatchPhysicalTableSourceScan oldScan, int[] calcRefFields) {
        RowType inputType = FlinkTypeFactory.toLogicalRowType(oldScan.getRowType());
        List<int[]> groupingSets = Collections.singletonList(ArrayUtils.addAll((int[])localAgg.grouping(), (int[])localAgg.auxGrouping()));
        List<AggregateCall> aggCallList = JavaScalaConversionUtil.toJava(localAgg.getAggCallList());
        if (calcRefFields != null) {
            groupingSets = this.translateGroupingArgIndex(groupingSets, calcRefFields);
            aggCallList = this.translateAggCallArgIndex(aggCallList, calcRefFields);
        }
        RowType producedType = FlinkTypeFactory.toLogicalRowType(localAgg.getRowType());
        TableSourceTable oldTableSourceTable = oldScan.tableSourceTable();
        DynamicTableSource newTableSource = oldScan.tableSource().copy();
        boolean isPushDownSuccess = AggregatePushDownSpec.apply(inputType, groupingSets, aggCallList, producedType, newTableSource, SourceAbilityContext.from(oldScan));
        if (!isPushDownSuccess) {
            return;
        }
        AggregatePushDownSpec aggregatePushDownSpec = new AggregatePushDownSpec(inputType, groupingSets, aggCallList, producedType);
        TableSourceTable newTableSourceTable = oldTableSourceTable.copy(newTableSource, localAgg.getRowType(), new SourceAbilitySpec[]{aggregatePushDownSpec}).copy(FlinkStatistic.UNKNOWN());
        BatchPhysicalTableSourceScan newScan = oldScan.copy(oldScan.getTraitSet(), newTableSourceTable);
        BatchPhysicalExchange oldExchange = (BatchPhysicalExchange)call.rel(0);
        BatchPhysicalExchange newExchange = oldExchange.copy(oldExchange.getTraitSet(), newScan, oldExchange.getDistribution());
        call.transformTo(newExchange);
    }

    protected boolean isProjectionNotPushedDown(BatchPhysicalTableSourceScan tableSourceScan) {
        TableSourceTable tableSourceTable = tableSourceScan.tableSourceTable();
        return tableSourceTable != null && Arrays.stream(tableSourceTable.abilitySpecs()).noneMatch(spec -> spec instanceof ProjectPushDownSpec);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    protected boolean isInputRefOnly(BatchPhysicalCalc calc) {
        RexProgram program = calc.getProgram();
        if (program.getCondition() != null) {
            return false;
        }
        if (program.getProjectList().isEmpty()) return false;
        if (!program.getProjectList().stream().map(calc.getProgram()::expandLocalRef).allMatch(RexInputRef.class::isInstance)) return false;
        return true;
    }

    protected int[] getRefFiledIndex(BatchPhysicalCalc calc) {
        List<RexNode> projects = calc.getProgram().getProjectList().stream().map(calc.getProgram()::expandLocalRef).collect(Collectors.toList());
        return RexNodeExtractor.extractRefInputFields(projects);
    }

    protected List<int[]> translateGroupingArgIndex(List<int[]> groupingSets, int[] refFields) {
        ArrayList<int[]> newGroupingSets = new ArrayList<int[]>();
        groupingSets.forEach(grouping -> {
            int[] newGrouping = new int[((int[])grouping).length];
            for (int i = 0; i < ((int[])grouping).length; ++i) {
                int argIndex = grouping[i];
                newGrouping[i] = refFields[argIndex];
            }
            newGroupingSets.add(newGrouping);
        });
        return newGroupingSets;
    }

    protected List<AggregateCall> translateAggCallArgIndex(List<AggregateCall> aggCallList, int[] refFields) {
        ArrayList<AggregateCall> newAggCallList = new ArrayList<AggregateCall>();
        aggCallList.forEach(aggCall -> {
            ArrayList<Integer> argList = new ArrayList<Integer>();
            for (int i = 0; i < aggCall.getArgList().size(); ++i) {
                int argIndex = aggCall.getArgList().get(i);
                argList.add(refFields[argIndex]);
            }
            newAggCallList.add(aggCall.copy(argList, aggCall.filterArg, aggCall.collation));
        });
        return newAggCallList;
    }
}

