/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.adaptivebatch;

import java.util.List;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.JobManagerOptions;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingResultInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.VertexParallelismDecider;
import org.apache.flink.util.MathUtils;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DefaultVertexParallelismDecider
implements VertexParallelismDecider {
    private static final Logger LOG = LoggerFactory.getLogger(DefaultVertexParallelismDecider.class);
    private static final double CAP_RATIO_OF_BROADCAST = 0.5;
    private final int maxParallelism;
    private final int minParallelism;
    private final long dataVolumePerTask;
    private final int defaultSourceParallelism;

    private DefaultVertexParallelismDecider(int maxParallelism, int minParallelism, MemorySize dataVolumePerTask, int defaultSourceParallelism) {
        Preconditions.checkArgument((minParallelism > 0 ? 1 : 0) != 0, (Object)"The minimum parallelism must be larger than 0.");
        Preconditions.checkArgument((maxParallelism >= minParallelism ? 1 : 0) != 0, (Object)"Maximum parallelism should be greater than or equal to the minimum parallelism.");
        Preconditions.checkArgument((defaultSourceParallelism > 0 ? 1 : 0) != 0, (Object)"The default source parallelism must be larger than 0.");
        Preconditions.checkNotNull((Object)dataVolumePerTask);
        this.maxParallelism = maxParallelism;
        this.minParallelism = minParallelism;
        this.dataVolumePerTask = dataVolumePerTask.getBytes();
        this.defaultSourceParallelism = defaultSourceParallelism;
    }

    @Override
    public int decideParallelismForVertex(List<BlockingResultInfo> consumedResults) {
        if (consumedResults.isEmpty()) {
            return this.defaultSourceParallelism;
        }
        return this.calculateParallelism(consumedResults);
    }

    private int calculateParallelism(List<BlockingResultInfo> consumedResults) {
        long broadcastBytes = consumedResults.stream().filter(BlockingResultInfo::isBroadcast).mapToLong(consumedResult -> consumedResult.getBlockingPartitionSizes().stream().reduce(0L, Long::sum)).sum();
        long nonBroadcastBytes = consumedResults.stream().filter(consumedResult -> !consumedResult.isBroadcast()).mapToLong(consumedResult -> consumedResult.getBlockingPartitionSizes().stream().reduce(0L, Long::sum)).sum();
        long expectedMaxBroadcastBytes = (long)Math.ceil((double)this.dataVolumePerTask * 0.5);
        if (broadcastBytes > expectedMaxBroadcastBytes) {
            LOG.info("The size of broadcast data {} is larger than the expected maximum value {} ('{}' * {}). Use {} as the size of broadcast data to decide the parallelism.", new Object[]{new MemorySize(broadcastBytes), new MemorySize(expectedMaxBroadcastBytes), JobManagerOptions.ADAPTIVE_BATCH_SCHEDULER_AVG_DATA_VOLUME_PER_TASK.key(), 0.5, new MemorySize(expectedMaxBroadcastBytes)});
            broadcastBytes = expectedMaxBroadcastBytes;
        }
        int initialParallelism = (int)Math.ceil((double)nonBroadcastBytes / (double)(this.dataVolumePerTask - broadcastBytes));
        int parallelism = DefaultVertexParallelismDecider.normalizeParallelism(initialParallelism);
        LOG.debug("The size of broadcast data is {}, the size of non-broadcast data is {}, the initially decided parallelism is {}, after normalize is {}", new Object[]{new MemorySize(broadcastBytes), new MemorySize(nonBroadcastBytes), initialParallelism, parallelism});
        if (parallelism < this.minParallelism) {
            LOG.info("The initially normalized parallelism {} is smaller than the normalized minimum parallelism {}. Use {} as the finally decided parallelism.", new Object[]{parallelism, this.minParallelism, this.minParallelism});
            parallelism = this.minParallelism;
        } else if (parallelism > this.maxParallelism) {
            LOG.info("The initially normalized parallelism {} is larger than the normalized maximum parallelism {}. Use {} as the finally decided parallelism.", new Object[]{parallelism, this.maxParallelism, this.maxParallelism});
            parallelism = this.maxParallelism;
        }
        return parallelism;
    }

    @VisibleForTesting
    int getMaxParallelism() {
        return this.maxParallelism;
    }

    @VisibleForTesting
    int getMinParallelism() {
        return this.minParallelism;
    }

    static DefaultVertexParallelismDecider from(Configuration configuration) {
        int minParallelism;
        int maxParallelism = DefaultVertexParallelismDecider.getNormalizedMaxParallelism(configuration);
        Preconditions.checkState((maxParallelism >= (minParallelism = DefaultVertexParallelismDecider.getNormalizedMinParallelism(configuration)) ? 1 : 0) != 0, (Object)String.format("Invalid configuration: '%s' should be greater than or equal to '%s' and the range must contain at least one power of 2.", JobManagerOptions.ADAPTIVE_BATCH_SCHEDULER_MAX_PARALLELISM.key(), JobManagerOptions.ADAPTIVE_BATCH_SCHEDULER_MIN_PARALLELISM.key()));
        return new DefaultVertexParallelismDecider(maxParallelism, minParallelism, (MemorySize)configuration.get(JobManagerOptions.ADAPTIVE_BATCH_SCHEDULER_AVG_DATA_VOLUME_PER_TASK), (Integer)configuration.get(JobManagerOptions.ADAPTIVE_BATCH_SCHEDULER_DEFAULT_SOURCE_PARALLELISM));
    }

    static int getNormalizedMaxParallelism(Configuration configuration) {
        return MathUtils.roundDownToPowerOf2((int)configuration.getInteger(JobManagerOptions.ADAPTIVE_BATCH_SCHEDULER_MAX_PARALLELISM));
    }

    static int getNormalizedMinParallelism(Configuration configuration) {
        return MathUtils.roundUpToPowerOfTwo((int)configuration.getInteger(JobManagerOptions.ADAPTIVE_BATCH_SCHEDULER_MIN_PARALLELISM));
    }

    static int normalizeParallelism(int parallelism) {
        int down = MathUtils.roundDownToPowerOf2((int)parallelism);
        int up = MathUtils.roundUpToPowerOfTwo((int)parallelism);
        return parallelism < (up + down) / 2 ? down : up;
    }
}

