diff --git a/backends/cbackend.py b/backends/cbackend.py index e5d9730ff83aeb46de8b0ecc840b063119946275..84efdb9ffe413aa9c8623a45aeb7c011359eba14 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -10,14 +10,14 @@ except ImportError: from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \ bitwise_or, modulo_floor, modulo_ceil -from pystencils.astnodes import Node, ResolvedFieldAccess, KernelFunction +from pystencils.astnodes import Node, KernelFunction from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \ vector_memory_access __all__ = ['generate_c', 'CustomCppCode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] -def generate_c(ast_node: Node, signature_only: bool = False, use_float_constants: Optional[bool] = None) -> str: +def generate_c(ast_node: Node, signature_only: bool = False) -> str: """Prints an abstract syntax tree node as C or CUDA code. This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded @@ -27,17 +27,11 @@ def generate_c(ast_node: Node, signature_only: bool = False, use_float_constants Args: ast_node: signature_only: - use_float_constants: Returns: C-like code for the ast node and its descendants """ - if use_float_constants is None: - field_types = set(o.field.dtype for o in ast_node.atoms(ResolvedFieldAccess)) - double = create_type('double') - use_float_constants = double not in field_types - - printer = CBackend(constants_as_floats=use_float_constants, signature_only=signature_only, + printer = CBackend(signature_only=signature_only, vector_instruction_set=ast_node.instruction_set) return printer(ast_node) @@ -100,13 +94,13 @@ class PrintNode(CustomCppCode): # noinspection PyPep8Naming class CBackend: - def __init__(self, constants_as_floats=False, sympy_printer=None, + def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None): if sympy_printer is None: if vector_instruction_set is not None: - self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, constants_as_floats) + self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set) else: - self.sympy_printer = CustomSympyPrinter(constants_as_floats) + self.sympy_printer = CustomSympyPrinter() else: self.sympy_printer = sympy_printer @@ -211,8 +205,7 @@ class CBackend: # noinspection PyPep8Naming class CustomSympyPrinter(CCodePrinter): - def __init__(self, constants_as_floats=False): - self._constantsAsFloats = constants_as_floats + def __init__(self): super(CustomSympyPrinter, self).__init__() self._float_type = create_type("float32") @@ -280,8 +273,8 @@ class CustomSympyPrinter(CCodePrinter): class VectorizedCustomSympyPrinter(CustomSympyPrinter): SummandInfo = namedtuple("SummandInfo", ['sign', 'term']) - def __init__(self, instruction_set, constants_as_floats=False): - super(VectorizedCustomSympyPrinter, self).__init__(constants_as_floats) + def __init__(self, instruction_set): + super(VectorizedCustomSympyPrinter, self).__init__() self.instruction_set = instruction_set def _scalarFallback(self, func_name, expr, *args, **kwargs):