From 51092f67fe56f288f29a42c35d90cf572b6914c3 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 16 Jan 2024 15:41:46 +0100
Subject: [PATCH] JIT debugging

---
 src/pystencils/cpu/cpujit.py                     |  9 ++++++++-
 src/pystencils/nbackend/arrays.py                |  8 +++-----
 src/pystencils/nbackend/ast/kernelfunction.py    | 14 +++++++++++++-
 .../nbackend/jit/cpu_extension_module.py         | 16 +++++++++-------
 src/pystencils/nbackend/types/basic_types.py     |  8 ++++----
 tests/nbackend/test_basic_printing.py            |  5 +++--
 6 files changed, 40 insertions(+), 20 deletions(-)

diff --git a/src/pystencils/cpu/cpujit.py b/src/pystencils/cpu/cpujit.py
index b839f87cf..b3db1c096 100644
--- a/src/pystencils/cpu/cpujit.py
+++ b/src/pystencils/cpu/cpujit.py
@@ -69,6 +69,9 @@ from pystencils.kernel_wrapper import KernelWrapper
 from pystencils.typing import BasicType, CastFunc, VectorType, VectorMemoryAccess
 from pystencils.utils import atomic_file_write, recursive_dict_update
 
+from ..nbackend.ast import PsKernelFunction
+from ..nbackend.jit.cpu_extension_module import PsKernelExtensioNModule
+
 
 def make_python_function(kernel_function_node, custom_backend=None):
     """
@@ -619,7 +622,11 @@ def compile_and_load(ast, custom_backend=None):
     compiler_config = get_compiler_config()
     function_prefix = '__declspec(dllexport)' if compiler_config['os'].lower() == 'windows' else ''
 
-    code = ExtensionModuleCode(custom_backend=custom_backend)
+    if isinstance(ast, PsKernelFunction):
+        code = PsKernelExtensioNModule()
+    else:
+        code = ExtensionModuleCode(custom_backend=custom_backend)
+    
     code.add_function(ast, ast.function_name)
 
     code.create_code_string(compiler_config['restrict_qualifier'], function_prefix)
diff --git a/src/pystencils/nbackend/arrays.py b/src/pystencils/nbackend/arrays.py
index 47394d2ac..368c2940b 100644
--- a/src/pystencils/nbackend/arrays.py
+++ b/src/pystencils/nbackend/arrays.py
@@ -43,7 +43,6 @@ kernel_function.add_constraints(*constraints)
 
 from __future__ import annotations
 
-from typing import TYPE_CHECKING
 from abc import ABC
 
 import pymbolic.primitives as pb
@@ -57,8 +56,7 @@ from .types import (
     constify,
 )
 
-if TYPE_CHECKING:
-    from .typed_expressions import PsTypedVariable, PsTypedConstant
+from .typed_expressions import PsTypedVariable, PsTypedConstant
 
 
 class PsLinearizedArray:
@@ -152,7 +150,7 @@ class PsArrayShapeVar(PsArrayAssocVar):
     __match_args__ = ("array", "coordinate", "dtype")
 
     def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType):
-        name = f"{array}_size{coordinate}"
+        name = f"{array.name}_size{coordinate}"
         super().__init__(name, dtype, array)
         self._coordinate = coordinate
 
@@ -169,7 +167,7 @@ class PsArrayStrideVar(PsArrayAssocVar):
     __match_args__ = ("array", "coordinate", "dtype")
 
     def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType):
-        name = f"{array}_size{coordinate}"
+        name = f"{array.name}_size{coordinate}"
         super().__init__(name, dtype, array)
         self._coordinate = coordinate
 
diff --git a/src/pystencils/nbackend/ast/kernelfunction.py b/src/pystencils/nbackend/ast/kernelfunction.py
index e2f0ac8e6..4aa963c1b 100644
--- a/src/pystencils/nbackend/ast/kernelfunction.py
+++ b/src/pystencils/nbackend/ast/kernelfunction.py
@@ -94,6 +94,17 @@ class PsKernelFunction(PsAstNode):
     def name(self, value: str):
         self._name = value
 
+    @property
+    def function_name(self) -> str:
+        """For backward compatibility."""
+        return self._name
+    
+    @property
+    def instruction_set(self) -> str | None:
+        """For backward compatibility"""
+        return None
+
+
     def num_children(self) -> int:
         return 1
 
@@ -128,4 +139,5 @@ class PsKernelFunction(PsAstNode):
         return PsKernelParametersSpec(tuple(params_list), tuple(arrays), tuple(self._constraints))
     
     def get_required_headers(self) -> set[str]:
-        raise NotImplementedError()
+        #   TODO: Headers from types, vectorizer, ...
+        return set()
diff --git a/src/pystencils/nbackend/jit/cpu_extension_module.py b/src/pystencils/nbackend/jit/cpu_extension_module.py
index e4d021c5d..c70c45a39 100644
--- a/src/pystencils/nbackend/jit/cpu_extension_module.py
+++ b/src/pystencils/nbackend/jit/cpu_extension_module.py
@@ -6,6 +6,7 @@ from os import path
 import hashlib
 
 from itertools import chain
+from textwrap import indent
 
 from ..exceptions import PsInternalCompilerError
 from ..ast import PsKernelFunction
@@ -109,6 +110,8 @@ class PsKernelExtensioNModule:
 
         code += create_module_boilerplate_code(self._code_hash, self._kernels.keys())
 
+        self._code_string = code
+
     def get_hash_of_code(self):
         assert self._code_string is not None, "The code must be generated first"
         return self._code_hash
@@ -151,9 +154,9 @@ if (buffer_{name}_res == -1) {{ return NULL; }}
 
     KWCHECK = """
 if( !kwargs || !PyDict_Check(kwargs) ) {{ 
-        PyErr_SetString(PyExc_TypeError, "No keyword arguments passed"); 
-        return NULL; 
-    }}
+    PyErr_SetString(PyExc_TypeError, "No keyword arguments passed"); 
+    return NULL; 
+}}
 """
 
     def __init__(self) -> None:
@@ -196,8 +199,6 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
 
     def extract_scalar(self, variable: PsTypedVariable) -> str:
         if variable not in self._scalar_extractions:
-            self.TMPL_EXTRACT_SCALAR.format()
-
             extract_func = self._scalar_extractor(variable.dtype)
             code = self.TMPL_EXTRACT_SCALAR.format(
                 name=variable.name,
@@ -268,12 +269,13 @@ if(!({cond}))
     def resolve(self, function_name) -> str:
         assert self._call is not None
 
-        body = "\n".join(
+        body = "\n\n".join(
             chain(
                 [self.KWCHECK],
                 self._scalar_extractions.values(),
                 self._array_extractions.values(),
                 self._array_assoc_var_extractions.values(),
+                self._constraint_checks,
                 [self._call],
                 self._array_frees.values(),
                 ["Py_RETURN_NONE;"],
@@ -281,6 +283,6 @@ if(!({cond}))
         )
 
         code = f"static PyObject * {function_name}(PyObject * self, PyObject * args, PyObject * kwargs)\n"
-        code += "{\n" + body + "\n}\n"
+        code += "{\n" + indent(body, prefix="    ") + "\n}\n"
 
         return code
diff --git a/src/pystencils/nbackend/types/basic_types.py b/src/pystencils/nbackend/types/basic_types.py
index 45deb88a0..be6de603e 100644
--- a/src/pystencils/nbackend/types/basic_types.py
+++ b/src/pystencils/nbackend/types/basic_types.py
@@ -40,7 +40,7 @@ class PsAbstractType(ABC):
         return type(self) is type(other) and self._const == other._const
 
     def _const_string(self) -> str:
-        return "const" if self._const else ""
+        return "const " if self._const else ""
 
     @abstractmethod
     def _c_string(self) -> str:
@@ -241,7 +241,7 @@ class PsIntegerType(PsScalarType, ABC):
 
     def _c_string(self) -> str:
         prefix = "" if self._signed else "u"
-        return f"{self._const_string()} {prefix}int{self._width}_t"
+        return f"{self._const_string()}{prefix}int{self._width}_t"
 
     def __repr__(self) -> str:
         return f"PsIntegerType( width={self.width}, signed={self.signed}, const={self.const} )"
@@ -359,9 +359,9 @@ class PsIeeeFloatType(PsScalarType):
     def _c_string(self) -> str:
         match self._width:
             case 32:
-                return f"{self._const_string()} float"
+                return f"{self._const_string()}float"
             case 64:
-                return f"{self._const_string()} double"
+                return f"{self._const_string()}double"
             case _:
                 assert False, "unreachable code"
 
diff --git a/tests/nbackend/test_basic_printing.py b/tests/nbackend/test_basic_printing.py
index 4031c12fe..adf7b8d2e 100644
--- a/tests/nbackend/test_basic_printing.py
+++ b/tests/nbackend/test_basic_printing.py
@@ -4,13 +4,14 @@ from pystencils import Target
 
 from pystencils.nbackend.ast import *
 from pystencils.nbackend.typed_expressions import *
+from pystencils.nbackend.arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAccess
 from pystencils.nbackend.types.quick import *
 from pystencils.nbackend.emission import CPrinter
 
 def test_basic_kernel():
 
     u_size = PsTypedVariable("u_length", UInt(32, True))
-    u_arr = PsArray("u", u_size, Fp(64))
+    u_arr = PsLinearizedArray("u", Fp(64), 1)
     u_base = PsArrayBasePointer("u_data", u_arr)
 
     loop_ctr = PsTypedVariable("ctr", UInt(32))
@@ -34,7 +35,7 @@ def test_basic_kernel():
     printer = CPrinter()
     code = printer.print(func)
 
-    paramlist = func.get_parameters()
+    paramlist = func.get_parameters().params
     params_str = ", ".join(f"{p.dtype} {p.name}" for p in paramlist)
 
     assert code.find("(" + params_str + ")") >= 0
-- 
GitLab