import sympy as sp from collections import namedtuple from sympy.core import S from typing import Set from sympy.printing.ccode import C89CodePrinter from pystencils.fast_approximation import fast_division, fast_sqrt, fast_inv_sqrt try: from sympy.printing.ccode import C99CodePrinter as CCodePrinter except ImportError: from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1 from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \ bitwise_or, modulo_ceil from pystencils.astnodes import Node, KernelFunction from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \ vector_memory_access, reinterpret_cast_func __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str: """Prints an abstract syntax tree node as C or CUDA code. This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different create_kernel functions. Args: ast_node: signature_only: dialect: 'c' or 'cuda' Returns: C-like code for the ast node and its descendants """ printer = CBackend(signature_only=signature_only, vector_instruction_set=ast_node.instruction_set, dialect=dialect) return printer(ast_node) def get_headers(ast_node: Node) -> Set[str]: """Return a set of header files, necessary to compile the printed C-like code.""" headers = set() if isinstance(ast_node, KernelFunction) and ast_node.instruction_set: headers.update(ast_node.instruction_set['headers']) if hasattr(ast_node, 'headers'): headers.update(ast_node.headers) for a in ast_node.args: if isinstance(a, Node): headers.update(get_headers(a)) return headers # --------------------------------------- Backend Specific Nodes ------------------------------------------------------- class CustomCodeNode(Node): def __init__(self, code, symbols_read, symbols_defined, parent=None): super(CustomCodeNode, self).__init__(parent=parent) self._code = "\n" + code self._symbolsRead = set(symbols_read) self._symbolsDefined = set(symbols_defined) self.headers = [] def get_code(self, dialect, vector_instruction_set): return self._code @property def args(self): return [] @property def symbols_defined(self): return self._symbolsDefined @property def undefined_symbols(self): return self.symbols_defined - self._symbolsRead class PrintNode(CustomCodeNode): # noinspection SpellCheckingInspection def __init__(self, symbol_to_print): code = '\nstd::cout << "%s = " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name) super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set()) self.headers.append("<iostream>") # ------------------------------------------- Printer ------------------------------------------------------------------ # noinspection PyPep8Naming class CBackend: def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'): if sympy_printer is None: if vector_instruction_set is not None: self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, dialect) else: self.sympy_printer = CustomSympyPrinter(dialect) else: self.sympy_printer = sympy_printer self._vector_instruction_set = vector_instruction_set self._indent = " " self._dialect = dialect self._signatureOnly = signature_only def __call__(self, node): prev_is = VectorType.instruction_set VectorType.instruction_set = self._vector_instruction_set result = str(self._print(node)) VectorType.instruction_set = prev_is return result def _print(self, node): for cls in type(node).__mro__: method_name = "_print_" + cls.__name__ if hasattr(self, method_name): return getattr(self, method_name)(node) raise NotImplementedError("CBackend does not support node of type " + str(type(node))) def _print_KernelFunction(self, node): function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()] func_declaration = "FUNC_PREFIX void %s(%s)" % (node.function_name, ", ".join(function_arguments)) if self._signatureOnly: return func_declaration body = self._print(node.body) return func_declaration + "\n" + body def _print_Block(self, node): block_contents = "\n".join([self._print(child) for child in node.args]) return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True))) def _print_PragmaBlock(self, node): return "%s\n%s" % (node.pragma_line, self._print_Block(node)) def _print_LoopOverCoordinate(self, node): counter_symbol = node.loop_counter_name start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start)) condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop)) update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),) loop_str = "for (%s; %s; %s)" % (start, condition, update) prefix = "\n".join(node.prefix_lines) if prefix: prefix += "\n" return "%s%s\n%s" % (prefix, loop_str, self._print(node.body)) def _print_SympyAssignment(self, node): if node.is_declaration: data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " " return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs)) else: lhs_type = get_type_of_expression(node.lhs) if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func): arg, data_type, aligned, nontemporal = node.lhs.args instr = 'storeU' if aligned: instr = 'stream' if nontemporal else 'storeA' rhs_type = get_type_of_expression(node.rhs) if type(rhs_type) is not VectorType: rhs = cast_func(node.rhs, VectorType(rhs_type)) else: rhs = node.rhs return self._vector_instruction_set[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]), self.sympy_printer.doprint(rhs)) + ';' else: return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs)) def _print_TemporaryMemoryAllocation(self, node): align = 64 np_dtype = node.symbol.dtype.base_type.numpy_dtype required_size = np_dtype.itemsize * node.size + align size = modulo_ceil(required_size, align) code = "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};" return code.format(dtype=node.symbol.dtype, name=self.sympy_printer.doprint(node.symbol.name), size=self.sympy_printer.doprint(size), offset=int(node.offset(align)), align=align) def _print_TemporaryMemoryFree(self, node): align = 64 return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align)) def _print_CustomCodeNode(self, node): return node.get_code(self._dialect, self._vector_instruction_set) def _print_Conditional(self, node): condition_expr = self.sympy_printer.doprint(node.condition_expr) true_block = self._print_Block(node.true_block) result = "if (%s)\n%s " % (condition_expr, true_block) if node.false_block: false_block = self._print_Block(node.false_block) result += "else " + false_block return result # ------------------------------------------ Helper function & classes ------------------------------------------------- # noinspection PyPep8Naming class CustomSympyPrinter(CCodePrinter): def __init__(self, dialect): super(CustomSympyPrinter, self).__init__() self._float_type = create_type("float32") self._dialect = dialect if 'Min' in self.known_functions: del self.known_functions['Min'] if 'Max' in self.known_functions: del self.known_functions['Max'] def _print_Pow(self, expr): """Don't use std::pow function, for small integer exponents, write as multiplication""" if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: return "1 / ({})".format(self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False))) else: return super(CustomSympyPrinter, self)._print_Pow(expr) def _print_Rational(self, expr): """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0""" res = str(expr.evalf().num) return res def _print_Equality(self, expr): """Equality operator is not printable in default printer""" return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))' def _print_Piecewise(self, expr): """Print piecewise in one line (remove newlines)""" result = super(CustomSympyPrinter, self)._print_Piecewise(expr) return result.replace("\n", "") def _print_Function(self, expr): infix_functions = { bitwise_xor: '^', bit_shift_right: '>>', bit_shift_left: '<<', bitwise_or: '|', bitwise_and: '&', } if hasattr(expr, 'to_c'): return expr.to_c(self._print) if isinstance(expr, reinterpret_cast_func): arg, data_type = expr.args return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg)) elif isinstance(expr, cast_func): arg, data_type = expr.args if isinstance(arg, sp.Number): return self._typed_number(arg, data_type) else: return "((%s)(%s))" % (data_type, self._print(arg)) elif isinstance(expr, fast_division): if self._dialect == "cuda": return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args) else: return "({})".format(self._print(expr.args[0] / expr.args[1])) elif isinstance(expr, fast_sqrt): if self._dialect == "cuda": return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) else: return "({})".format(self._print(sp.sqrt(expr.args[0]))) elif isinstance(expr, fast_inv_sqrt): if self._dialect == "cuda": return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) else: return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) elif expr.func in infix_functions: return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1])) else: return super(CustomSympyPrinter, self)._print_Function(expr) def _typed_number(self, number, dtype): res = self._print(number) if dtype.is_float(): if dtype == self._float_type: if '.' not in res: res += ".0f" else: res += "f" return res else: return res _print_Max = C89CodePrinter._print_Max _print_Min = C89CodePrinter._print_Min # noinspection PyPep8Naming class VectorizedCustomSympyPrinter(CustomSympyPrinter): SummandInfo = namedtuple("SummandInfo", ['sign', 'term']) def __init__(self, instruction_set, dialect): super(VectorizedCustomSympyPrinter, self).__init__(dialect=dialect) self.instruction_set = instruction_set def _scalarFallback(self, func_name, expr, *args, **kwargs): expr_type = get_type_of_expression(expr) if type(expr_type) is not VectorType: return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs) else: assert self.instruction_set['width'] == expr_type.width return None def _print_Function(self, expr): if isinstance(expr, vector_memory_access): arg, data_type, aligned, _ = expr.args instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] return instruction.format("& " + self._print(arg)) elif isinstance(expr, cast_func): arg, data_type = expr.args if type(data_type) is VectorType: return self.instruction_set['makeVec'].format(self._print(arg)) elif expr.func == fast_division: return self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1])) elif expr.func == fast_sqrt: return "({})".format(self._print(sp.sqrt(expr.args[0]))) elif expr.func == fast_inv_sqrt: if self.instruction_set['rsqrt']: return self.instruction_set['rsqrt'].format(self._print(expr.args[0])) else: return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) def _print_And(self, expr): result = self._scalarFallback('_print_And', expr) if result: return result arg_strings = [self._print(a) for a in expr.args] assert len(arg_strings) > 0 result = arg_strings[0] for item in arg_strings[1:]: result = self.instruction_set['&'].format(result, item) return result def _print_Or(self, expr): result = self._scalarFallback('_print_Or', expr) if result: return result arg_strings = [self._print(a) for a in expr.args] assert len(arg_strings) > 0 result = arg_strings[0] for item in arg_strings[1:]: result = self.instruction_set['|'].format(result, item) return result def _print_Add(self, expr, order=None): result = self._scalarFallback('_print_Add', expr) if result: return result summands = [] for term in expr.args: if term.func == sp.Mul: sign, t = self._print_Mul(term, inside_add=True) else: t = self._print(term) sign = 1 summands.append(self.SummandInfo(sign, t)) # Use positive terms first summands.sort(key=lambda e: e.sign, reverse=True) # if no positive term exists, prepend a zero if summands[0].sign == -1: summands.insert(0, self.SummandInfo(1, "0")) assert len(summands) >= 2 processed = summands[0].term for summand in summands[1:]: func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+'] processed = func.format(processed, summand.term) return processed def _print_Pow(self, expr): result = self._scalarFallback('_print_Pow', expr) if result: return result if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" elif expr.exp == -1: one = self.instruction_set['makeVec'].format(1.0) return self.instruction_set['/'].format(one, self._print(expr.base)) elif expr.exp == 0.5: return self.instruction_set['sqrt'].format(self._print(expr.base)) elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: one = self.instruction_set['makeVec'].format(1.0) return self.instruction_set['/'].format(one, self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False))) else: raise ValueError("Generic exponential not supported: " + str(expr)) def _print_Mul(self, expr, inside_add=False): # noinspection PyProtectedMember from sympy.core.mul import _keep_coeff result = self._scalarFallback('_print_Mul', expr) if result: return result c, e = expr.as_coeff_Mul() if c < 0: expr = _keep_coeff(-c, e) sign = -1 else: sign = 1 a = [] # items in the numerator b = [] # items that are in the denominator (if any) # Gather args for numerator/denominator for item in expr.as_ordered_factors(): if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative: if item.exp != -1: b.append(sp.Pow(item.base, -item.exp, evaluate=False)) else: b.append(sp.Pow(item.base, -item.exp)) else: a.append(item) a = a or [S.One] a_str = [self._print(x) for x in a] b_str = [self._print(x) for x in b] result = a_str[0] for item in a_str[1:]: result = self.instruction_set['*'].format(result, item) if len(b) > 0: denominator_str = b_str[0] for item in b_str[1:]: denominator_str = self.instruction_set['*'].format(denominator_str, item) result = self.instruction_set['/'].format(result, denominator_str) if inside_add: return sign, result else: if sign < 0: return self.instruction_set['*'].format(self._print(S.NegativeOne), result) else: return result def _print_Relational(self, expr): result = self._scalarFallback('_print_Relational', expr) if result: return result return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs)) def _print_Equality(self, expr): result = self._scalarFallback('_print_Equality', expr) if result: return result return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs)) def _print_Piecewise(self, expr): result = self._scalarFallback('_print_Piecewise', expr) if result: return result if expr.args[-1].cond.args[0] is not sp.sympify(True): # We need the last conditional to be a True, otherwise the resulting # function may not return a result. raise ValueError("All Piecewise expressions must contain an " "(expr, True) statement to be used as a default " "condition. Without one, the generated " "expression may not evaluate to anything under " "some condition.") result = self._print(expr.args[-1][0]) for true_expr, condition in reversed(expr.args[:-1]): # noinspection SpellCheckingInspection result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition)) return result