root/tools/find_inverse.cpp

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. r
  2. sdiv
  3. u_method_0
  4. u_method_1
  5. u_method_2
  6. s_method_0
  7. s_method_1
  8. main

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <assert.h>
#include <algorithm>

int64_t r(int64_t min, int64_t max) {
    int64_t n1 = rand();
    int64_t n2 = rand();
    int64_t n3 = rand();
    n1 = n1 ^ (n2 << 16) ^ (n3 << 16);
    n1 = n1 % (max - min);
    n1 = n1 + min;
    return n1;
}

int64_t sdiv(int64_t a, int64_t b) {
  return (a - ((a % b) + b) % b) / b;
}

bool u_method_0(int den, int sh_post, int bits) {
    uint64_t max = (1L << bits) - 1;
    //for (int64_t num = 0; num <= max; num++) {
    for (unsigned iter = 0; iter < 1000000UL; iter++) {
        uint64_t num = r(0, max);
        // Make sure we hit the extremes
        if (iter == 0) num = 0;
        if (iter == 1) num = max;
        uint64_t result = num;
        result >>= sh_post;
        if (num / den != result) return false;
    }
    return true;
}

bool u_method_1(int den, int64_t mul, int sh_post, int bits) {
    uint64_t max = (1L << bits) - 1;
    //for (uint64_t num = 0; num <= max; num++) {
    for (unsigned iter = 0; iter < 1000000UL; iter++) {
        uint64_t num = r(0, max);
        // Make sure we hit the extremes
        if (iter == 0) num = 0;
        if (iter == 1) num = max;
        uint64_t result = num;
        result *= mul;
        result >>= bits;
        if (result > max) return false;
        result >>= sh_post;
        if (num / den != result) return false;
    }
    return true;
}

bool u_method_2(int den, int64_t mul, int sh_post, int bits) {
    uint64_t max = (1UL << bits) - 1;
    //for (uint64_t num = 0; num <= max; num++) {
    for (unsigned iter = 0; iter < 1000000UL; iter++) {
        uint64_t num = r(0, max);
        // Make sure we hit the extremes
        if (iter == 0) num = 0;
        if (iter == 1) num = max;
        uint64_t result = num;
        result *= mul;
        result >>= bits;
        if (result > max) return false;
        result += (num - result)>>1;
        if (result > max) return false;
        result >>= sh_post;
        if (num / den != result) return false;
    }
    return true;
}

bool s_method_0(int den, int sh_post, int bits) {
    int64_t min = -(1L << (bits-1)), max = (1L << (bits-1))-1;
    //for (int64_t num = min; num <= max; num++) {
    for (int iter = 0; iter < 1000000L; iter++) {
        int64_t num = r(min, max);
        // Make sure we hit the extremes
        if (iter == 0) num = min;
        if (iter == 1) num = max;
        int64_t result = num;
        result >>= sh_post;
        if (sdiv(num, den) != result) return false;
    }
    return true;
}

bool s_method_1(int den, int64_t mul, int sh_post, int bits) {
    int64_t min = -(1 << (bits-1)), max = (1 << (bits-1))-1;

    //for (int64_t num = min; num <= max; num++) {
    for (int iter = 0; iter < 1000000L; iter++) {
        int64_t num = r(min, max);
        // Make sure we hit the extremes
        if (iter == 0) num = min;
        if (iter == 1) num = max;
        int64_t result = num;
        uint64_t xsign = result >> (bits-1);
        uint64_t q0 = (mul * (xsign ^ result)) >> bits;
        result = xsign ^ (q0 >> sh_post);
        if (sdiv(num, den) != result) return false;
    }
    return true;
}

int main(int argc, char **argv) {
    /* This program computes a table to help us do cheap integer
        division by a constant. It is based on the paper "Division by
        Invariant Integers using Multiplication" by Granlund and
        Montgomery.
    */

    FILE *c_out = fopen("IntegerDivisionTable.cpp", "w");
    FILE *h_out = fopen("IntegerDivisionTable.h", "w");

    fprintf(h_out, "%s",
        "#ifndef HALIDE_INTEGER_DIVISION_TABLE_H\n"
        "#define HALIDE_INTEGER_DIVISION_TABLE_H\n"
        "\n"
        "#include <cstdint>\n"
        "\n"
        "/** \\file\n"
        " * Tables telling us how to do integer division via fixed-point\n"
        " * multiplication for various small constants. This file is \n"
        " * automatically generated by find_inverse.cpp.\n"
        " */\n"
        "namespace Halide {\n"
        "namespace Internal {\n"
        "namespace IntegerDivision {\n");

    fprintf(c_out, "%s",
           "/** \\file\n"
           " * Tables telling us how to do integer division\n"
           " * via fixed-point multiplication for various small\n"
           " * constants. This file is automatically generated\n"
           " * by find_inverse.cpp. There are two sets of tables.\n"
           " * The first set is for compile-time-constant divisors\n"
           " * from 2 to 256. The second is for runtime divisors\n"
           " * from 1 to 255. The second set always uses the most\n"
           " * expensive method, while the compile-time set uses\n"
           " * the cheapest method for the given divisor.\n"
           " */\n"
           "\n"
           "#include \"IntegerDivisionTable.h\"\n"
           "\n"
           "namespace Halide {\n"
           "namespace Internal {\n"
           "namespace IntegerDivision {\n\n");

    for (int runtime = 0; runtime < 2; runtime++) {
        for (int bits = 8; bits <= 32; bits *= 2) {
            printf("Generating table%s_u%d...\n", runtime ? "_runtime" : "", bits);
            if (runtime) {
                fprintf(h_out, "extern const int64_t table_runtime_u%d[256][4];\n", bits);
                fprintf(c_out, "const int64_t table_runtime_u%d[256][4] = {\n", bits);
            } else {
                fprintf(h_out, "extern const int64_t table_u%d[256][4];\n", bits);
                fprintf(c_out, "const int64_t table_u%d[256][4] = {\n", bits);
            }
            for (int d = 0; d < 256; d++) {
                int den = d;
                if (den == 0) den = 256;
                if (!runtime) {
                    for (int shift = 0; shift < 16; shift++) {
                        if (u_method_0(den, shift, bits)) {
                            fprintf(c_out, "    {%d, 0, 0, %d},\n", den, shift);
                            goto next_unsigned;
                        }
                    }

                    for (int shift = 0; shift < 8; shift++) {
                        int64_t mul = (1L << (bits+shift)) / den + 1;
                        if (u_method_1(den, mul, shift, bits)) {
                            fprintf(c_out, "    {%d, 1, %lldULL, %d},\n", den, (long long) mul, shift);
                            goto next_unsigned;
                        }
                    }
                }

                for (int shift = 0; shift < 8; shift++) {
                    int64_t mul = (1L << (bits+shift+1)) / den - (1L << bits) + 1;
                    if (u_method_2(den, mul, shift, bits)) {
                        fprintf(c_out, "    {%d, 2, %lldULL, %d},\n", den, (long long) mul, shift);
                        goto next_unsigned;
                    }
                }
                fprintf(c_out, "ERROR! No solution found for unsigned %d\n", den);
              next_unsigned:;
            }
            fprintf(c_out, "};\n");
            printf("Generating table%s_s%d...\n", runtime ? "_runtime" : "", bits);
            if (runtime) {
                fprintf(h_out, "extern const int64_t table_runtime_s%d[256][4];\n", bits);
                fprintf(c_out, "const int64_t table_runtime_s%d[256][4] = {\n", bits);
            } else {
                fprintf(h_out, "extern const int64_t table_s%d[256][4];\n", bits);
                fprintf(c_out, "const int64_t table_s%d[256][4] = {\n", bits);
            }
            for (int d = 0; d < 256; d++) {
                int den = d;
                if (den == 0) den = 256;
                if (!runtime) {
                    for (int shift = 0; shift < 8; shift++) {
                        if (s_method_0(den, shift, bits)) {
                            fprintf(c_out, "    {%d, 0, 0, %d},\n", den, shift);
                            goto next_signed;
                        }
                    }
                }

                for (int shift = 0; shift < 8; shift++) {
                    int64_t mul = (1L << (shift + bits)) / den + 1;
                    if (s_method_1(den, mul, shift, bits)) {
                        fprintf(c_out, "    {%d, 1, %lldLL, %d},\n", den, (long long) mul, shift);
                        goto next_signed;
                    }
                }
                fprintf(c_out, "ERROR! No solution found for signed %d\n", den);
              next_signed:;
            }
            fprintf(c_out, "};\n");
        }
    }

    fprintf(h_out, "}\n}\n}\n\n#endif\n");
    fprintf(c_out, "}\n}\n}\n");

    fclose(h_out);
    fclose(c_out);

    return 0;
}

/* [<][>][^][v][top][bottom][index][help] */