Commit a8997a2f authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add TypedMatrixSymbol (for usage of `MatrixSymbol` in kernels)

parent 6e663484
Pipeline #21827 failed with stage
in 10 minutes and 15 seconds
......@@ -6,7 +6,8 @@ from typing import Any, List, Optional, Sequence, Set, Union
import sympy as sp
import pystencils
from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type
from pystencils.data_types import (
TypedImaginaryUnit, TypedMatrixSymbol, TypedSymbol, cast_func, create_type)
from pystencils.field import Field
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils.sympyextensions import fast_subs
......@@ -579,6 +580,9 @@ class SympyAssignment(Node):
def symbols_defined(self):
if not self._is_declaration:
return set()
if isinstance(self._lhs_symbol, TypedMatrixSymbol):
return {self._lhs_symbol, sp.Symbol(}
return {self._lhs_symbol}
......@@ -828,3 +828,40 @@ class TypedImaginaryUnit(TypedSymbol):
def __getnewargs__(self):
return (self.dtype,)
class MatrixDtype(Type):
def __init__(self, base_type, shape, ctype):
self.base_type = base_type
self.shape = shape
self.ctype = ctype
def _sympystr(self, *args, **kwargs):
return str(self.ctype)
class TypedMatrixSymbol(sp.MatrixSymbol):
def __new__(cls, name, n, m, base_dtype, ctype):
obj = sp.MatrixSymbol.__new__(cls, name, n, m)
obj.dtype = MatrixDtype(base_dtype, (n, m), ctype)
obj._name = name
return obj
def args(self):
return, self.shape[0], self.shape[1], self.dtype.base_type, self.dtype.ctype
def shape(self):
return super().args[1:3]
def _hashable_content(self):
return (, self.shape, self.dtype)
def name(self):
return self._name
def __getnewargs__(self):
return, self.shape[0], self.shape[1], self.dtype.base_type, self.dtype.ctype
......@@ -14,8 +14,9 @@ import pystencils.astnodes as ast
import pystencils.integer_functions
from pystencils.assignment import Assignment
from pystencils.data_types import (
PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type,
get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
PointerType, StructType, TypedImaginaryUnit, TypedMatrixSymbol, TypedSymbol, cast_func,
collate_types, create_type, get_base_type, get_type_of_expression, pointer_arithmetic_func,
from pystencils.field import AbstractField, Field, FieldType
from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.simp.assignment_collection import AssignmentCollection
......@@ -518,7 +519,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
return visit_sympy_expr(result, enclosing_block, sympy_assignment)
if isinstance(expr, ast.ResolvedFieldAccess):
if isinstance(expr, (ast.ResolvedFieldAccess, TypedMatrixSymbol)):
return expr
if hasattr(expr, 'args'):
......@@ -828,6 +829,7 @@ class KernelConstraintsCheck:
def process_expression(self, rhs, type_constants=True):
from pystencils.interpolation_astnodes import InterpolatorAccess
from sympy.matrices.expressions.matexpr import MatrixElement
if isinstance(rhs, AbstractField.AbstractAccess):
......@@ -869,7 +871,7 @@ class KernelConstraintsCheck:
if arg not in (-1, 1) else arg for arg in rhs.args
return rhs.func(*new_args) if new_args else rhs
elif isinstance(rhs, sp.Indexed):
elif isinstance(rhs, (sp.Indexed, TypedMatrixSymbol, MatrixElement)):
return rhs
if isinstance(rhs, sp.Pow):
......@@ -889,9 +891,9 @@ class KernelConstraintsCheck:
return set(k.field for k, v in self._field_writes.items() if len(v))
def _process_lhs(self, lhs):
assert isinstance(lhs, sp.Symbol)
assert isinstance(lhs, (sp.Symbol, sp.MatrixSymbol))
if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)):
if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol, TypedMatrixSymbol)):
return TypedSymbol(, self._type_for_symbol[])
return lhs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment