Skip to content
Snippets Groups Projects
Commit 2956a1ab authored by Martin Bauer's avatar Martin Bauer
Browse files

Removed unused "use_float_constants" from C backend

- constants are typed - no need any more for this parameter
parent 9a147ba2
No related merge requests found
......@@ -10,14 +10,14 @@ except ImportError:
from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \
bitwise_or, modulo_floor, modulo_ceil
from pystencils.astnodes import Node, ResolvedFieldAccess, KernelFunction
from pystencils.astnodes import Node, KernelFunction
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \
vector_memory_access
__all__ = ['generate_c', 'CustomCppCode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
def generate_c(ast_node: Node, signature_only: bool = False, use_float_constants: Optional[bool] = None) -> str:
def generate_c(ast_node: Node, signature_only: bool = False) -> str:
"""Prints an abstract syntax tree node as C or CUDA code.
This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
......@@ -27,17 +27,11 @@ def generate_c(ast_node: Node, signature_only: bool = False, use_float_constants
Args:
ast_node:
signature_only:
use_float_constants:
Returns:
C-like code for the ast node and its descendants
"""
if use_float_constants is None:
field_types = set(o.field.dtype for o in ast_node.atoms(ResolvedFieldAccess))
double = create_type('double')
use_float_constants = double not in field_types
printer = CBackend(constants_as_floats=use_float_constants, signature_only=signature_only,
printer = CBackend(signature_only=signature_only,
vector_instruction_set=ast_node.instruction_set)
return printer(ast_node)
......@@ -100,13 +94,13 @@ class PrintNode(CustomCppCode):
# noinspection PyPep8Naming
class CBackend:
def __init__(self, constants_as_floats=False, sympy_printer=None,
def __init__(self, sympy_printer=None,
signature_only=False, vector_instruction_set=None):
if sympy_printer is None:
if vector_instruction_set is not None:
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, constants_as_floats)
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
else:
self.sympy_printer = CustomSympyPrinter(constants_as_floats)
self.sympy_printer = CustomSympyPrinter()
else:
self.sympy_printer = sympy_printer
......@@ -211,8 +205,7 @@ class CBackend:
# noinspection PyPep8Naming
class CustomSympyPrinter(CCodePrinter):
def __init__(self, constants_as_floats=False):
self._constantsAsFloats = constants_as_floats
def __init__(self):
super(CustomSympyPrinter, self).__init__()
self._float_type = create_type("float32")
......@@ -280,8 +273,8 @@ class CustomSympyPrinter(CCodePrinter):
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])
def __init__(self, instruction_set, constants_as_floats=False):
super(VectorizedCustomSympyPrinter, self).__init__(constants_as_floats)
def __init__(self, instruction_set):
super(VectorizedCustomSympyPrinter, self).__init__()
self.instruction_set = instruction_set
def _scalarFallback(self, func_name, expr, *args, **kwargs):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment