Commit 2956a1ab authored by Martin Bauer's avatar Martin Bauer
Browse files

Removed unused "use_float_constants" from C backend

- constants are typed - no need any more for this parameter
parent 9a147ba2
...@@ -10,14 +10,14 @@ except ImportError: ...@@ -10,14 +10,14 @@ except ImportError:
from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \ from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \
bitwise_or, modulo_floor, modulo_ceil bitwise_or, modulo_floor, modulo_ceil
from pystencils.astnodes import Node, ResolvedFieldAccess, KernelFunction from pystencils.astnodes import Node, KernelFunction
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \ from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \
vector_memory_access vector_memory_access
__all__ = ['generate_c', 'CustomCppCode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] __all__ = ['generate_c', 'CustomCppCode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
def generate_c(ast_node: Node, signature_only: bool = False, use_float_constants: Optional[bool] = None) -> str: def generate_c(ast_node: Node, signature_only: bool = False) -> str:
"""Prints an abstract syntax tree node as C or CUDA code. """Prints an abstract syntax tree node as C or CUDA code.
This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
...@@ -27,17 +27,11 @@ def generate_c(ast_node: Node, signature_only: bool = False, use_float_constants ...@@ -27,17 +27,11 @@ def generate_c(ast_node: Node, signature_only: bool = False, use_float_constants
Args: Args:
ast_node: ast_node:
signature_only: signature_only:
use_float_constants:
Returns: Returns:
C-like code for the ast node and its descendants C-like code for the ast node and its descendants
""" """
if use_float_constants is None: printer = CBackend(signature_only=signature_only,
field_types = set(o.field.dtype for o in ast_node.atoms(ResolvedFieldAccess))
double = create_type('double')
use_float_constants = double not in field_types
printer = CBackend(constants_as_floats=use_float_constants, signature_only=signature_only,
vector_instruction_set=ast_node.instruction_set) vector_instruction_set=ast_node.instruction_set)
return printer(ast_node) return printer(ast_node)
...@@ -100,13 +94,13 @@ class PrintNode(CustomCppCode): ...@@ -100,13 +94,13 @@ class PrintNode(CustomCppCode):
# noinspection PyPep8Naming # noinspection PyPep8Naming
class CBackend: class CBackend:
def __init__(self, constants_as_floats=False, sympy_printer=None, def __init__(self, sympy_printer=None,
signature_only=False, vector_instruction_set=None): signature_only=False, vector_instruction_set=None):
if sympy_printer is None: if sympy_printer is None:
if vector_instruction_set is not None: if vector_instruction_set is not None:
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, constants_as_floats) self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
else: else:
self.sympy_printer = CustomSympyPrinter(constants_as_floats) self.sympy_printer = CustomSympyPrinter()
else: else:
self.sympy_printer = sympy_printer self.sympy_printer = sympy_printer
...@@ -211,8 +205,7 @@ class CBackend: ...@@ -211,8 +205,7 @@ class CBackend:
# noinspection PyPep8Naming # noinspection PyPep8Naming
class CustomSympyPrinter(CCodePrinter): class CustomSympyPrinter(CCodePrinter):
def __init__(self, constants_as_floats=False): def __init__(self):
self._constantsAsFloats = constants_as_floats
super(CustomSympyPrinter, self).__init__() super(CustomSympyPrinter, self).__init__()
self._float_type = create_type("float32") self._float_type = create_type("float32")
...@@ -280,8 +273,8 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -280,8 +273,8 @@ class CustomSympyPrinter(CCodePrinter):
class VectorizedCustomSympyPrinter(CustomSympyPrinter): class VectorizedCustomSympyPrinter(CustomSympyPrinter):
SummandInfo = namedtuple("SummandInfo", ['sign', 'term']) SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])
def __init__(self, instruction_set, constants_as_floats=False): def __init__(self, instruction_set):
super(VectorizedCustomSympyPrinter, self).__init__(constants_as_floats) super(VectorizedCustomSympyPrinter, self).__init__()
self.instruction_set = instruction_set self.instruction_set = instruction_set
def _scalarFallback(self, func_name, expr, *args, **kwargs): def _scalarFallback(self, func_name, expr, *args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment