#ifndef HALIDE_THREAD_POOL_H
#define HALIDE_THREAD_POOL_H
#include <condition_variable>
#include <future>
#include <mutex>
#include <queue>
#include <thread>
#ifdef _MSC_VER
#else
#include <unistd.h>
#endif
namespace Halide {
namespace Internal {
template<typename T>
class ThreadPool {
struct Job {
std::function<T()> func;
std::promise<T> result;
void run_unlocked(std::unique_lock<std::mutex> &unique_lock);
};
std::mutex mutex;
std::queue<Job> jobs;
std::condition_variable wakeup_threads;
std::vector<std::thread> threads;
bool shutting_down{false};
void worker_thread() {
std::unique_lock<std::mutex> unique_lock(mutex);
while (!shutting_down) {
if (jobs.empty()) {
wakeup_threads.wait(unique_lock);
} else {
Job cur_job = std::move(jobs.front());
jobs.pop();
cur_job.run_unlocked(unique_lock);
}
}
}
public:
static size_t num_processors_online() {
#ifdef _WIN32
char *num_cores = getenv("NUMBER_OF_PROCESSORS");
return num_cores ? atoi(num_cores) : 8;
#else
return sysconf(_SC_NPROCESSORS_ONLN);
#endif
}
ThreadPool(size_t desired_num_threads = num_processors_online()) {
assert(desired_num_threads > 0);
std::lock_guard<std::mutex> lock(mutex);
for (size_t i = 0; i < desired_num_threads; ++i) {
threads.emplace_back([this]{ worker_thread(); });
}
}
~ThreadPool() {
{
std::lock_guard<std::mutex> lock(mutex);
shutting_down = true;
wakeup_threads.notify_all();
}
for (auto &t : threads) {
t.join();
}
}
template<typename Func, typename... Args>
std::future<T> async(Func func, Args... args) {
std::lock_guard<std::mutex> lock(mutex);
Job job;
job.func = std::bind(func, args...);
jobs.emplace(std::move(job));
std::future<T> result = jobs.back().result.get_future();
wakeup_threads.notify_all();
return result;
}
};
template<typename T>
inline void ThreadPool<T>::Job::run_unlocked(std::unique_lock<std::mutex> &unique_lock) {
unique_lock.unlock();
T r = func();
unique_lock.lock();
result.set_value(std::move(r));
}
template<>
inline void ThreadPool<void>::Job::run_unlocked(std::unique_lock<std::mutex> &unique_lock) {
unique_lock.unlock();
func();
unique_lock.lock();
result.set_value();
}
}
}
#endif