diff --git a/conftest.py b/conftest.py index 040ddf59505d29d1235a91279006ad34e6bcb8db..4af58abc28902146267d07401c178097707638f8 100644 --- a/conftest.py +++ b/conftest.py @@ -8,7 +8,6 @@ import nbformat import pytest from nbconvert import PythonExporter -from pystencils.boundaries.createindexlistcython import * # NOQA # Trigger config file reading / creation once - to avoid race conditions when multiple instances are creating it # at the same time from pystencils.cpu import cpujit @@ -18,10 +17,10 @@ from pystencils.cpu import cpujit try: import pyximport pyximport.install(language_level=3) + from pystencils.boundaries.createindexlistcython import * # NOQA except ImportError: pass - SCRIPT_FOLDER = os.path.dirname(os.path.realpath(__file__)) sys.path.insert(0, os.path.abspath('pystencils')) diff --git a/debug.py b/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..c8b2577319aac4ff52bc9a5158c5a5c46ebeefe8 --- /dev/null +++ b/debug.py @@ -0,0 +1,26 @@ +#%% +import pytest +from pystencils.nbackend.types.quick import * + + +def test_parsing_positive(): + assert make_type("const uint32_t * restrict") == Ptr(UInt(32, const=True), restrict=True) + assert make_type("float * * const") == Ptr(Ptr(Fp(32)), const=True) + +def test_parsing_negative(): + bad_specs = [ + "const notatype * const", + "cnost uint32_t", + "int", # plain ints are ambiguous + "float float", + "double * int", + "bool" + ] + + for spec in bad_specs: + with pytest.raises(ValueError): + make_type(spec) + + +#%% +test_parsing_positive() \ No newline at end of file diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000000000000000000000000000000000000..296657f7c197198c9f0c873b8233334b0f573986 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,11 @@ +[mypy] +python_version=3.10 + +[mypy-pymbolic.*] +ignore_missing_imports=true + +[mypy-pystencils.*] +ignore_errors=true + +[mypy-pystencils.nbackend.*] +ignore_errors = False diff --git a/pystencils/nbackend/typed_expressions.py b/pystencils/nbackend/typed_expressions.py index 24d52bec8a2ce4bdd58cefbc384551282267309d..22b565f701b57a85eee8fb597af0fcd3d2c60323 100644 --- a/pystencils/nbackend/typed_expressions.py +++ b/pystencils/nbackend/typed_expressions.py @@ -5,16 +5,16 @@ from typing import TypeAlias, Union, Any, Tuple import pymbolic.primitives as pb -from ..typing import AbstractType, BasicType, PointerType +from .types import PsAbstractType, PsScalarType, PsPointerType, constify class PsTypedVariable(pb.Variable): - def __init__(self, name: str, dtype: AbstractType): + def __init__(self, name: str, dtype: PsAbstractType): super(PsTypedVariable, self).__init__(name) self._dtype = dtype @property - def dtype(self) -> AbstractType: + def dtype(self) -> PsAbstractType: return self._dtype @@ -23,7 +23,7 @@ class PsArray: self, name: str, length: pb.Expression, - element_type: BasicType, # todo Frederik: is BasicType correct? + element_type: PsScalarType, # todo Frederik: is PsScalarType correct? ): self._name = name self._length = length @@ -50,7 +50,7 @@ class PsLinearizedArray(PsArray): name: str, shape: Tuple[pb.Expression, ...], strides: Tuple[pb.Expression], - element_type: BasicType, + element_type: PsScalarType, ): length = reduce(lambda x, y: x * y, shape, 1) super().__init__(name, length, element_type) @@ -69,7 +69,7 @@ class PsLinearizedArray(PsArray): class PsArrayBasePointer(PsTypedVariable): def __init__(self, name: str, array: PsArray): - dtype = PointerType(array.element_type) + dtype = PsPointerType(array.element_type) super().__init__(name, dtype) self._array = array @@ -98,7 +98,7 @@ class PsArrayAccess(pb.Subscript): return self._base_ptr.array @property - def dtype(self) -> AbstractType: + def dtype(self) -> PsAbstractType: """Data type of this expression, i.e. the element type of the underlying array""" return self._base_ptr.array.element_type @@ -108,7 +108,7 @@ PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess] class PsTypedConstant: @staticmethod - def _cast(value, target_dtype: AbstractType): + def _cast(value, target_dtype: PsAbstractType): if isinstance(value, PsTypedConstant): if value._dtype != target_dtype: raise ValueError( @@ -119,11 +119,11 @@ class PsTypedConstant: # TODO check legality return PsTypedConstant(value, target_dtype) - def __init__(self, value, dtype: AbstractType): + def __init__(self, value, dtype: PsAbstractType): """Represents typed constants occuring in the pystencils AST""" - if isinstance(dtype, BasicType): - dtype = BasicType(dtype, const=True) - self._value = dtype.numpy_dtype.type(value) + if isinstance(dtype, PsScalarType): + dtype = constify(dtype) + self._value = value # todo: cast to given type else: raise ValueError(f"Cannot create constant of type {dtype}") @@ -133,19 +133,22 @@ class PsTypedConstant: return str(self._value) def __add__(self, other: Any): - other = PsTypedConstant._cast(other, self._dtype) + return NotImplemented # todo + # other = PsTypedConstant._cast(other, self._dtype) - return PsTypedConstant(self._value + other._value, self._dtype) + # return PsTypedConstant(self._value + other._value, self._dtype) def __mul__(self, other: Any): - other = PsTypedConstant._cast(other, self._dtype) + return NotImplemented # todo + # other = PsTypedConstant._cast(other, self._dtype) - return PsTypedConstant(self._value * other._value, self._dtype) + # return PsTypedConstant(self._value * other._value, self._dtype) def __sub__(self, other: Any): - other = PsTypedConstant._cast(other, self._dtype) + return NotImplemented # todo + # other = PsTypedConstant._cast(other, self._dtype) - return PsTypedConstant(self._value - other._value, self._dtype) + # return PsTypedConstant(self._value - other._value, self._dtype) # TODO: Remaining operators diff --git a/pystencils/nbackend/types/__init__.py b/pystencils/nbackend/types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5319afa33abf54bbd74382459bdbe709243723bb --- /dev/null +++ b/pystencils/nbackend/types/__init__.py @@ -0,0 +1,25 @@ +from .basic_types import ( + PsAbstractType, + PsCustomType, + PsScalarType, + PsPointerType, + PsIntegerType, + PsUnsignedIntegerType, + PsSignedIntegerType, + PsIeeeFloatType, + constify, + deconstify +) + +__all__ = [ + "PsAbstractType", + "PsCustomType", + "PsScalarType", + "PsPointerType", + "PsIntegerType", + "PsUnsignedIntegerType", + "PsSignedIntegerType", + "PsIeeeFloatType", + "constify", + "deconstify", +] \ No newline at end of file diff --git a/pystencils/nbackend/types/basic_types.py b/pystencils/nbackend/types/basic_types.py new file mode 100644 index 0000000000000000000000000000000000000000..698418e73e3ee5cca2f4f1b7cb31653f0f8da528 --- /dev/null +++ b/pystencils/nbackend/types/basic_types.py @@ -0,0 +1,252 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import final, TypeVar +from copy import copy + + +class PsAbstractType(ABC): + """Base class for all pystencils types. + + Implementation Notes + ==================== + + **Type Equality:** Subclasses must implement `__eq__`, but may rely on `_base_equal` to implement + type equality checks. + """ + + def __init__(self, const: bool = False): + """ + Args: + name: Name of this type + const: Const-qualification of this type + """ + self._const = const + + @property + def const(self) -> bool: + return self._const + + # ------------------------------------------------------------------------------------------- + # Internal virtual operations + # ------------------------------------------------------------------------------------------- + + def _base_equal(self, other: PsAbstractType) -> bool: + return type(self) is type(other) and self._const == other._const + + def _const_string(self) -> str: + return "const" if self._const else "" + + @abstractmethod + def _c_string(self) -> str: + ... + + # ------------------------------------------------------------------------------------------- + # Dunder Methods + # ------------------------------------------------------------------------------------------- + + @abstractmethod + def __eq__(self, other: object) -> bool: + ... + + def __str__(self) -> str: + return self._c_string() + + def __hash__(self) -> int: + return hash(self._c_string()) + + +class PsCustomType(PsAbstractType): + """Class to model custom types by their names.""" + + __match_args__ = ("name",) + + def __init__(self, name: str, const: bool = False): + super().__init__(const) + self._name = name + + @property + def name(self) -> str: + return self._name + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsCustomType): + return False + return self._base_equal(other) and self._name == other._name + + def _c_string(self) -> str: + return f"{self._const_string()} {self._name}" + + def __repr__(self) -> str: + return f"CustomType( {self.name}, const={self.const} )" + + +@final +class PsPointerType(PsAbstractType): + """Class to model C pointer types.""" + + __match_args__ = ("base_type",) + + def __init__( + self, base_type: PsAbstractType, const: bool = False, restrict: bool = True + ): + super().__init__(const) + self._base_type = base_type + self._restrict = restrict + + @property + def base_type(self) -> PsAbstractType: + return self._base_type + + @property + def restrict(self) -> bool: + return self._restrict + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsPointerType): + return False + return self._base_equal(other) and self._base_type == other._base_type + + def _c_string(self) -> str: + base_str = self._base_type._c_string() + return f"{base_str} * {self._const_string()}" + + def __repr__(self) -> str: + return f"PsPointerType( {repr(self.base_type)}, const={self.const} )" + + +class PsScalarType(PsAbstractType, ABC): + """Class to model scalar types""" + + def is_int(self) -> bool: + return isinstance(self, PsIntegerType) + + def is_sint(self) -> bool: + return isinstance(self, PsIntegerType) and self.signed + + def is_uint(self) -> bool: + return isinstance(self, PsIntegerType) and not self.signed + + def is_float(self) -> bool: + return isinstance(self, PsIeeeFloatType) + + +class PsIntegerType(PsAbstractType, ABC): + """Class to model signed and unsigned integer types. + + `PsIntegerType` cannot be instantiated on its own, but only through `PsSignedIntegerType` + and `PsUnsignedIntegerType`. This distinction is meant mostly to help in pattern matching. + """ + + __match_args__ = ("width",) + + SUPPORTED_WIDTHS = (8, 16, 32, 64) + + def __init__(self, width: int, signed: bool = True, const: bool = False): + if width not in self.SUPPORTED_WIDTHS: + raise ValueError( + f"Invalid integer width; must be one of {self.SUPPORTED_WIDTHS}." + ) + + super().__init__(const) + + self._width = width + self._signed = signed + + @property + def width(self) -> int: + return self._width + + @property + def signed(self) -> bool: + return self._signed + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsIntegerType): + return False + + return ( + self._base_equal(other) + and self._width == other._width + and self._signed == other._signed + ) + + def _c_string(self) -> str: + prefix = "" if self._signed else "u" + return f"{self._const_string()} {prefix}int{self._width}_t" + + def __repr__(self) -> str: + return f"PsIntegerType( width={self.width}, signed={self.signed}, const={self.const} )" + + +@final +class PsSignedIntegerType(PsIntegerType): + """Class to model signed integers.""" + + __match_args__ = ("width",) + + def __init__(self, width: int, const: bool = False): + super().__init__(width, True, const) + + +@final +class PsUnsignedIntegerType(PsIntegerType): + """Class to model unsigned integers.""" + + __match_args__ = ("width",) + + def __init__(self, width: int, const: bool = False): + super().__init__(width, True, const) + + +@final +class PsIeeeFloatType(PsAbstractType): + """Class to model IEEE-754 floating point data types""" + + __match_args__ = ("width",) + + SUPPORTED_WIDTHS = (32, 64) + + def __init__(self, width: int, const: bool = False): + if width not in self.SUPPORTED_WIDTHS: + raise ValueError( + f"Invalid integer width; must be one of {self.SUPPORTED_WIDTHS}." + ) + + super().__init__(const) + self._width = width + + @property + def width(self) -> int: + return self._width + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsIeeeFloatType): + return False + return self._base_equal(other) and self._width == other._width + + def _c_string(self) -> str: + match self._width: + case 32: + return f"{self._const_string()} float" + case 64: + return f"{self._const_string()} double" + case _: + assert False, "unreachable code" + + def __repr__(self) -> str: + return f"PsIeeeFloatType( width={self.width}, const={self.const} )" + + +T = TypeVar("T", bound=PsAbstractType) + +def constify(t: T): + """Adds the const qualifier to a given type.""" + t_copy = copy(t) + t_copy._const = True + return t_copy + +def deconstify(t: T): + """Removes the const qualifier from a given type.""" + t_copy = copy(t) + t_copy._const = False + return t_copy diff --git a/pystencils/nbackend/types/parsing.py b/pystencils/nbackend/types/parsing.py new file mode 100644 index 0000000000000000000000000000000000000000..a6f2ac42f6a1aea46c55a4dc309d295937c2ca93 --- /dev/null +++ b/pystencils/nbackend/types/parsing.py @@ -0,0 +1,103 @@ +import numpy as np + +from .basic_types import ( + PsAbstractType, + PsCustomType, + PsScalarType, + PsPointerType, + PsIntegerType, + PsUnsignedIntegerType, + PsSignedIntegerType, + PsIeeeFloatType, +) + + +def interpret_python_type(t: type) -> PsAbstractType: + if t is int: + return PsSignedIntegerType(64) + if t is float: + return PsIeeeFloatType(64) + + if t is np.uint8: + return PsUnsignedIntegerType(8) + if t is np.uint16: + return PsUnsignedIntegerType(16) + if t is np.uint32: + return PsUnsignedIntegerType(32) + if t is np.uint64: + return PsUnsignedIntegerType(64) + + if t is np.int8: + return PsSignedIntegerType(8) + if t is np.int16: + return PsSignedIntegerType(16) + if t is np.int32: + return PsSignedIntegerType(32) + if t is np.int64: + return PsSignedIntegerType(64) + + if t is np.float32: + return PsIeeeFloatType(32) + if t is np.float64: + return PsIeeeFloatType(64) + + raise ValueError(f"Could not interpret Python data type {t} as a pystencils type.") + + +def parse_type_string(s: str) -> PsAbstractType: + tokens = s.rsplit("*", 1) + match tokens: + case [base]: # input contained no '*', is no pointer + match base.split(): # split at whitespace to find `const` qualifiers (C typenames cannot contain spaces) + case [typename]: + return parse_type_name(typename, False) + case ["const", typename] | [typename, "const"]: + return parse_type_name(typename, True) + case _: + raise ValueError(f"Could not parse token '{base}' as C type.") + + case [base, suffix]: # input was "base * suffix" + base_type = parse_type_string(base) + match suffix.split(): + case []: + return PsPointerType(base_type, const=False, restrict=False) + case ["const"]: + return PsPointerType(base_type, const=True, restrict=False) + case ["restrict"]: + return PsPointerType(base_type, const=False, restrict=True) + case ["const", "restrict"] | ["restrict", "const"]: + return PsPointerType(base_type, const=True, restrict=True) + case _: + raise ValueError(f"Could not parse token '{s}' as C type.") + + case _: + raise ValueError(f"Could not parse token '{s}`' as C type.") + + +def parse_type_name(typename: str, const: bool): + match typename: + case "int64" | "int64_t": + return PsSignedIntegerType(64, const=const) + case "int32" | "int32_t": + return PsSignedIntegerType(32, const=const) + case "int16" | "int16_t": + return PsSignedIntegerType(16, const=const) + case "int8" | "int8_t": + return PsSignedIntegerType(8, const=const) + + case "uint64" | "uint64_t": + return PsUnsignedIntegerType(64, const=const) + case "uint32" | "uint32_t": + return PsUnsignedIntegerType(32, const=const) + case "uint16" | "uint16_t": + return PsUnsignedIntegerType(16, const=const) + case "uint8" | "uint8_t": + return PsUnsignedIntegerType(8, const=const) + + case "float" | "float32": + return PsIeeeFloatType(32, const=const) + case "double" | "float64": + return PsIeeeFloatType(64, const=const) + + case _: + raise ValueError(f"Could not parse token '{typename}' as C type.") diff --git a/pystencils/nbackend/types/quick.py b/pystencils/nbackend/types/quick.py new file mode 100644 index 0000000000000000000000000000000000000000..8c6b4d006a30178dca5e79acfea92da88df02e6c --- /dev/null +++ b/pystencils/nbackend/types/quick.py @@ -0,0 +1,75 @@ +"""Abbreviations and creation functions for pystencils type-modelling classes +for quick, user-friendly construction and compact pattern matching. + +This module is meant to be included whole, e.g. as `from pystencils.nbackend.types.quick import *` +""" + +from __future__ import annotations + +from .basic_types import ( + PsAbstractType, + PsCustomType, + PsScalarType, + PsPointerType, + PsIntegerType, + PsUnsignedIntegerType, + PsSignedIntegerType, + PsIeeeFloatType, +) + +UserTypeSpec = str | type | PsAbstractType + + +def make_type(type_spec: UserTypeSpec) -> PsAbstractType: + """Create a pystencils type object from a variety of specifications. + + Possible arguments are: + - Strings ('str'): will be parsed as common C types, throwing an exception if that fails. + To construct a `PsCustomType` instead, use the constructor of `PsCustomType` or its abbreviation + `types.quick.Custom` instead + - Python builtin data types (instances of `type`): Attempts to interpret Python numeric types like so: + - `int` becomes a signed 64-bit integer + - `float` becomes a double-precision IEEE-754 float + - No others are supported at the moment + - Supported Numpy scalar data types (see https://numpy.org/doc/stable/reference/arrays.scalars.html) are converted to pystencils + scalar data types + - Instances of `PsAbstractType` will be returned as they are + """ + + from .parsing import ( + parse_type_string, + interpret_python_type, + ) + + if isinstance(type_spec, PsAbstractType): + return type_spec + if isinstance(type_spec, str): + return parse_type_string(type_spec) + if isinstance(type_spec, type): + return interpret_python_type(type_spec) + raise ValueError(f"{type_spec} is not a valid type specification.") + + +Custom = PsCustomType +"""`Custom(name)` matches `PsCustomType(name)`""" + +Scalar = PsScalarType +"""`Scalar()` matches any subclass of `PsScalarType`""" + +Ptr = PsPointerType +"""`Ptr(t)` matches `PsPointerType(base_type=t)`""" + +AnyInt = PsIntegerType +"""`AnyInt(width)` matches both `PsUnsignedIntegerType(width)` and `PsSignedIntegerType(width)`""" + +UInt = PsUnsignedIntegerType +"""`UInt(width)` matches `PsUnsignedIntegerType(width)`""" + +Int = PsSignedIntegerType +"""`Int(width)` matches `PsSignedIntegerType(width)`""" + +SInt = PsSignedIntegerType +"""`SInt(width)` matches `PsSignedIntegerType(width)`""" + +Fp = PsIeeeFloatType +"""`Fp(width)` matches `PsIeeeFloatType(width)`""" diff --git a/pystencils_tests/nbackend/types/test_quick_types.py b/pystencils_tests/nbackend/types/test_quick_types.py new file mode 100644 index 0000000000000000000000000000000000000000..f45bf565d3ad9ff3e13bcefdf32585edf903c68b --- /dev/null +++ b/pystencils_tests/nbackend/types/test_quick_types.py @@ -0,0 +1,30 @@ +import pytest +from pystencils.nbackend.types.quick import * + + +def test_parsing_positive(): + assert make_type("const uint32_t * restrict") == Ptr(UInt(32, const=True), restrict=True) + assert make_type("float * * const") == Ptr(Ptr(Fp(32)), const=True) + assert make_type("uint16 * const") == Ptr(UInt(16), const=True) + assert make_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True) + +def test_parsing_negative(): + bad_specs = [ + "const notatype * const", + "cnost uint32_t", + "uint45_t", + "int", # plain ints are ambiguous + "float float", + "double * int", + "bool" + ] + + for spec in bad_specs: + with pytest.raises(ValueError): + make_type(spec) + +def test_numpy(): + import numpy as np + assert make_type(np.single) == make_type(np.float32) == PsIeeeFloatType(32) + assert make_type(float) == make_type(np.double) == make_type(np.float64) == PsIeeeFloatType(64) + assert make_type(int) == make_type(np.int64) == PsSignedIntegerType(64)