Skip to content
Snippets Groups Projects
Commit 0168aaed authored by Jan Hoenig's avatar Jan Hoenig
Browse files

Created DataType class for storing information about data inside a class

and not as a string.
Changed name of the file TypedSymbol to types.
Fixed usage of dtype accordingly, however i might not have found every
usage of dtype.
parent ab84a4cb
No related merge requests found
from pystencils.field import Field, extractCommonSubexpressions
from pystencils.typedsymbol import TypedSymbol
from pystencils.types import TypedSymbol
from pystencils.slicing import makeSlice
......@@ -2,7 +2,7 @@ import sympy as sp
import textwrap as textwrap
from sympy.tensor import IndexedBase, Indexed
from pystencils.field import Field
from pystencils.typedsymbol import TypedSymbol
from pystencils.types import TypedSymbol
class Node(object):
......
......@@ -76,7 +76,7 @@ class CBackend:
raise NotImplementedError("CBackend does not support node of type " + cls.__name__)
def _print_KernelFunction(self, node):
functionArguments = ["%s %s" % (s.dtype, s.name) for s in node.parameters]
functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
prefix = "__global__ void" if self.cuda else "void"
funcDeclaration = "%s %s(%s)" % (prefix, node.functionName, ", ".join(functionArguments))
body = self._print(node.body)
......@@ -105,10 +105,10 @@ class CBackend:
dtype = ""
if node.isDeclaration:
if node.isConst:
dtype = "const " + node.lhs.dtype + " "
dtype = "const " + str(node.lhs.dtype) + " "
else:
dtype = node.lhs.dtype + " "
return "%s %s = %s;" % (dtype, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
dtype = str(node.lhs.dtype) + " "
return "%s %s = %s;" % (str(dtype), self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
def _print_TemporaryMemoryAllocation(self, node):
return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol),
......
import sympy as sp
from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
from pystencils.typedsymbol import TypedSymbol
from pystencils.types import TypedSymbol
from pystencils.field import Field
import pystencils.ast as ast
......
......@@ -3,7 +3,7 @@ import numpy as np
import sympy as sp
from sympy.core.cache import cacheit
from sympy.tensor import IndexedBase
from pystencils.typedsymbol import TypedSymbol
from pystencils.types import TypedSymbol
class Field:
......
import sympy as sp
from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
from pystencils.typedsymbol import TypedSymbol
from pystencils.types import TypedSymbol
from pystencils.field import Field
import pystencils.ast as ast
......
......@@ -4,7 +4,7 @@ from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase
from pystencils.field import Field, offsetComponentToDirectionString
from pystencils.typedsymbol import TypedSymbol
from pystencils.types import TypedSymbol, DataType
from pystencils.slicing import normalizeSlice
import pystencils.ast as ast
......@@ -220,9 +220,10 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
else:
basePointerInfo = [list(range(field.indexDimensions + field.spatialDimensions))]
dtype = "%s * __restrict__" % field.dtype
dtype = DataType(field.dtype)
dtype.alias = False
if field.name in readOnlyFieldNames:
dtype = "const " + dtype
dtype.const = True
fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, field.name), dtype)
......
......@@ -28,3 +28,19 @@ class TypedSymbol(sp.Symbol):
def __getnewargs__(self):
return self.name, self.dtype
_c_dtype_dict = {0: 'int', 1: 'double', 2: 'float'}
_dtype_dict = {'int': 0, 'double': 1, 'float': 2}
class DataType(object):
def __init__(self, dtype):
self.alias = True
self.const = False
if isinstance(dtype, str):
self.dtype = _dtype_dict[dtype]
else:
self.dtype = dtype
def __repr__(self):
return "{!s} {!s} {!s}".format("const" if self.const else "", "__restrict__" if not self.alias else "", _c_dtype_dict[self.dtype])
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment