/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.examples.sql;

import java.io.Serializable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.functions;

public class JavaUserDefinedUntypedAggregation {
    public static void main(String[] args) {
        SparkSession spark = SparkSession.builder().appName("Java Spark SQL user-defined DataFrames aggregation example").getOrCreate();
        spark.udf().register("myAverage", functions.udaf((Aggregator)new MyAverage(), (Encoder)Encoders.LONG()));
        Dataset df = spark.read().json("examples/src/main/resources/employees.json");
        df.createOrReplaceTempView("employees");
        df.show();
        Dataset result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
        result.show();
        spark.stop();
    }

    public static class MyAverage
    extends Aggregator<Long, Average, Double> {
        public Average zero() {
            return new Average(0L, 0L);
        }

        public Average reduce(Average buffer, Long data) {
            long newSum = buffer.getSum() + data;
            long newCount = buffer.getCount() + 1L;
            buffer.setSum(newSum);
            buffer.setCount(newCount);
            return buffer;
        }

        public Average merge(Average b1, Average b2) {
            long mergedSum = b1.getSum() + b2.getSum();
            long mergedCount = b1.getCount() + b2.getCount();
            b1.setSum(mergedSum);
            b1.setCount(mergedCount);
            return b1;
        }

        public Double finish(Average reduction) {
            return (double)reduction.getSum() / (double)reduction.getCount();
        }

        public Encoder<Average> bufferEncoder() {
            return Encoders.bean(Average.class);
        }

        public Encoder<Double> outputEncoder() {
            return Encoders.DOUBLE();
        }
    }

    public static class Average
    implements Serializable {
        private long sum;
        private long count;

        public Average() {
        }

        public Average(long sum, long count) {
            this.sum = sum;
            this.count = count;
        }

        public long getSum() {
            return this.sum;
        }

        public void setSum(long sum) {
            this.sum = sum;
        }

        public long getCount() {
            return this.count;
        }

        public void setCount(long count) {
            this.count = count;
        }
    }
}

