diff --git a/backends/cbackend.py b/backends/cbackend.py index 54ef38149199d844d25561e7878d281875fc8ad1..7dd75d0a6f37b7051b65b40d5facc0f7d7cfad44 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -3,6 +3,9 @@ from collections import namedtuple from sympy.core import S from typing import Set from sympy.printing.ccode import C89CodePrinter + +from pystencils.fast_approximation import fast_division, fast_sqrt, fast_inv_sqrt + try: from sympy.printing.ccode import C99CodePrinter as CCodePrinter except ImportError: @@ -98,9 +101,9 @@ class CBackend: signature_only=False, vector_instruction_set=None, dialect='c'): if sympy_printer is None: if vector_instruction_set is not None: - self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set) + self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, dialect) else: - self.sympy_printer = CustomSympyPrinter() + self.sympy_printer = CustomSympyPrinter(dialect) else: self.sympy_printer = sympy_printer @@ -210,9 +213,10 @@ class CBackend: # noinspection PyPep8Naming class CustomSympyPrinter(CCodePrinter): - def __init__(self): + def __init__(self, dialect): super(CustomSympyPrinter, self).__init__() self._float_type = create_type("float32") + self._dialect = dialect if 'Min' in self.known_functions: del self.known_functions['Min'] if 'Max' in self.known_functions: @@ -259,7 +263,22 @@ class CustomSympyPrinter(CCodePrinter): if isinstance(arg, sp.Number): return self._typed_number(arg, data_type) else: - return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg)) + return "((%s)(%s))" % (data_type, self._print(arg)) + elif isinstance(expr, fast_division): + if self._dialect == "cuda": + return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args) + else: + return "({})".format(self._print(expr.args[0] / expr.args[1])) + elif isinstance(expr, fast_sqrt): + if self._dialect == "cuda": + return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) + else: + return "({})".format(self._print(sp.sqrt(expr.args[0]))) + elif isinstance(expr, fast_inv_sqrt): + if self._dialect == "cuda": + return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) + else: + return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) elif expr.func in infix_functions: return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1])) else: @@ -285,8 +304,8 @@ class CustomSympyPrinter(CCodePrinter): class VectorizedCustomSympyPrinter(CustomSympyPrinter): SummandInfo = namedtuple("SummandInfo", ['sign', 'term']) - def __init__(self, instruction_set): - super(VectorizedCustomSympyPrinter, self).__init__() + def __init__(self, instruction_set, dialect): + super(VectorizedCustomSympyPrinter, self).__init__(dialect=dialect) self.instruction_set = instruction_set def _scalarFallback(self, func_name, expr, *args, **kwargs): @@ -306,7 +325,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): arg, data_type = expr.args if type(data_type) is VectorType: return self.instruction_set['makeVec'].format(self._print(arg)) - + elif expr.func == fast_division: + return self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1])) + elif expr.func == fast_sqrt: + return "({})".format(self._print(sp.sqrt(expr.args[0]))) + elif expr.func == fast_inv_sqrt: + return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) def _print_And(self, expr): diff --git a/display_utils.py b/display_utils.py index 3607e45b99bd1f6cc81415a05ad300ef344672fc..55a4720c194e250cd549279ae76607c48968c676 100644 --- a/display_utils.py +++ b/display_utils.py @@ -38,17 +38,18 @@ def show_code(ast: KernelFunction): Can either be displayed as HTML in Jupyter notebooks or printed as normal string. """ from pystencils.backends.cbackend import generate_c + dialect = 'cuda' if ast.backend == 'gpucuda' else 'c' class CodeDisplay: def __init__(self, ast_input): self.ast = ast_input def _repr_html_(self): - return highlight_cpp(generate_c(self.ast)).__html__() + return highlight_cpp(generate_c(self.ast, dialect=dialect)).__html__() def __str__(self): - return generate_c(self.ast) + return generate_c(self.ast, dialect=dialect) def __repr__(self): - return generate_c(self.ast) + return generate_c(self.ast, dialect=dialect) return CodeDisplay(ast) diff --git a/fast_approximation.py b/fast_approximation.py new file mode 100644 index 0000000000000000000000000000000000000000..538e493f9bebcafa1eb788050758a4c93b556332 --- /dev/null +++ b/fast_approximation.py @@ -0,0 +1,75 @@ +import sympy as sp +from typing import List, Union + +from pystencils.astnodes import Node +from pystencils.simp import AssignmentCollection + + +# noinspection PyPep8Naming +class fast_division(sp.Function): + nargs = (2,) + +# noinspection PyPep8Naming +class fast_sqrt(sp.Function): + nargs = (1, ) + +# noinspection PyPep8Naming +class fast_inv_sqrt(sp.Function): + nargs = (1, ) + + +def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]): + def visit(expr): + if isinstance(expr, Node): + return expr + if expr.func == sp.Pow and isinstance(expr.exp, sp.Rational) and expr.exp.q == 2: + power = expr.exp.p + if power < 0: + return fast_inv_sqrt(expr.args[0]) ** (-power) + else: + return fast_sqrt(expr.args[0]) ** power + else: + new_args = [visit(a) for a in expr.args] + return expr.func(*new_args) if new_args else expr + + if isinstance(term, AssignmentCollection): + new_main_assignments = insert_fast_sqrts(term.main_assignments) + new_subexpressions = insert_fast_sqrts(term.subexpressions) + return term.copy(new_main_assignments, new_subexpressions) + elif isinstance(term, list): + return [insert_fast_sqrts(e) for e in term] + else: + return visit(term) + + +def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]): + + def visit(expr): + if isinstance(expr, Node): + return expr + if expr.func == sp.Mul: + div_args = [] + other_args = [] + for a in expr.args: + if a.func == sp.Pow and a.exp.is_integer and a.exp < 0: + div_args.append(visit(a.base) ** (-a.exp)) + else: + other_args.append(visit(a)) + if div_args: + return fast_division(sp.Mul(*other_args), sp.Mul(*div_args)) + else: + return sp.Mul(*other_args) + elif expr.func == sp.Pow and expr.exp.is_integer and expr.exp < 0: + return fast_division(1, visit(expr.base) ** (-expr.exp)) + else: + new_args = [visit(a) for a in expr.args] + return expr.func(*new_args) if new_args else expr + + if isinstance(term, AssignmentCollection): + new_main_assignments = insert_fast_divisions(term.main_assignments) + new_subexpressions = insert_fast_divisions(term.subexpressions) + return term.copy(new_main_assignments, new_subexpressions) + elif isinstance(term, list): + return [insert_fast_divisions(e) for e in term] + else: + return visit(term)