Commit b72ef215 authored by Martin Bauer's avatar Martin Bauer
Browse files

Removed 'symbol_name_to_variable_name'

- was not used consistently before
- symbol names are expected to be valid C identifiers
- for complicated field names, the latex_name of field should be used
parent b98decd5
......@@ -198,7 +198,7 @@ class KernelFunction(Node):
This function is expensive, cache the result where possible!
"""
field_map = {symbol_name_to_variable_name(f.name): f for f in self.fields_accessed}
field_map = {f.name: f for f in self.fields_accessed}
def get_fields(symbol):
if hasattr(symbol, 'field_name'):
......
......@@ -379,7 +379,8 @@ class Field:
return Field.Access(self, center)(*args, **kwargs)
def hashable_contents(self):
return self._layout, self.shape, self.strides, hash(self._dtype), self.field_type, self._field_name
dth = hash(self._dtype)
return self._layout, self.shape, self.strides, dth, self.field_type, self._field_name, self.latex_name
def __hash__(self):
return hash(self.hashable_contents())
......
import numpy as np
from pystencils.backends.cbackend import generate_c
from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.sympyextensions import symbol_name_to_variable_name
from pystencils.data_types import StructType
from pystencils.field import FieldType
......@@ -67,7 +66,7 @@ def make_python_function(kernel_function_node, argument_dict=None):
def _build_numpy_argument_list(parameters, argument_dict):
argument_dict = {symbol_name_to_variable_name(k): v for k, v in argument_dict.items()}
argument_dict = {k: v for k, v in argument_dict.items()}
result = []
for param in parameters:
......@@ -101,7 +100,7 @@ def _check_arguments(parameter_specification, argument_dict):
Checks if parameters passed to kernel match the description in the AST function node.
If not it raises a ValueError, on success it returns the array shape that determines the CUDA blocks and threads
"""
argument_dict = {symbol_name_to_variable_name(k): v for k, v in argument_dict.items()}
argument_dict = {k: v for k, v in argument_dict.items()}
array_shapes = set()
index_arr_shapes = set()
......
......@@ -15,7 +15,6 @@ would reference back to the field.
"""
from sympy.core.cache import cacheit
from pystencils.data_types import TypedSymbol, create_composite_type_from_string, PointerType, get_base_type
from pystencils.sympyextensions import symbol_name_to_variable_name
SHAPE_DTYPE = create_composite_type_from_string("const int64")
STRIDE_DTYPE = create_composite_type_from_string("const int64")
......@@ -28,7 +27,7 @@ class FieldStrideSymbol(TypedSymbol):
return obj
def __new_stage2__(cls, field_name, coordinate):
name = "_stride_{name}_{i}".format(name=symbol_name_to_variable_name(field_name), i=coordinate)
name = "_stride_{name}_{i}".format(name=field_name, i=coordinate)
obj = super(FieldStrideSymbol, cls).__xnew__(cls, name, STRIDE_DTYPE)
obj.field_name = field_name
obj.coordinate = coordinate
......@@ -52,7 +51,7 @@ class FieldShapeSymbol(TypedSymbol):
return obj
def __new_stage2__(cls, field_names, coordinate):
names = "_".join([symbol_name_to_variable_name(field_name) for field_name in field_names])
names = "_".join([field_name for field_name in field_names])
name = "_size_{names}_{i}".format(names=names, i=coordinate)
obj = super(FieldShapeSymbol, cls).__xnew__(cls, name, SHAPE_DTYPE)
obj.field_names = tuple(field_names)
......@@ -76,7 +75,7 @@ class FieldPointerSymbol(TypedSymbol):
return obj
def __new_stage2__(cls, field_name, field_dtype, const):
name = "_data_{name}".format(name=symbol_name_to_variable_name(field_name))
name = "_data_{name}".format(name=field_name)
dtype = PointerType(get_base_type(field_dtype), const=const, restrict=False)
obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype)
obj.field_name = field_name
......
......@@ -5,12 +5,11 @@ import ctypes as ct
from pystencils.data_types import create_composite_type_from_string
from ..data_types import to_ctypes, ctypes_from_llvm, StructType
from .llvm import generate_llvm
from pystencils.sympyextensions import symbol_name_to_variable_name
from pystencils.field import FieldType
def build_ctypes_argument_list(parameter_specification, argument_dict):
argument_dict = {symbol_name_to_variable_name(k): v for k, v in argument_dict.items()}
argument_dict = {k: v for k, v in argument_dict.items()}
ct_arguments = []
array_shapes = set()
index_arr_shapes = set()
......
......@@ -549,8 +549,3 @@ def sort_assignments_topologically(assignments: Sequence[Assignment]) -> List[As
class SymbolCreator:
def __getattribute__(self, name):
return sp.Symbol(name)
def symbol_name_to_variable_name(symbol_name):
"""Replaces characters which are allowed in sympy symbol names but not in C/C++ variable names"""
return symbol_name.replace("^", "_")
......@@ -14,7 +14,6 @@ from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base
from pystencils.kernelparameters import FieldPointerSymbol, FieldStrideSymbol
from pystencils.slicing import normalize_slice
import pystencils.astnodes as ast
from pystencils.sympyextensions import symbol_name_to_variable_name
def filtered_tree_iteration(node, node_type, stop_type=None):
......@@ -755,7 +754,7 @@ class KernelConstraintsCheck:
elif isinstance(rhs, TypedSymbol):
return rhs
elif isinstance(rhs, sp.Symbol):
return TypedSymbol(symbol_name_to_variable_name(rhs.name), self._type_for_symbol[rhs.name])
return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
elif type_constants and isinstance(rhs, sp.Number):
return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
elif isinstance(rhs, sp.Mul):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment