Commit 0d12a2ac authored by Martin Bauer's avatar Martin Bauer
Browse files

pystencils: support for approximate divisions and sqrt's (CUDA)

parent 61800b73
......@@ -3,6 +3,9 @@ from collections import namedtuple
from sympy.core import S
from typing import Set
from sympy.printing.ccode import C89CodePrinter
from pystencils.fast_approximation import fast_division, fast_sqrt, fast_inv_sqrt
try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
except ImportError:
......@@ -98,9 +101,9 @@ class CBackend:
signature_only=False, vector_instruction_set=None, dialect='c'):
if sympy_printer is None:
if vector_instruction_set is not None:
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, dialect)
else:
self.sympy_printer = CustomSympyPrinter()
self.sympy_printer = CustomSympyPrinter(dialect)
else:
self.sympy_printer = sympy_printer
......@@ -210,9 +213,10 @@ class CBackend:
# noinspection PyPep8Naming
class CustomSympyPrinter(CCodePrinter):
def __init__(self):
def __init__(self, dialect):
super(CustomSympyPrinter, self).__init__()
self._float_type = create_type("float32")
self._dialect = dialect
if 'Min' in self.known_functions:
del self.known_functions['Min']
if 'Max' in self.known_functions:
......@@ -259,7 +263,22 @@ class CustomSympyPrinter(CCodePrinter):
if isinstance(arg, sp.Number):
return self._typed_number(arg, data_type)
else:
return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg))
return "((%s)(%s))" % (data_type, self._print(arg))
elif isinstance(expr, fast_division):
if self._dialect == "cuda":
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
else:
return "({})".format(self._print(expr.args[0] / expr.args[1]))
elif isinstance(expr, fast_sqrt):
if self._dialect == "cuda":
return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
else:
return "({})".format(self._print(sp.sqrt(expr.args[0])))
elif isinstance(expr, fast_inv_sqrt):
if self._dialect == "cuda":
return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
else:
return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
elif expr.func in infix_functions:
return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
else:
......@@ -285,8 +304,8 @@ class CustomSympyPrinter(CCodePrinter):
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])
def __init__(self, instruction_set):
super(VectorizedCustomSympyPrinter, self).__init__()
def __init__(self, instruction_set, dialect):
super(VectorizedCustomSympyPrinter, self).__init__(dialect=dialect)
self.instruction_set = instruction_set
def _scalarFallback(self, func_name, expr, *args, **kwargs):
......@@ -306,7 +325,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
arg, data_type = expr.args
if type(data_type) is VectorType:
return self.instruction_set['makeVec'].format(self._print(arg))
elif expr.func == fast_division:
return self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]))
elif expr.func == fast_sqrt:
return "({})".format(self._print(sp.sqrt(expr.args[0])))
elif expr.func == fast_inv_sqrt:
return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
def _print_And(self, expr):
......
......@@ -38,17 +38,18 @@ def show_code(ast: KernelFunction):
Can either be displayed as HTML in Jupyter notebooks or printed as normal string.
"""
from pystencils.backends.cbackend import generate_c
dialect = 'cuda' if ast.backend == 'gpucuda' else 'c'
class CodeDisplay:
def __init__(self, ast_input):
self.ast = ast_input
def _repr_html_(self):
return highlight_cpp(generate_c(self.ast)).__html__()
return highlight_cpp(generate_c(self.ast, dialect=dialect)).__html__()
def __str__(self):
return generate_c(self.ast)
return generate_c(self.ast, dialect=dialect)
def __repr__(self):
return generate_c(self.ast)
return generate_c(self.ast, dialect=dialect)
return CodeDisplay(ast)
import sympy as sp
from typing import List, Union
from pystencils.astnodes import Node
from pystencils.simp import AssignmentCollection
# noinspection PyPep8Naming
class fast_division(sp.Function):
nargs = (2,)
# noinspection PyPep8Naming
class fast_sqrt(sp.Function):
nargs = (1, )
# noinspection PyPep8Naming
class fast_inv_sqrt(sp.Function):
nargs = (1, )
def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]):
def visit(expr):
if isinstance(expr, Node):
return expr
if expr.func == sp.Pow and isinstance(expr.exp, sp.Rational) and expr.exp.q == 2:
power = expr.exp.p
if power < 0:
return fast_inv_sqrt(expr.args[0]) ** (-power)
else:
return fast_sqrt(expr.args[0]) ** power
else:
new_args = [visit(a) for a in expr.args]
return expr.func(*new_args) if new_args else expr
if isinstance(term, AssignmentCollection):
new_main_assignments = insert_fast_sqrts(term.main_assignments)
new_subexpressions = insert_fast_sqrts(term.subexpressions)
return term.copy(new_main_assignments, new_subexpressions)
elif isinstance(term, list):
return [insert_fast_sqrts(e) for e in term]
else:
return visit(term)
def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]):
def visit(expr):
if isinstance(expr, Node):
return expr
if expr.func == sp.Mul:
div_args = []
other_args = []
for a in expr.args:
if a.func == sp.Pow and a.exp.is_integer and a.exp < 0:
div_args.append(visit(a.base) ** (-a.exp))
else:
other_args.append(visit(a))
if div_args:
return fast_division(sp.Mul(*other_args), sp.Mul(*div_args))
else:
return sp.Mul(*other_args)
elif expr.func == sp.Pow and expr.exp.is_integer and expr.exp < 0:
return fast_division(1, visit(expr.base) ** (-expr.exp))
else:
new_args = [visit(a) for a in expr.args]
return expr.func(*new_args) if new_args else expr
if isinstance(term, AssignmentCollection):
new_main_assignments = insert_fast_divisions(term.main_assignments)
new_subexpressions = insert_fast_divisions(term.subexpressions)
return term.copy(new_main_assignments, new_subexpressions)
elif isinstance(term, list):
return [insert_fast_divisions(e) for e in term]
else:
return visit(term)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment