#pragma once

#include "types.hpp"
#include <algorithm>
#include <optional>
#include <string_view>
#include <vector>

#ifdef _WIN32
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <Windows.h>

#include <TlHelp32.h>

namespace util {
namespace threads {

auto get_all_threads(u32 pid) -> std::vector<u32>;

auto get_all_threads() -> std::vector<u32>;
auto get_other_threads() -> std::vector<u32>;

auto suspend_threads(const std::vector<u32> &thread_ids) -> std::vector<u32>;

auto resume_threads(const std::vector<u32> &thread_ids) -> bool;

/// suspends all threads, only then executes function, and resumes threads
/// returns `nullopt` if not all threads were suspended.
template <typename T, typename F>
auto with_suspended_threads(F &&fn) -> std::optional<T> {
  auto threads = get_other_threads();

  auto suspended_threads = suspend_threads(threads);

  const auto suspended_all_threads = suspended_threads.size() == threads.size();

  std::optional<T> result = std::nullopt;
  if (suspended_all_threads) {
    result = fn();
  }

  return result;
}

} // namespace threads

namespace processes {
auto get_process_id(std::string_view name) -> std::optional<u32>;
auto get_process_id(std::wstring_view name) -> std::optional<u32>;
} // namespace processes

/// impl

namespace threads {
inline auto get_all_threads(u32 pid) -> std::vector<u32> {
  std::vector<u32> threads;
  const auto snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, 0);

  if (snapshot != INVALID_HANDLE_VALUE) {
    THREADENTRY32 t;
    t.dwSize = sizeof(t);
    if (Thread32First(snapshot, &t)) {
      if (t.dwSize >= FIELD_OFFSET(THREADENTRY32, th32OwnerProcessID) +
                          sizeof(t.th32OwnerProcessID)) {
        if (t.th32OwnerProcessID == pid) {
          threads.push_back(t.th32ThreadID);
        }
      }
    }
  }

  return threads;
}

inline auto get_all_threads() -> std::vector<u32> {
  return get_all_threads(GetCurrentProcessId());
}

inline auto get_other_threads() -> std::vector<u32> {
  auto threads = get_all_threads(GetCurrentProcessId());

  threads.erase(
      std::remove_if(threads.begin(), threads.end(),
                     [](auto &&id) { return id == GetCurrentThreadId(); }),
      threads.end());

  return threads;
}

inline auto resume_threads(const std::vector<u32> &thread_ids) -> bool {
  // return false if any threads failed to resume
  return std::count_if(thread_ids.begin(), thread_ids.end(), [](auto &&id) {
           if (auto handle = OpenThread(THREAD_SUSPEND_RESUME, false, id)) {
             const auto result = ResumeThread(handle);
             CloseHandle(handle);

             return result != (u32)-1;
           }
           return false;
         }) == thread_ids.size();
}

inline auto suspend_threads(const std::vector<u32> &thread_ids)
    -> std::vector<u32> {
  std::vector<u32> suspended_threads;
  for (auto &&id : thread_ids) {
    if (auto handle = OpenThread(THREAD_SUSPEND_RESUME, false, id)) {
      if (SuspendThread(handle) != (u32)-1) {
        suspended_threads.push_back(id);
      }
      CloseHandle(handle);
    }
  }
  return suspended_threads;
}
} // namespace threads

namespace processes {

inline auto get_process_id(std::string_view name) -> std::optional<u32> {
  auto snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0);
  std::optional<u32> result = std::nullopt;

  PROCESSENTRY32 process_entry;
  if (Process32First(snapshot, &process_entry)) {
    do {
      const auto process_name = std::string_view(process_entry.szExeFile);
      if (name == process_name) {
        result = process_entry.th32ProcessID;
      }
    } while (Process32Next(snapshot, &process_entry));
  }

  CloseHandle(snapshot);

  return result;
}

inline auto get_process_id(std::wstring_view name) -> std::optional<u32> {
  auto snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0);
  std::optional<u32> result = std::nullopt;

  PROCESSENTRY32W process_entry;
  if (Process32FirstW(snapshot, &process_entry)) {
    do {
      const auto process_name = std::wstring_view(process_entry.szExeFile);
      if (name == process_name) {
        result = process_entry.th32ProcessID;
      }
    } while (Process32NextW(snapshot, &process_entry));
  }

  CloseHandle(snapshot);

  return result;
}

} // namespace processes

} // namespace util
#endif