cpp-utils/include/thread.hpp
2022-06-23 17:26:19 +01:00

164 lines
4.3 KiB
C++

#pragma once
#include "types.hpp"
#include <algorithm>
#include <optional>
#include <string_view>
#include <vector>
#ifdef _WIN32
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <Windows.h>
#include <TlHelp32.h>
namespace util {
namespace threads {
auto get_all_threads(u32 pid) -> std::vector<u32>;
auto get_all_threads() -> std::vector<u32>;
auto get_other_threads() -> std::vector<u32>;
auto suspend_threads(const std::vector<u32> &thread_ids) -> std::vector<u32>;
auto resume_threads(const std::vector<u32> &thread_ids) -> bool;
/// suspends all threads, only then executes function, and resumes threads
/// returns `nullopt` if not all threads were suspended.
template <typename T, typename F>
auto with_suspended_threads(F &&fn) -> std::optional<T> {
auto threads = get_other_threads();
auto suspended_threads = suspend_threads(threads);
const auto suspended_all_threads = suspended_threads.size() == threads.size();
std::optional<T> result = std::nullopt;
if (suspended_all_threads) {
result = fn();
}
return result;
}
} // namespace threads
namespace processes {
auto get_process_id(std::string_view name) -> std::optional<u32>;
auto get_process_id(std::wstring_view name) -> std::optional<u32>;
} // namespace processes
/// impl
namespace threads {
inline auto get_all_threads(u32 pid) -> std::vector<u32> {
std::vector<u32> threads;
const auto snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, 0);
if (snapshot != INVALID_HANDLE_VALUE) {
THREADENTRY32 t;
t.dwSize = sizeof(t);
if (Thread32First(snapshot, &t)) {
if (t.dwSize >= FIELD_OFFSET(THREADENTRY32, th32OwnerProcessID) +
sizeof(t.th32OwnerProcessID)) {
if (t.th32OwnerProcessID == pid) {
threads.push_back(t.th32ThreadID);
}
}
}
}
return threads;
}
inline auto get_all_threads() -> std::vector<u32> {
return get_all_threads(GetCurrentProcessId());
}
inline auto get_other_threads() -> std::vector<u32> {
auto threads = get_all_threads(GetCurrentProcessId());
threads.erase(
std::remove_if(threads.begin(), threads.end(),
[](auto &&id) { return id == GetCurrentThreadId(); }),
threads.end());
return threads;
}
inline auto resume_threads(const std::vector<u32> &thread_ids) -> bool {
// return false if any threads failed to resume
return std::count_if(thread_ids.begin(), thread_ids.end(), [](auto &&id) {
if (auto handle = OpenThread(THREAD_SUSPEND_RESUME, false, id)) {
const auto result = ResumeThread(handle);
CloseHandle(handle);
return result != (u32)-1;
}
return false;
}) == thread_ids.size();
}
inline auto suspend_threads(const std::vector<u32> &thread_ids)
-> std::vector<u32> {
std::vector<u32> suspended_threads;
for (auto &&id : thread_ids) {
if (auto handle = OpenThread(THREAD_SUSPEND_RESUME, false, id)) {
if (SuspendThread(handle) != (u32)-1) {
suspended_threads.push_back(id);
}
CloseHandle(handle);
}
}
return suspended_threads;
}
} // namespace threads
namespace processes {
inline auto get_process_id(std::string_view name) -> std::optional<u32> {
auto snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0);
std::optional<u32> result = std::nullopt;
PROCESSENTRY32 process_entry;
if (Process32First(snapshot, &process_entry)) {
do {
const auto process_name = std::string_view(process_entry.szExeFile);
if (name == process_name) {
result = process_entry.th32ProcessID;
}
} while (Process32Next(snapshot, &process_entry));
}
CloseHandle(snapshot);
return result;
}
inline auto get_process_id(std::wstring_view name) -> std::optional<u32> {
auto snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0);
std::optional<u32> result = std::nullopt;
PROCESSENTRY32W process_entry;
if (Process32FirstW(snapshot, &process_entry)) {
do {
const auto process_name = std::wstring_view(process_entry.szExeFile);
if (name == process_name) {
result = process_entry.th32ProcessID;
}
} while (Process32NextW(snapshot, &process_entry));
}
CloseHandle(snapshot);
return result;
}
} // namespace processes
} // namespace util
#endif