/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.service.deploy.worker.congestcontrol;

import com.google.common.annotations.VisibleForTesting;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.quota.UserTrafficQuota;
import org.apache.celeborn.common.quota.WorkerTrafficQuota;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.ThreadUtils;
import org.apache.celeborn.server.common.service.config.ConfigService;
import org.apache.celeborn.service.deploy.worker.WorkerSource;
import org.apache.celeborn.service.deploy.worker.congestcontrol.BufferStatusHub;
import org.apache.celeborn.service.deploy.worker.congestcontrol.UserBufferInfo;
import org.apache.celeborn.service.deploy.worker.congestcontrol.UserCongestionControlContext;
import org.apache.celeborn.service.deploy.worker.memory.MemoryManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CongestionController {
    private static final Logger logger = LoggerFactory.getLogger(CongestionController.class);
    private static volatile CongestionController _INSTANCE = null;
    private final WorkerSource workerSource;
    private final int sampleTimeWindowSeconds;
    private final long userInactiveTimeMills;
    private final AtomicBoolean overHighWatermark = new AtomicBoolean(false);
    private final BufferStatusHub consumedBufferStatusHub;
    private final BufferStatusHub producedBufferStatusHub;
    private final ConcurrentHashMap<UserIdentifier, UserBufferInfo> userBufferStatuses;
    private final ScheduledExecutorService removeUserExecutorService;
    private final ScheduledExecutorService checkService;
    private final ConcurrentHashMap<UserIdentifier, UserCongestionControlContext> userCongestionContextMap;
    private final ConfigService configService;
    private final UserTrafficQuota defaultUserQuota;
    private volatile WorkerTrafficQuota workerTrafficQuota;

    protected CongestionController(WorkerSource workerSource, int sampleTimeWindowSeconds, CelebornConf conf, ConfigService configService) {
        this.workerSource = workerSource;
        this.sampleTimeWindowSeconds = sampleTimeWindowSeconds;
        this.userInactiveTimeMills = conf.workerCongestionControlUserInactiveIntervalMs();
        this.consumedBufferStatusHub = new BufferStatusHub(sampleTimeWindowSeconds);
        this.producedBufferStatusHub = new BufferStatusHub(sampleTimeWindowSeconds);
        this.userBufferStatuses = JavaUtils.newConcurrentHashMap();
        this.userCongestionContextMap = JavaUtils.newConcurrentHashMap();
        this.defaultUserQuota = new UserTrafficQuota(conf.workerCongestionControlUserProduceSpeedHighWatermark(), conf.workerCongestionControlUserProduceSpeedLowWatermark());
        this.workerTrafficQuota = new WorkerTrafficQuota(conf.workerCongestionControlDiskBufferHighWatermark(), conf.workerCongestionControlDiskBufferLowWatermark(), conf.workerCongestionControlWorkerProduceSpeedHighWatermark(), conf.workerCongestionControlWorkerProduceSpeedLowWatermark());
        this.removeUserExecutorService = ThreadUtils.newDaemonSingleThreadScheduledExecutor((String)"worker-congestion-controller-inactive-user-remover");
        this.removeUserExecutorService.scheduleWithFixedDelay(this::removeInactiveUsers, 0L, this.userInactiveTimeMills, TimeUnit.MILLISECONDS);
        this.checkService = ThreadUtils.newDaemonSingleThreadScheduledExecutor((String)"worker-congestion-controller-checker");
        this.checkService.scheduleWithFixedDelay(this::checkCongestion, 0L, conf.workerCongestionControlCheckIntervalMs(), TimeUnit.MILLISECONDS);
        this.workerSource.addGauge(WorkerSource.POTENTIAL_CONSUME_SPEED(), this::getPotentialConsumeSpeed);
        this.workerSource.addGauge(WorkerSource.WORKER_CONSUME_SPEED(), this.consumedBufferStatusHub::avgBytesPerSec);
        this.configService = configService;
        if (configService != null) {
            this.updateQuota();
            configService.registerListenerOnConfigUpdate(this::updateQuota);
        }
    }

    public static synchronized CongestionController initialize(WorkerSource workSource, int sampleTimeWindowSeconds, CelebornConf conf, ConfigService configService) {
        _INSTANCE = new CongestionController(workSource, sampleTimeWindowSeconds, conf, configService);
        return _INSTANCE;
    }

    public static CongestionController instance() {
        return _INSTANCE;
    }

    public boolean isUserCongested(UserCongestionControlContext userCongestionControlContext) {
        long avgConsumeSpeed;
        if (this.userBufferStatuses.isEmpty()) {
            return false;
        }
        UserIdentifier userIdentifier = userCongestionControlContext.getUserIdentifier();
        long userProduceSpeed = this.getUserProduceSpeed(userCongestionControlContext.getUserBufferInfo());
        UserTrafficQuota userTrafficQuota = userCongestionControlContext.getUserTrafficQuota();
        if (this.overHighWatermark.get() && userProduceSpeed > (avgConsumeSpeed = this.getPotentialProduceSpeed())) {
            if (logger.isDebugEnabled()) {
                logger.debug("The user {}, produceSpeed is {}, while consumeSpeed is {}, need to congest it.", new Object[]{userIdentifier, userProduceSpeed, avgConsumeSpeed});
            }
            return true;
        }
        if (userProduceSpeed > userTrafficQuota.userProduceSpeedHighWatermark()) {
            userCongestionControlContext.onCongestionControl();
            if (logger.isDebugEnabled()) {
                logger.debug("The user {}, produceSpeed is {}, while userProduceSpeedHighWatermark is {}, need to congest it.", new Object[]{userIdentifier, userProduceSpeed, userTrafficQuota.userProduceSpeedHighWatermark()});
            }
        } else if (userCongestionControlContext.inCongestionControl() && userProduceSpeed < userTrafficQuota.userProduceSpeedLowWatermark()) {
            userCongestionControlContext.offCongestionControl();
        }
        return userCongestionControlContext.inCongestionControl();
    }

    public UserBufferInfo getUserBuffer(UserIdentifier userIdentifier) {
        return this.userBufferStatuses.computeIfAbsent(userIdentifier, user -> {
            logger.info("New user {} comes, initializing its rate status", user);
            BufferStatusHub bufferStatusHub = new BufferStatusHub(this.sampleTimeWindowSeconds);
            UserBufferInfo userInfo = new UserBufferInfo(System.currentTimeMillis(), bufferStatusHub);
            return userInfo;
        });
    }

    public void consumeBytes(int numBytes) {
        long currentTimeMillis = System.currentTimeMillis();
        BufferStatusHub.BufferStatusNode node = new BufferStatusHub.BufferStatusNode(numBytes);
        this.consumedBufferStatusHub.add(currentTimeMillis, node);
    }

    public long getTotalPendingBytes() {
        return MemoryManager.instance().getMemoryUsage();
    }

    public void trimMemoryUsage() {
        MemoryManager.instance().trimAllListeners();
    }

    public long getPotentialConsumeSpeed() {
        if (this.userBufferStatuses.size() == 0) {
            return 0L;
        }
        return this.consumedBufferStatusHub.avgBytesPerSec() / (long)this.userBufferStatuses.size();
    }

    public long getPotentialProduceSpeed() {
        if (this.userBufferStatuses.size() == 0) {
            return 0L;
        }
        return this.producedBufferStatusHub.avgBytesPerSec() / (long)this.userBufferStatuses.size();
    }

    private long getUserProduceSpeed(UserBufferInfo userBufferInfo) {
        if (userBufferInfo != null) {
            return userBufferInfo.getBufferStatusHub().avgBytesPerSec();
        }
        return 0L;
    }

    private void removeInactiveUsers() {
        try {
            long currentTimeMillis = System.currentTimeMillis();
            for (Map.Entry<UserIdentifier, UserBufferInfo> next : this.userBufferStatuses.entrySet()) {
                UserIdentifier userIdentifier = next.getKey();
                UserBufferInfo userBufferInfo = next.getValue();
                if (currentTimeMillis - userBufferInfo.getTimestamp() < this.userInactiveTimeMills) continue;
                this.userBufferStatuses.remove(userIdentifier);
                this.userCongestionContextMap.remove(userIdentifier);
                this.workerSource.removeGauge(WorkerSource.USER_PRODUCE_SPEED(), userIdentifier.toMap());
                logger.info("User {} has been expired, remove from rate limit list", (Object)userIdentifier);
            }
        }
        catch (Exception e) {
            logger.error("Error occurs when removing inactive users", (Throwable)e);
        }
    }

    protected void checkCongestion() {
        try {
            long pendingConsume = this.getTotalPendingBytes();
            long workerProduceSpeed = this.producedBufferStatusHub.avgBytesPerSec();
            if (pendingConsume < this.workerTrafficQuota.diskBufferLowWatermark() && workerProduceSpeed < this.workerTrafficQuota.workerProduceSpeedLowWatermark()) {
                if (this.overHighWatermark.compareAndSet(true, false)) {
                    logger.info("Pending consume and produce speed is lower than low watermark, exit congestion control");
                }
                return;
            }
            if ((pendingConsume > this.workerTrafficQuota.diskBufferHighWatermark() || workerProduceSpeed > this.workerTrafficQuota.workerProduceSpeedHighWatermark()) && this.overHighWatermark.compareAndSet(false, true)) {
                logger.info("Pending consume or produce speed is higher than high watermark, need congestion control");
            }
            if (this.overHighWatermark.get()) {
                this.trimMemoryUsage();
            }
        }
        catch (Exception e) {
            logger.error("Congestion check error", (Throwable)e);
        }
    }

    public Boolean isOverHighWatermark() {
        return this.overHighWatermark.get();
    }

    public void close() {
        logger.info("Closing {}", (Object)this.getClass().getSimpleName());
        ThreadUtils.shutdown((ExecutorService)this.removeUserExecutorService);
        ThreadUtils.shutdown((ExecutorService)this.checkService);
        this.userBufferStatuses.clear();
        this.consumedBufferStatusHub.clear();
        this.producedBufferStatusHub.clear();
    }

    @VisibleForTesting
    public void shutDownCheckService() {
        ThreadUtils.shutdown((ExecutorService)this.checkService);
    }

    public static synchronized void destroy() {
        if (_INSTANCE != null) {
            _INSTANCE.close();
            _INSTANCE = null;
        }
    }

    public BufferStatusHub getProducedBufferStatusHub() {
        return this.producedBufferStatusHub;
    }

    public UserCongestionControlContext getUserCongestionContext(UserIdentifier userIdentifier) {
        return this.userCongestionContextMap.computeIfAbsent(userIdentifier, user -> {
            UserBufferInfo userBufferInfo = this.getUserBuffer(userIdentifier);
            UserTrafficQuota userTrafficQuota = this.configService == null ? this.defaultUserQuota : this.configService.getTenantUserConfigFromCache(userIdentifier.tenantId(), userIdentifier.name()).getUserTrafficQuota();
            return new UserCongestionControlContext(userTrafficQuota, this.producedBufferStatusHub, userBufferInfo, this.workerSource, userIdentifier);
        });
    }

    public ConcurrentHashMap<UserIdentifier, UserCongestionControlContext> getUserCongestionContextMap() {
        return this.userCongestionContextMap;
    }

    public BufferStatusHub getConsumedBufferStatusHub() {
        return this.consumedBufferStatusHub;
    }

    private void updateQuota() {
        this.workerTrafficQuota = this.configService.getSystemConfigFromCache().getWorkerTrafficQuota();
        for (Map.Entry<UserIdentifier, UserCongestionControlContext> entry : this.userCongestionContextMap.entrySet()) {
            UserIdentifier user = entry.getKey();
            UserCongestionControlContext context = entry.getValue();
            context.updateUserTrafficQuota(this.configService.getTenantUserConfigFromCache(user.tenantId(), user.name()).getUserTrafficQuota());
        }
    }
}

