From 314aea65ea2d7ccfb3f7837c96d4783c2740a239 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 5 Mar 2024 14:35:38 +0100 Subject: [PATCH] remove old type system, replace by pystencils.types --- .../backend/kernelcreation/context.py | 9 +- .../backend/kernelcreation/freeze.py | 16 +- .../backend/kernelcreation/typification.py | 2 +- .../erase_anonymous_structs.py | 2 +- src/pystencils/field.py | 16 +- src/pystencils/functions.py | 4 +- src/pystencils/runhelper/db.py | 4 +- src/pystencils/spatial_coordinates.py | 8 +- src/pystencils/sympyextensions/math.py | 7 +- src/pystencils/sympyextensions/typed_sympy.py | 553 +++--------------- src/pystencils/types/__init__.py | 8 +- src/pystencils/types/basic_types.py | 55 +- src/pystencils/types/quick.py | 10 +- .../kernelcreation/test_typification.py | 4 +- tests/nbackend/types/test_types.py | 22 +- 15 files changed, 176 insertions(+), 544 deletions(-) diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 9ebc913d5..ffa400a16 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -6,12 +6,11 @@ from types import EllipsisType from ...defaults import DEFAULTS from ...field import Field, FieldType -from ...sympyextensions.typed_sympy import TypedSymbol, BasicType, StructType +from ...sympyextensions.typed_sympy import TypedSymbol from ..symbols import PsSymbol from ..arrays import PsLinearizedArray -from ...types import PsAbstractType, PsIntegerType, PsNumericType -from ...types.quick import make_type +from ...types import PsAbstractType, PsIntegerType, PsNumericType, PsScalarType, PsStructType from ..constraints import PsKernelParamsConstraint from ..exceptions import PsInternalCompilerError, KernelConstraintsError @@ -216,8 +215,8 @@ class KernelCreationContext: # Add array assert arr_strides is not None - assert isinstance(field.dtype, (BasicType, StructType)) - element_type = make_type(field.dtype.numpy_dtype) + assert isinstance(field.dtype, (PsScalarType, PsStructType)) + element_type = field.dtype arr = PsLinearizedArray( field.name, element_type, arr_shape, arr_strides, self.index_dtype diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index d0bca50a0..2d10b0c0c 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -5,7 +5,7 @@ from operator import add, mul import sympy as sp from ...sympyextensions import Assignment, AssignmentCollection -from ...sympyextensions.typed_sympy import BasicType +from ...sympyextensions.typed_sympy import TypedSymbol from ...field import Field, FieldType from .context import KernelCreationContext @@ -27,7 +27,7 @@ from ..ast.expressions import ( ) from ..constants import PsConstant -from ...types import constify, make_type, PsAbstractType, PsStructType +from ...types import PsStructType from ..exceptions import PsInputError from ..functions import PsMathFunction, MathFunctions @@ -118,16 +118,8 @@ class FreezeExpressions: denom = PsConstantExpr(PsConstant(expr.denominator)) return num / denom - def map_type(self, expr: BasicType) -> PsAbstractType: - # TODO: This should not be necessary; the frontend should use the new type system. - dtype = make_type(expr.numpy_dtype.type) - if expr.const: - return constify(dtype) - else: - return dtype - - def map_TypedSymbol(self, expr): - dtype = self.map_type(expr.dtype) + def map_TypedSymbol(self, expr: TypedSymbol): + dtype = expr.dtype symb = self._ctx.get_symbol(expr.name, dtype) return PsSymbolExpr(symb) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 5c041fb2a..2f6b5b33b 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -197,7 +197,7 @@ class Typifier: "Aggregate type of lookup was not a struct type." ) - member = aggr_type.get_member(member_name) + member = aggr_type.find_member(member_name) if member is None: raise TypificationError( f"Aggregate of type {aggr_type} does not have a member {member}." diff --git a/src/pystencils/backend/transformations/erase_anonymous_structs.py b/src/pystencils/backend/transformations/erase_anonymous_structs.py index ecca07aea..c946ae7bb 100644 --- a/src/pystencils/backend/transformations/erase_anonymous_structs.py +++ b/src/pystencils/backend/transformations/erase_anonymous_structs.py @@ -85,7 +85,7 @@ class EraseAnonymousStructTypes: ) member_name = lookup.member_name - member = struct_type.get_member(member_name) + member = struct_type.find_member(member_name) assert member is not None np_struct = struct_type.numpy_dtype diff --git a/src/pystencils/field.py b/src/pystencils/field.py index 389413f25..253c94fc0 100644 --- a/src/pystencils/field.py +++ b/src/pystencils/field.py @@ -14,8 +14,8 @@ from sympy.core.cache import cacheit from pystencils.alignedarray import aligned_empty from pystencils.spatial_coordinates import x_staggered_vector, x_vector from pystencils.stencil import direction_string_to_offset, inverse_direction, offset_to_direction_string -from pystencils.sympyextensions.typed_sympy import (FieldShapeSymbol, FieldStrideSymbol, StructType, - TypedSymbol, BasicType, create_type) +from pystencils.types import PsAbstractType, PsStructType, create_type +from pystencils.sympyextensions.typed_sympy import (FieldShapeSymbol, FieldStrideSymbol, TypedSymbol) from pystencils.sympyextensions.math import is_integer_sequence @@ -315,12 +315,12 @@ class Field: return self.strides[self.spatial_dimensions:] @property - def dtype(self): + def dtype(self) -> PsAbstractType: return self._dtype @property def itemsize(self): - return self.dtype.numpy_dtype.itemsize + return self.dtype.itemsize def __repr__(self): if any(isinstance(s, sp.Symbol) for s in self.spatial_shape): @@ -592,7 +592,7 @@ class Field: else: idx_str = ",".join([str(e) for e in idx]) superscript = idx_str - if field.has_fixed_index_shape and not isinstance(field.dtype, StructType): + if field.has_fixed_index_shape and not isinstance(field.dtype, PsStructType): for i, bound in zip(idx, field.index_shape): if i >= bound: raise ValueError("Field index out of bounds") @@ -604,7 +604,7 @@ class Field: if superscript is not None: symbol_name += "^" + superscript - if dtype: + if dtype is not None: obj = super(Field.Access, self).__xnew__(self, symbol_name, dtype) else: obj = super(Field.Access, self).__xnew__(self, symbol_name, field.dtype) @@ -652,7 +652,9 @@ class Field: if len(idx) != self.field.index_dimensions: raise ValueError(f"Wrong number of indices: Got {len(idx)}, expected {self.field.index_dimensions}") if len(idx) == 1 and isinstance(idx[0], str): - dtype = BasicType(self.field.dtype.numpy_dtype[idx[0]]) + struct_type = self.field.dtype + assert isinstance(struct_type, PsStructType) + dtype = struct_type.get_member(idx[0]).dtype return Field.Access(self.field, self._offsets, idx, is_absolute_access=self.is_absolute_access, dtype=dtype) else: diff --git a/src/pystencils/functions.py b/src/pystencils/functions.py index 550a2724f..df8d0ef6f 100644 --- a/src/pystencils/functions.py +++ b/src/pystencils/functions.py @@ -1,5 +1,5 @@ import sympy as sp -from .sympyextensions.typed_sympy import PointerType +from .types import PsPointerType class DivFunc(sp.Function): @@ -52,6 +52,6 @@ class AddressOf(sp.Function): @property def dtype(self): if hasattr(self.args[0], 'dtype'): - return PointerType(self.args[0].dtype, restrict=True) + return PsPointerType(self.args[0].dtype, const=True, restrict=True) else: raise ValueError(f'pystencils supports only non void pointers. Current address_of type: {self.args[0]}') diff --git a/src/pystencils/runhelper/db.py b/src/pystencils/runhelper/db.py index 1c8d3aa66..53d65b2ea 100644 --- a/src/pystencils/runhelper/db.py +++ b/src/pystencils/runhelper/db.py @@ -14,7 +14,7 @@ from pystencils import CreateKernelConfig, Target, Backend, Field import json import sympy as sp -from pystencils.typing import BasicType +from pystencils.types import PsAbstractType class PystencilsJsonEncoder(JsonEncoder): @@ -26,7 +26,7 @@ class PystencilsJsonEncoder(JsonEncoder): return float(obj) if isinstance(obj, sp.Integer): return int(obj) - if isinstance(obj, (BasicType, MappingProxyType)): + if isinstance(obj, (PsAbstractType, MappingProxyType)): return str(obj) if isinstance(obj, (Target, Backend, sp.Symbol)): return obj.name diff --git a/src/pystencils/spatial_coordinates.py b/src/pystencils/spatial_coordinates.py index cc244b11c..794bb713f 100644 --- a/src/pystencils/spatial_coordinates.py +++ b/src/pystencils/spatial_coordinates.py @@ -1,14 +1,14 @@ import sympy -from pystencils.sympyextensions.typed_sympy import get_loop_counter_symbol +from .defaults import DEFAULTS -x_, y_, z_ = tuple(get_loop_counter_symbol(i) for i in range(3)) +x_, y_, z_ = DEFAULTS.spatial_counters x_staggered, y_staggered, z_staggered = x_ + 0.5, y_ + 0.5, z_ + 0.5 def x_vector(ndim): - return sympy.Matrix(tuple(get_loop_counter_symbol(i) for i in range(ndim))) + return sympy.Matrix(DEFAULTS.spatial_counters) def x_staggered_vector(ndim): - return sympy.Matrix(tuple(get_loop_counter_symbol(i) + 0.5 for i in range(ndim))) + return sympy.Matrix(tuple(DEFAULTS.spatial_counters[i] + 0.5 for i in range(ndim))) diff --git a/src/pystencils/sympyextensions/math.py b/src/pystencils/sympyextensions/math.py index 00974809f..3e7eeaabc 100644 --- a/src/pystencils/sympyextensions/math.py +++ b/src/pystencils/sympyextensions/math.py @@ -12,7 +12,8 @@ from sympy.core.numbers import Zero from .astnodes import Assignment from pystencils.functions import DivFunc -from .typed_sympy import CastFunc, PointerType, VectorType, FieldPointerSymbol +from .typed_sympy import CastFunc, FieldPointerSymbol +from ..types import PsPointerType, PsVectorType T = TypeVar('T') @@ -572,9 +573,9 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]], base_type = get_type_of_expression(e) except ValueError: return False - if isinstance(base_type, VectorType): + if isinstance(base_type, PsVectorType): return False - if isinstance(base_type, PointerType): + if isinstance(base_type, PsPointerType): return only_type == 'int' if only_type == 'int' and (base_type.is_int() or base_type.is_uint()): return True diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py index 1f20624ce..c437afd92 100644 --- a/src/pystencils/sympyextensions/typed_sympy.py +++ b/src/pystencils/sympyextensions/typed_sympy.py @@ -1,429 +1,37 @@ -from abc import abstractmethod -from itertools import groupby -from typing import Sequence, Union - -import numpy as np import sympy as sp +from ..types import PsAbstractType, PsNumericType, PsPointerType, PsBoolType +from ..types.quick import create_type -class AbstractType(sp.Atom): - # TODO: Is it necessary to ineherit from sp.Atom? - def __new__(cls, *args, **kwargs): - return sp.Basic.__new__(cls) - - def _sympystr(self, *args, **kwargs): - return str(self) - - @property - @abstractmethod - def base_type(self) -> Union[None, 'BasicType']: - """ - Returns: Returns BasicType of a Vector or Pointer type, None otherwise - """ - pass - - @property - @abstractmethod - def item_size(self) -> int: - """ - Returns: Number of items. - E.g. width * item_size(basic_type) in vector's case, or simple numpy itemsize in Struct's case. - """ - pass - -def is_supported_type(dtype: np.dtype): - scalar = dtype.type - c = np.issctype(dtype) - subclass = issubclass(scalar, np.floating) or issubclass(scalar, np.integer) or issubclass(scalar, np.bool_) - additional_checks = dtype.fields is None and dtype.hasobject is False and dtype.subdtype is None - return c and subclass and additional_checks - - -def numpy_name_to_c(name: str) -> str: - """ - Converts a np.dtype.name into a C type - Args: - name: np.dtype.name string - Returns: - type as a C string - """ - if name == 'float64': - return 'double' - elif name == 'float32': - return 'float' - elif name == 'float16' or name == 'half': - return 'half' - elif name.startswith('int'): - width = int(name[len("int"):]) - return f"int{width}_t" - elif name.startswith('uint'): - width = int(name[len("uint"):]) - return f"uint{width}_t" - elif name == 'bool': - return 'bool' - else: - raise NotImplementedError(f"Can't map numpy to C name for {name}") - - -def create_type(specification: Union[type, AbstractType, str]) -> AbstractType: - # TODO: Deprecated Use the constructor of BasicType or StructType instead - """Creates a subclass of Type according to a string or an object of subclass Type. +def assumptions_from_dtype(dtype: PsAbstractType): + """Derives SymPy assumptions from :class:`PsAbstractType` Args: - specification: Type object, or a string - - Returns: - Type object, or a new Type object parsed from the string - """ - if isinstance(specification, AbstractType): - return specification - else: - numpy_dtype = np.dtype(specification) - if numpy_dtype.fields is None: - return BasicType(numpy_dtype, const=False) - else: - return StructType(numpy_dtype, const=False) - - -def get_base_type(data_type): - """ - Returns the BasicType of a Pointer or a Vector - """ - while data_type.base_type is not None: - data_type = data_type.base_type - return data_type - - -class BasicType(AbstractType): - """ - BasicType is defined with a const qualifier and a np.dtype. - """ - - def __init__(self, dtype: Union[type, 'BasicType', str], const: bool = False): - if isinstance(dtype, BasicType): - self.numpy_dtype = dtype.numpy_dtype - self.const = dtype.const - else: - self.numpy_dtype = np.dtype(dtype) - self.const = const - assert is_supported_type(self.numpy_dtype), f'Type {self.numpy_dtype} is currently not supported!' - - def __getnewargs__(self): - return self.numpy_dtype, self.const - - def __getnewargs_ex__(self): - return (self.numpy_dtype, self.const), {} - - @property - def base_type(self): - return None - - @property - def item_size(self): # TODO: Do we want self.numpy_type.itemsize???? - return 1 - - def is_float(self): - return issubclass(self.numpy_dtype.type, np.floating) - - def is_half(self): - return issubclass(self.numpy_dtype.type, np.half) - - def is_int(self): - return issubclass(self.numpy_dtype.type, np.integer) - - def is_uint(self): - return issubclass(self.numpy_dtype.type, np.unsignedinteger) - - def is_sint(self): - return issubclass(self.numpy_dtype.type, np.signedinteger) - - def is_bool(self): - return issubclass(self.numpy_dtype.type, np.bool_) - - def dtype_eq(self, other): - if not isinstance(other, BasicType): - return False - else: - return self.numpy_dtype == other.numpy_dtype - - @property - def c_name(self) -> str: - return numpy_name_to_c(self.numpy_dtype.name) - - def __str__(self): - return f'{self.c_name}{" const" if self.const else ""}' - - def __repr__(self): - return f'BasicType( {str(self)} )' - - def _repr_html_(self): - return f'BasicType( {str(self)} )' - - def __eq__(self, other): - return self.dtype_eq(other) and self.const == other.const - - def __hash__(self): - return hash(str(self)) - - -class VectorType(AbstractType): - """ - VectorType consists of a BasicType and a width. - """ - instruction_set = None - - def __init__(self, base_type: BasicType, width: int): - self._base_type = base_type - self.width = width - - @property - def base_type(self): - return self._base_type - - @property - def item_size(self): - return self.width * self.base_type.item_size - - def __eq__(self, other): - if not isinstance(other, VectorType): - return False - else: - return (self.base_type, self.width) == (other.base_type, other.width) - - def __str__(self): - if self.instruction_set is None: - return f"{self.base_type}[{self.width}]" - else: - # TODO VectorizationRevamp: this seems super weird. the instruction_set should know how to print a type out! - # TODO VectorizationRevamp: this is error prone. base_type could be cons=True. Use dtype instead - if self.base_type == create_type("int64") or self.base_type == create_type("int32"): - return self.instruction_set['int'] - elif self.base_type == create_type("float64"): - return self.instruction_set['double'] - elif self.base_type == create_type("float32"): - return self.instruction_set['float'] - elif self.base_type == create_type("bool"): - return self.instruction_set['bool'] - else: - raise NotImplementedError() - - def __hash__(self): - return hash((self.base_type, self.width)) - - def __getnewargs__(self): - return self._base_type, self.width - - def __getnewargs_ex__(self): - return (self._base_type, self.width), {} - - -class PointerType(AbstractType): - def __init__(self, base_type: BasicType, const: bool = False, restrict: bool = True, double_pointer: bool = False): - self._base_type = base_type - self.const = const - self.restrict = restrict - self.double_pointer = double_pointer - - def __getnewargs__(self): - return self.base_type, self.const, self.restrict, self.double_pointer - - def __getnewargs_ex__(self): - return (self.base_type, self.const, self.restrict, self.double_pointer), {} - - @property - def alias(self): - return not self.restrict - - @property - def base_type(self): - return self._base_type - - @property - def item_size(self): - if self.double_pointer: - raise NotImplementedError("The item_size for double_pointer is not implemented") - else: - return self.base_type.item_size - - def __eq__(self, other): - if not isinstance(other, PointerType): - return False - else: - own = (self.base_type, self.const, self.restrict, self.double_pointer) - return own == (other.base_type, other.const, other.restrict, other.double_pointer) - - def __str__(self): - restrict_str = "RESTRICT" if self.restrict else "" - const_str = "const" if self.const else "" - if self.double_pointer: - return f'{str(self.base_type)} ** {restrict_str} {const_str}' - else: - return f'{str(self.base_type)} * {restrict_str} {const_str}' - - def __repr__(self): - return str(self) - - def _repr_html_(self): - return str(self) - - def __hash__(self): - return hash((self._base_type, self.const, self.restrict, self.double_pointer)) - - -class StructType(AbstractType): - """ - A list of types (with C offsets). - It is implemented with uint8_t and casts to the correct datatype. - """ - def __init__(self, numpy_type, const=False): - self.const = const - self._dtype = np.dtype(numpy_type) - - def __getnewargs__(self): - return self.numpy_dtype, self.const - - def __getnewargs_ex__(self): - return (self.numpy_dtype, self.const), {} - - @property - def base_type(self): - return None - - @property - def numpy_dtype(self): - return self._dtype - - @property - def item_size(self): - return self.numpy_dtype.itemsize - - def get_element_offset(self, element_name): - return self.numpy_dtype.fields[element_name][1] - - def get_element_type(self, element_name): - np_element_type = self.numpy_dtype.fields[element_name][0] - return BasicType(np_element_type, self.const) - - def has_element(self, element_name): - return element_name in self.numpy_dtype.fields - - def __eq__(self, other): - if not isinstance(other, StructType): - return False - else: - return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) - - def __str__(self): - # structs are handled byte-wise - result = "uint8_t" - if self.const: - result += " const" - return result - - def __repr__(self): - return str(self) - - def _repr_html_(self): - return str(self) - - def __hash__(self): - return hash((self.numpy_dtype, self.const)) - - -def result_type(*args: np.dtype): - """Returns the type of the result if the np.dtype arguments would be collated. - We can't use numpy functionality, because numpy casts don't behave exactly like C casts""" - s = sorted(args, key=lambda x: x.itemsize) - - def kind_to_value(kind: str) -> int: - if kind == 'f': - return 3 - elif kind == 'i': - return 2 - elif kind == 'u': - return 1 - elif kind == 'b': - return 0 - else: - raise NotImplementedError(f'{kind=} is not a supported kind of a type. See "numpy.dtype.kind" for options') - s = sorted(s, key=lambda x: kind_to_value(x.kind)) - return s[-1] - - -def all_equal(iterable): - """ - Returns ``True`` if all the elements are equal to each other. - Copied from: more-itertools 8.12.0 - """ - g = groupby(iterable) - return next(g, True) and not next(g, False) - - -def collate_types(types: Sequence[Union[BasicType, VectorType]]): - """ - Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double - Uses the collation rules from numpy. - """ - # Pointer arithmetic case i.e. pointer + [int, uint] is allowed - if any(isinstance(t, PointerType) for t in types): - pointer_type = None - for t in types: - if isinstance(t, PointerType): - if pointer_type is not None: - raise ValueError(f'Cannot collate the combination of two pointer types "{pointer_type}" and "{t}"') - pointer_type = t - elif isinstance(t, BasicType): - if not (t.is_int() or t.is_uint()): - raise ValueError("Invalid pointer arithmetic") - else: - raise ValueError("Invalid pointer arithmetic") - return pointer_type - - # # peel of vector types, if at least one vector type occurred the result will also be the vector type - vector_type = [t for t in types if isinstance(t, VectorType)] - if not all_equal(t.width for t in vector_type): - raise ValueError("Collation failed because of vector types with different width") - - types = [t.base_type if isinstance(t, VectorType) else t for t in types] - - # now we should have a list of basic types - struct types are not yet supported - assert all(type(t) is BasicType for t in types) - - result_numpy_type = result_type(*(t.numpy_dtype for t in types)) - result = BasicType(result_numpy_type) - if vector_type: - result = VectorType(result, vector_type[0].width) - return result - - -def assumptions_from_dtype(dtype: Union[BasicType, np.dtype]): - """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype - - Args: - dtype (BasicType, np.dtype): a Numpy data type + dtype (PsAbstractType): a pystencils data type Returns: A dict of SymPy assumptions """ - if hasattr(dtype, 'numpy_dtype'): - dtype = dtype.numpy_dtype - assumptions = dict() - try: - if np.issubdtype(dtype, np.integer): - assumptions.update({'integer': True}) + if isinstance(dtype, PsNumericType): + if dtype.is_int: + assumptions.update({"integer": True}) + if dtype.is_uint: + assumptions.update({"negative": False}) + if dtype.is_int or dtype.is_float: + assumptions.update({"real": True}) - if np.issubdtype(dtype, np.unsignedinteger): - assumptions.update({'negative': False}) + return assumptions - if np.issubdtype(dtype, np.integer) or \ - np.issubdtype(dtype, np.floating): - assumptions.update({'real': True}) - except Exception: # TODO this is dirty - pass - return assumptions +def is_loop_counter_symbol(symbol): + from ..defaults import DEFAULTS + + try: + return DEFAULTS.spatial_counters.index(symbol) + except ValueError: + return None class TypedSymbol(sp.Symbol): @@ -431,27 +39,30 @@ class TypedSymbol(sp.Symbol): obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds) return obj - def __new_stage2__(cls, name, dtype, **kwargs): # TODO does not match signature of sp.Symbol??? + def __new_stage2__( + cls, name, dtype, **kwargs + ): # TODO does not match signature of sp.Symbol??? # TODO: also Symbol should be allowed ---> see sympy Variable + dtype = create_type(dtype) assumptions = assumptions_from_dtype(dtype) assumptions.update(kwargs) + obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions) - try: - obj.numpy_dtype = create_type(dtype) - except (TypeError, ValueError): - # on error keep the string - obj.numpy_dtype = dtype + obj._dtype = create_type(dtype) + return obj __xnew__ = staticmethod(__new_stage2__) __xnew_cached_ = staticmethod(sp.core.cacheit(__new_stage2__)) @property - def dtype(self): - return self.numpy_dtype + def dtype(self) -> PsAbstractType: + # mypy: ignore + return self._dtype def _hashable_content(self): - return super()._hashable_content(), hash(self.numpy_dtype) + # mypy: ignore + return super()._hashable_content(), hash(self._dtype) def __getnewargs__(self): return self.name, self.dtype @@ -468,38 +79,24 @@ class TypedSymbol(sp.Symbol): return self @property - def headers(self): - headers = [] - try: - if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating): - headers.append('"cuda_complex.hpp"') - except Exception: - pass - try: - if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating): - headers.append('"cuda_complex.hpp"') - except Exception: - pass - - return headers - - -SHAPE_DTYPE = BasicType('int64', const=True) -STRIDE_DTYPE = BasicType('int64', const=True) -LOOP_COUNTER_DTYPE = BasicType('int64', const=True) -LOOP_COUNTER_NAME_PREFIX = "ctr" -BLOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr" + def headers(self) -> set[str]: + return self.dtype.required_headers class FieldStrideSymbol(TypedSymbol): """Sympy symbol representing the stride value of a field in a specific coordinate.""" + def __new__(cls, *args, **kwds): obj = FieldStrideSymbol.__xnew_cached_(cls, *args, **kwds) return obj def __new_stage2__(cls, field_name, coordinate): + from ..defaults import DEFAULTS + name = f"_stride_{field_name}_{coordinate}" - obj = super(FieldStrideSymbol, cls).__xnew__(cls, name, STRIDE_DTYPE, positive=True) + obj = super(FieldStrideSymbol, cls).__xnew__( + cls, name, DEFAULTS.index_dtype, positive=True + ) obj.field_name = field_name obj.coordinate = coordinate return obj @@ -519,15 +116,21 @@ class FieldStrideSymbol(TypedSymbol): class FieldShapeSymbol(TypedSymbol): """Sympy symbol representing the shape value of a sequence of fields. In a kernel iterating over multiple fields - there is only one set of `FieldShapeSymbol`s since all the fields have to be of equal size.""" + there is only one set of `FieldShapeSymbol`s since all the fields have to be of equal size. + """ + def __new__(cls, *args, **kwds): obj = FieldShapeSymbol.__xnew_cached_(cls, *args, **kwds) return obj def __new_stage2__(cls, field_names, coordinate): + from ..defaults import DEFAULTS + names = "_".join([field_name for field_name in field_names]) name = f"_size_{names}_{coordinate}" - obj = super(FieldShapeSymbol, cls).__xnew__(cls, name, SHAPE_DTYPE, positive=True) + obj = super(FieldShapeSymbol, cls).__xnew__( + cls, name, DEFAULTS.index_dtype, positive=True + ) obj.field_names = tuple(field_names) obj.coordinate = coordinate return obj @@ -547,13 +150,14 @@ class FieldShapeSymbol(TypedSymbol): class FieldPointerSymbol(TypedSymbol): """Sympy symbol representing the pointer to the beginning of the field data.""" + def __new__(cls, *args, **kwds): obj = FieldPointerSymbol.__xnew_cached_(cls, *args, **kwds) return obj - def __new_stage2__(cls, field_name, field_dtype, const): + def __new_stage2__(cls, field_name, field_dtype: PsAbstractType, const: bool): name = f"_data_{field_name}" - dtype = PointerType(get_base_type(field_dtype), const=const, restrict=True) + dtype = PsPointerType(field_dtype, const=const, restrict=True) obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype) obj.field_name = field_name return obj @@ -571,30 +175,12 @@ class FieldPointerSymbol(TypedSymbol): __xnew_cached_ = staticmethod(sp.core.cacheit(__new_stage2__)) -def get_loop_counter_symbol(coordinate_to_loop_over): - return TypedSymbol(f"{LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}", - LOOP_COUNTER_DTYPE, nonnegative=True) - - -def get_block_loop_counter_symbol(coordinate_to_loop_over): - return TypedSymbol(f"{BLOCK_LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}", - LOOP_COUNTER_DTYPE, nonnegative=True) - - -def is_loop_counter_symbol(symbol): - if not symbol.name.startswith(LOOP_COUNTER_NAME_PREFIX): - return None - if symbol.dtype != LOOP_COUNTER_DTYPE: - return None - coordinate = int(symbol.name[len(LOOP_COUNTER_NAME_PREFIX) + 1:]) - return coordinate - - class CastFunc(sp.Function): """ CastFunc is used in order to introduce static casts. They are especially useful as a way to signal what type a certain node should have, if it is impossible to add a type to a node, e.g. a sp.Number. """ + is_Atom = True def __new__(cls, *args, **kwargs): @@ -606,8 +192,7 @@ class CastFunc(sp.Function): # This optimisation is only available for simple casts. Thus the == is intended here! if expr.__class__ == CastFunc: expr = expr.args[0] - if not isinstance(dtype, AbstractType): - dtype = BasicType(dtype) + dtype = create_type(dtype) # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads # to problems when for example comparing cast_func's for equality @@ -616,15 +201,16 @@ class CastFunc(sp.Function): # rhs = cast_func(0, 'int') # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans # -> thus a separate class boolean_cast_func is introduced - if (isinstance(expr, sp.logic.boolalg.Boolean) and - (not isinstance(expr, TypedSymbol) or expr.dtype == BasicType('bool'))): + if isinstance(expr, sp.logic.boolalg.Boolean) and ( + not isinstance(expr, TypedSymbol) or isinstance(expr.dtype, PsBoolType) + ): cls = BooleanCastFunc return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs) @property def canonical(self): - if hasattr(self.args[0], 'canonical'): + if hasattr(self.args[0], "canonical"): return self.args[0].canonical else: raise NotImplementedError() @@ -643,13 +229,8 @@ class CastFunc(sp.Function): @property def is_integer(self): - """ - Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate - - For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html - """ - if hasattr(self.dtype, 'numpy_dtype'): - return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer + if isinstance(self.dtype, PsNumericType): + return self.dtype.is_int() or super().is_integer else: return super().is_integer @@ -658,8 +239,8 @@ class CastFunc(sp.Function): """ See :func:`.TypedSymbol.is_integer` """ - if hasattr(self.dtype, 'numpy_dtype'): - if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger): + if isinstance(self.dtype, PsNumericType): + if self.dtype.is_uint(): return False return super().is_negative @@ -679,9 +260,8 @@ class CastFunc(sp.Function): """ See :func:`.TypedSymbol.is_integer` """ - if hasattr(self.dtype, 'numpy_dtype'): - return np.issubdtype(self.dtype.numpy_dtype, np.integer) or np.issubdtype(self.dtype.numpy_dtype, - np.floating) or super().is_real + if isinstance(self.dtype, PsNumericType): + return self.dtype.is_int() or self.dtype.is_float() or super().is_real else: return super().is_real @@ -696,6 +276,7 @@ class VectorMemoryAccess(CastFunc): Special memory access for vectorized kernel. Arguments: read/write expression, type, aligned, non-temporal, mask (or none), stride """ + nargs = (6,) @@ -703,6 +284,7 @@ class ReinterpretCastFunc(CastFunc): """ Reinterpret cast is necessary for the StructType """ + pass @@ -710,10 +292,7 @@ class PointerArithmeticFunc(sp.Function, sp.logic.boolalg.Boolean): # TODO: documentation, or deprecate! @property def canonical(self): - if hasattr(self.args[0], 'canonical'): + if hasattr(self.args[0], "canonical"): return self.args[0].canonical else: raise NotImplementedError() - - - diff --git a/src/pystencils/types/__init__.py b/src/pystencils/types/__init__.py index 4aee28086..f33d8a1f7 100644 --- a/src/pystencils/types/__init__.py +++ b/src/pystencils/types/__init__.py @@ -6,6 +6,7 @@ from .basic_types import ( PsScalarType, PsVectorType, PsPointerType, + PsBoolType, PsIntegerType, PsUnsignedIntegerType, PsSignedIntegerType, @@ -14,7 +15,7 @@ from .basic_types import ( deconstify, ) -from .quick import make_type, make_numeric_type +from .quick import create_type, create_numeric_type from .exception import PsTypeError @@ -27,12 +28,13 @@ __all__ = [ "PsScalarType", "PsVectorType", "PsIntegerType", + "PsBoolType", "PsUnsignedIntegerType", "PsSignedIntegerType", "PsIeeeFloatType", "constify", "deconstify", - "make_type", - "make_numeric_type", + "create_type", + "create_numeric_type", "PsTypeError", ] diff --git a/src/pystencils/types/basic_types.py b/src/pystencils/types/basic_types.py index 947895e2c..749a4d34c 100644 --- a/src/pystencils/types/basic_types.py +++ b/src/pystencils/types/basic_types.py @@ -187,12 +187,18 @@ class PsStructType(PsAbstractType): def members(self) -> tuple[PsStructType.Member, ...]: return self._members - def get_member(self, member_name: str) -> PsStructType.Member | None: + def find_member(self, member_name: str) -> PsStructType.Member | None: """Find a member by name""" for m in self._members: if m.name == member_name: return m return None + + def get_member(self, member_name: str) -> PsStructType.Member: + m = self.find_member(member_name) + if m is None: + raise KeyError(f"No struct member with name {member_name}") + return m @property def name(self) -> str: @@ -289,6 +295,10 @@ class PsNumericType(PsAbstractType, ABC): def is_float(self) -> bool: pass + @abstractmethod + def is_bool(self) -> bool: + pass + class PsScalarType(PsNumericType, ABC): """Class to model scalar numeric types.""" @@ -317,6 +327,9 @@ class PsScalarType(PsNumericType, ABC): def is_float(self) -> bool: return isinstance(self, PsIeeeFloatType) + + def is_bool(self) -> bool: + return isinstance(self, PsBoolType) class PsVectorType(PsNumericType): @@ -357,6 +370,9 @@ class PsVectorType(PsNumericType): def is_float(self) -> bool: return self._scalar_type.is_float() + + def is_bool(self) -> bool: + return self._scalar_type.is_bool() @property def itemsize(self) -> int | None: @@ -412,6 +428,43 @@ class PsVectorType(PsNumericType): ) +class PsBoolType(PsScalarType): + """Class to model the boolean type.""" + + NUMPY_TYPE = np.bool_ + + def __init__(self, const: bool = False): + super().__init__(const) + + @property + def width(self) -> int: + return 8 + + @property + def itemsize(self) -> int: + return self.width // 8 + + @property + def numpy_dtype(self) -> np.dtype | None: + return np.dtype(PsBoolType.NUMPY_TYPE) + + def create_literal(self, value: Any) -> str: + if value in (1, True, np.True_): + return "true" + elif value in (0, False, np.False_): + return "false" + else: + raise PsTypeError(f"Cannot create boolean literal from {value}") + + def create_constant(self, value: Any) -> Any: + if value in (1, True, np.True_): + return np.True_ + elif value in (0, False, np.False_): + return np.False_ + else: + raise PsTypeError(f"Cannot create boolean constant from value {value}") + + class PsIntegerType(PsScalarType, ABC): """Class to model signed and unsigned integer types. diff --git a/src/pystencils/types/quick.py b/src/pystencils/types/quick.py index 65683c80a..0268e9a79 100644 --- a/src/pystencils/types/quick.py +++ b/src/pystencils/types/quick.py @@ -13,6 +13,7 @@ from .basic_types import ( PsCustomType, PsNumericType, PsScalarType, + PsBoolType, PsPointerType, PsIntegerType, PsUnsignedIntegerType, @@ -23,7 +24,7 @@ from .basic_types import ( UserTypeSpec = str | type | np.dtype | PsAbstractType -def make_type(type_spec: UserTypeSpec) -> PsAbstractType: +def create_type(type_spec: UserTypeSpec) -> PsAbstractType: """Create a pystencils type object from a variety of specifications. Possible arguments are: @@ -53,9 +54,9 @@ def make_type(type_spec: UserTypeSpec) -> PsAbstractType: raise ValueError(f"{type_spec} is not a valid type specification.") -def make_numeric_type(type_spec: UserTypeSpec) -> PsNumericType: +def create_numeric_type(type_spec: UserTypeSpec) -> PsNumericType: """Like `make_type`, but only for numeric types.""" - dtype = make_type(type_spec) + dtype = create_type(type_spec) if not isinstance(dtype, PsNumericType): raise ValueError( f"Given type {type_spec} does not translate to a numeric type." @@ -72,6 +73,9 @@ Scalar = PsScalarType Ptr = PsPointerType """`Ptr(t)` matches `PsPointerType(base_type=t)`""" +Bool = PsBoolType +"""Bool() matches PsBoolType()""" + AnyInt = PsIntegerType """`AnyInt(width)` matches both `PsUnsignedIntegerType(width)` and `PsSignedIntegerType(width)`""" diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 9f49b2ed0..7625e22e3 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -7,7 +7,7 @@ from pystencils import Assignment, TypedSymbol, Field, FieldType from pystencils.backend.ast.structural import PsDeclaration from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp from pystencils.types import constify -from pystencils.types.quick import Fp, make_numeric_type +from pystencils.types.quick import Fp, create_numeric_type from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.typification import Typifier, TypificationError @@ -92,7 +92,7 @@ def test_contextual_typing(): def test_erronous_typing(): - ctx = KernelCreationContext(default_dtype=make_numeric_type(np.float64)) + ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64)) freeze = FreezeExpressions(ctx) typify = Typifier(ctx) diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py index 5f11cd081..487f77783 100644 --- a/tests/nbackend/types/test_types.py +++ b/tests/nbackend/types/test_types.py @@ -19,12 +19,12 @@ def test_widths(Type): def test_parsing_positive(): - assert make_type("const uint32_t * restrict") == Ptr( + assert create_type("const uint32_t * restrict") == Ptr( UInt(32, const=True), restrict=True ) - assert make_type("float * * const") == Ptr(Ptr(Fp(32)), const=True) - assert make_type("uint16 * const") == Ptr(UInt(16), const=True) - assert make_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True) + assert create_type("float * * const") == Ptr(Ptr(Fp(32)), const=True) + assert create_type("uint16 * const") == Ptr(UInt(16), const=True) + assert create_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True) def test_parsing_negative(): @@ -40,20 +40,20 @@ def test_parsing_negative(): for spec in bad_specs: with pytest.raises(ValueError): - make_type(spec) + create_type(spec) def test_numpy(): import numpy as np - assert make_type(np.single) == make_type(np.float32) == PsIeeeFloatType(32) + assert create_type(np.single) == create_type(np.float32) == PsIeeeFloatType(32) assert ( - make_type(float) - == make_type(np.double) - == make_type(np.float64) + create_type(float) + == create_type(np.double) + == create_type(np.float64) == PsIeeeFloatType(64) ) - assert make_type(int) == make_type(np.int64) == PsSignedIntegerType(64) + assert create_type(int) == create_type(np.int64) == PsSignedIntegerType(64) @pytest.mark.parametrize( @@ -74,7 +74,7 @@ def test_numpy(): ) def test_numpy_translation(numpy_type): dtype_obj = np.dtype(numpy_type) - ps_type = make_type(numpy_type) + ps_type = create_type(numpy_type) assert isinstance(ps_type, PsNumericType) assert ps_type.numpy_dtype == dtype_obj -- GitLab