llvm/polly/lib/External/isl/imath/tests/gmp-compat-test/genctest.py

#!/usr/bin/env python
import sys
import gmpapi
from gmpapi import void
from gmpapi import ilong
from gmpapi import iint
from gmpapi import ulong
from gmpapi import mpz_t
from gmpapi import size_t
from gmpapi import charp
from gmpapi import mpq_t


class APITest(object):
    def __init__(self, gmpapi):
        self.api = gmpapi

    def test_prefix(self):
        return "test"

    def test_param_name(self, ty, i):
        if ty == mpz_t:
            pname = "p_zs"
        elif ty == ilong:
            pname = "p_si"
        elif ty == ulong:
            pname = "p_ui"
        elif ty == iint:
            pname = "p_i"
        elif ty == charp:
            pname = "p_cs"
        elif ty == mpq_t:
            pname = "p_qs"
        else:
            raise RuntimeError("Unknown param type: " + str(ty))
        return pname + str(i)

    def test_param_type(self, ty):
        if ty == mpz_t or ty == mpq_t:
            pty_name = "char *"
        else:
            pty_name = str(ty)
        return pty_name

    def test_var_name(self, ty, i):
        if ty == mpz_t:
            vname = "v_z"
        elif ty == ilong:
            vname = "v_si"
        elif ty == ulong:
            vname = "v_ui"
        elif ty == iint:
            vname = "v_i"
        elif ty == size_t:
            vname = "v_st"
        elif ty == charp:
            vname = "v_cs"
        elif ty == mpq_t:
            vname = "v_q"
        else:
            raise RuntimeError("Unknown param type: " + str(ty))
        return vname + str(i)

    def test_var_type(self, ty):
        if ty == mpz_t:
            return self.mpz_type()
        elif ty == mpq_t:
            return self.mpq_type()
        else:
            return str(ty)

    def init_var_from_param(self, ty, var, param):
        code = "\t"
        if ty == mpz_t or ty == mpq_t:
            code += self.api_call_prefix(ty) + "init(" + var + ");\n\t"
            code += (
                self.api_call_prefix(ty)
                + "set_str("
                + ",".join([var, param, "10"])
                + ")"
            )
            if ty == mpq_t:
                code += ";\n\t"
                code += self.api_call_prefix(ty) + "canonicalize(" + var + ")"
        else:
            code += var + "=" + param
        return code

    def init_vars_from_params(self):
        code = ""
        for (i, p) in enumerate(self.api.params):
            param = self.test_param_name(p, i)
            code += "\t"
            code += self.test_var_type(p) + " "
            var = self.test_var_name(p, i)
            code += var + ";\n"
            code += self.init_var_from_param(p, var, param) + ";\n\n"
        return code

    def make_api_call(self):
        bare_name = self.api.name.replace("mpz_", "", 1).replace("mpq_", "", 1)
        call_params = [
            self.test_var_name(p, i) for (i, p) in enumerate(self.api.params)
        ]
        ret = "\t"
        ret_ty = self.api.ret_ty
        if ret_ty != void:
            ret += (
                self.test_var_type(ret_ty)
                + " "
                + self.test_var_name(ret_ty, "_ret")
                + " = "
            )
        # call mpq or mpz function
        if self.api.name.startswith("mpz_"):
            prefix = self.api_call_prefix(mpz_t)
        else:
            prefix = self.api_call_prefix(mpq_t)
        return ret + prefix + bare_name + "(" + ",".join(call_params) + ");\n"

    def normalize_cmp(self, ty):
        cmpval = self.test_var_name(ty, "_ret")
        code = ""
        code += """
	if ({var} > 0)
	  {var} = 1;
	else if ({var} < 0)
	  {var} = -1;\n\t
""".format(
            var=cmpval
        )
        return code

    def extract_result(self, ty, pos):
        code = ""
        if ty == mpz_t or ty == mpq_t:
            var = self.test_var_name(ty, pos)
            code += self.api_call_prefix(ty) + "get_str(out+offset, 10," + var + ");\n"
            code += "\toffset = offset + strlen(out); "
            code += "out[offset] = ' '; out[offset+1] = 0; offset += 1;"
        else:
            assert pos == -1, "expected a return value, not a param value"
            if ty == ilong:
                var = self.test_var_name(ty, "_ret")
                code += 'offset = sprintf(out+offset, " %ld ", ' + var + ");"
            elif ty == ulong:
                var = self.test_var_name(ty, "_ret")
                code += 'offset = sprintf(out+offset, " %lu ", ' + var + ");"
            elif ty == iint:
                var = self.test_var_name(ty, "_ret")
                code += 'offset = sprintf(out+offset, " %d ", ' + var + ");"
            elif ty == size_t:
                var = self.test_var_name(ty, "_ret")
                code += 'offset = sprintf(out+offset, " %zu ", ' + var + ");"
            elif ty == charp:
                var = self.test_var_name(ty, "_ret")
                code += 'offset = sprintf(out+offset, " %s ", ' + var + ");"
            else:
                raise RuntimeError("Unknown param type: " + str(ty))
        return code

    def extract_results(self):
        ret_ty = self.api.ret_ty
        code = "\tint offset = 0;\n\t"

        # normalize cmp return values
        if ret_ty == iint and "cmp" in self.api.name:
            code += self.normalize_cmp(ret_ty)

        # call canonicalize for mpq_set_ui
        if self.api.name == "mpq_set_ui":
            code += (
                self.api_call_prefix(mpq_t)
                + "canonicalize("
                + self.test_var_name(mpq_t, 0)
                + ");\n\t"
            )

        # get return value
        if ret_ty != void:
            code += self.extract_result(ret_ty, -1) + "\n"

        # get out param values
        for pos in self.api.out_params:
            code += "\t"
            code += self.extract_result(self.api.params[pos], pos) + "\n"

        return code + "\n"

    def clear_local_vars(self):
        code = ""
        for (i, p) in enumerate(self.api.params):
            if p == mpz_t or p == mpq_t:
                var = self.test_var_name(p, i)
                code += "\t" + self.api_call_prefix(p) + "clear(" + var + ");\n"
        return code

    def print_test_code(self, outf):
        api = self.api
        params = [
            self.test_param_type(p) + " " + self.test_param_name(p, i)
            for (i, p) in enumerate(api.params)
        ]
        code = "void {}_{}(char *out, {})".format(
            self.test_prefix(), api.name, ", ".join(params)
        )
        code += "{\n"
        code += self.init_vars_from_params()
        code += self.make_api_call()
        code += self.extract_results()
        code += self.clear_local_vars()
        code += "}\n"
        outf.write(code)
        outf.write("\n")


class GMPTest(APITest):
    def __init__(self, gmpapi):
        super(GMPTest, self).__init__(gmpapi)

    def api_call_prefix(self, kind):
        if kind == mpz_t:
            return "mpz_"
        elif kind == mpq_t:
            return "mpq_"
        else:
            raise RuntimeError("Unknown call kind: " + str(kind))

    def mpz_type(self):
        return "mpz_t"

    def mpq_type(self):
        return "mpq_t"


class ImathTest(APITest):
    def __init__(self, gmpapi):
        super(ImathTest, self).__init__(gmpapi)

    def api_call_prefix(self, kind):
        if kind == mpz_t:
            return "impz_"
        elif kind == mpq_t:
            return "impq_"
        else:
            raise RuntimeError("Unknown call kind: " + str(kind))

    def mpz_type(self):
        return "impz_t"

    def mpq_type(self):
        return "impq_t"


def print_gmp_header(outf):
    code = ""
    code += "#include <gmp.h>\n"
    code += "#include <stdio.h>\n"
    code += "#include <string.h>\n"
    code += '#include "gmp_custom_test.c"\n'
    outf.write(code)


def print_imath_header(outf):
    code = ""
    code += "#include <gmp_compat.h>\n"
    code += "#include <stdio.h>\n"
    code += "#include <string.h>\n"
    code += "typedef mpz_t impz_t[1];\n"
    code += "typedef mpq_t impq_t[1];\n"
    code += '#include "imath_custom_test.c"\n'
    outf.write(code)


def print_gmp_tests(outf):
    print_gmp_header(outf)
    for api in gmpapi.apis:
        if not api.custom_test:
            GMPTest(api).print_test_code(outf)


def print_imath_tests(outf):
    print_imath_header(outf)
    for api in gmpapi.apis:
        if not api.custom_test:
            ImathTest(api).print_test_code(outf)


def main():
    test = sys.argv[1]

    if test == "gmp":
        print_gmp_tests(sys.stdout)
    elif test == "imath":
        print_imath_tests(sys.stdout)


if __name__ == "__main__":
    main()