Skip to content
Snippets Groups Projects
Commit 0d12a2ac authored by Martin Bauer's avatar Martin Bauer
Browse files

pystencils: support for approximate divisions and sqrt's (CUDA)

parent 61800b73
Branches
Tags
No related merge requests found
...@@ -3,6 +3,9 @@ from collections import namedtuple ...@@ -3,6 +3,9 @@ from collections import namedtuple
from sympy.core import S from sympy.core import S
from typing import Set from typing import Set
from sympy.printing.ccode import C89CodePrinter from sympy.printing.ccode import C89CodePrinter
from pystencils.fast_approximation import fast_division, fast_sqrt, fast_inv_sqrt
try: try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter from sympy.printing.ccode import C99CodePrinter as CCodePrinter
except ImportError: except ImportError:
...@@ -98,9 +101,9 @@ class CBackend: ...@@ -98,9 +101,9 @@ class CBackend:
signature_only=False, vector_instruction_set=None, dialect='c'): signature_only=False, vector_instruction_set=None, dialect='c'):
if sympy_printer is None: if sympy_printer is None:
if vector_instruction_set is not None: if vector_instruction_set is not None:
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set) self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, dialect)
else: else:
self.sympy_printer = CustomSympyPrinter() self.sympy_printer = CustomSympyPrinter(dialect)
else: else:
self.sympy_printer = sympy_printer self.sympy_printer = sympy_printer
...@@ -210,9 +213,10 @@ class CBackend: ...@@ -210,9 +213,10 @@ class CBackend:
# noinspection PyPep8Naming # noinspection PyPep8Naming
class CustomSympyPrinter(CCodePrinter): class CustomSympyPrinter(CCodePrinter):
def __init__(self): def __init__(self, dialect):
super(CustomSympyPrinter, self).__init__() super(CustomSympyPrinter, self).__init__()
self._float_type = create_type("float32") self._float_type = create_type("float32")
self._dialect = dialect
if 'Min' in self.known_functions: if 'Min' in self.known_functions:
del self.known_functions['Min'] del self.known_functions['Min']
if 'Max' in self.known_functions: if 'Max' in self.known_functions:
...@@ -259,7 +263,22 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -259,7 +263,22 @@ class CustomSympyPrinter(CCodePrinter):
if isinstance(arg, sp.Number): if isinstance(arg, sp.Number):
return self._typed_number(arg, data_type) return self._typed_number(arg, data_type)
else: else:
return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg)) return "((%s)(%s))" % (data_type, self._print(arg))
elif isinstance(expr, fast_division):
if self._dialect == "cuda":
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
else:
return "({})".format(self._print(expr.args[0] / expr.args[1]))
elif isinstance(expr, fast_sqrt):
if self._dialect == "cuda":
return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
else:
return "({})".format(self._print(sp.sqrt(expr.args[0])))
elif isinstance(expr, fast_inv_sqrt):
if self._dialect == "cuda":
return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
else:
return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
elif expr.func in infix_functions: elif expr.func in infix_functions:
return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1])) return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
else: else:
...@@ -285,8 +304,8 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -285,8 +304,8 @@ class CustomSympyPrinter(CCodePrinter):
class VectorizedCustomSympyPrinter(CustomSympyPrinter): class VectorizedCustomSympyPrinter(CustomSympyPrinter):
SummandInfo = namedtuple("SummandInfo", ['sign', 'term']) SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])
def __init__(self, instruction_set): def __init__(self, instruction_set, dialect):
super(VectorizedCustomSympyPrinter, self).__init__() super(VectorizedCustomSympyPrinter, self).__init__(dialect=dialect)
self.instruction_set = instruction_set self.instruction_set = instruction_set
def _scalarFallback(self, func_name, expr, *args, **kwargs): def _scalarFallback(self, func_name, expr, *args, **kwargs):
...@@ -306,7 +325,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -306,7 +325,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
arg, data_type = expr.args arg, data_type = expr.args
if type(data_type) is VectorType: if type(data_type) is VectorType:
return self.instruction_set['makeVec'].format(self._print(arg)) return self.instruction_set['makeVec'].format(self._print(arg))
elif expr.func == fast_division:
return self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]))
elif expr.func == fast_sqrt:
return "({})".format(self._print(sp.sqrt(expr.args[0])))
elif expr.func == fast_inv_sqrt:
return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
def _print_And(self, expr): def _print_And(self, expr):
......
...@@ -38,17 +38,18 @@ def show_code(ast: KernelFunction): ...@@ -38,17 +38,18 @@ def show_code(ast: KernelFunction):
Can either be displayed as HTML in Jupyter notebooks or printed as normal string. Can either be displayed as HTML in Jupyter notebooks or printed as normal string.
""" """
from pystencils.backends.cbackend import generate_c from pystencils.backends.cbackend import generate_c
dialect = 'cuda' if ast.backend == 'gpucuda' else 'c'
class CodeDisplay: class CodeDisplay:
def __init__(self, ast_input): def __init__(self, ast_input):
self.ast = ast_input self.ast = ast_input
def _repr_html_(self): def _repr_html_(self):
return highlight_cpp(generate_c(self.ast)).__html__() return highlight_cpp(generate_c(self.ast, dialect=dialect)).__html__()
def __str__(self): def __str__(self):
return generate_c(self.ast) return generate_c(self.ast, dialect=dialect)
def __repr__(self): def __repr__(self):
return generate_c(self.ast) return generate_c(self.ast, dialect=dialect)
return CodeDisplay(ast) return CodeDisplay(ast)
import sympy as sp
from typing import List, Union
from pystencils.astnodes import Node
from pystencils.simp import AssignmentCollection
# noinspection PyPep8Naming
class fast_division(sp.Function):
nargs = (2,)
# noinspection PyPep8Naming
class fast_sqrt(sp.Function):
nargs = (1, )
# noinspection PyPep8Naming
class fast_inv_sqrt(sp.Function):
nargs = (1, )
def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]):
def visit(expr):
if isinstance(expr, Node):
return expr
if expr.func == sp.Pow and isinstance(expr.exp, sp.Rational) and expr.exp.q == 2:
power = expr.exp.p
if power < 0:
return fast_inv_sqrt(expr.args[0]) ** (-power)
else:
return fast_sqrt(expr.args[0]) ** power
else:
new_args = [visit(a) for a in expr.args]
return expr.func(*new_args) if new_args else expr
if isinstance(term, AssignmentCollection):
new_main_assignments = insert_fast_sqrts(term.main_assignments)
new_subexpressions = insert_fast_sqrts(term.subexpressions)
return term.copy(new_main_assignments, new_subexpressions)
elif isinstance(term, list):
return [insert_fast_sqrts(e) for e in term]
else:
return visit(term)
def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]):
def visit(expr):
if isinstance(expr, Node):
return expr
if expr.func == sp.Mul:
div_args = []
other_args = []
for a in expr.args:
if a.func == sp.Pow and a.exp.is_integer and a.exp < 0:
div_args.append(visit(a.base) ** (-a.exp))
else:
other_args.append(visit(a))
if div_args:
return fast_division(sp.Mul(*other_args), sp.Mul(*div_args))
else:
return sp.Mul(*other_args)
elif expr.func == sp.Pow and expr.exp.is_integer and expr.exp < 0:
return fast_division(1, visit(expr.base) ** (-expr.exp))
else:
new_args = [visit(a) for a in expr.args]
return expr.func(*new_args) if new_args else expr
if isinstance(term, AssignmentCollection):
new_main_assignments = insert_fast_divisions(term.main_assignments)
new_subexpressions = insert_fast_divisions(term.subexpressions)
return term.copy(new_main_assignments, new_subexpressions)
elif isinstance(term, list):
return [insert_fast_divisions(e) for e in term]
else:
return visit(term)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment