This source file includes following definitions.
- bitonic_sort
- merge_sort
- main
#include "Halide.h"
#include <cstdio>
#include <algorithm>
#include "halide_benchmark.h"
using namespace Halide;
using namespace Halide::Tools;
Var x("x"), y("y");
Func bitonic_sort(Func input, int size) {
Func next, prev = input;
Var xo("xo"), xi("xi");
for (int pass_size = 1; pass_size < size; pass_size <<= 1) {
for (int chunk_size = pass_size; chunk_size > 0; chunk_size >>= 1) {
next = Func("bitonic_pass");
Expr chunk_start = (x/(2*chunk_size))*(2*chunk_size);
Expr chunk_end = (x/(2*chunk_size) + 1)*(2*chunk_size);
Expr chunk_middle = chunk_start + chunk_size;
Expr chunk_index = x - chunk_start;
if (pass_size == chunk_size && pass_size > 1) {
Expr partner = 2*chunk_middle - x - 1;
partner = clamp(partner, chunk_start, chunk_end-1);
next(x) = select(x < chunk_middle,
min(prev(x), prev(partner)),
max(prev(x), prev(partner)));
} else {
Expr partner = chunk_start + (chunk_index + chunk_size) % (chunk_size*2);
next(x) = select(x < chunk_middle,
min(prev(x), prev(partner)),
max(prev(x), prev(partner)));
}
if (pass_size > 1) {
next.split(x, xo, xi, 2*chunk_size);
}
if (chunk_size > 128) {
next.parallel(xo);
}
next.compute_root();
prev = next;
}
}
return next;
}
Func merge_sort(Func input, int total_size) {
std::vector<Func> stages;
Func result;
const int parallel_work_size = 512;
Func parallel_stage("parallel_stage");
{
assert(input.dimensions() == 1);
Expr a0 = input(4*y);
Expr a1 = input(4*y+1);
Expr a2 = input(4*y+2);
Expr a3 = input(4*y+3);
Expr b0 = min(a0, a1);
Expr b1 = max(a0, a1);
Expr b2 = min(a2, a3);
Expr b3 = max(a2, a3);
a0 = min(b0, b2);
a1 = max(b0, b2);
a2 = min(b1, b3);
a3 = max(b1, b3);
b0 = a0;
b1 = min(a1, a2);
b2 = max(a1, a2);
b3 = a3;
result(x, y) = select(x == 0, b0,
select(x == 1, b1,
select(x == 2, b2, b3)));
result.compute_at(parallel_stage, y).bound(x, 0, 4).unroll(x);
stages.push_back(result);
}
for (int chunk_size = 4; chunk_size < total_size; chunk_size *= 2) {
assert(result.dimensions() == 2);
Func merge_rows("merge_rows");
RDom r(0, chunk_size*2);
merge_rows(x, y) = Tuple(0, 0, cast(input.value().type(), 0));
Expr candidate_a = merge_rows(r-1, y)[0];
Expr candidate_b = merge_rows(r-1, y)[1];
Expr valid_a = candidate_a < chunk_size;
Expr valid_b = candidate_b < chunk_size;
Expr value_a = result(clamp(candidate_a, 0, chunk_size-1), 2*y);
Expr value_b = result(clamp(candidate_b, 0, chunk_size-1), 2*y+1);
merge_rows(r, y) = tuple_select(valid_a && ((value_a < value_b) || !valid_b),
Tuple(candidate_a + 1, candidate_b, value_a),
Tuple(candidate_a, candidate_b + 1, value_b));
if (chunk_size <= parallel_work_size) {
merge_rows.compute_at(parallel_stage, y);
} else {
merge_rows.compute_root();
}
if (chunk_size == parallel_work_size) {
parallel_stage(x, y) = merge_rows(x, y)[2];
parallel_stage.compute_root().parallel(y);
result = parallel_stage;
} else {
result = lambda(x, y, merge_rows(x, y)[2]);
}
}
return lambda(x, result(x, 0));
}
int main(int argc, char **argv) {
const int N = 1 << 10;
Buffer<int> data(N);
for (int i = 0; i < N; i++) {
data(i) = rand() & 0xfffff;
}
Func input = lambda(x, data(x));
printf("Bitonic sort...\n");
Func f = bitonic_sort(input, N);
f.bound(x, 0, N);
f.compile_jit();
printf("Running...\n");
Buffer<int> bitonic_sorted(N);
f.realize(bitonic_sorted);
double t_bitonic = benchmark(1, 10, [&]() {
f.realize(bitonic_sorted);
});
printf("Merge sort...\n");
f = merge_sort(input, N);
f.bound(x, 0, N);
f.compile_jit();
printf("Running...\n");
Buffer<int> merge_sorted(N);
f.realize(merge_sorted);
double t_merge = benchmark(1, 10, [&]() {
f.realize(merge_sorted);
});
Buffer<int> correct(N);
for (int i = 0; i < N; i++) {
correct(i) = data(i);
}
printf("std::sort...\n");
double t_std = benchmark(1, 1, [&]() {
std::sort(&correct(0), &correct(N));
});
printf("Times:\n"
"bitonic sort: %fms \n"
"merge sort: %fms \n"
"std::sort %fms\n",
t_bitonic * 1e3, t_merge * 1e3, t_std * 1e3);
if (N <= 100) {
for (int i = 0; i < N; i++) {
printf("%8d %8d %8d\n",
correct(i), bitonic_sorted(i), merge_sorted(i));
}
}
for (int i = 0; i < N; i++) {
if (bitonic_sorted(i) != correct(i)) {
printf("bitonic sort failed: %d -> %d instead of %d\n", i, bitonic_sorted(i), correct(i));
return -1;
}
if (merge_sorted(i) != correct(i)) {
printf("merge sort failed: %d -> %d instead of %d\n", i, merge_sorted(i), correct(i));
return -1;
}
}
return 0;
}