From d61a7c0cb93e8d10ab54f25d349e08126285cd86 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Sun, 28 Jan 2024 17:30:57 +0100 Subject: [PATCH] codegen for sparse kernels --- src/pystencils/nbackend/ast/collectors.py | 8 +- src/pystencils/nbackend/exceptions.py | 4 + .../nbackend/kernelcreation/analysis.py | 7 +- .../nbackend/kernelcreation/context.py | 28 +++-- .../nbackend/kernelcreation/defaults.py | 18 ++- .../nbackend/kernelcreation/freeze.py | 6 +- .../kernelcreation/iteration_space.py | 115 +++++++++++++++--- .../nbackend/kernelcreation/kernelcreation.py | 10 +- .../nbackend/kernelcreation/options.py | 17 ++- .../kernelcreation/platform/__init__.py | 4 +- .../kernelcreation/platform/basic_cpu.py | 48 ++++++-- .../kernelcreation/platform/platform.py | 4 +- .../nbackend/kernelcreation/typification.py | 23 ++-- src/pystencils/nbackend/types/__init__.py | 2 + src/pystencils/nbackend/types/basic_types.py | 115 +++++++++++++++++- src/pystencils/nbackend/types/parsing.py | 17 +++ src/pystencils/nbackend/types/quick.py | 8 +- 17 files changed, 360 insertions(+), 74 deletions(-) diff --git a/src/pystencils/nbackend/ast/collectors.py b/src/pystencils/nbackend/ast/collectors.py index 3c995ccf6..c94d0b18b 100644 --- a/src/pystencils/nbackend/ast/collectors.py +++ b/src/pystencils/nbackend/ast/collectors.py @@ -30,6 +30,8 @@ class UndefinedVariablesCollector: def __call__(self, node: PsAstNode) -> set[PsTypedVariable]: """Returns all `PsTypedVariable`s that occur in the given AST without being defined prior to their usage.""" + undefined_vars: set[PsTypedVariable] = set() + match node: case PsKernelFunction(block): return self(block) @@ -46,10 +48,12 @@ class UndefinedVariablesCollector: return cast(set[PsTypedVariable], variables) case PsAssignment(lhs, rhs): - return self(lhs) | self(rhs) + undefined_vars = self(lhs) | self(rhs) + if isinstance(lhs.expression, PsTypedVariable): + undefined_vars.remove(lhs.expression) + return undefined_vars case PsBlock(statements): - undefined_vars: set[PsTypedVariable] = set() for stmt in statements[::-1]: undefined_vars -= self.declared_variables(stmt) undefined_vars |= self(stmt) diff --git a/src/pystencils/nbackend/exceptions.py b/src/pystencils/nbackend/exceptions.py index 6fe83b3f9..29434d8ef 100644 --- a/src/pystencils/nbackend/exceptions.py +++ b/src/pystencils/nbackend/exceptions.py @@ -10,5 +10,9 @@ class PsInputError(Exception): pass +class KernelConstraintsError(Exception): + pass + + class PsMalformedAstException(Exception): pass diff --git a/src/pystencils/nbackend/kernelcreation/analysis.py b/src/pystencils/nbackend/kernelcreation/analysis.py index 82bd769b3..98f4887f7 100644 --- a/src/pystencils/nbackend/kernelcreation/analysis.py +++ b/src/pystencils/nbackend/kernelcreation/analysis.py @@ -13,11 +13,7 @@ from ...assignment import Assignment from ...simp import AssignmentCollection from ...transformations import NestedScopes -from ..exceptions import PsInternalCompilerError - - -class KernelConstraintsError(Exception): - pass +from ..exceptions import PsInternalCompilerError, KernelConstraintsError class KernelAnalysis: @@ -39,6 +35,7 @@ class KernelAnalysis: the same location. - **Independence of Writes:** A weaker requirement than access independence; each field may only be written once at each index. + - **Dimension of index fields:** Index fields occuring in the kernel must have exactly one spatial dimension. Knowledge Collection -------------------- diff --git a/src/pystencils/nbackend/kernelcreation/context.py b/src/pystencils/nbackend/kernelcreation/context.py index 7e4fad9ba..a77189331 100644 --- a/src/pystencils/nbackend/kernelcreation/context.py +++ b/src/pystencils/nbackend/kernelcreation/context.py @@ -1,15 +1,13 @@ from __future__ import annotations -from typing import cast - from ...field import Field, FieldType -from ...typing import TypedSymbol, BasicType +from ...typing import TypedSymbol, BasicType, StructType from ..arrays import PsLinearizedArray from ..types import PsIntegerType from ..types.quick import make_type from ..constraints import PsKernelConstraint -from ..exceptions import PsInternalCompilerError +from ..exceptions import PsInternalCompilerError, KernelConstraintsError from .options import KernelCreationOptions from .iteration_space import IterationSpace, FullIterationSpace, SparseIterationSpace @@ -80,12 +78,26 @@ class KernelCreationContext: match field.field_type: case FieldType.GENERIC | FieldType.STAGGERED | FieldType.STAGGERED_FLUX: self._fields_collection.domain_fields.add(field) + case FieldType.BUFFER: + if field.spatial_dimensions != 1: + raise KernelConstraintsError( + f"Invalid spatial shape of buffer field {field.name}: {field.spatial_dimensions}. " + "Buffer fields must be one-dimensional." + ) self._fields_collection.buffer_fields.add(field) + case FieldType.INDEXED: + if field.spatial_dimensions != 1: + raise KernelConstraintsError( + f"Invalid spatial shape of index field {field.name}: {field.spatial_dimensions}. " + "Index fields must be one-dimensional." + ) self._fields_collection.index_fields.add(field) + case FieldType.CUSTOM: self._fields_collection.custom_fields.add(field) + case _: assert False, "unreachable code" @@ -105,8 +117,8 @@ class KernelCreationContext: for s in field.strides ) - # TODO: frontend should use new type system - element_type = make_type(cast(BasicType, field.dtype).numpy_dtype.type) + assert isinstance(field.dtype, (BasicType, StructType)) + element_type = make_type(field.dtype.numpy_dtype) arr = PsLinearizedArray( field.name, element_type, arr_shape, arr_strides, self.index_dtype @@ -114,13 +126,11 @@ class KernelCreationContext: self._arrays[field] = arr - return self._arrays[field] + return self._arrays[field] # Iteration Space def set_iteration_space(self, ispace: IterationSpace): - if self._ispace is not None: - raise PsInternalCompilerError("Iteration space was already set.") self._ispace = ispace def get_iteration_space(self) -> IterationSpace: diff --git a/src/pystencils/nbackend/kernelcreation/defaults.py b/src/pystencils/nbackend/kernelcreation/defaults.py index c928ef4e4..fc0a602a1 100644 --- a/src/pystencils/nbackend/kernelcreation/defaults.py +++ b/src/pystencils/nbackend/kernelcreation/defaults.py @@ -17,7 +17,7 @@ A possibly incomplete list of symbols and types that need to be defined: """ from typing import TypeVar, Generic, Callable -from ..types import PsAbstractType, PsSignedIntegerType +from ..types import PsAbstractType, PsSignedIntegerType, PsStructType from ..typed_expressions import PsTypedVariable from ...typing import TypedSymbol @@ -40,14 +40,20 @@ class PsDefaults(Generic[SymbolT]): ) """Default spatial counters""" + self._index_struct_coordinate_names = ("x", "y", "z") + """Default names of spatial coordinate members in index list structures""" + self.index_struct_coordinates = ( - symcreate("x", self.index_dtype), - symcreate("y", self.index_dtype), - symcreate("z", self.index_dtype), + PsStructType.Member("x", self.index_dtype), + PsStructType.Member("y", self.index_dtype), + PsStructType.Member("z", self.index_dtype), ) - """Default symbols for spatial coordinates in index list structures""" + """Default spatial coordinate members in index list structures""" + + self.sparse_counter_name = "sparse_idx" + """Name of the default sparse iteration counter""" - self.sparse_iteration_counter = symcreate("list_idx", self.index_dtype) + self.sparse_counter = symcreate(self.sparse_counter_name, self.index_dtype) """Default sparse iteration counter.""" diff --git a/src/pystencils/nbackend/kernelcreation/freeze.py b/src/pystencils/nbackend/kernelcreation/freeze.py index 69e1dff7d..64d3fa2e9 100644 --- a/src/pystencils/nbackend/kernelcreation/freeze.py +++ b/src/pystencils/nbackend/kernelcreation/freeze.py @@ -100,9 +100,9 @@ class FreezeExpressions(SympyToPymbolicMapper): case FieldType.INDEXED: # flake8: noqa sparse_ispace = self._ctx.get_sparse_iteration_space() - # TODO: make sure index (and all offsets?) are zero - # TODO: Add sparse iteration counter - raise NotImplementedError() + # Add sparse iteration counter to offset + assert len(offsets) == 1 # must have been checked by the context + offsets = [offsets[0] + sparse_ispace.sparse_counter] case FieldType.CUSTOM: raise ValueError("Custom fields support only absolute accesses.") case unknown: diff --git a/src/pystencils/nbackend/kernelcreation/iteration_space.py b/src/pystencils/nbackend/kernelcreation/iteration_space.py index d38cb3534..a752a16d6 100644 --- a/src/pystencils/nbackend/kernelcreation/iteration_space.py +++ b/src/pystencils/nbackend/kernelcreation/iteration_space.py @@ -15,8 +15,10 @@ from ..typed_expressions import ( PsTypedConstant, ) from ..arrays import PsLinearizedArray +from ..ast.util import failing_cast +from ..types import PsStructType, constify from .defaults import Pymbolic as Defaults -from ..exceptions import PsInputError, PsInternalCompilerError +from ..exceptions import PsInputError, PsInternalCompilerError, KernelConstraintsError if TYPE_CHECKING: from .context import KernelCreationContext @@ -36,20 +38,24 @@ class IterationSpace(ABC): spatial indices. """ - def __init__(self, spatial_indices: tuple[PsTypedVariable, ...]): + def __init__(self, spatial_indices: Sequence[PsTypedVariable]): if len(spatial_indices) == 0: raise ValueError("Iteration space must be at least one-dimensional.") - self._spatial_indices = spatial_indices + self._spatial_indices = tuple(spatial_indices) @property def spatial_indices(self) -> tuple[PsTypedVariable, ...]: return self._spatial_indices + @property + def dim(self) -> int: + return len(self._spatial_indices) + class FullIterationSpace(IterationSpace): """N-dimensional full iteration space. - + Each dimension of the full iteration space is represented by an instance of `FullIterationSpace.Dimension`. Dimensions are ordered slowest-to-fastest: The first dimension corresponds to the slowest coordinate, translates to the outermost loop, while the last dimension is the fastest coordinate and translates @@ -144,24 +150,108 @@ class FullIterationSpace(IterationSpace): class SparseIterationSpace(IterationSpace): - # TODO: To properly implement sparse iteration, we still need struct data types def __init__( self, - spatial_indices: tuple[PsTypedVariable, ...], + spatial_indices: Sequence[PsTypedVariable], index_list: PsLinearizedArray, + coordinate_members: Sequence[PsStructType.Member], + sparse_counter: PsTypedVariable, ): super().__init__(spatial_indices) self._index_list = index_list + self._coord_members = tuple(coordinate_members) + self._sparse_counter = sparse_counter @property def index_list(self) -> PsLinearizedArray: return self._index_list + + @property + def coordinate_members(self) -> tuple[PsStructType.Member, ...]: + return self._coord_members + + @property + def sparse_counter(self) -> PsTypedVariable: + return self._sparse_counter + + +def get_archetype_field( + fields: set[Field], + check_compatible_shapes: bool = True, + check_same_layouts: bool = True, + check_same_dimensions: bool = True, +): + shapes = set(f.spatial_shape for f in fields) + fixed_shapes = set(f.spatial_shape for f in fields if f.has_fixed_shape) + layouts = set(f.layout for f in fields) + dimensionalities = set(f.spatial_dimensions for f in fields) + + if check_same_dimensions and len(dimensionalities) != 1: + raise KernelConstraintsError( + "All fields must have the same number of spatial dimensions." + ) + + if check_same_layouts and len(layouts) != 1: + raise KernelConstraintsError("All fields must have the same memory layout.") + + if check_compatible_shapes: + if len(fixed_shapes) > 0: + if len(fixed_shapes) != len(shapes): + raise KernelConstraintsError( + "Cannot mix fixed- and variable-shape fields." + ) + if len(fixed_shapes) != 0: + raise KernelConstraintsError( + "Fixed-shape fields of different sizes encountered." + ) + + archetype_field = sorted(fields, key=lambda f: str(f))[0] + return archetype_field def create_sparse_iteration_space( ctx: KernelCreationContext, assignments: AssignmentCollection ) -> IterationSpace: - return NotImplemented + # All domain and custom fields must have the same spatial dimensions + # TODO: Must all domain fields have the same shape? + archetype_field = get_archetype_field( + ctx.fields.domain_fields | ctx.fields.custom_fields, + check_compatible_shapes=False, + check_same_layouts=False, + check_same_dimensions=True, + ) + + dim = archetype_field.spatial_dimensions + coord_members = [ + PsStructType.Member(name, ctx.index_dtype) + for name in Defaults._index_struct_coordinate_names[:dim] + ] + + # Determine index field + if ctx.options.index_field is not None: + idx_field = ctx.options.index_field + idx_arr = ctx.get_array(idx_field) + idx_struct_type: PsStructType = failing_cast(PsStructType, idx_arr.element_type) + + for coord in coord_members: + if coord not in idx_struct_type.members: + raise PsInputError( + f"Given index field does not provide required coordinate member {coord}" + ) + else: + # TODO: Find index field from the fields list + raise NotImplementedError( + "Automatic inference of index field for sparse iteration not supported yet." + ) + + spatial_counters = [ + PsTypedVariable(name, constify(ctx.index_dtype)) + for name in Defaults.spatial_counter_names[:dim] + ] + + sparse_counter = PsTypedVariable(Defaults.sparse_counter_name, ctx.index_dtype) + + return SparseIterationSpace(spatial_counters, idx_arr, coord_members, sparse_counter) def create_full_iteration_space( @@ -185,15 +275,12 @@ def create_full_iteration_space( # - We have no domain fields, but at least one custom field -> determine common field from custom fields # - We have neither domain nor custom fields -> Error - # TODO: Re-implement as `get_archetype_field`, check not only shape but also layout equality - # The archetype field must encompass all information about the iteration space: shape, extents, and loop order. - from ...transformations import get_common_field - if len(domain_field_accesses) > 0: - archetype_field = get_common_field(ctx.fields.domain_fields) + archetype_field = get_archetype_field(ctx.fields.domain_fields) inferred_gls = max([fa.required_ghost_layers for fa in domain_field_accesses]) elif len(ctx.fields.custom_fields) > 0: - archetype_field = get_common_field(ctx.fields.custom_fields) + # TODO: Warn about inferring iteration space from custom fields + archetype_field = get_archetype_field(ctx.fields.custom_fields) inferred_gls = 0 else: raise PsInputError( @@ -204,8 +291,6 @@ def create_full_iteration_space( # Otherwise, if an iteration slice was specified, use that # Otherwise, use the inferred ghost layers - from .iteration_space import FullIterationSpace - if ctx.options.ghost_layers is not None: return FullIterationSpace.create_with_ghost_layers( ctx, archetype_field, ctx.options.ghost_layers diff --git a/src/pystencils/nbackend/kernelcreation/kernelcreation.py b/src/pystencils/nbackend/kernelcreation/kernelcreation.py index c87ef2fe4..07614da9a 100644 --- a/src/pystencils/nbackend/kernelcreation/kernelcreation.py +++ b/src/pystencils/nbackend/kernelcreation/kernelcreation.py @@ -23,7 +23,7 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti ispace: IterationSpace = ( create_sparse_iteration_space(ctx, assignments) - if len(ctx.fields.index_fields) > 0 + if len(ctx.fields.index_fields) > 0 or ctx.options.index_field is not None else create_full_iteration_space(ctx, assignments) ) @@ -37,22 +37,22 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti match options.target: case Target.CPU: - from .platform import BasicCpu + from .platform import BasicCpuGen # TODO: CPU platform should incorporate instruction set info, OpenMP, etc. - platform = BasicCpu(ctx) + platform_generator = BasicCpuGen(ctx) case _: # TODO: CUDA/HIP platform # TODO: SYCL platform (?) raise NotImplementedError("Target platform not implemented") - kernel_ast = platform.apply_iteration_space(kernel_body, ispace) + kernel_ast = platform_generator.materialize_iteration_space(kernel_body, ispace) # 7. Apply optimizations # - Vectorization # - OpenMP # - Loop Splitting, Tiling, Blocking - kernel_ast = platform.optimize(kernel_ast) + kernel_ast = platform_generator.optimize(kernel_ast) function = PsKernelFunction(kernel_ast, options.target, name=options.function_name) function.add_constraints(*ctx.constraints) diff --git a/src/pystencils/nbackend/kernelcreation/options.py b/src/pystencils/nbackend/kernelcreation/options.py index 355050b7e..5f5028a94 100644 --- a/src/pystencils/nbackend/kernelcreation/options.py +++ b/src/pystencils/nbackend/kernelcreation/options.py @@ -2,6 +2,8 @@ from typing import Sequence from dataclasses import dataclass from ...enums import Target +from ...field import Field + from ..exceptions import PsOptionsError from ..types import PsIntegerType, PsNumericType, PsIeeeFloatType @@ -43,6 +45,12 @@ class KernelCreationOptions: TODO: Specification of valid slices and their behaviour """ + index_field: Field | None = None + """Index field for a sparse kernel. + + If this option is set, a sparse kernel with the given field as index field will be generated. + """ + """Data Types""" index_dtype: PsIntegerType = SpDefaults.index_dtype @@ -55,8 +63,13 @@ class KernelCreationOptions: """ def __post_init__(self): - if self.iteration_slice is not None and self.ghost_layers is not None: + if ( + int(self.iteration_slice is not None) + + int(self.ghost_layers is not None) + + int(self.index_field is not None) + > 1 + ): raise PsOptionsError( - "Parameters `iteration_slice` and `ghost_layers` are mutually exclusive; " + "Parameters `iteration_slice`, `ghost_layers` and 'index_field` are mutually exclusive; " "at most one of them may be set." ) diff --git a/src/pystencils/nbackend/kernelcreation/platform/__init__.py b/src/pystencils/nbackend/kernelcreation/platform/__init__.py index 20e2c0aae..85b5af9c0 100644 --- a/src/pystencils/nbackend/kernelcreation/platform/__init__.py +++ b/src/pystencils/nbackend/kernelcreation/platform/__init__.py @@ -1,5 +1,5 @@ -from .basic_cpu import BasicCpu +from .basic_cpu import BasicCpuGen __all__ = [ - 'BasicCpu' + 'BasicCpuGen' ] diff --git a/src/pystencils/nbackend/kernelcreation/platform/basic_cpu.py b/src/pystencils/nbackend/kernelcreation/platform/basic_cpu.py index bba06a8f9..347061f19 100644 --- a/src/pystencils/nbackend/kernelcreation/platform/basic_cpu.py +++ b/src/pystencils/nbackend/kernelcreation/platform/basic_cpu.py @@ -1,17 +1,26 @@ -from pystencils.nbackend.ast import PsBlock, PsLoop, PsSymbolExpr, PsExpression -from pystencils.nbackend.kernelcreation.iteration_space import ( +from .platform import PlatformGen + +from ..iteration_space import ( IterationSpace, FullIterationSpace, + SparseIterationSpace, ) -from .platform import Platform + +from ...ast import PsDeclaration, PsSymbolExpr, PsExpression, PsLoop, PsBlock +from ...typed_expressions import PsTypedConstant +from ...arrays import PsArrayAccess -class BasicCpu(Platform): - def apply_iteration_space(self, block: PsBlock, ispace: IterationSpace) -> PsBlock: +class BasicCpuGen(PlatformGen): + def materialize_iteration_space( + self, body: PsBlock, ispace: IterationSpace + ) -> PsBlock: if isinstance(ispace, FullIterationSpace): - return self._create_domain_loops(block, ispace) + return self._create_domain_loops(body, ispace) + elif isinstance(ispace, SparseIterationSpace): + return self._create_sparse_loop(body, ispace) else: - raise NotImplementedError("Iteration space not supported yet.") + assert False, "unreachable code" def optimize(self, kernel: PsBlock) -> PsBlock: return kernel @@ -35,3 +44,28 @@ class BasicCpu(Platform): outer_block = PsBlock([loop]) return outer_block + + def _create_sparse_loop(self, body: PsBlock, ispace: SparseIterationSpace): + mappings = [ + PsDeclaration( + PsSymbolExpr(ctr), + PsExpression( + PsArrayAccess( + ispace.index_list.base_pointer, ispace.sparse_counter + ).a.__getattr__(coord.name) + ), + ) + for ctr, coord in zip(ispace.spatial_indices, ispace.coordinate_members) + ] + + body = PsBlock(mappings + body.statements) + + loop = PsLoop( + PsSymbolExpr(ispace.sparse_counter), + PsExpression(PsTypedConstant(0, self._ctx.index_dtype)), + PsExpression(ispace.index_list.shape[0]), + PsExpression(PsTypedConstant(1, self._ctx.index_dtype)), + body, + ) + + return PsBlock([loop]) diff --git a/src/pystencils/nbackend/kernelcreation/platform/platform.py b/src/pystencils/nbackend/kernelcreation/platform/platform.py index d7aeec39e..17dfa23f6 100644 --- a/src/pystencils/nbackend/kernelcreation/platform/platform.py +++ b/src/pystencils/nbackend/kernelcreation/platform/platform.py @@ -6,7 +6,7 @@ from ..context import KernelCreationContext from ..iteration_space import IterationSpace -class Platform(ABC): +class PlatformGen(ABC): """Abstract base class for all supported platforms. The platform performs all target-dependent tasks during code generation: @@ -18,7 +18,7 @@ class Platform(ABC): self._ctx = ctx @abstractmethod - def apply_iteration_space(self, block: PsBlock, ispace: IterationSpace) -> PsBlock: + def materialize_iteration_space(self, block: PsBlock, ispace: IterationSpace) -> PsBlock: ... @abstractmethod diff --git a/src/pystencils/nbackend/kernelcreation/typification.py b/src/pystencils/nbackend/kernelcreation/typification.py index df26a6471..9bc9a462f 100644 --- a/src/pystencils/nbackend/kernelcreation/typification.py +++ b/src/pystencils/nbackend/kernelcreation/typification.py @@ -6,7 +6,7 @@ import pymbolic.primitives as pb from pymbolic.mapper import Mapper from .context import KernelCreationContext -from ..types import PsAbstractType, PsNumericType +from ..types import PsAbstractType, PsNumericType, deconstify from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant from ..arrays import PsArrayAccess from ..ast import PsAstNode, PsBlock, PsExpression, PsAssignment @@ -52,7 +52,6 @@ class Typifier(Mapper): new_lhs, lhs_dtype = self.rec(lhs.expression, None) new_rhs, rhs_dtype = self.rec(rhs.expression, lhs_dtype) if lhs_dtype != rhs_dtype: - # todo: (optional) automatic cast insertion? raise TypificationError( "Mismatched types in assignment: \n" f" {lhs} <- {rhs}\n" @@ -67,7 +66,13 @@ class Typifier(Mapper): return node - # def rec(self, expr: Any, target_type: PsNumericType | None) + """ + def rec(self, expr: Any, target_type: PsNumericType | None) + + All visitor methods take an expression and the target type of the current context. + They shall return the typified expression together with its type. + The returned type shall always be non-const, so make sure to call deconstify if necessary. + """ def typify_expression( self, expr: Any, target_type: PsNumericType | None = None @@ -80,7 +85,7 @@ class Typifier(Mapper): self, var: PsTypedVariable, target_type: PsNumericType | None ): self._check_target_type(var, var.dtype, target_type) - return var, var.dtype + return var, deconstify(var.dtype) def map_variable( self, var: pb.Variable, target_type: PsNumericType | None @@ -88,20 +93,20 @@ class Typifier(Mapper): dtype = self._ctx.options.default_dtype typed_var = PsTypedVariable(var.name, dtype) self._check_target_type(typed_var, dtype, target_type) - return typed_var, dtype + return typed_var, deconstify(dtype) def map_constant( self, value: Any, target_type: PsNumericType | None ) -> tuple[PsTypedConstant, PsNumericType]: if isinstance(value, PsTypedConstant): self._check_target_type(value, value.dtype, target_type) - return value, value.dtype + return value, deconstify(value.dtype) elif target_type is None: raise TypificationError( f"Unable to typify constant {value}: Unknown target type in this context." ) else: - return PsTypedConstant(value, target_type), target_type + return PsTypedConstant(value, target_type), deconstify(target_type) # Array Access @@ -110,7 +115,7 @@ class Typifier(Mapper): ) -> tuple[PsArrayAccess, PsNumericType]: self._check_target_type(access, access.dtype, target_type) index, _ = self.rec(access.index_tuple[0], self._ctx.options.index_dtype) - return PsArrayAccess(access.base_ptr, index), cast(PsNumericType, access.dtype) + return PsArrayAccess(access.base_ptr, index), cast(PsNumericType, deconstify(access.dtype)) # Arithmetic Expressions @@ -157,7 +162,7 @@ class Typifier(Mapper): expr_type: PsAbstractType, target_type: PsNumericType | None, ): - if target_type is not None and expr_type != target_type: + if target_type is not None and deconstify(expr_type) != deconstify(target_type): raise TypificationError( f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n" f" Expression type: {expr_type}\n" diff --git a/src/pystencils/nbackend/types/__init__.py b/src/pystencils/nbackend/types/__init__.py index c398aea9d..13deab6b4 100644 --- a/src/pystencils/nbackend/types/__init__.py +++ b/src/pystencils/nbackend/types/__init__.py @@ -1,6 +1,7 @@ from .basic_types import ( PsAbstractType, PsCustomType, + PsStructType, PsNumericType, PsScalarType, PsPointerType, @@ -19,6 +20,7 @@ from .exception import PsTypeError __all__ = [ "PsAbstractType", "PsCustomType", + "PsStructType", "PsPointerType", "PsNumericType", "PsScalarType", diff --git a/src/pystencils/nbackend/types/basic_types.py b/src/pystencils/nbackend/types/basic_types.py index e6b918080..540f28334 100644 --- a/src/pystencils/nbackend/types/basic_types.py +++ b/src/pystencils/nbackend/types/basic_types.py @@ -1,11 +1,13 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import final, TypeVar, Any +from typing import final, TypeVar, Any, Sequence +from dataclasses import dataclass from copy import copy import numpy as np from .exception import PsTypeError +from ..exceptions import PsInternalCompilerError class PsAbstractType(ABC): @@ -40,12 +42,20 @@ class PsAbstractType(ABC): def required_headers(self) -> set[str]: """The set of header files required when this type occurs in generated code.""" return set() - + @property def itemsize(self) -> int | None: """If this type has a valid in-memory size, return that size.""" return None + @property + def numpy_dtype(self) -> np.dtype | None: + """A np.dtype object representing this data type. + + Available both for backward compatibility and for interaction with the numpy-based runtime system. + """ + return None + # ------------------------------------------------------------------------------------------- # Internal virtual operations # ------------------------------------------------------------------------------------------- @@ -142,6 +152,87 @@ class PsPointerType(PsAbstractType): return f"PsPointerType( {repr(self.base_type)}, const={self.const} )" +class PsStructType(PsAbstractType): + """Class to model structured data types. + + A struct type is defined by its sequence of members. + The struct may optionally have a name, although the code generator currently does not support named structs + and treats them the same way as anonymous structs. + """ + + @dataclass(frozen=True) + class Member: + name: str + dtype: PsAbstractType + + def __init__( + self, + members: Sequence[PsStructType.Member | tuple[str, PsAbstractType]], + name: str | None = None, + const: bool = False, + ): + super().__init__(const=const) + + self._name = name + self._members = tuple( + (PsStructType.Member(m[0], m[1]) if isinstance(m, tuple) else m) + for m in members + ) + + names: set[str] = set() + for member in self._members: + if member.name in names: + raise ValueError(f"Duplicate struct member name: {member.name}") + names.add(member.name) + + @property + def members(self) -> tuple[PsStructType.Member, ...]: + return self._members + + @property + def name(self) -> str: + if self._name is None: + raise PsInternalCompilerError( + "Cannot retrieve name from anonymous struct type" + ) + return self._name + + @property + def anonymous(self) -> bool: + return self._name is None + + @property + def numpy_dtype(self) -> np.dtype | None: + members = [(m.name, m.dtype.numpy_dtype) for m in self._members] + return np.dtype(members) + + def _c_string(self) -> str: + if self._name is None: + # raise PsInternalCompilerError( + # "Cannot retrieve C string for anonymous struct type" + # ) + return "<anonymous>" + return self._name + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsStructType): + return False + + return ( + self._base_equal(other) + and self._name == other._name + and self._members == other._members + ) + + def __hash__(self) -> int: + return hash(("PsStructTupe", self._name, self._members, self._const)) + + def __repr__(self) -> str: + members = ", ".join(f"{m.dtype} {m.name}" for m in self._members) + name = "<anonymous>" if self.anonymous else f"name={self._name}" + return f"PsStructType( [{members}], {name}, const={self.const} )" + + class PsNumericType(PsAbstractType, ABC): """Class to model numeric types, which are all types that may occur at the top-level inside arithmetic-logical expressions. @@ -244,6 +335,10 @@ class PsIntegerType(PsScalarType, ABC): def itemsize(self) -> int: return self.width // 8 + @property + def numpy_dtype(self) -> np.dtype | None: + return np.dtype(self.NUMPY_TYPES[self._width]) + def create_literal(self, value: Any) -> str: np_dtype = self.NUMPY_TYPES[self._width] if not isinstance(value, np_dtype): @@ -360,6 +455,10 @@ class PsIeeeFloatType(PsScalarType): def itemsize(self) -> int: return self.width // 8 + @property + def numpy_dtype(self) -> np.dtype | None: + return np.dtype(self.NUMPY_TYPES[self._width]) + @property def required_headers(self) -> set[str]: if self._width == 16: @@ -373,10 +472,14 @@ class PsIeeeFloatType(PsScalarType): raise PsTypeError(f"Given value {value} is not of required type {np_dtype}") match self.width: - case 16: return f"((half) {value})" # see include/half_precision.h - case 32: return f"{value}f" - case 64: return str(value) - case _: assert False, "unreachable code" + case 16: + return f"((half) {value})" # see include/half_precision.h + case 32: + return f"{value}f" + case 64: + return str(value) + case _: + assert False, "unreachable code" def create_constant(self, value: Any) -> Any: np_type = self.NUMPY_TYPES[self._width] diff --git a/src/pystencils/nbackend/types/parsing.py b/src/pystencils/nbackend/types/parsing.py index 14db20a92..952438f11 100644 --- a/src/pystencils/nbackend/types/parsing.py +++ b/src/pystencils/nbackend/types/parsing.py @@ -3,6 +3,7 @@ import numpy as np from .basic_types import ( PsAbstractType, PsPointerType, + PsStructType, PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType, @@ -41,6 +42,22 @@ def interpret_python_type(t: type) -> PsAbstractType: raise ValueError(f"Could not interpret Python data type {t} as a pystencils type.") +def interpret_numpy_dtype(t: np.dtype) -> PsAbstractType: + if t.fields is not None: + # it's a struct + members = [] + for fname, fspec in t.fields.items(): + members.append(PsStructType.Member(fname, interpret_numpy_dtype(fspec[0]))) + return PsStructType(members) + else: + try: + return interpret_python_type(t.type) + except ValueError: + raise ValueError( + f"Could not interpret numpy dtype object {t} as a pystencils type." + ) + + def parse_type_string(s: str) -> PsAbstractType: tokens = s.rsplit("*", 1) match tokens: diff --git a/src/pystencils/nbackend/types/quick.py b/src/pystencils/nbackend/types/quick.py index b1da0c5e2..cf65897d7 100644 --- a/src/pystencils/nbackend/types/quick.py +++ b/src/pystencils/nbackend/types/quick.py @@ -6,6 +6,8 @@ This module is meant to be included whole, e.g. as `from pystencils.nbackend.typ from __future__ import annotations +import numpy as np + from .basic_types import ( PsAbstractType, PsCustomType, @@ -17,7 +19,7 @@ from .basic_types import ( PsIeeeFloatType, ) -UserTypeSpec = str | type | PsAbstractType +UserTypeSpec = str | type | np.dtype | PsAbstractType def make_type(type_spec: UserTypeSpec) -> PsAbstractType: @@ -33,12 +35,14 @@ def make_type(type_spec: UserTypeSpec) -> PsAbstractType: - No others are supported at the moment - Supported Numpy scalar data types (see https://numpy.org/doc/stable/reference/arrays.scalars.html) are converted to pystencils scalar data types + - Instances of `np.dtype`: Attempt to interpret scalar types like above, and structured types as structs. - Instances of `PsAbstractType` will be returned as they are """ from .parsing import ( parse_type_string, interpret_python_type, + interpret_numpy_dtype ) if isinstance(type_spec, PsAbstractType): @@ -47,6 +51,8 @@ def make_type(type_spec: UserTypeSpec) -> PsAbstractType: return parse_type_string(type_spec) if isinstance(type_spec, type): return interpret_python_type(type_spec) + if isinstance(type_spec, np.dtype): + return interpret_numpy_dtype(type_spec) raise ValueError(f"{type_spec} is not a valid type specification.") -- GitLab