# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Benchmark API overhead of different python FFI API calling overhead through DLPack API.

Specifically, we would like to understand the overall overhead python/C++ API calls.
The general goal is to understand the overall space and get a sense of what are the possible operations.

We pick function f(x, y, z) where x, y, z are length 1 tensors.
The benchmark is running in eager mode so we can see what is possible.
It is orthogonal to other optimizations. For example cudagraph can
eliminate these overheads completely. So the goal is to get a sense
of what is possible under eager mode.

Summary of some takeaways:
- numpy.add roughly takes 0.36 us per call, which gives roughly what can be done in python env.
- torch.add on gpu takes about 3.7us per call, giving us an idea of what roughly we need to get to in eager mode.
"""

from __future__ import annotations

import time
from typing import Any, Callable

import numpy as np
import torch
import tvm_ffi


class TestFFITensor:
    """Test FFI Tensor that exposes __tvm_ffi_object__ protocol."""

    def __init__(self, tensor: tvm_ffi.Tensor) -> None:
        """Initialize the TestFFITensor."""
        self._tensor = tensor

    def __tvm_ffi_object__(self) -> tvm_ffi.Tensor:
        """Implement __tvm_ffi_object__ protocol."""
        return self._tensor


def print_speed(name: str, speed: float) -> None:
    print(f"{name:<60} {speed} sec/call")


def print_error(name: str, error: Any) -> None:
    print(f"{name:<60} {error}")


def baseline_torch_add(repeat: int) -> None:
    """Run torch.add with one element."""

    def run_bench(device: str) -> None:
        x = torch.arange(1, device=device)
        y = torch.arange(1, device=device)
        z = torch.arange(1, device=device)

        torch.add(x, y, out=z)
        if device == "cuda":
            torch.cuda.synchronize()
        start = time.time()
        for i in range(repeat):
            torch.add(x, y, out=z)
        # note we deliberately do not use torch.cuda.synchronize()
        # because we want to see the overhead of the FFI call.
        end = time.time()
        print_speed(f"torch.add[{device}]", (end - start) / repeat)

    # rough take away: add on cuda roughly takes 3e-6 sec/call
    run_bench("cpu")
    run_bench("cuda")


def baseline_numpy_add(repeat: int) -> None:
    """Run numpy.add with one element."""
    x = np.arange(1)
    y = np.arange(1)
    z = np.arange(1)

    np.add(x, y, out=z)
    start = time.time()
    for i in range(repeat):
        np.add(x, y, out=z)
    end = time.time()
    speed = (end - start) / repeat
    print_speed("numpy.add", speed)


def baseline_cupy_add(repeat: int) -> None:
    """Run cupy.add with one element."""
    try:
        import cupy  # noqa: PLC0415
    except ImportError:
        # skip if cupy is not installed
        return
    x = cupy.arange(1)
    y = cupy.arange(1)
    z = cupy.arange(1)

    cupy.add(x, y, out=z)
    start = time.time()
    for i in range(repeat):
        cupy.add(x, y, out=z)
    end = time.time()
    speed = (end - start) / repeat
    print_speed("cupy.add", speed)


def tvm_ffi_nop(repeat: int) -> None:
    """Overhead of tvm FFI python call via calling a NOP.

    testing.nop is defined in c++ and do nothing.
    """
    nop = tvm_ffi.get_global_func("testing.nop")
    x = tvm_ffi.from_dlpack(torch.arange(1))
    y = tvm_ffi.from_dlpack(torch.arange(1))
    z = tvm_ffi.from_dlpack(torch.arange(1))
    nop(x, y, z)
    start = time.time()
    for i in range(repeat):
        nop(x, y, z)
    end = time.time()
    print_speed("tvm_ffi.nop", (end - start) / repeat)


def bench_ffi_nop_from_dlpack(name: str, x: Any, y: Any, z: Any, repeat: int) -> None:
    """Run dlpack conversion + tvm_ffi.nop.

    Measures overhead of running dlpack for each args then invoke
    """
    nop = tvm_ffi.get_global_func("testing.nop")
    tx = tvm_ffi.from_dlpack(x)
    ty = tvm_ffi.from_dlpack(y)
    tz = tvm_ffi.from_dlpack(z)
    nop(tx, ty, tz)

    start = time.time()
    for i in range(repeat):
        tx = tvm_ffi.from_dlpack(x)
        ty = tvm_ffi.from_dlpack(y)
        tz = tvm_ffi.from_dlpack(z)
        nop(tx, ty, tz)
    end = time.time()
    print_speed(name, (end - start) / repeat)


def tvm_ffi_nop_from_torch_dlpack(repeat: int) -> None:
    """Run dlpack conversion + tvm_ffi.nop.

    Measures overhead of running dlpack for each args then invoke
    """
    x = torch.arange(1)
    y = torch.arange(1)
    z = torch.arange(1)
    bench_ffi_nop_from_dlpack("tvm_ffi.nop+from_dlpack(torch)", x, y, z, repeat)


def tvm_ffi_nop_from_numpy_dlpack(repeat: int) -> None:
    """Run dlpack conversion + tvm_ffi.nop.

    Measures overhead of running dlpack for each args then invoke
    """
    x = np.arange(1)
    y = np.arange(1)
    z = np.arange(1)
    bench_ffi_nop_from_dlpack("tvm_ffi.nop+from_dlpack(numpy)", x, y, z, repeat)


def tvm_ffi_self_dlpack_nop(repeat: int) -> None:
    """Run dlpack conversion + tvm_ffi.nop.

    Measures overhead of running dlpack for each args then invoke
    """
    x = tvm_ffi.from_dlpack(torch.arange(1))
    y = tvm_ffi.from_dlpack(torch.arange(1))
    z = tvm_ffi.from_dlpack(torch.arange(1))
    bench_ffi_nop_from_dlpack("tvm_ffi.nop+from_dlpack(tvm)", x, y, z, repeat)


def tvm_ffi_nop_from_torch_utils_to_dlpack(repeat: int) -> None:
    """Measures overhead of running dlpack for each args then invoke
    but uses the legacy torch.utils.dlpack.to_dlpack API.

    This helps to measure possible implementation overhead of torch.
    """
    nop = tvm_ffi.get_global_func("testing.nop")
    x = torch.arange(1)
    y = torch.arange(1)
    z = torch.arange(1)

    tx = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(x))
    ty = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(y))
    tz = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(z))
    nop(tx, ty, tz)

    start = time.time()
    for i in range(repeat):
        tx = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(x))
        ty = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(y))
        tz = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(z))
        nop(tx, ty, tz)
    end = time.time()
    speed = (end - start) / repeat
    print_speed("tvm_ffi.nop+from_dlpack(torch.utils)", speed)


def bench_tvm_ffi_nop_autodlpack(name: str, x: Any, y: Any, z: Any, repeat: int) -> None:
    """Measures overhead of running dlpack via auto convert by directly
    take torch.Tensor as inputs.
    """
    nop = tvm_ffi.get_global_func("testing.nop")
    nop(x, y, z)
    start = time.time()
    for i in range(repeat):
        nop(x, y, z)
    end = time.time()
    speed = (end - start) / repeat
    print_speed(name, speed)


def tvm_ffi_nop_autodlpack_from_torch(
    repeat: int, device: str = "cpu", stream: bool = False
) -> None:
    """Measures overhead of running dlpack via auto convert by directly
    take torch.Tensor as inputs.
    """
    # use larger to ensure alignment req is met
    x = torch.arange(1, device=device)
    y = torch.arange(1, device=device)
    z = torch.arange(1, device=device)
    if stream:
        with torch.cuda.stream(torch.cuda.Stream()):
            bench_tvm_ffi_nop_autodlpack(
                f"tvm_ffi.nop.autodlpack(torch[{device}][stream])", x, y, z, repeat
            )
    else:
        bench_tvm_ffi_nop_autodlpack(f"tvm_ffi.nop.autodlpack(torch[{device}])", x, y, z, repeat)


def tvm_ffi_nop_autodlpack_from_numpy(repeat: int) -> None:
    """Measures overhead of running dlpack via auto convert by directly
    take numpy.ndarray as inputs.
    """
    # use larger to ensure alignment req is met
    x = np.arange(256)
    y = np.arange(256)
    z = np.arange(256)
    bench_tvm_ffi_nop_autodlpack("tvm_ffi.nop.autodlpack(numpy)", x, y, z, repeat)


def tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat: int, device: str) -> None:
    """Measures overhead of running dlpack via auto convert by directly
    take test wrapper as inputs. This effectively measure DLPack exchange in tvm ffi.
    """
    x = tvm_ffi.from_dlpack(torch.arange(1, device=device))
    y = tvm_ffi.from_dlpack(torch.arange(1, device=device))
    z = tvm_ffi.from_dlpack(torch.arange(1, device=device))
    x = tvm_ffi.core.DLTensorTestWrapper(x)  # type: ignore[assignment]
    y = tvm_ffi.core.DLTensorTestWrapper(y)  # type: ignore[assignment]
    z = tvm_ffi.core.DLTensorTestWrapper(z)  # type: ignore[assignment]
    bench_tvm_ffi_nop_autodlpack(
        f"tvm_ffi.nop.autodlpack(DLTensorTestWrapper[{device}])", x, y, z, repeat
    )


def tvm_ffi_nop_autodlpack_from_test_ffi_tensor(repeat: int, device: str) -> None:
    """Measures overhead of running dlpack via auto convert by directly
    take test wrapper as inputs. This effectively measure DLPack exchange in tvm ffi.
    """
    x = tvm_ffi.from_dlpack(torch.arange(1, device=device))
    y = tvm_ffi.from_dlpack(torch.arange(1, device=device))
    z = tvm_ffi.from_dlpack(torch.arange(1, device=device))
    x = TestFFITensor(x)
    y = TestFFITensor(y)
    z = TestFFITensor(z)
    bench_tvm_ffi_nop_autodlpack(
        f"tvm_ffi.nop.autodlpack(TestFFITensor[{device}])", x, y, z, repeat
    )


def bench_to_dlpack(x: Any, name: str, repeat: int) -> None:
    x.__dlpack__()
    start = time.time()
    for i in range(repeat):
        x.__dlpack__()
    end = time.time()
    speed = (end - start) / repeat
    print_speed(name, speed)


def bench_to_dlpack_versioned(
    x: Any, name: str, repeat: int, max_version: tuple[int, int] = (1, 1)
) -> None:
    """Measures overhead of running dlpack with latest 1.1."""
    try:
        x.__dlpack__(max_version=max_version)
        start = time.time()
        for i in range(repeat):
            x.__dlpack__(max_version=max_version)
        end = time.time()
        speed = (end - start) / repeat
        print_speed(name, speed)
    except Exception as e:
        print_error(name, e)


def bench_torch_utils_to_dlpack(repeat: int) -> None:
    """Measures overhead of running torch.utils.dlpack.to_dlpack."""
    x = torch.arange(1)
    torch.utils.dlpack.to_dlpack(x)
    start = time.time()
    for i in range(repeat):
        torch.utils.dlpack.to_dlpack(x)
    end = time.time()
    speed = (end - start) / repeat
    print_speed("torch.utils.dlpack.to_dlpack", speed)


def torch_get_cuda_stream_native(device_id: int) -> int:
    return torch.cuda.current_stream(device_id).cuda_stream


def load_torch_get_current_cuda_stream() -> Callable[[int], int]:
    """Create a faster get_current_cuda_stream for torch through cpp extension."""
    from torch.utils import cpp_extension  # noqa: PLC0415

    source = """
    #include <c10/cuda/CUDAStream.h>

    int64_t get_current_cuda_stream(int device_id) {
        at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_id);
        // fast invariant, default stream is always 0
        if (stream.id() == 0) return 0;
        // convert to cudaStream_t
        return reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
    }
    """
    result = cpp_extension.load_inline(
        name="get_current_cuda_stream",
        cpp_sources=[source],
        cuda_sources=[],
        extra_cflags=["-O3"],
        extra_include_paths=cpp_extension.include_paths("cuda"),
        functions=["get_current_cuda_stream"],
    )
    return result.get_current_cuda_stream


def bench_torch_get_current_stream(repeat: int, name: str, func: Callable[[int], int]) -> None:
    """Measures overhead of running torch.cuda.current_stream."""
    x = torch.arange(1, device="cuda")  # noqa: F841
    func(0)
    start = time.time()
    for i in range(repeat):
        func(0)
    end = time.time()
    speed = (end - start) / repeat
    print_speed(f"torch.cuda.current_stream[{name}]", speed)


def populate_object_table(num_classes: int) -> None:
    nop = tvm_ffi.get_global_func("testing.nop")
    dummy_instances = [type(f"DummyClass{i}", (object,), {})() for i in range(num_classes)]
    for instance in dummy_instances:
        nop(instance)


def main() -> None:  # noqa: PLR0915
    repeat = 10000
    # measures impact of object dispatch table size
    # takeaway so far is that there is no impact on the performance
    num_classes = 0
    populate_object_table(num_classes)
    print("-----------------------------")
    print("Benchmark f(x, y, z) overhead")
    print("-----------------------------")
    baseline_numpy_add(repeat)
    baseline_torch_add(repeat)
    baseline_cupy_add(repeat)
    tvm_ffi_nop_from_torch_dlpack(repeat)
    tvm_ffi_nop_from_numpy_dlpack(repeat)
    tvm_ffi_self_dlpack_nop(repeat)
    tvm_ffi_nop_from_torch_utils_to_dlpack(repeat)
    tvm_ffi_nop_autodlpack_from_torch(repeat, "cpu")
    tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda")
    tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda", stream=True)

    tvm_ffi_nop_autodlpack_from_numpy(repeat)
    tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cpu")
    tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cuda")
    tvm_ffi_nop_autodlpack_from_test_ffi_tensor(repeat, "cpu")
    tvm_ffi_nop_autodlpack_from_test_ffi_tensor(repeat, "cuda")
    tvm_ffi_nop(repeat)
    print("-------------------------------")
    print("Benchmark x.__dlpack__ overhead")
    print("-------------------------------")
    bench_torch_utils_to_dlpack(repeat)
    bench_to_dlpack(torch.arange(1), "torch.__dlpack__", repeat)
    bench_to_dlpack(np.arange(1), "numpy.__dlpack__", repeat)
    bench_to_dlpack(tvm_ffi.from_dlpack(torch.arange(1)), "tvm.__dlpack__", repeat)
    print("---------------------------------------------------")
    print("Benchmark x.__dlpack__(max_version=(1,1)) overhead")
    print("---------------------------------------------------")
    bench_to_dlpack_versioned(torch.arange(1), "torch.__dlpack__(max_version=(1,1))", repeat)
    bench_to_dlpack_versioned(np.arange(1), "numpy.__dlpack__(max_version=(1,1))", repeat)
    bench_to_dlpack_versioned(
        tvm_ffi.from_dlpack(torch.arange(1)),
        "tvm.__dlpack__(max_version=(1,1))",
        repeat,
    )
    print("---------------------------------------------------")
    print("Benchmark torch.get_cuda_stream[default stream]")
    print("---------------------------------------------------")
    bench_torch_get_current_stream(repeat, "cpp-extension", load_torch_get_current_cuda_stream())
    bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native)
    print("---------------------------------------------------")
    print("Benchmark torch.get_cuda_stream[non-default stream]")
    print("---------------------------------------------------")
    with torch.cuda.stream(torch.cuda.Stream()):
        bench_torch_get_current_stream(
            repeat, "cpp-extension", load_torch_get_current_cuda_stream()
        )
        bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native)
    print("---------------------------------------------------")
    print("Debug information")
    print("---------------------------------------------------")
    tvm_ffi.core._print_debug_info()
    release_gil = tvm_ffi.get_global_func("testing.nop").release_gil
    print(f"TVM_FFI_RELEASE_GIL_BY_DEFAULT={int(release_gil)}")
    print("---------------------------------------------------")


if __name__ == "__main__":
    main()
