From 51092f67fe56f288f29a42c35d90cf572b6914c3 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 16 Jan 2024 15:41:46 +0100 Subject: [PATCH] JIT debugging --- src/pystencils/cpu/cpujit.py | 9 ++++++++- src/pystencils/nbackend/arrays.py | 8 +++----- src/pystencils/nbackend/ast/kernelfunction.py | 14 +++++++++++++- .../nbackend/jit/cpu_extension_module.py | 16 +++++++++------- src/pystencils/nbackend/types/basic_types.py | 8 ++++---- tests/nbackend/test_basic_printing.py | 5 +++-- 6 files changed, 40 insertions(+), 20 deletions(-) diff --git a/src/pystencils/cpu/cpujit.py b/src/pystencils/cpu/cpujit.py index b839f87cf..b3db1c096 100644 --- a/src/pystencils/cpu/cpujit.py +++ b/src/pystencils/cpu/cpujit.py @@ -69,6 +69,9 @@ from pystencils.kernel_wrapper import KernelWrapper from pystencils.typing import BasicType, CastFunc, VectorType, VectorMemoryAccess from pystencils.utils import atomic_file_write, recursive_dict_update +from ..nbackend.ast import PsKernelFunction +from ..nbackend.jit.cpu_extension_module import PsKernelExtensioNModule + def make_python_function(kernel_function_node, custom_backend=None): """ @@ -619,7 +622,11 @@ def compile_and_load(ast, custom_backend=None): compiler_config = get_compiler_config() function_prefix = '__declspec(dllexport)' if compiler_config['os'].lower() == 'windows' else '' - code = ExtensionModuleCode(custom_backend=custom_backend) + if isinstance(ast, PsKernelFunction): + code = PsKernelExtensioNModule() + else: + code = ExtensionModuleCode(custom_backend=custom_backend) + code.add_function(ast, ast.function_name) code.create_code_string(compiler_config['restrict_qualifier'], function_prefix) diff --git a/src/pystencils/nbackend/arrays.py b/src/pystencils/nbackend/arrays.py index 47394d2ac..368c2940b 100644 --- a/src/pystencils/nbackend/arrays.py +++ b/src/pystencils/nbackend/arrays.py @@ -43,7 +43,6 @@ kernel_function.add_constraints(*constraints) from __future__ import annotations -from typing import TYPE_CHECKING from abc import ABC import pymbolic.primitives as pb @@ -57,8 +56,7 @@ from .types import ( constify, ) -if TYPE_CHECKING: - from .typed_expressions import PsTypedVariable, PsTypedConstant +from .typed_expressions import PsTypedVariable, PsTypedConstant class PsLinearizedArray: @@ -152,7 +150,7 @@ class PsArrayShapeVar(PsArrayAssocVar): __match_args__ = ("array", "coordinate", "dtype") def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType): - name = f"{array}_size{coordinate}" + name = f"{array.name}_size{coordinate}" super().__init__(name, dtype, array) self._coordinate = coordinate @@ -169,7 +167,7 @@ class PsArrayStrideVar(PsArrayAssocVar): __match_args__ = ("array", "coordinate", "dtype") def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType): - name = f"{array}_size{coordinate}" + name = f"{array.name}_size{coordinate}" super().__init__(name, dtype, array) self._coordinate = coordinate diff --git a/src/pystencils/nbackend/ast/kernelfunction.py b/src/pystencils/nbackend/ast/kernelfunction.py index e2f0ac8e6..4aa963c1b 100644 --- a/src/pystencils/nbackend/ast/kernelfunction.py +++ b/src/pystencils/nbackend/ast/kernelfunction.py @@ -94,6 +94,17 @@ class PsKernelFunction(PsAstNode): def name(self, value: str): self._name = value + @property + def function_name(self) -> str: + """For backward compatibility.""" + return self._name + + @property + def instruction_set(self) -> str | None: + """For backward compatibility""" + return None + + def num_children(self) -> int: return 1 @@ -128,4 +139,5 @@ class PsKernelFunction(PsAstNode): return PsKernelParametersSpec(tuple(params_list), tuple(arrays), tuple(self._constraints)) def get_required_headers(self) -> set[str]: - raise NotImplementedError() + # TODO: Headers from types, vectorizer, ... + return set() diff --git a/src/pystencils/nbackend/jit/cpu_extension_module.py b/src/pystencils/nbackend/jit/cpu_extension_module.py index e4d021c5d..c70c45a39 100644 --- a/src/pystencils/nbackend/jit/cpu_extension_module.py +++ b/src/pystencils/nbackend/jit/cpu_extension_module.py @@ -6,6 +6,7 @@ from os import path import hashlib from itertools import chain +from textwrap import indent from ..exceptions import PsInternalCompilerError from ..ast import PsKernelFunction @@ -109,6 +110,8 @@ class PsKernelExtensioNModule: code += create_module_boilerplate_code(self._code_hash, self._kernels.keys()) + self._code_string = code + def get_hash_of_code(self): assert self._code_string is not None, "The code must be generated first" return self._code_hash @@ -151,9 +154,9 @@ if (buffer_{name}_res == -1) {{ return NULL; }} KWCHECK = """ if( !kwargs || !PyDict_Check(kwargs) ) {{ - PyErr_SetString(PyExc_TypeError, "No keyword arguments passed"); - return NULL; - }} + PyErr_SetString(PyExc_TypeError, "No keyword arguments passed"); + return NULL; +}} """ def __init__(self) -> None: @@ -196,8 +199,6 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ def extract_scalar(self, variable: PsTypedVariable) -> str: if variable not in self._scalar_extractions: - self.TMPL_EXTRACT_SCALAR.format() - extract_func = self._scalar_extractor(variable.dtype) code = self.TMPL_EXTRACT_SCALAR.format( name=variable.name, @@ -268,12 +269,13 @@ if(!({cond})) def resolve(self, function_name) -> str: assert self._call is not None - body = "\n".join( + body = "\n\n".join( chain( [self.KWCHECK], self._scalar_extractions.values(), self._array_extractions.values(), self._array_assoc_var_extractions.values(), + self._constraint_checks, [self._call], self._array_frees.values(), ["Py_RETURN_NONE;"], @@ -281,6 +283,6 @@ if(!({cond})) ) code = f"static PyObject * {function_name}(PyObject * self, PyObject * args, PyObject * kwargs)\n" - code += "{\n" + body + "\n}\n" + code += "{\n" + indent(body, prefix=" ") + "\n}\n" return code diff --git a/src/pystencils/nbackend/types/basic_types.py b/src/pystencils/nbackend/types/basic_types.py index 45deb88a0..be6de603e 100644 --- a/src/pystencils/nbackend/types/basic_types.py +++ b/src/pystencils/nbackend/types/basic_types.py @@ -40,7 +40,7 @@ class PsAbstractType(ABC): return type(self) is type(other) and self._const == other._const def _const_string(self) -> str: - return "const" if self._const else "" + return "const " if self._const else "" @abstractmethod def _c_string(self) -> str: @@ -241,7 +241,7 @@ class PsIntegerType(PsScalarType, ABC): def _c_string(self) -> str: prefix = "" if self._signed else "u" - return f"{self._const_string()} {prefix}int{self._width}_t" + return f"{self._const_string()}{prefix}int{self._width}_t" def __repr__(self) -> str: return f"PsIntegerType( width={self.width}, signed={self.signed}, const={self.const} )" @@ -359,9 +359,9 @@ class PsIeeeFloatType(PsScalarType): def _c_string(self) -> str: match self._width: case 32: - return f"{self._const_string()} float" + return f"{self._const_string()}float" case 64: - return f"{self._const_string()} double" + return f"{self._const_string()}double" case _: assert False, "unreachable code" diff --git a/tests/nbackend/test_basic_printing.py b/tests/nbackend/test_basic_printing.py index 4031c12fe..adf7b8d2e 100644 --- a/tests/nbackend/test_basic_printing.py +++ b/tests/nbackend/test_basic_printing.py @@ -4,13 +4,14 @@ from pystencils import Target from pystencils.nbackend.ast import * from pystencils.nbackend.typed_expressions import * +from pystencils.nbackend.arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAccess from pystencils.nbackend.types.quick import * from pystencils.nbackend.emission import CPrinter def test_basic_kernel(): u_size = PsTypedVariable("u_length", UInt(32, True)) - u_arr = PsArray("u", u_size, Fp(64)) + u_arr = PsLinearizedArray("u", Fp(64), 1) u_base = PsArrayBasePointer("u_data", u_arr) loop_ctr = PsTypedVariable("ctr", UInt(32)) @@ -34,7 +35,7 @@ def test_basic_kernel(): printer = CPrinter() code = printer.print(func) - paramlist = func.get_parameters() + paramlist = func.get_parameters().params params_str = ", ".join(f"{p.dtype} {p.name}" for p in paramlist) assert code.find("(" + params_str + ")") >= 0 -- GitLab