diff --git a/__init__.py b/__init__.py index 2f6f6bf2f64761f6127ce91e7c4c0c28e1e73ee2..1a81830067bba340a602f87f2e20369500502930 100644 --- a/__init__.py +++ b/__init__.py @@ -1,3 +1,3 @@ from pystencils.field import Field, extractCommonSubexpressions -from pystencils.typedsymbol import TypedSymbol +from pystencils.types import TypedSymbol from pystencils.slicing import makeSlice diff --git a/ast.py b/ast.py index 5d36c1e0d4f0c09bff77208b12efe087bc1a9e81..04ab30cf94c4bcbdaeb683d02ce2b6e82963a5ae 100644 --- a/ast.py +++ b/ast.py @@ -2,7 +2,7 @@ import sympy as sp import textwrap as textwrap from sympy.tensor import IndexedBase, Indexed from pystencils.field import Field -from pystencils.typedsymbol import TypedSymbol +from pystencils.types import TypedSymbol class Node(object): diff --git a/backends/cbackend.py b/backends/cbackend.py index 8bdf166a8dcf5ec2d23c7831e51903b07501c1ef..f59741d5cc1b153215e8cd901366cd6cdebf01ca 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -76,7 +76,7 @@ class CBackend: raise NotImplementedError("CBackend does not support node of type " + cls.__name__) def _print_KernelFunction(self, node): - functionArguments = ["%s %s" % (s.dtype, s.name) for s in node.parameters] + functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters] prefix = "__global__ void" if self.cuda else "void" funcDeclaration = "%s %s(%s)" % (prefix, node.functionName, ", ".join(functionArguments)) body = self._print(node.body) @@ -105,10 +105,10 @@ class CBackend: dtype = "" if node.isDeclaration: if node.isConst: - dtype = "const " + node.lhs.dtype + " " + dtype = "const " + str(node.lhs.dtype) + " " else: - dtype = node.lhs.dtype + " " - return "%s %s = %s;" % (dtype, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs)) + dtype = str(node.lhs.dtype) + " " + return "%s %s = %s;" % (str(dtype), self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs)) def _print_TemporaryMemoryAllocation(self, node): return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol), diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index ac2db27ef99b437c8621af01dd1da7369c5dd955..e8e722c13940ad157b138ddc31e83074800873a5 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -1,7 +1,7 @@ import sympy as sp from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \ typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop -from pystencils.typedsymbol import TypedSymbol +from pystencils.types import TypedSymbol from pystencils.field import Field import pystencils.ast as ast diff --git a/field.py b/field.py index 50e6d36ae2d122ae8dc6fb79ab7f5790dc0df038..e0835bc6f5d96ea75da7b2d931b16b75bb31f3ae 100644 --- a/field.py +++ b/field.py @@ -3,7 +3,7 @@ import numpy as np import sympy as sp from sympy.core.cache import cacheit from sympy.tensor import IndexedBase -from pystencils.typedsymbol import TypedSymbol +from pystencils.types import TypedSymbol class Field: diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index 2fd0245a0d15ca4b04f5f59507a32b7ca2e04196..54d4ed0ce5d197d1c160092a07de4285a56efdfb 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -1,7 +1,7 @@ import sympy as sp from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \ typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop -from pystencils.typedsymbol import TypedSymbol +from pystencils.types import TypedSymbol from pystencils.field import Field import pystencils.ast as ast diff --git a/transformations.py b/transformations.py index ddeef910ecaebdfacf992826c88e9f99201bf3ef..f83fa9c6495295145e38de19a949996e7c708ae2 100644 --- a/transformations.py +++ b/transformations.py @@ -4,7 +4,7 @@ from sympy.logic.boolalg import Boolean from sympy.tensor import IndexedBase from pystencils.field import Field, offsetComponentToDirectionString -from pystencils.typedsymbol import TypedSymbol +from pystencils.types import TypedSymbol, DataType from pystencils.slicing import normalizeSlice import pystencils.ast as ast @@ -220,9 +220,10 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn else: basePointerInfo = [list(range(field.indexDimensions + field.spatialDimensions))] - dtype = "%s * __restrict__" % field.dtype + dtype = DataType(field.dtype) + dtype.alias = False if field.name in readOnlyFieldNames: - dtype = "const " + dtype + dtype.const = True fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, field.name), dtype) diff --git a/types.py b/types.py index 72ad8fe6942551c3c24dcdeb5fa6a1bba72f868b..e4f579e6804f5f2ea3c151d5e33ee4614b483113 100644 --- a/types.py +++ b/types.py @@ -28,3 +28,19 @@ class TypedSymbol(sp.Symbol): def __getnewargs__(self): return self.name, self.dtype + +_c_dtype_dict = {0: 'int', 1: 'double', 2: 'float'} +_dtype_dict = {'int': 0, 'double': 1, 'float': 2} + + +class DataType(object): + def __init__(self, dtype): + self.alias = True + self.const = False + if isinstance(dtype, str): + self.dtype = _dtype_dict[dtype] + else: + self.dtype = dtype + + def __repr__(self): + return "{!s} {!s} {!s}".format("const" if self.const else "", "__restrict__" if not self.alias else "", _c_dtype_dict[self.dtype])