Skip to content
Snippets Groups Projects
Commit e439a228 authored by Markus Holzer's avatar Markus Holzer
Browse files

Add pystencils to pymbolic mapper

parent d2520fd3
No related merge requests found
Pipeline #59688 failed with stages
in 2 minutes and 53 seconds
......@@ -9,6 +9,8 @@ from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue
T = TypeVar("T")
def failing_cast(target: type, obj: T):
if not isinstance(obj, target):
raise TypeError(f"Casting {obj} to {target} failed.")
......@@ -62,6 +64,7 @@ class PsBlock(PsAstNode):
def set_child(self, idx: int, c: PsAstNode):
self._children[idx] = c
class PsLeafNode(PsAstNode):
def num_children(self) -> int:
return 0
......
from pymbolic.interop.sympy import SympyToPymbolicMapper
from pystencils.typing import TypedSymbol
from pystencils.typing.typed_sympy import SHAPE_DTYPE
from .ast.nodes import PsAssignment, PsSymbolExpr
from .types import PsSignedIntegerType, PsIeeeFloatType, PsUnsignedIntegerType
from .typed_expressions import PsArrayBasePointer, PsLinearizedArray, PsTypedVariable, PsArrayAccess
CTR_SYMBOLS = [TypedSymbol(f"ctr_{i}", SHAPE_DTYPE) for i in range(3)]
class PystencilsToPymbolicMapper(SympyToPymbolicMapper):
def map_Assignment(self, expr): # noqa
lhs = self.rec(expr.lhs)
rhs = self.rec(expr.rhs)
return PsAssignment(lhs, rhs)
def map_BasicType(self, expr):
width = expr.numpy_dtype.itemsize * 8
const = expr.const
if expr.is_float():
return PsIeeeFloatType(width, const)
elif expr.is_uint():
return PsUnsignedIntegerType(width, const)
elif expr.is_int():
return PsSignedIntegerType(width, const)
else:
raise (NotImplementedError, "Not supported dtype")
def map_FieldShapeSymbol(self, expr):
dtype = self.rec(expr.dtype)
return PsTypedVariable(expr.name, dtype)
def map_TypedSymbol(self, expr):
dtype = self.rec(expr.dtype)
return PsTypedVariable(expr.name, dtype)
def map_Access(self, expr):
name = expr.field.name
shape = tuple([self.rec(s) for s in expr.field.shape])
strides = tuple([self.rec(s) for s in expr.field.strides])
dtype = self.rec(expr.dtype)
array = PsLinearizedArray(name, shape, strides, dtype)
ptr = PsArrayBasePointer(expr.name, array)
index = sum([ctr * stride for ctr, stride in zip(CTR_SYMBOLS, expr.field.strides)])
index = self.rec(index)
return PsSymbolExpr(PsArrayAccess(ptr, index))
......@@ -89,9 +89,9 @@ class PsArrayAccess(pb.Subscript):
def base_ptr(self):
return self._base_ptr
@property
def index(self):
return self._index
# @property
# def index(self):
# return self._index
@property
def array(self) -> PsArray:
......
......@@ -204,7 +204,7 @@ class PsIeeeFloatType(PsAbstractType):
__match_args__ = ("width",)
SUPPORTED_WIDTHS = (32, 64)
SUPPORTED_WIDTHS = (16, 32, 64)
def __init__(self, width: int, const: bool = False):
if width not in self.SUPPORTED_WIDTHS:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment