Skip to content
Snippets Groups Projects
Commit 4d13952d authored by Frederik Hennig's avatar Frederik Hennig
Browse files

some cleanup. Notes on translation context.

parent 09d6b6e7
Branches
Tags
No related merge requests found
Pipeline #60311 failed with stages
in 3 minutes and 30 seconds
......@@ -56,7 +56,7 @@ from .types import (
constify,
)
from .typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant
from .typed_expressions import PsTypedVariable, ExprOrConstant
class PsLinearizedArray:
......@@ -67,27 +67,19 @@ class PsLinearizedArray:
name: str,
element_type: PsScalarType,
dim: int,
offsets: tuple[int, ...] | None = None,
index_dtype: PsIntegerType = PsSignedIntegerType(64),
):
self._name = name
if offsets is not None and len(offsets) != dim:
raise ValueError(f"Must have exactly {dim} offsets.")
self._shape = tuple(
PsArrayShapeVar(self, d, constify(index_dtype)) for d in range(dim)
)
self._strides = tuple(
PsArrayStrideVar(self, d, constify(index_dtype)) for d in range(dim)
)
self._element_type = element_type
if offsets is None:
offsets = (0,) * dim
self._element_type = element_type
self._dim = dim
self._offsets = tuple(PsTypedConstant(o, index_dtype) for o in offsets)
self._index_dtype = index_dtype
@property
......@@ -110,10 +102,6 @@ class PsLinearizedArray:
def element_type(self):
return self._element_type
@property
def offsets(self) -> tuple[PsTypedConstant, ...]:
return self._offsets
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsLinearizedArray):
return False
......@@ -122,13 +110,11 @@ class PsLinearizedArray:
self._name,
self._element_type,
self._dim,
self._offsets,
self._index_dtype,
) == (
other._name,
other._element_type,
other._dim,
other._offsets,
other._index_dtype,
)
......@@ -138,7 +124,6 @@ class PsLinearizedArray:
self._name,
self._element_type,
self._dim,
self._offsets,
self._index_dtype,
)
)
......
class PsTranslationContext:
"""The `PsTranslationContext` manages the translation process from the SymPy frontend
to the backend AST.
It does the following things:
- Default data types: The context knows the data types that should be applied by default
to SymPy expressions.
- Management of fields. The context manages all mappings from front-end `Field`s to their
underlying `PsLinearizedArray`s.
- Collection of constraints. All constraints that arise during translation are collected in the
context, and finally attached to the kernel function object once translation is complete.
Data Types
----------
- The `index_dtype` is the data type used throughout translation for all loop counters and array indexing.
- The `default_numeric_dtype` is the data type assigned by default to all symbols occuring in SymPy assignments
Fields and Arrays
-----------------
There's several types of fields that need to be mapped to arrays.
- `FieldType.GENERIC` corresponds to domain fields.
Domain fields can only be accessed by relative offsets, and therefore must always
be associated with an *iteration domain* that provides a spatial index tuple.
All domain fields associated with the same domain must have the same spatial shape, modulo ghost layers.
A field and its array may be associated with multiple iteration domains.
- `FieldType.INDEXED` are 1D arrays of index structures. They must be accessed by a single running index.
If there is at least one indexed field present there must also exist an index source for that field
(loop or device indexing).
An indexed field may itself be an index source for domain fields.
- `FieldType.BUFFER` are 1D arrays whose indices must be incremented with each access.
Within a domain, a buffer may be either written to or read from, never both.
"""
......@@ -27,6 +27,7 @@ def test_constant_folding_int(width):
assert folder(expr) == PsTypedConstant(-53, SInt(width))
@pytest.mark.xfail(reason="Current constant folder does not handle products")
@pytest.mark.parametrize("width", (8, 16, 32, 64))
def test_constant_folding_product(width):
"""
......@@ -46,8 +47,11 @@ def test_constant_folding_product(width):
assert folder(expr) == PsTypedConstant(-24, SInt(width))
@pytest.mark.xfail(reason="Current constant folder does not handle divisions")
@pytest.mark.parametrize("width", (32, 64))
def test_constant_folding_float(width):
"""The pymbolic constant folder does not fold quotients. This test serves as a reminder
to consider that behaviour"""
folder = ConstantFoldingMapper()
expr = pb.Quotient(
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment