/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.shuffle.celeborn;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornRuntimeException;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.protocol.message.ControlMessages;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.KeyLock;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.reflect.DynFields;
import org.apache.celeborn.shaded.com.google.common.annotations.VisibleForTesting;
import org.apache.celeborn.shaded.org.apache.commons.io.output.ByteArrayOutputStream;
import org.apache.spark.BarrierTaskContext;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkContext$;
import org.apache.spark.SparkEnv;
import org.apache.spark.SparkEnv$;
import org.apache.spark.TaskContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.scheduler.DAGScheduler;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.scheduler.ShuffleMapStage;
import org.apache.spark.scheduler.SparkListener;
import org.apache.spark.scheduler.SparkListenerInterface;
import org.apache.spark.scheduler.TaskInfo;
import org.apache.spark.scheduler.TaskSchedulerImpl;
import org.apache.spark.scheduler.TaskSetManager;
import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle;
import org.apache.spark.sql.execution.UnsafeRowSerializer;
import org.apache.spark.sql.execution.metric.SQLMetric;
import org.apache.spark.storage.BlockManagerId;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Option;
import scala.Some;
import scala.Tuple2;
import scala.collection.Iterable;
import scala.collection.JavaConverters;
import scala.collection.mutable.HashMap;
import scala.reflect.ClassManifestFactory;

public class SparkUtils {
    private static final Logger logger = LoggerFactory.getLogger(SparkUtils.class);
    public static final String FETCH_FAILURE_ERROR_MSG = "Celeborn FetchFailure with shuffle id ";
    private static final DynFields.UnboundField shuffleIdToMapStage_FIELD = DynFields.builder().hiddenImpl(DAGScheduler.class, "shuffleIdToMapStage").build();
    private static final DynFields.UnboundField<ConcurrentHashMap<Long, TaskSetManager>> TASK_ID_TO_TASK_SET_MANAGER_FIELD = DynFields.builder().hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager").defaultAlwaysNull().build();
    private static final DynFields.UnboundField<HashMap<Long, TaskInfo>> TASK_INFOS_FIELD = DynFields.builder().hiddenImpl(TaskSetManager.class, "taskInfos").defaultAlwaysNull().build();
    protected static Map<String, Set<Long>> reportedStageShuffleFetchFailureTaskIds = JavaUtils.newConcurrentHashMap();
    protected static volatile Long lastReportedShuffleFetchFailureTaskId = null;
    private static final KeyLock<Integer> shuffleBroadcastLock = new KeyLock();
    @VisibleForTesting
    public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new AtomicInteger();
    @VisibleForTesting
    public static Map<Integer, Tuple2<Broadcast<TransportMessage>, byte[]>> getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();

    public static MapStatus createMapStatus(BlockManagerId loc, long[] uncompressedSizes, long[] uncompressedRecords) throws IOException {
        MapStatus$ status = MapStatus$.MODULE$;
        Class<?> clz = status.getClass();
        Method applyMethod = null;
        for (Method method : clz.getDeclaredMethods()) {
            if (!"apply".equals(method.getName())) continue;
            applyMethod = method;
            break;
        }
        if (applyMethod == null) {
            throw new IOException("Could not find apply method in MapStatus object.");
        }
        try {
            switch (applyMethod.getParameterCount()) {
                case 2: {
                    return (MapStatus)applyMethod.invoke((Object)status, loc, uncompressedSizes);
                }
                case 3: {
                    return (MapStatus)applyMethod.invoke((Object)status, loc, uncompressedSizes, uncompressedRecords);
                }
            }
            throw new IllegalStateException("Could not find apply method with correct parameter number in MapStatus object.");
        }
        catch (Exception e) {
            throw new IOException(e);
        }
    }

    public static SQLMetric getUnsafeRowSerializerDataSizeMetric(UnsafeRowSerializer serializer) {
        try {
            Field field = serializer.getClass().getDeclaredField("dataSize");
            field.setAccessible(true);
            return (SQLMetric)field.get(serializer);
        }
        catch (IllegalAccessException | NoSuchFieldException e) {
            logger.warn("Failed to get dataSize metric, aqe won't work properly.");
            return null;
        }
    }

    public static long[] unwrap(LongAdder[] adders) {
        int adderCounter = adders.length;
        long[] res = new long[adderCounter];
        for (int i = 0; i < adderCounter; ++i) {
            res[i] = adders[i].longValue();
        }
        return res;
    }

    public static CelebornConf fromSparkConf(SparkConf conf) {
        CelebornConf tmpCelebornConf = new CelebornConf();
        for (Tuple2 kv : conf.getAll()) {
            if (!((String)kv._1).startsWith("spark.celeborn.")) continue;
            tmpCelebornConf.set(((String)kv._1).substring("spark.".length()), (String)kv._2);
        }
        return tmpCelebornConf;
    }

    public static String appUniqueId(SparkContext context) {
        if (context.applicationAttemptId().isDefined()) {
            return context.applicationId() + "_" + (String)context.applicationAttemptId().get();
        }
        return context.applicationId();
    }

    public static String getAppShuffleIdentifier(int appShuffleId, TaskContext context) {
        return appShuffleId + "-" + context.stageId() + "-" + context.stageAttemptNumber();
    }

    public static int celebornShuffleId(ShuffleClient client, CelebornShuffleHandle<?, ?, ?> handle, TaskContext context, Boolean isWriter) {
        if (handle.stageRerunEnabled()) {
            String appShuffleIdentifier = SparkUtils.getAppShuffleIdentifier(handle.shuffleId(), context);
            Tuple2<Integer, Boolean> res = client.getShuffleId(handle.shuffleId(), appShuffleIdentifier, isWriter, context instanceof BarrierTaskContext);
            if (!((Boolean)res._2).booleanValue()) {
                throw new CelebornRuntimeException(String.format("Get invalid shuffle id %s", res._1));
            }
            return (Integer)res._1;
        }
        return handle.shuffleId();
    }

    public static <T> T instantiateClass(String className, SparkConf conf, Boolean isDriver) {
        Class<?> cls = Utils.classForName(className);
        try {
            return (T)cls.getConstructor(SparkConf.class, Boolean.TYPE).newInstance(conf, isDriver);
        }
        catch (ReflectiveOperationException roe1) {
            try {
                return (T)cls.getConstructor(SparkConf.class).newInstance(conf);
            }
            catch (ReflectiveOperationException roe2) {
                try {
                    return (T)cls.getConstructor(new Class[0]).newInstance(new Object[0]);
                }
                catch (ReflectiveOperationException roe3) {
                    throw new RuntimeException(roe3);
                }
            }
        }
    }

    public static void addFailureListenerIfBarrierTask(ShuffleClient shuffleClient, TaskContext taskContext, CelebornShuffleHandle<?, ?, ?> handle) {
        if (!(taskContext instanceof BarrierTaskContext)) {
            return;
        }
        int appShuffleId = handle.shuffleId();
        String appShuffleIdentifier = SparkUtils.getAppShuffleIdentifier(appShuffleId, taskContext);
        BarrierTaskContext barrierContext = (BarrierTaskContext)taskContext;
        barrierContext.addTaskFailureListener((context, error) -> shuffleClient.reportBarrierTaskFailure(appShuffleId, appShuffleIdentifier));
    }

    public static void cancelShuffle(int shuffleId, String reason) {
        if (SparkContext$.MODULE$.getActive().nonEmpty()) {
            DAGScheduler scheduler = ((SparkContext)SparkContext$.MODULE$.getActive().get()).dagScheduler();
            scala.collection.mutable.Map shuffleIdToMapStageValue = (scala.collection.mutable.Map)shuffleIdToMapStage_FIELD.bind(scheduler).get();
            Option shuffleMapStage = shuffleIdToMapStageValue.get((Object)shuffleId);
            if (shuffleMapStage.nonEmpty()) {
                scheduler.cancelStage(((ShuffleMapStage)shuffleMapStage.get()).id(), (Option)new Some((Object)reason));
            }
        } else {
            logger.error("Can not get active SparkContext, skip cancelShuffle.");
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @VisibleForTesting
    protected static TaskSetManager getTaskSetManager(TaskSchedulerImpl taskScheduler, long taskId) {
        TaskSchedulerImpl taskSchedulerImpl = taskScheduler;
        synchronized (taskSchedulerImpl) {
            ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager = TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
            return taskIdToTaskSetManager.get(taskId);
        }
    }

    @VisibleForTesting
    protected static Tuple2<TaskInfo, List<TaskInfo>> getTaskAttempts(TaskSetManager taskSetManager, long taskId) {
        if (taskSetManager != null) {
            Option taskInfoOption = TASK_INFOS_FIELD.bind(taskSetManager).get().get((Object)taskId);
            if (taskInfoOption.isDefined()) {
                TaskInfo taskInfo = (TaskInfo)taskInfoOption.get();
                List taskAttempts = JavaConverters.asJavaCollectionConverter((Iterable)taskSetManager.taskAttempts()[taskInfo.index()]).asJavaCollection().stream().collect(Collectors.toList());
                return Tuple2.apply((Object)taskInfo, taskAttempts);
            }
            logger.error("Can not get TaskInfo for taskId: {}", (Object)taskId);
            return null;
        }
        logger.error("Can not get TaskSetManager for taskId: {}", (Object)taskId);
        return null;
    }

    protected static void removeStageReportedShuffleFetchFailureTaskIds(int stageId, int stageAttemptId) {
        reportedStageShuffleFetchFailureTaskIds.remove(stageId + "-" + stageAttemptId);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static boolean shouldReportShuffleFetchFailure(long taskId) {
        TaskSchedulerImpl taskScheduler;
        SparkContext sparkContext = (SparkContext)SparkContext$.MODULE$.getActive().getOrElse(null);
        if (sparkContext == null) {
            logger.error("Can not get active SparkContext.");
            return true;
        }
        TaskSchedulerImpl taskSchedulerImpl = taskScheduler = (TaskSchedulerImpl)sparkContext.taskScheduler();
        synchronized (taskSchedulerImpl) {
            TaskSetManager taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, taskId);
            if (taskSetManager != null) {
                int stageId = taskSetManager.stageId();
                int stageAttemptId = taskSetManager.taskSet().stageAttemptId();
                int maxTaskFails = taskSetManager.maxTaskFailures();
                String stageUniqId = stageId + "-" + stageAttemptId;
                Set reportedStageTaskIds = reportedStageShuffleFetchFailureTaskIds.computeIfAbsent(stageUniqId, k -> new HashSet());
                reportedStageTaskIds.add(taskId);
                lastReportedShuffleFetchFailureTaskId = taskId;
                Tuple2<TaskInfo, List<TaskInfo>> taskAttempts = SparkUtils.getTaskAttempts(taskSetManager, taskId);
                if (taskAttempts == null) {
                    return true;
                }
                TaskInfo taskInfo = (TaskInfo)taskAttempts._1();
                int failedTaskAttempts = 1;
                boolean hasRunningAttempt = false;
                for (TaskInfo ti : (List)taskAttempts._2()) {
                    if (ti.taskId() == taskId) continue;
                    if (reportedStageTaskIds.contains(ti.taskId())) {
                        logger.info("StageId={} index={} taskId={} attempt={} another attempt {} has reported shuffle fetch failure.", new Object[]{stageId, taskInfo.index(), taskId, taskInfo.attemptNumber(), ti.attemptNumber()});
                        ++failedTaskAttempts;
                        continue;
                    }
                    if (ti.successful()) {
                        logger.info("StageId={} index={} taskId={} attempt={} another attempt {} is successful.", new Object[]{stageId, taskInfo.index(), taskId, taskInfo.attemptNumber(), ti.attemptNumber()});
                        return false;
                    }
                    if (ti.running()) {
                        logger.info("StageId={} index={} taskId={} attempt={} another attempt {} is running.", new Object[]{stageId, taskInfo.index(), taskId, taskInfo.attemptNumber(), ti.attemptNumber()});
                        hasRunningAttempt = true;
                        continue;
                    }
                    if (!"FAILED".equals(ti.status()) && !"UNKNOWN".equals(ti.status())) continue;
                    logger.info("StageId={} index={} taskId={} attempt={} another attempt {} status={}.", new Object[]{stageId, taskInfo.index(), taskId, taskInfo.attemptNumber(), ti.attemptNumber(), ti.status()});
                    ++failedTaskAttempts;
                }
                if (failedTaskAttempts >= maxTaskFails || !hasRunningAttempt) {
                    logger.warn("StageId={}, index={}, taskId={}, attemptNumber={}: Task failure count {} reached maximum allowed failures {} or no running attempt exists.", new Object[]{stageId, taskInfo.index(), taskId, taskInfo.attemptNumber(), failedTaskAttempts, maxTaskFails});
                    return true;
                }
                return false;
            }
            logger.error("Can not get TaskSetManager for taskId: {}, ignore it. (This typically occurs when:  task completed/cleaned up, executor marked as failed, or stage cancelled/completed)", (Object)taskId);
            return false;
        }
    }

    public static void addSparkListener(SparkListener listener) {
        SparkContext sparkContext = (SparkContext)SparkContext$.MODULE$.getActive().getOrElse(null);
        if (sparkContext != null) {
            sparkContext.addSparkListener((SparkListenerInterface)listener);
        }
    }

    public static byte[] serializeGetReducerFileGroupResponse(Integer shuffleId, ControlMessages.GetReducerFileGroupResponse response) {
        SparkContext sparkContext = (SparkContext)SparkContext$.MODULE$.getActive().getOrElse(null);
        if (sparkContext == null) {
            logger.error("Can not get active SparkContext.");
            return null;
        }
        return shuffleBroadcastLock.withLock(shuffleId, () -> {
            Tuple2<Broadcast<TransportMessage>, byte[]> cachedSerializeGetReducerFileGroupResponse = getReducerFileGroupResponseBroadcasts.get(shuffleId);
            if (cachedSerializeGetReducerFileGroupResponse != null) {
                return (byte[])cachedSerializeGetReducerFileGroupResponse._2;
            }
            try {
                logger.info("Broadcasting GetReducerFileGroupResponse for shuffle: {}", (Object)shuffleId);
                TransportMessage transportMessage = (TransportMessage)Utils.toTransportMessage(response);
                Broadcast broadcast = sparkContext.broadcast((Object)transportMessage, ClassManifestFactory.fromClass(TransportMessage.class));
                CompressionCodec codec = CompressionCodec$.MODULE$.createCodec(sparkContext.conf());
                ByteArrayOutputStream out = new ByteArrayOutputStream();
                try (ObjectOutputStream oos = new ObjectOutputStream(codec.compressedOutputStream((OutputStream)out));){
                    oos.writeObject(broadcast);
                }
                byte[] _serializeResult = out.toByteArray();
                getReducerFileGroupResponseBroadcasts.put(shuffleId, (Tuple2<Broadcast<TransportMessage>, byte[]>)Tuple2.apply((Object)broadcast, (Object)_serializeResult));
                getReducerFileGroupResponseBroadcastNum.incrementAndGet();
                return _serializeResult;
            }
            catch (Throwable e) {
                logger.error("Failed to serialize GetReducerFileGroupResponse for shuffle: {}", (Object)shuffleId, (Object)e);
                return null;
            }
        });
    }

    public static ControlMessages.GetReducerFileGroupResponse deserializeGetReducerFileGroupResponse(Integer shuffleId, byte[] bytes) {
        SparkEnv sparkEnv = SparkEnv$.MODULE$.get();
        if (sparkEnv == null) {
            logger.error("Can not get SparkEnv.");
            return null;
        }
        return shuffleBroadcastLock.withLock(shuffleId, () -> {
            ControlMessages.GetReducerFileGroupResponse response = null;
            logger.info("Deserializing GetReducerFileGroupResponse broadcast for shuffle: {}", (Object)shuffleId);
            try {
                CompressionCodec codec = CompressionCodec$.MODULE$.createCodec(sparkEnv.conf());
                try (ObjectInputStream objIn = new ObjectInputStream(codec.compressedInputStream((InputStream)new ByteArrayInputStream(bytes)));){
                    Broadcast broadcast = (Broadcast)objIn.readObject();
                    response = (ControlMessages.GetReducerFileGroupResponse)Utils.fromTransportMessage(broadcast.value());
                }
            }
            catch (Throwable e) {
                logger.error("Failed to deserialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e);
            }
            return response;
        });
    }

    public static void invalidateSerializedGetReducerFileGroupResponse(Integer shuffleId) {
        shuffleBroadcastLock.withLock(shuffleId, () -> {
            try {
                Tuple2<Broadcast<TransportMessage>, byte[]> cachedSerializeGetReducerFileGroupResponse = getReducerFileGroupResponseBroadcasts.remove(shuffleId);
                if (cachedSerializeGetReducerFileGroupResponse != null) {
                    ((Broadcast)cachedSerializeGetReducerFileGroupResponse._1()).destroy();
                }
            }
            catch (Throwable e) {
                logger.error("Failed to invalidate serialized GetReducerFileGroupResponse for shuffle: " + shuffleId, e);
            }
            return null;
        });
    }
}

