From 62e6ddf8b6bd769c2c6088a10b80e55bf4d174d7 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 11 Jan 2024 16:56:47 +0100 Subject: [PATCH] PsTypedConstant: Creation and Arithmetic, Part One --- pystencils/nbackend/ast/nodes.py | 41 +++-- pystencils/nbackend/typed_expressions.py | 165 ++++++++++++++---- pystencils/nbackend/types/__init__.py | 11 +- pystencils/nbackend/types/basic_types.py | 136 ++++++++++++++- pystencils/nbackend/types/exception.py | 4 + pystencils/nbackend/types/parsing.py | 3 - pystencils/nbackend/types/quick.py | 4 +- .../nbackend/test_constant_folding.py | 45 +++++ .../nbackend/types/test_constants.py | 61 +++++++ 9 files changed, 405 insertions(+), 65 deletions(-) create mode 100644 pystencils/nbackend/types/exception.py create mode 100644 pystencils_tests/nbackend/test_constant_folding.py create mode 100644 pystencils_tests/nbackend/types/test_constants.py diff --git a/pystencils/nbackend/ast/nodes.py b/pystencils/nbackend/ast/nodes.py index adb3e2a01..40200ff38 100644 --- a/pystencils/nbackend/ast/nodes.py +++ b/pystencils/nbackend/ast/nodes.py @@ -19,7 +19,7 @@ def failing_cast(target: type, obj: T): class PsAstNode(ABC): """Base class for all nodes in the pystencils AST. - + This base class provides a common interface to inspect and update the AST's branching structure. The four methods `num_children`, `children`, `get_child` and `set_child` must be implemented by each subclass. @@ -150,7 +150,7 @@ class PsAssignment(PsAstNode): return (self._lhs, self._rhs)[idx] def set_child(self, idx: int, c: PsAstNode): - idx = [0, 1][idx] # trick to normalize index + idx = [0, 1][idx] # trick to normalize index if idx == 0: self._lhs = failing_cast(PsLvalueExpr, c) elif idx == 1: @@ -180,7 +180,7 @@ class PsDeclaration(PsAssignment): self._lhs = lvalue def set_child(self, idx: int, c: PsAstNode): - idx = [0, 1][idx] # trick to normalize index + idx = [0, 1][idx] # trick to normalize index if idx == 0: self._lhs = failing_cast(PsSymbolExpr, c) elif idx == 1: @@ -190,12 +190,14 @@ class PsDeclaration(PsAssignment): class PsLoop(PsAstNode): - def __init__(self, - ctr: PsSymbolExpr, - start: PsExpression, - stop: PsExpression, - step: PsExpression, - body: PsBlock): + def __init__( + self, + ctr: PsSymbolExpr, + start: PsExpression, + stop: PsExpression, + step: PsExpression, + body: PsBlock, + ): self._ctr = ctr self._start = start self._stop = stop @@ -205,7 +207,7 @@ class PsLoop(PsAstNode): @property def counter(self) -> PsSymbolExpr: return self._ctr - + @counter.setter def counter(self, expr: PsSymbolExpr): self._ctr = expr @@ -251,13 +253,18 @@ class PsLoop(PsAstNode): def get_child(self, idx: int): return (self._ctr, self._start, self._stop, self._step, self._body)[idx] - def set_child(self, idx: int, c: PsAstNode): idx = list(range(5))[idx] match idx: - case 0: self._ctr = failing_cast(PsSymbolExpr, c) - case 1: self._start = failing_cast(PsExpression, c) - case 2: self._stop = failing_cast(PsExpression, c) - case 3: self._step = failing_cast(PsExpression, c) - case 4: self._body = failing_cast(PsBlock, c) - case _: assert False, "unreachable code" + case 0: + self._ctr = failing_cast(PsSymbolExpr, c) + case 1: + self._start = failing_cast(PsExpression, c) + case 2: + self._stop = failing_cast(PsExpression, c) + case 3: + self._step = failing_cast(PsExpression, c) + case 4: + self._body = failing_cast(PsBlock, c) + case _: + assert False, "unreachable code" diff --git a/pystencils/nbackend/typed_expressions.py b/pystencils/nbackend/typed_expressions.py index 5d0c0fc89..c3e4488e8 100644 --- a/pystencils/nbackend/typed_expressions.py +++ b/pystencils/nbackend/typed_expressions.py @@ -1,11 +1,19 @@ from __future__ import annotations from functools import reduce -from typing import TypeAlias, Union, Any, Tuple +from typing import TypeAlias, Union, Any, Tuple, Callable +import operator import pymbolic.primitives as pb -from .types import PsAbstractType, PsScalarType, PsPointerType, constify +from .types import ( + PsAbstractType, + PsScalarType, + PsNumericType, + PsPointerType, + constify, + PsTypeError, +) class PsTypedVariable(pb.Variable): @@ -96,7 +104,7 @@ class PsArrayAccess(pb.Subscript): @property def array(self) -> PsArray: return self._base_ptr.array - + @property def dtype(self) -> PsAbstractType: """Data type of this expression, i.e. the element type of the underlying array""" @@ -107,50 +115,145 @@ PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess] class PsTypedConstant: + """Represents typed constants occuring in the pystencils AST. + + A Note On Divisions + ------------------- + + The semantics of divisions in C and Python differ greatly. + Python has two division operators: `/` (`truediv`) and `//` (`floordiv`). + `truediv` is pure floating-point division, and so applied to floating-point numbers maps exactly to + floating-point division in C, but not when applied to integers. + `floordiv` has no C equivalent: + While `floordiv` performs euclidean division and always rounds its result + downward (`3 // 2 == 1`, and `-3 // 2 = -2`), + the C `/` operator on integers always rounds toward zero (in C, `-3 / 2 = -1`.) + + The same applies to the `%` operator: + In Python, `%` computes the euclidean modulus (e.g. `-3 % 2 = 1`), + while in C, `%` computes the remainder (e.g. `-3 % 2 = -1`). + These two differ whenever negative numbers are involved. + + Pymbolic provides `Quotient` to model Python's `/`, `FloorDiv` to model `//`, and `Remainder` to model `%`. + The last one is a misnomer: it should instead be called `Modulus`. + + Since the pystencils backend has to accurately capture the behaviour of C, + the behaviour of `/` is overridden in `PsTypedConstant`. + In a floating-point context, it behaves as usual, while in an integer context, + it implements C-style integer division. + Similarily, `%` is only legal in integer contexts, where it implements the C-style remainder. + Usage of `//` and the pymbolic `FloorDiv` is illegal. + """ + @staticmethod - def _cast(value, target_dtype: PsAbstractType): - if isinstance(value, PsTypedConstant): - if value._dtype != target_dtype: - raise ValueError( - f"Incompatible types: {value._dtype} and {target_dtype}" - ) - return value - - # TODO check legality - return PsTypedConstant(value, target_dtype) - - def __init__(self, value, dtype: PsAbstractType): - """Represents typed constants occuring in the pystencils AST""" - if isinstance(dtype, PsScalarType): - dtype = constify(dtype) - self._value = value # todo: cast to given type - else: + def try_create(value: Any, dtype: PsNumericType): + try: + return PsTypedConstant(value, dtype) + except PsTypeError: + return None + + def __init__(self, value: Any, dtype: PsNumericType): + """Create a new `PsTypedConstant`. + + The constructor of `PsTypedConstant` will first convert the given `dtype` to its const version. + The given `value` will then be interpreted as that data type. The constructor will fail with an + exception if that is not possible. + """ + if not isinstance(dtype, PsNumericType): raise ValueError(f"Cannot create constant of type {dtype}") - self._dtype = dtype + self._dtype = constify(dtype) + self._value = self._dtype.create_constant(value) def __str__(self) -> str: return str(self._value) + def __repr__(self) -> str: + return f"PsTypedConstant( {self._value}, {repr(self._dtype)} )" + + @staticmethod + def _fix(v: Any, dtype: PsNumericType) -> PsTypedConstant: + if not isinstance(v, PsTypedConstant): + return PsTypedConstant(v, dtype) + else: + return v + + @staticmethod + def _bin_op( + lhs: PsTypedConstant, rhs: PsTypedConstant, op: Callable[[Any, Any], Any] + ) -> PsTypedConstant: + """Backend for binary operators. Never call directly!""" + + if lhs._dtype != rhs._dtype: + raise PsTypeError( + f"Incompatible operand types in constant folding: {lhs._dtype} and {rhs._dtype}" + ) + + try: + return PsTypedConstant(op(lhs._value, rhs._value), lhs._dtype) + except PsTypeError: + raise PsTypeError( + f"Invalid operation in constant folding: {op.__name__}( {repr(lhs)}, {repr(rhs)} )" + ) + def __add__(self, other: Any): - return NotImplemented # todo - # other = PsTypedConstant._cast(other, self._dtype) + return PsTypedConstant._bin_op( + self, PsTypedConstant._fix(other, self._dtype), operator.add + ) - # return PsTypedConstant(self._value + other._value, self._dtype) + def __radd__(self, other: Any): + return PsTypedConstant._bin_op( + PsTypedConstant._fix(other, self._dtype), self, operator.add + ) def __mul__(self, other: Any): - return NotImplemented # todo - # other = PsTypedConstant._cast(other, self._dtype) + return PsTypedConstant._bin_op( + self, PsTypedConstant._fix(other, self._dtype), operator.mul + ) - # return PsTypedConstant(self._value * other._value, self._dtype) + def __rmul__(self, other: Any): + return PsTypedConstant._bin_op( + PsTypedConstant._fix(other, self._dtype), self, operator.mul + ) def __sub__(self, other: Any): - return NotImplemented # todo - # other = PsTypedConstant._cast(other, self._dtype) + return PsTypedConstant._bin_op( + self, PsTypedConstant._fix(other, self._dtype), operator.sub + ) + + def __rsub__(self, other: Any): + return PsTypedConstant._bin_op( + PsTypedConstant._fix(other, self._dtype), self, operator.sub + ) + + def __truediv__(self, other: Any): + other2 = PsTypedConstant._fix(other, self._dtype) + if self._dtype.is_float(): + return PsTypedConstant._bin_op(self, other2, operator.truediv) + else: + return NotImplemented # todo: C integer division + + def __rtruediv__(self, other: Any): + other2 = PsTypedConstant._fix(other, self._dtype) + if self._dtype.is_float(): + return PsTypedConstant._bin_op(other2, self, operator.truediv) + else: + return NotImplemented # todo: C integer division + + def __mod__(self, other: Any): + return NotImplemented # todo: C integer division + + def __neg__(self): + return PsTypedConstant(-self._value, self._dtype) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsTypedConstant): + return False - # return PsTypedConstant(self._value - other._value, self._dtype) + return self._dtype == other._dtype and self._value == other._value - # TODO: Remaining operators + def __hash__(self) -> int: + return hash((self._value, self._dtype)) pb.VALID_CONSTANT_CLASSES += (PsTypedConstant,) diff --git a/pystencils/nbackend/types/__init__.py b/pystencils/nbackend/types/__init__.py index 5319afa33..1f15c4516 100644 --- a/pystencils/nbackend/types/__init__.py +++ b/pystencils/nbackend/types/__init__.py @@ -1,6 +1,7 @@ from .basic_types import ( PsAbstractType, PsCustomType, + PsNumericType, PsScalarType, PsPointerType, PsIntegerType, @@ -8,18 +9,22 @@ from .basic_types import ( PsSignedIntegerType, PsIeeeFloatType, constify, - deconstify + deconstify, ) +from .exception import PsTypeError + __all__ = [ "PsAbstractType", "PsCustomType", - "PsScalarType", "PsPointerType", + "PsNumericType", + "PsScalarType", "PsIntegerType", "PsUnsignedIntegerType", "PsSignedIntegerType", "PsIeeeFloatType", "constify", "deconstify", -] \ No newline at end of file + "PsTypeError", +] diff --git a/pystencils/nbackend/types/basic_types.py b/pystencils/nbackend/types/basic_types.py index dcdb300ce..0a56f2f57 100644 --- a/pystencils/nbackend/types/basic_types.py +++ b/pystencils/nbackend/types/basic_types.py @@ -1,8 +1,12 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import final, TypeVar +from typing import final, TypeVar, Any from copy import copy +import numpy as np + +from .exception import PsTypeError + class PsAbstractType(ABC): """Base class for all pystencils types. @@ -51,8 +55,9 @@ class PsAbstractType(ABC): def __str__(self) -> str: return self._c_string() + @abstractmethod def __hash__(self) -> int: - return hash(self._c_string()) + ... class PsCustomType(PsAbstractType): @@ -73,6 +78,9 @@ class PsCustomType(PsAbstractType): return False return self._base_equal(other) and self._name == other._name + def __hash__(self) -> int: + return hash(("PsCustomType", self._name, self._const)) + def _c_string(self) -> str: return f"{self._const_string()} {self._name}" @@ -106,6 +114,9 @@ class PsPointerType(PsAbstractType): return False return self._base_equal(other) and self._base_type == other._base_type + def __hash__(self) -> int: + return hash(("PsPointerType", self._base_type, self._restrict, self._const)) + def _c_string(self) -> str: base_str = self._base_type._c_string() return f"{base_str} * {self._const_string()}" @@ -114,8 +125,51 @@ class PsPointerType(PsAbstractType): return f"PsPointerType( {repr(self.base_type)}, const={self.const} )" -class PsScalarType(PsAbstractType, ABC): - """Class to model scalar types""" +class PsNumericType(PsAbstractType, ABC): + """Class to model numeric types, which are all types that may occur at the top-level inside + arithmetic-logical expressions. + + Constants + --------- + + Every numeric type has to act as a factory for compile-time constants of that type. + The `PsTypedConstant` class relies on `create_constant` to instantiate constants + of a given numeric type. The object returned by `create_constant` must implement the + necessary arithmetic operations, and its arithmetic behaviour must match the given type. + + `create_constant` should fail whenever its input cannot safely be interpreted as the given + type. As for which interpretations are considered 'safe', it should be as restrictive as possible. + However, `create_constant` must *never* fail for the literals `0`, `1` and `-1`. + """ + + @abstractmethod + def create_constant(self, value: Any) -> Any: + """ + Create the internal representation of a constant with this type. + + Raises: + PsTypeError: If the given value cannot be interpreted in this type. + """ + + @abstractmethod + def is_int(self) -> bool: + ... + + @abstractmethod + def is_sint(self) -> bool: + ... + + @abstractmethod + def is_uint(self) -> bool: + ... + + @abstractmethod + def is_float(self) -> bool: + ... + + +class PsScalarType(PsNumericType, ABC): + """Class to model scalar numeric types.""" def is_int(self) -> bool: return isinstance(self, PsIntegerType) @@ -130,7 +184,7 @@ class PsScalarType(PsAbstractType, ABC): return isinstance(self, PsIeeeFloatType) -class PsIntegerType(PsAbstractType, ABC): +class PsIntegerType(PsScalarType, ABC): """Class to model signed and unsigned integer types. `PsIntegerType` cannot be instantiated on its own, but only through `PsSignedIntegerType` @@ -170,6 +224,9 @@ class PsIntegerType(PsAbstractType, ABC): and self._signed == other._signed ) + def __hash__(self) -> int: + return hash(("PsIntegerType", self._width, self._signed, self._const)) + def _c_string(self) -> str: prefix = "" if self._signed else "u" return f"{self._const_string()} {prefix}int{self._width}_t" @@ -184,9 +241,27 @@ class PsSignedIntegerType(PsIntegerType): __match_args__ = ("width",) + NUMPY_TYPES = { + 8: np.int8, + 16: np.int16, + 32: np.int32, + 64: np.int64, + } + def __init__(self, width: int, const: bool = False): super().__init__(width, True, const) + def create_constant(self, value: Any) -> Any: + np_type = self.NUMPY_TYPES[self._width] + + if isinstance(value, int): + return np_type(value) + + if isinstance(value, np_type): + return value + + raise PsTypeError(f"Could not interpret {value} as {repr(self)}") + @final class PsUnsignedIntegerType(PsIntegerType): @@ -194,18 +269,42 @@ class PsUnsignedIntegerType(PsIntegerType): __match_args__ = ("width",) + NUMPY_TYPES = { + 8: np.uint8, + 16: np.uint16, + 32: np.uint32, + 64: np.uint64, + } + def __init__(self, width: int, const: bool = False): - super().__init__(width, True, const) + super().__init__(width, False, const) + + def create_constant(self, value: Any) -> Any: + np_type = self.NUMPY_TYPES[self._width] + + if isinstance(value, int) and value >= 0: + return np_type(value) + + if isinstance(value, np_type): + return value + + raise PsTypeError(f"Could not interpret {value} as {repr(self)}") @final -class PsIeeeFloatType(PsAbstractType): +class PsIeeeFloatType(PsScalarType): """Class to model IEEE-754 floating point data types""" __match_args__ = ("width",) SUPPORTED_WIDTHS = (16, 32, 64) + NUMPY_TYPES = { + 16: np.float16, + 32: np.float32, + 64: np.float64, + } + def __init__(self, width: int, const: bool = False): if width not in self.SUPPORTED_WIDTHS: raise ValueError( @@ -219,11 +318,28 @@ class PsIeeeFloatType(PsAbstractType): def width(self) -> int: return self._width + def create_constant(self, value: Any) -> Any: + np_type = self.NUMPY_TYPES[self._width] + + if isinstance(value, int) and value in (0, 1, -1): + return np_type(value) + + if isinstance(value, float): + return np_type(value) + + if isinstance(value, np_type): + return value + + raise PsTypeError(f"Could not interpret {value} as {repr(self)}") + def __eq__(self, other: object) -> bool: if not isinstance(other, PsIeeeFloatType): return False return self._base_equal(other) and self._width == other._width + def __hash__(self) -> int: + return hash(("PsIeeeFloatType", self._width, self._const)) + def _c_string(self) -> str: match self._width: case 32: @@ -239,13 +355,15 @@ class PsIeeeFloatType(PsAbstractType): T = TypeVar("T", bound=PsAbstractType) -def constify(t: T): + +def constify(t: T) -> T: """Adds the const qualifier to a given type.""" t_copy = copy(t) t_copy._const = True return t_copy -def deconstify(t: T): + +def deconstify(t: T) -> T: """Removes the const qualifier from a given type.""" t_copy = copy(t) t_copy._const = False diff --git a/pystencils/nbackend/types/exception.py b/pystencils/nbackend/types/exception.py new file mode 100644 index 000000000..7c0cb97af --- /dev/null +++ b/pystencils/nbackend/types/exception.py @@ -0,0 +1,4 @@ + + +class PsTypeError(Exception): + """Indicates a type error in the pystencils AST.""" diff --git a/pystencils/nbackend/types/parsing.py b/pystencils/nbackend/types/parsing.py index a6f2ac42f..8a5e687aa 100644 --- a/pystencils/nbackend/types/parsing.py +++ b/pystencils/nbackend/types/parsing.py @@ -2,10 +2,7 @@ import numpy as np from .basic_types import ( PsAbstractType, - PsCustomType, - PsScalarType, PsPointerType, - PsIntegerType, PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType, diff --git a/pystencils/nbackend/types/quick.py b/pystencils/nbackend/types/quick.py index 8c6b4d006..b1da0c5e2 100644 --- a/pystencils/nbackend/types/quick.py +++ b/pystencils/nbackend/types/quick.py @@ -31,8 +31,8 @@ def make_type(type_spec: UserTypeSpec) -> PsAbstractType: - `int` becomes a signed 64-bit integer - `float` becomes a double-precision IEEE-754 float - No others are supported at the moment - - Supported Numpy scalar data types (see https://numpy.org/doc/stable/reference/arrays.scalars.html) are converted to pystencils - scalar data types + - Supported Numpy scalar data types (see https://numpy.org/doc/stable/reference/arrays.scalars.html) + are converted to pystencils scalar data types - Instances of `PsAbstractType` will be returned as they are """ diff --git a/pystencils_tests/nbackend/test_constant_folding.py b/pystencils_tests/nbackend/test_constant_folding.py new file mode 100644 index 000000000..12149c3ce --- /dev/null +++ b/pystencils_tests/nbackend/test_constant_folding.py @@ -0,0 +1,45 @@ +import pytest + +import pymbolic.primitives as pb +from pymbolic.mapper.constant_folder import ConstantFoldingMapper + +from pystencils.nbackend.types.quick import * +from pystencils.nbackend.typed_expressions import PsTypedConstant + + +@pytest.mark.parametrize("width", (8, 16, 32, 64)) +def test_constant_folding_int(width): + folder = ConstantFoldingMapper() + + expr = pb.Sum( + ( + PsTypedConstant(13, UInt(width)), + PsTypedConstant(5, UInt(width)), + PsTypedConstant(3, UInt(width)), + ) + ) + + assert folder(expr) == PsTypedConstant(21, UInt(width)) + + expr = pb.Product( + (PsTypedConstant(-1, SInt(width)), PsTypedConstant(41, SInt(width))) + ) - PsTypedConstant(12, SInt(width)) + + assert folder(expr) == PsTypedConstant(-53, SInt(width)) + + +@pytest.mark.parametrize("width", (32, 64)) +def test_constant_folding_float(width): + folder = ConstantFoldingMapper() + + expr = pb.Quotient( + PsTypedConstant(14.0, Fp(width)), + pb.Sum( + ( + PsTypedConstant(2.5, Fp(width)), + PsTypedConstant(4.5, Fp(width)), + ) + ), + ) + + assert folder(expr) == PsTypedConstant(7.0, Fp(width)) diff --git a/pystencils_tests/nbackend/types/test_constants.py b/pystencils_tests/nbackend/types/test_constants.py new file mode 100644 index 000000000..f166a1570 --- /dev/null +++ b/pystencils_tests/nbackend/types/test_constants.py @@ -0,0 +1,61 @@ +import pytest + +from pystencils.nbackend.types.quick import * +from pystencils.nbackend.types import PsTypeError +from pystencils.nbackend.typed_expressions import PsTypedConstant + + +@pytest.mark.parametrize("width", (8, 16, 32, 64)) +def test_integer_constants(width): + dtype = SInt(width) + a = PsTypedConstant(42, dtype) + b = PsTypedConstant(2, dtype) + + assert a + b == PsTypedConstant(44, dtype) + assert a - b == PsTypedConstant(40, dtype) + assert a * b == PsTypedConstant(84, dtype) + + assert a - b != PsTypedConstant(-12, dtype) + + # Typed constants only compare to themselves + assert a + b != 44 + + +@pytest.mark.parametrize("width", (32, 64)) +def test_float_constants(width): + a = PsTypedConstant(32.0, Fp(width)) + b = PsTypedConstant(0.5, Fp(width)) + c = PsTypedConstant(2.0, Fp(width)) + + assert a + b == PsTypedConstant(32.5, Fp(width)) + assert a * b == PsTypedConstant(16.0, Fp(width)) + assert a - b == PsTypedConstant(31.5, Fp(width)) + assert a / c == PsTypedConstant(16.0, Fp(width)) + + +def test_illegal_ops(): + # Cannot interpret negative numbers as unsigned types + with pytest.raises(PsTypeError): + _ = PsTypedConstant(-3, UInt(32)) + + # Mixed ops are illegal + with pytest.raises(PsTypeError): + _ = PsTypedConstant(32.0, Fp(32)) + PsTypedConstant(2, UInt(32)) + + with pytest.raises(PsTypeError): + _ = PsTypedConstant(32.0, Fp(32)) - PsTypedConstant(2, UInt(32)) + + with pytest.raises(PsTypeError): + _ = PsTypedConstant(32.0, Fp(32)) * PsTypedConstant(2, UInt(32)) + + with pytest.raises(PsTypeError): + _ = PsTypedConstant(32.0, Fp(32)) / PsTypedConstant(2, UInt(32)) + + +@pytest.mark.parametrize("width", (8, 16, 32, 64)) +def test_integer_division(width): + a = PsTypedConstant(-5, SInt(width)) + b = PsTypedConstant(2, SInt(width)) + + assert a / b == PsTypedConstant(-2, SInt(width)) + assert a % b == PsTypedConstant(-1, SInt(width)) -- GitLab