diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index 56cf6c2bb94971d0915603bffbf8886cf9b187c9..61016e14f11a536444c798441ba8be516d97a167 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -7,7 +7,7 @@ from . import stencil as stencil from .display_utils import get_code_obj, get_code_str, show_code, to_dot from .field import Field, FieldType, fields from .cache import clear_cache -from .config import CreateKernelConfig +from .config import CreateKernelConfig, CpuOptimConfig, VectorizationConfig from .kernel_decorator import kernel, kernel_config from .kernelcreation import create_kernel from .backend.kernelfunction import KernelFunction @@ -36,6 +36,8 @@ __all__ = [ "TypedSymbol", "make_slice", "CreateKernelConfig", + "CpuOptimConfig", + "VectorizationConfig", "create_kernel", "KernelFunction", "Target", diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 7c24b55b2b9dbc5ee2813fdf6dbe6a8a19a3f51e..73f34a1fd97642fdbad0e29bc41ab3d223b17fd4 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -60,15 +60,11 @@ class PsExpression(PsAstNode, ABC): pass -class PsLvalueExpr(PsExpression, ABC): - """Base class for all expressions that may occur as an lvalue""" - - @abstractmethod - def clone(self) -> PsLvalueExpr: - pass +class PsLvalue(ABC): + """Mix-in for all expressions that may occur as an lvalue""" -class PsSymbolExpr(PsLeafMixIn, PsLvalueExpr): +class PsSymbolExpr(PsLeafMixIn, PsLvalue, PsExpression): """A single symbol as an expression.""" __match_args__ = ("symbol",) @@ -124,7 +120,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): return f"Constant({repr(self._constant)})" -class PsSubscript(PsLvalueExpr): +class PsSubscript(PsLvalue, PsExpression): __match_args__ = ("base", "index") def __init__(self, base: PsExpression, index: PsExpression): @@ -271,7 +267,7 @@ class PsVectorArrayAccess(PsArrayAccess): ) -class PsLookup(PsExpression): +class PsLookup(PsExpression, PsLvalue): __match_args__ = ("aggregate", "member_name") def __init__(self, aggregate: PsExpression, member_name: str) -> None: @@ -384,7 +380,7 @@ class PsNeg(PsUnOp): return operator.neg -class PsDeref(PsUnOp): +class PsDeref(PsLvalue, PsUnOp): pass diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index e5b88891cbfc8f69edad1aef6c58f49e4cdc4091..441faa606fd615e75cdcf39db167399254fdec45 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -3,7 +3,7 @@ from typing import Sequence, cast from types import NoneType from .astnode import PsAstNode, PsLeafMixIn -from .expressions import PsExpression, PsLvalueExpr, PsSymbolExpr +from .expressions import PsExpression, PsLvalue, PsSymbolExpr from .util import failing_cast @@ -76,16 +76,20 @@ class PsAssignment(PsAstNode): "rhs", ) - def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression): - self._lhs = lhs + def __init__(self, lhs: PsExpression, rhs: PsExpression): + if not isinstance(lhs, PsLvalue): + raise ValueError("Assignment LHS must be an lvalue") + self._lhs: PsExpression = lhs self._rhs = rhs @property - def lhs(self) -> PsLvalueExpr: + def lhs(self) -> PsExpression: return self._lhs @lhs.setter - def lhs(self, lvalue: PsLvalueExpr): + def lhs(self, lvalue: PsExpression): + if not isinstance(lvalue, PsLvalue): + raise ValueError("Assignment LHS must be an lvalue") self._lhs = lvalue @property @@ -105,7 +109,7 @@ class PsAssignment(PsAstNode): def set_child(self, idx: int, c: PsAstNode): idx = [0, 1][idx] # trick to normalize index if idx == 0: - self._lhs = failing_cast(PsLvalueExpr, c) + self.lhs = failing_cast(PsExpression, c) elif idx == 1: self._rhs = failing_cast(PsExpression, c) else: @@ -125,11 +129,11 @@ class PsDeclaration(PsAssignment): super().__init__(lhs, rhs) @property - def lhs(self) -> PsLvalueExpr: + def lhs(self) -> PsExpression: return self._lhs @lhs.setter - def lhs(self, lvalue: PsLvalueExpr): + def lhs(self, lvalue: PsExpression): self._lhs = failing_cast(PsSymbolExpr, lvalue) @property @@ -146,7 +150,7 @@ class PsDeclaration(PsAssignment): def set_child(self, idx: int, c: PsAstNode): idx = [0, 1][idx] # trick to normalize index if idx == 0: - self._lhs = failing_cast(PsSymbolExpr, c) + self.lhs = failing_cast(PsSymbolExpr, c) elif idx == 1: self._rhs = failing_cast(PsExpression, c) else: diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index ebaf2281275cae2ad91d6fd517c3d58629ac90e7..0d1e34639d5dac3d33b211e652f7bb85c127ec32 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -131,10 +131,12 @@ class FreezeExpressions: if isinstance(lhs, PsSymbolExpr): return PsDeclaration(lhs, rhs) - elif isinstance(lhs, (PsArrayAccess, PsVectorArrayAccess)): # todo + elif isinstance(lhs, (PsArrayAccess, PsLookup, PsVectorArrayAccess)): # todo return PsAssignment(lhs, rhs) else: - assert False, "That should not have happened." + raise FreezeError( + f"Encountered unsupported expression on assignment left-hand side: {lhs}" + ) def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr: symb = self._ctx.get_symbol(spsym.name) diff --git a/src/pystencils/field.py b/src/pystencils/field.py index b055ccb6bf0c5a84f73a998c1498eb1e4685bc10..3f019f5660d67bda97ebf5ec74cd586d4a4ddbed 100644 --- a/src/pystencils/field.py +++ b/src/pystencils/field.py @@ -160,6 +160,7 @@ class Field: dtype = create_type(dtype) np_data_type = dtype.numpy_dtype assert np_data_type is not None + if np_data_type.fields is not None: if index_dimensions != 0: raise ValueError("Structured arrays/fields are not allowed to have an index dimension") @@ -207,7 +208,8 @@ class Field: @staticmethod def create_fixed_size(field_name: str, shape: Tuple[int, ...], index_dimensions: int = 0, - dtype=np.float64, layout: str = 'numpy', strides: Optional[Sequence[int]] = None, + dtype: UserTypeSpec = np.float64, layout: str = 'numpy', + strides: Optional[Sequence[int]] = None, field_type=FieldType.GENERIC) -> 'Field': """ Creates a field with fixed sizes i.e. can be called only with arrays of the same size and layout @@ -234,7 +236,10 @@ class Field: assert len(strides) == len(shape) strides = tuple([s // np.dtype(dtype).itemsize for s in strides]) - numpy_dtype = np.dtype(dtype) + dtype = create_type(dtype) + numpy_dtype = dtype.numpy_dtype + assert numpy_dtype is not None + if numpy_dtype.fields is not None: if index_dimensions != 0: raise ValueError("Structured arrays/fields are not allowed to have an index dimension")