#pragma once

extern int enzyme_dup;
extern int enzyme_dupnoneed;
extern int enzyme_out;
extern int enzyme_const;

extern int enzyme_const_return;
extern int enzyme_active_return;
extern int enzyme_dup_return;

extern int enzyme_primal_return;
extern int enzyme_noret;

template<typename Return, typename... T>
Return __enzyme_autodiff(T...);

template<typename Return, typename... T>
Return __enzyme_fwddiff(T...);

#include <enzyme/tuple>

namespace enzyme {

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wmissing-braces"

    struct nodiff{};

    template<bool ReturnPrimal = false>
    struct ReverseMode {

    };
    using Reverse = ReverseMode<false>;
    using ReverseWithPrimal = ReverseMode<true>;

    struct ForwardMode {

    };
    using Forward = ForwardMode;

    template < typename T >
    struct Active{
      static_assert(
        !std::is_reference_v<T> && !std::is_pointer_v<T>,
        "Reference/pointer active arguments don't make sense for AD!"
      );
      T value;
      Active(const T& v) : value(v) {};
      Active(T&& v) : value(std::move(v)) {};
      Active(const Active<T>&) = default;
      Active(Active<T>&&) = default;
      operator T&() { return value; }
    };

    namespace detail {

        /**
         * @brief Helper base for all `Duplicated`-type arguments
         *
         * Due to essentially the same implementation, it prevents code
         * duplication
         */
        template < typename T >
        struct DupArg {
            // Argument as regular value. Also works for pointer types
            T value;
            T shadow;
            DupArg(const T& v, const T& s) : value(v), shadow(s) {};
            DupArg(T&& v, T&& s) : value(v), shadow(s) {};
            DupArg(const DupArg<T>&) = default;
            DupArg(DupArg<T>&&) = default;
        };

        template < typename T >
        struct DupArg<T&> {
            T* value;
            T* shadow;
            DupArg(T& v, T& s) : value(&v), shadow(&s) {}
            DupArg(const DupArg<T&>&) = default;
            DupArg(DupArg<T&>&&) = default;
        };

    } // <-- namespace detail

    template < typename T >
    struct Duplicated : public detail::DupArg<T> {
      using detail::DupArg<T>::DupArg;
    };

    template < typename T >
    struct DuplicatedNoNeed : public detail::DupArg<T> {
      using detail::DupArg<T>::DupArg;
    };

    template < typename T >
    struct Const{
      T value;
      Const(const T& v) : value(v) {};
      Const(T&& v) : value(std::move(v)) {};
      Const(const Const<T>&) = default;
      Const(Const<T>&&) = default;
      operator T&() { return value; }
    };

    template < typename T >
    struct Const<T&>{
        T* value;
        Const(T& v) : value(&v) {}
        Const(const Const<T&>&) = default;
        Const(Const<T&>&&) = default;
        operator T&() { return *value; }
    };

    // CTAD available in C++17 or later
    #if __cplusplus >= 201703L
      template < typename T >
      Active(T) -> Active<T>;

      template < typename T >
      Const(T) -> Const<T>;

      template < typename T >
      Duplicated(T,T) -> Duplicated<T>;

      template < typename T >
      DuplicatedNoNeed(T,T) -> DuplicatedNoNeed<T>;
    #endif

    template < typename T >
    struct type_info {
      static constexpr bool is_active     = false;
      static constexpr bool is_duplicated = false;
      using type = nodiff;
    };

    template < typename T >
    struct type_info < Active<T> >{
        static constexpr bool is_active     = true;
        static constexpr bool is_duplicated = false;
        using type = T;
    };

    template < typename T >
    struct type_info < Duplicated<T> >{
        static constexpr bool is_active     = false;
        static constexpr bool is_duplicated = true;
        using type = T;
    };

    template < typename T >
    struct type_info < DuplicatedNoNeed<T> >{
        static constexpr bool is_active     = false;
        static constexpr bool is_duplicated = true;
        using type = T;
    };

    template < typename ... T >
    struct concatenated;

    template < typename ... S, typename T, typename ... rest >
    struct concatenated < tuple < S ... >, T, rest ... > {
        using type = typename concatenated< tuple< S ..., T>, rest ... >::type;
    };

    template < typename T >
    struct concatenated < T > {
        using type = T;
    };

    // Yikes!
    // slightly cleaner in C++20, with std::remove_cvref
    template < typename ... T >
    struct autodiff_return;

    template < typename RetType, typename ... T >
    struct autodiff_return<ReverseMode<false>, RetType, T...>
    {
        using type = tuple<typename concatenated< tuple< >,
            typename type_info<
                typename remove_cvref< T >::type
            >::type ...
        >::type>;
    };

    template < typename RetType, typename ... T >
    struct autodiff_return<ReverseMode<true>, RetType, T...>
    {
        using type = tuple<
            typename type_info<RetType>::type,
            typename concatenated< tuple< >,
                typename type_info<
                    typename remove_cvref< T >::type
                >::type ...
            >::type
        >;
    };

    template < typename T0, typename ... T >
    struct autodiff_return<ForwardMode, Const<T0>, T...>
    {
        using type = tuple<T0>;
    };

    template < typename T0, typename ... T >
    struct autodiff_return<ForwardMode, Duplicated<T0>, T...>
    {
        using type = tuple<T0, T0>;
    };

    template < typename T0, typename ... T >
    struct autodiff_return<ForwardMode, DuplicatedNoNeed<T0>, T...>
    {
        using type = tuple<T0>;
    };

    template < typename T >
    __attribute__((always_inline))
    auto expand_args(const enzyme::Duplicated<T> & arg) {
        return enzyme::tuple{enzyme_dup, arg.value, arg.shadow};
    }

    template < typename T >
    __attribute__((always_inline))
    auto expand_args(const enzyme::DuplicatedNoNeed<T> & arg) {
        return enzyme::tuple{enzyme_dupnoneed, arg.value, arg.shadow};
    }

    template < typename T >
    __attribute__((always_inline))
    auto expand_args(const enzyme::Active<T> & arg) {
        return enzyme::tuple<int, T>{enzyme_out, arg.value};
    }

    template < typename T >
    __attribute__((always_inline))
    auto expand_args(const enzyme::Const<T> & arg) {
        return enzyme::tuple{enzyme_const, arg.value};
    }

    template < typename T >
    __attribute__((always_inline))
    auto primal_args(const enzyme::Duplicated<T> & arg) {
        return enzyme::tuple<T>{arg.value};
    }

    template < typename T >
    __attribute__((always_inline))
    auto primal_args(const enzyme::DuplicatedNoNeed<T> & arg) {
        return enzyme::tuple<T>{arg.value};
    }

    template < typename T >
    __attribute__((always_inline))
    auto primal_args(const enzyme::Active<T> & arg) {
        return enzyme::tuple<T>{arg.value};
    }

    template < typename T >
    __attribute__((always_inline))
    auto primal_args(const enzyme::Const<T> & arg) {
        return enzyme::tuple<T>{arg.value};
    }

    template < template <typename> class Template, typename Arg >
    Arg arg_type(const Template<Arg>&);

    namespace detail {
        template < typename RetType, typename ... T >
        struct function_type
        {
            using type = RetType(T...);
        };

        template<typename function, typename prevfunc>
        struct templated_call {

        };

        template<typename function, typename RT, typename ...T>
        struct templated_call<function, RT(T...)> {
            static RT wrap(T... args, function* __restrict__ f) {
                return (*f)(args...);
            }
        };




        template<typename T>
        __attribute__((always_inline))
        constexpr decltype(auto) push_return_last(T &&t);

        template<typename ...T>
        __attribute__((always_inline))
        constexpr decltype(auto) push_return_last(tuple<tuple<T...>> &&t) {
          return tuple<tuple<T...>>{get<0>(t)};
        }

        template<typename ...T, typename R>
        __attribute__((always_inline))
        constexpr decltype(auto) push_return_last(tuple<R, tuple<T...>> &&t) {
          return tuple{get<1>(t), get<0>(t)};
        }

        template <typename Mode>
        struct autodiff_apply {};

        template <bool Mode>
        struct autodiff_apply<ReverseMode<Mode>> {
        template <class return_type, class Tuple, std::size_t... I, typename... ExtraArgs>
        __attribute__((always_inline))
        static constexpr decltype(auto) impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>, ExtraArgs... args) {
            return push_return_last(__enzyme_autodiff<return_type>(f, ret_attr, enzyme::get<I>(impl::forward<Tuple>(t))..., args...));
        }
        };

        template <>
        struct autodiff_apply<ForwardMode> {
        template <class return_type, class Tuple, std::size_t... I, typename... ExtraArgs>
        __attribute__((always_inline))
        static constexpr return_type impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>, ExtraArgs... args) {
            return __enzyme_fwddiff<return_type>(f, ret_attr, enzyme::get<I>(impl::forward<Tuple>(t))..., args...);
        }
        };

        template <typename function, class Tuple, std::size_t... I>
        __attribute__((always_inline))
        constexpr decltype(auto) primal_apply_impl(function &&f, Tuple&& t, std::index_sequence<I...>) {
            return f(enzyme::get<I>(impl::forward<Tuple>(t))...);
        }

        template < typename Mode, typename T >
        struct default_ret_activity {
          using type = Const<T>;
        };

        template <bool prim>
        struct default_ret_activity<ReverseMode<prim>, float> {
          using type = Active<float>;
        };

        template <bool prim>
        struct default_ret_activity<ReverseMode<prim>, double> {
          using type = Active<double>;
        };

        template<>
        struct default_ret_activity<ForwardMode, float> {
          using type = DuplicatedNoNeed<float>;
        };

        template<>
        struct default_ret_activity<ForwardMode, double> {
          using type = DuplicatedNoNeed<double>;
        };

        template < typename T >
        struct ret_global;

        template<typename T>
        struct ret_global<Const<T>> {
          static constexpr int* value = &enzyme_const_return;
        };

        template<typename T>
        struct ret_global<Active<T>> {
          static constexpr int* value = &enzyme_active_return;
        };

        template<typename T>
        struct ret_global<Duplicated<T>> {
          static constexpr int* value = &enzyme_dup_return;
        };

        template<typename T>
        struct ret_global<DuplicatedNoNeed<T>> {
          static constexpr int* value = &enzyme_dup_return;
        };

        template<typename Mode, typename RetAct>
        struct ret_used;

        template<typename RetAct>
        struct ret_used<ReverseMode<true>, RetAct> {
          static constexpr int* value = &enzyme_primal_return;
        };

        template<typename RetAct>
        struct ret_used<ReverseMode<false>, RetAct> {
          static constexpr int* value = &enzyme_noret;
        };

        template<typename T>
        struct ret_used<ForwardMode, DuplicatedNoNeed<T>> {
          static constexpr int* value = &enzyme_noret;
        };
        template<typename T>
        struct ret_used<ForwardMode, Const<T>> {
          static constexpr int* value = &enzyme_primal_return;
        };
        template<typename T>
        struct ret_used<ForwardMode, Duplicated<T>> {
          static constexpr int* value = &enzyme_primal_return;
        };

        template <typename T, typename... Args>
        struct verify_dup_args {
          static constexpr bool value = true;
        };

        template <typename... Args>
        struct verify_dup_args<enzyme::Reverse, Args...> {
          static constexpr bool value = (
            (
              !type_info<Args>::is_duplicated
              || std::is_reference_v<typename type_info<Args>::type>
              || std::is_pointer_v<typename type_info<Args>::type>
            ) && ...
          );
        };

        template<class T>
        struct is_function_or_function_ptr : std::is_function<T> {};
         
        template<class T>
        struct is_function_or_function_ptr<T*> : std::is_function<T> {};
    }  // namespace detail

    template < typename return_type, typename function, typename ... enz_arg_types >
    __attribute__((always_inline))
    auto primal_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
      using Tuple = enzyme::tuple< enz_arg_types ... >;
      return detail::primal_apply_impl<return_type>(std::move(f), impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
    }

    template < typename function, typename ... arg_types>
    auto primal_call(function && f, arg_types && ... args) {
        return primal_impl<function>(impl::forward<function>(f), enzyme::tuple_cat(primal_args(args)...));
    }

    template < typename return_type, typename DiffMode, typename function, typename functy, typename RetActivity, typename ... enz_arg_types, std::enable_if_t<!detail::is_function_or_function_ptr<typename remove_cvref< function >::type>::value, int> = 0>
    __attribute__((always_inline))
    auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
      using Tuple = enzyme::tuple< enz_arg_types ... >;
      return detail::autodiff_apply<DiffMode>::template impl<return_type>((void*)&detail::templated_call<function, functy>::wrap, detail::ret_global<RetActivity>::value,
            impl::forward<Tuple>(arg_tup),
            std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{}, &enzyme_const, &f);
    }

    template < typename return_type, typename DiffMode, typename function, typename functy, typename RetActivity, typename ... enz_arg_types, std::enable_if_t<detail::is_function_or_function_ptr<typename remove_cvref< function >::type>::value, int> = 0>
    __attribute__((always_inline))
    auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
      using Tuple = enzyme::tuple< enz_arg_types ... >;
      return detail::autodiff_apply<DiffMode>::template impl<return_type>((void*)static_cast<functy*>(f), detail::ret_global<RetActivity>::value, impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
    }

    template < typename DiffMode, typename RetActivity, typename function, typename ... arg_types>
    __attribute__((always_inline))
    auto autodiff(function && f, arg_types && ... args) {
        static_assert(
            detail::verify_dup_args<DiffMode, arg_types...>::value,
            "Non-reference/pointer Duplicated/DuplicatedNoNeed args don't make "
            "sense for Reverse mode AD"
        );
        using primal_return_type = decltype(f(arg_type(args)...));
        using functy = typename detail::function_type<primal_return_type, decltype(arg_type(args))...>::type;
        using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type;
        return autodiff_impl<return_type, DiffMode, function, functy, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
    }

    template < typename DiffMode, typename function, typename ... arg_types>
    __attribute__((always_inline))
    auto autodiff(function && f, arg_types && ... args) {
        static_assert(
            detail::verify_dup_args<DiffMode, arg_types...>::value,
            "Non-reference/pointer Duplicated/DuplicatedNoNeed args don't make "
            "sense for Reverse mode AD"
        );
        using primal_return_type = decltype(f(arg_type(args)...));
        using functy = typename detail::function_type<primal_return_type, decltype(arg_type(args))...>::type;
        using RetActivity = typename detail::default_ret_activity<DiffMode, primal_return_type>::type;
        using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type;
        return autodiff_impl<return_type, DiffMode, function, functy, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
    }
#pragma clang diagnostic pop

} // namespace enzyme
