Skip to content
Snippets Groups Projects
Commit bcd147bb authored by Frederik Hennig's avatar Frederik Hennig
Browse files

first compilation and execution test

parent 51092f67
Branches
Tags
No related merge requests found
......@@ -56,7 +56,7 @@ from .types import (
constify,
)
from .typed_expressions import PsTypedVariable, PsTypedConstant
from .typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant
class PsLinearizedArray:
......@@ -86,7 +86,9 @@ class PsLinearizedArray:
if offsets is None:
offsets = (0,) * dim
self._dim = dim
self._offsets = tuple(PsTypedConstant(o, index_dtype) for o in offsets)
self._index_dtype = index_dtype
@property
def name(self):
......@@ -100,6 +102,10 @@ class PsLinearizedArray:
def strides(self):
return self._strides
@property
def dim(self):
return self._dim
@property
def element_type(self):
return self._element_type
......@@ -108,6 +114,35 @@ class PsLinearizedArray:
def offsets(self) -> tuple[PsTypedConstant, ...]:
return self._offsets
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsLinearizedArray):
return False
return (
self._name,
self._element_type,
self._dim,
self._offsets,
self._index_dtype,
) == (
other._name,
other._element_type,
other._dim,
other._offsets,
other._index_dtype,
)
def __hash__(self) -> int:
return hash(
(
self._name,
self._element_type,
self._dim,
self._offsets,
self._index_dtype,
)
)
class PsArrayAssocVar(PsTypedVariable, ABC):
"""A variable that is associated to an array.
......@@ -180,7 +215,7 @@ class PsArrayStrideVar(PsArrayAssocVar):
class PsArrayAccess(pb.Subscript):
def __init__(self, base_ptr: PsArrayBasePointer, index: pb.Expression):
def __init__(self, base_ptr: PsArrayBasePointer, index: ExprOrConstant):
super(PsArrayAccess, self).__init__(base_ptr, index)
self._base_ptr = base_ptr
self._index = index
......
......@@ -16,7 +16,7 @@ from ...enums import Target
@dataclass
class PsKernelParametersSpec:
"""Specification of a kernel function's parameters.
Contains:
- Verbatim parameter list, a list of `PsTypedVariables`
- List of Arrays used in the kernel, in canonical order
......@@ -31,9 +31,9 @@ class PsKernelParametersSpec:
def params_for_array(self, arr: PsLinearizedArray):
def pred(p: PsTypedVariable):
return isinstance(p, PsArrayAssocVar) and p.array == arr
return tuple(filter(pred, self.params))
def __post_init__(self):
dep_mapper = DependencyMapper(False, False, False, False)
......@@ -59,7 +59,7 @@ class PsKernelParametersSpec:
class PsKernelFunction(PsAstNode):
"""A pystencils kernel function.
Objects of this class represent a full pystencils kernel and should provide all information required for
export, compilation, and inclusion of the kernel into a runtime system.
"""
......@@ -98,13 +98,12 @@ class PsKernelFunction(PsAstNode):
def function_name(self) -> str:
"""For backward compatibility."""
return self._name
@property
def instruction_set(self) -> str | None:
"""For backward compatibility"""
return None
def num_children(self) -> int:
return 1
......@@ -136,8 +135,10 @@ class PsKernelFunction(PsAstNode):
params_list = sorted(params_set, key=lambda p: p.name)
arrays = set(p.array for p in params_list if isinstance(p, PsArrayBasePointer))
return PsKernelParametersSpec(tuple(params_list), tuple(arrays), tuple(self._constraints))
return PsKernelParametersSpec(
tuple(params_list), tuple(arrays), tuple(self._constraints)
)
def get_required_headers(self) -> set[str]:
# TODO: Headers from types, vectorizer, ...
return set()
......@@ -4,10 +4,11 @@ from typing import Any
from os import path
import hashlib
from itertools import chain
from textwrap import indent
import numpy as np
from ..exceptions import PsInternalCompilerError
from ..ast import PsKernelFunction
from ..ast.constraints import PsParamConstraint
......@@ -19,7 +20,13 @@ from ..arrays import (
PsArrayShapeVar,
PsArrayStrideVar,
)
from ..types import PsAbstractType
from ..types import (
PsAbstractType,
PsScalarType,
PsUnsignedIntegerType,
PsSignedIntegerType,
PsIeeeFloatType,
)
from ..types.quick import Fp, SInt, UInt
from ..emission import emit_code
......@@ -150,6 +157,13 @@ if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '
Py_buffer buffer_{name};
int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE | PyBUF_FORMAT);
if (buffer_{name}_res == -1) {{ return NULL; }}
"""
TMPL_CHECK_ARRAY_TYPE = """
if(!({cond})) {{
PyErr_SetString(PyExc_TypeError, "Wrong {what} of array {name}. Expected {expected}");
return NULL;
}}
"""
KWCHECK = """
......@@ -185,10 +199,38 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
f"Don't know how to cast Python objects to {dtype}"
)
def _type_char(self, dtype: PsScalarType) -> str | None:
if isinstance(
dtype, (PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType)
):
np_dtype = dtype.NUMPY_TYPES[dtype.width]
return np.dtype(np_dtype).char
else:
return None
def extract_array(self, arr: PsLinearizedArray) -> str:
"""Adds an array, and returns the name of the underlying Py_Buffer."""
if arr not in self._array_extractions:
extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=arr.name)
# Check array type
type_char = self._type_char(arr.element_type)
if type_char is not None:
dtype_cond = f"buffer_{arr.name}.format[0] == '{type_char}'"
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
cond=dtype_cond,
what="data type",
name=arr.name,
expected=str(arr.element_type),
)
# Check item size
itemsize = arr.element_type.itemsize
item_size_cond = f"buffer_{arr.name}.itemsize == {itemsize}"
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
cond=item_size_cond, what="itemsize", name=arr.name, expected=itemsize
)
self._array_buffers[arr] = f"buffer_{arr.name}"
self._array_extractions[arr] = extraction_code
......@@ -219,8 +261,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
case PsArrayShapeVar():
coord = variable.coordinate
code = (
f"{variable.dtype} {variable.name} = "
f"{buffer}.shape[{coord}] / {arr.element_type.itemsize};"
f"{variable.dtype} {variable.name} = {buffer}.shape[{coord}];"
)
case PsArrayStrideVar():
coord = variable.coordinate
......
......@@ -13,7 +13,6 @@ from .types import (
class PsTypedVariable(pb.Variable):
init_arg_names: tuple[str, ...] = ("name", "dtype")
__match_args__ = ("name", "dtype")
......@@ -130,18 +129,27 @@ class PsTypedConstant:
return v
def __add__(self, other: Any):
if isinstance(other, pb.Expression): # let pymbolic handle this case
return NotImplemented
return PsTypedConstant(self._value + self._fix(other)._value, self._dtype)
def __radd__(self, other: Any):
return PsTypedConstant(self._rfix(other)._value + self._value, self._dtype)
def __mul__(self, other: Any):
if isinstance(other, pb.Expression): # let pymbolic handle this case
return NotImplemented
return PsTypedConstant(self._value * self._fix(other)._value, self._dtype)
def __rmul__(self, other: Any):
return PsTypedConstant(self._rfix(other)._value * self._value, self._dtype)
def __sub__(self, other: Any):
if isinstance(other, pb.Expression): # let pymbolic handle this case
return NotImplemented
return PsTypedConstant(self._value - self._fix(other)._value, self._dtype)
def __rsub__(self, other: Any):
......@@ -156,6 +164,9 @@ class PsTypedConstant:
return quotient, rem
def __truediv__(self, other: Any):
if isinstance(other, pb.Expression): # let pymbolic handle this case
return NotImplemented
if self._dtype.is_float():
return PsTypedConstant(self._value / self._fix(other)._value, self._dtype)
elif self._dtype.is_uint():
......@@ -183,6 +194,9 @@ class PsTypedConstant:
return NotImplemented
def __mod__(self, other: Any):
if isinstance(other, pb.Expression): # let pymbolic handle this case
return NotImplemented
if self._dtype.is_uint():
return PsTypedConstant(self._value % self._fix(other)._value, self._dtype)
else:
......
......@@ -10,8 +10,8 @@ from pystencils.nbackend.emission import CPrinter
def test_basic_kernel():
u_size = PsTypedVariable("u_length", UInt(32, True))
u_arr = PsLinearizedArray("u", Fp(64), 1)
u_size = u_arr.shape[0]
u_base = PsArrayBasePointer("u_data", u_arr)
loop_ctr = PsTypedVariable("ctr", UInt(32))
......
import pytest
from pystencils import Target
from pystencils.nbackend.ast import *
from pystencils.nbackend.ast.constraints import PsParamConstraint
from pystencils.nbackend.typed_expressions import *
from pystencils.nbackend.arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAccess
from pystencils.nbackend.types.quick import *
import numpy as np
from pystencils.cpu.cpujit import compile_and_load
def test_pairwise_addition():
idx_type = SInt(64)
u = PsLinearizedArray("u", Fp(64, const=True), 2, index_dtype=idx_type)
v = PsLinearizedArray("v", Fp(64), 2, index_dtype=idx_type)
u_data = PsArrayBasePointer("u_data", u)
v_data = PsArrayBasePointer("v_data", v)
loop_ctr = PsTypedVariable("ctr", idx_type)
zero = PsTypedConstant(0, idx_type)
one = PsTypedConstant(1, idx_type)
two = PsTypedConstant(2, idx_type)
update = PsAssignment(
PsLvalueExpr(PsArrayAccess(v_data, loop_ctr)),
PsExpression(PsArrayAccess(u_data, two * loop_ctr) + PsArrayAccess(u_data, two * loop_ctr + one))
)
loop = PsLoop(
PsSymbolExpr(loop_ctr),
PsExpression(zero),
PsExpression(v.shape[0]),
PsExpression(one),
PsBlock([update])
)
func = PsKernelFunction(PsBlock([loop]), target=Target.CPU)
sizes_constraint = PsParamConstraint(
u.shape[0].eq(2 * v.shape[0]),
"Array `u` must have twice the length of array `v`"
)
func.add_constraints(sizes_constraint)
kernel = compile_and_load(func)
# Positive case
N = 21
u_arr = np.arange(2 * N, dtype=np.float64)
v_arr = np.zeros((N,), dtype=np.float64)
assert u_arr.shape[0] == 2 * v_arr.shape[0]
kernel(u=u_arr, v=v_arr)
v_expected = np.zeros_like(v_arr)
for i in range(N):
v_expected[i] = u_arr[2 * i] + u_arr[2*i + 1]
np.testing.assert_allclose(v_arr, v_expected)
# Negative case - mismatched array sizes
u_arr = np.zeros((N + 2,), dtype=np.float64)
v_arr = np.zeros((N,), dtype=np.float64)
with pytest.raises(ValueError):
kernel(u=u_arr, v=v_arr)
# Negative case - mismatched types
u_arr = np.arange(2 * N, dtype=np.float64)
v_arr = np.zeros((N,), dtype=np.float32)
with pytest.raises(TypeError):
kernel(u=u_arr, v=v_arr)
from pystencils.nbackend.typed_expressions import PsTypedVariable
from pystencils.nbackend.arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayShapeVar, PsArrayStrideVar
from pystencils.nbackend.types.quick import *
def test_variable_equality():
var1 = PsTypedVariable("x", Fp(32))
var2 = PsTypedVariable("x", Fp(32))
assert var1 == var2
arr = PsLinearizedArray("arr", Fp(64), 3)
bp1 = PsArrayBasePointer("arr_data", arr)
bp2 = PsArrayBasePointer("arr_data", arr)
assert bp1 == bp2
arr1 = PsLinearizedArray("arr", Fp(64), 3)
bp1 = PsArrayBasePointer("arr_data", arr1)
arr2 = PsLinearizedArray("arr", Fp(64), 3)
bp2 = PsArrayBasePointer("arr_data", arr2)
assert bp1 == bp2
for v1, v2 in zip(arr1.shape, arr2.shape):
assert v1 == v2
for v1, v2 in zip(arr1.strides, arr2.strides):
assert v1 == v2
def test_variable_inequality():
var1 = PsTypedVariable("x", Fp(32))
var2 = PsTypedVariable("x", Fp(64))
assert var1 != var2
var1 = PsTypedVariable("x", Fp(32, True))
var2 = PsTypedVariable("x", Fp(32, False))
assert var1 != var2
# Arrays
arr1 = PsLinearizedArray("arr", Fp(64), 3)
bp1 = PsArrayBasePointer("arr_data", arr1)
arr2 = PsLinearizedArray("arr", Fp(32), 3)
bp2 = PsArrayBasePointer("arr_data", arr2)
assert bp1 != bp2
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment