From 0168aaedcdba5295319954186c655e0682ef1585 Mon Sep 17 00:00:00 2001 From: Jan Hoenig <hrominium@gmail.com> Date: Thu, 8 Dec 2016 10:18:53 +0100 Subject: [PATCH] Created DataType class for storing information about data inside a class and not as a string. Changed name of the file TypedSymbol to types. Fixed usage of dtype accordingly, however i might not have found every usage of dtype. --- __init__.py | 2 +- ast.py | 2 +- backends/cbackend.py | 8 ++++---- cpu/kernelcreation.py | 2 +- field.py | 2 +- llvm/kernelcreation.py | 2 +- transformations.py | 7 ++++--- types.py | 16 ++++++++++++++++ 8 files changed, 29 insertions(+), 12 deletions(-) diff --git a/__init__.py b/__init__.py index 2f6f6bf2f..1a8183006 100644 --- a/__init__.py +++ b/__init__.py @@ -1,3 +1,3 @@ from pystencils.field import Field, extractCommonSubexpressions -from pystencils.typedsymbol import TypedSymbol +from pystencils.types import TypedSymbol from pystencils.slicing import makeSlice diff --git a/ast.py b/ast.py index 5d36c1e0d..04ab30cf9 100644 --- a/ast.py +++ b/ast.py @@ -2,7 +2,7 @@ import sympy as sp import textwrap as textwrap from sympy.tensor import IndexedBase, Indexed from pystencils.field import Field -from pystencils.typedsymbol import TypedSymbol +from pystencils.types import TypedSymbol class Node(object): diff --git a/backends/cbackend.py b/backends/cbackend.py index 8bdf166a8..f59741d5c 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -76,7 +76,7 @@ class CBackend: raise NotImplementedError("CBackend does not support node of type " + cls.__name__) def _print_KernelFunction(self, node): - functionArguments = ["%s %s" % (s.dtype, s.name) for s in node.parameters] + functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters] prefix = "__global__ void" if self.cuda else "void" funcDeclaration = "%s %s(%s)" % (prefix, node.functionName, ", ".join(functionArguments)) body = self._print(node.body) @@ -105,10 +105,10 @@ class CBackend: dtype = "" if node.isDeclaration: if node.isConst: - dtype = "const " + node.lhs.dtype + " " + dtype = "const " + str(node.lhs.dtype) + " " else: - dtype = node.lhs.dtype + " " - return "%s %s = %s;" % (dtype, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs)) + dtype = str(node.lhs.dtype) + " " + return "%s %s = %s;" % (str(dtype), self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs)) def _print_TemporaryMemoryAllocation(self, node): return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol), diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index ac2db27ef..e8e722c13 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -1,7 +1,7 @@ import sympy as sp from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \ typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop -from pystencils.typedsymbol import TypedSymbol +from pystencils.types import TypedSymbol from pystencils.field import Field import pystencils.ast as ast diff --git a/field.py b/field.py index 50e6d36ae..e0835bc6f 100644 --- a/field.py +++ b/field.py @@ -3,7 +3,7 @@ import numpy as np import sympy as sp from sympy.core.cache import cacheit from sympy.tensor import IndexedBase -from pystencils.typedsymbol import TypedSymbol +from pystencils.types import TypedSymbol class Field: diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index 2fd0245a0..54d4ed0ce 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -1,7 +1,7 @@ import sympy as sp from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \ typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop -from pystencils.typedsymbol import TypedSymbol +from pystencils.types import TypedSymbol from pystencils.field import Field import pystencils.ast as ast diff --git a/transformations.py b/transformations.py index ddeef910e..f83fa9c64 100644 --- a/transformations.py +++ b/transformations.py @@ -4,7 +4,7 @@ from sympy.logic.boolalg import Boolean from sympy.tensor import IndexedBase from pystencils.field import Field, offsetComponentToDirectionString -from pystencils.typedsymbol import TypedSymbol +from pystencils.types import TypedSymbol, DataType from pystencils.slicing import normalizeSlice import pystencils.ast as ast @@ -220,9 +220,10 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn else: basePointerInfo = [list(range(field.indexDimensions + field.spatialDimensions))] - dtype = "%s * __restrict__" % field.dtype + dtype = DataType(field.dtype) + dtype.alias = False if field.name in readOnlyFieldNames: - dtype = "const " + dtype + dtype.const = True fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, field.name), dtype) diff --git a/types.py b/types.py index 72ad8fe69..e4f579e68 100644 --- a/types.py +++ b/types.py @@ -28,3 +28,19 @@ class TypedSymbol(sp.Symbol): def __getnewargs__(self): return self.name, self.dtype + +_c_dtype_dict = {0: 'int', 1: 'double', 2: 'float'} +_dtype_dict = {'int': 0, 'double': 1, 'float': 2} + + +class DataType(object): + def __init__(self, dtype): + self.alias = True + self.const = False + if isinstance(dtype, str): + self.dtype = _dtype_dict[dtype] + else: + self.dtype = dtype + + def __repr__(self): + return "{!s} {!s} {!s}".format("const" if self.const else "", "__restrict__" if not self.alias else "", _c_dtype_dict[self.dtype]) -- GitLab