From 03651cbf052871121bdf89cd1f439d85b82d38e4 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 20 Feb 2024 12:49:42 +0100 Subject: [PATCH] integrate iteration slices. buffer test cases now complete. --- .flake8 | 1 + src/pystencils/backend/arrays.py | 15 +- src/pystencils/backend/ast/collectors.py | 7 +- src/pystencils/backend/ast/kernelfunction.py | 9 +- src/pystencils/backend/ast/nodes.py | 10 +- src/pystencils/backend/ast/transformations.py | 8 +- src/pystencils/backend/ast/tree_iteration.py | 8 +- src/pystencils/backend/constraints.py | 4 +- src/pystencils/backend/emission.py | 11 +- .../backend/jit/cpu_extension_module.py | 12 +- src/pystencils/backend/jit/legacy_cpu.py | 283 +++++++++++------- src/pystencils/backend/jit/msvc_detection.py | 48 +-- .../backend/kernelcreation/context.py | 2 +- .../backend/kernelcreation/defaults.py | 2 +- .../backend/kernelcreation/freeze.py | 3 + .../backend/kernelcreation/iteration_space.py | 99 +++++- .../backend/kernelcreation/transformations.py | 4 +- .../backend/kernelcreation/typification.py | 9 +- src/pystencils/backend/platforms/__init__.py | 4 +- src/pystencils/backend/platforms/platform.py | 9 +- src/pystencils/backend/typed_expressions.py | 2 +- src/pystencils/backend/types/basic_types.py | 4 +- src/pystencils/backend/types/exception.py | 2 - src/pystencils/config.py | 3 +- .../kernelcreation/test_iteration_space.py | 46 +++ tests/test_buffer.py | 7 +- 26 files changed, 411 insertions(+), 201 deletions(-) diff --git a/.flake8 b/.flake8 index 8b65d20eb..3f946922a 100644 --- a/.flake8 +++ b/.flake8 @@ -3,4 +3,5 @@ max-line-length=120 exclude=src/pystencils/jupyter.py, src/pystencils/plot.py src/pystencils/session.py + src/pystencils/old ignore = W293 W503 W291 C901 E741 diff --git a/src/pystencils/backend/arrays.py b/src/pystencils/backend/arrays.py index 8c8ae2001..8c7a9306f 100644 --- a/src/pystencils/backend/arrays.py +++ b/src/pystencils/backend/arrays.py @@ -46,7 +46,13 @@ from abc import ABC import pymbolic.primitives as pb -from .types import PsAbstractType, PsPointerType, PsIntegerType, PsUnsignedIntegerType, PsSignedIntegerType +from .types import ( + PsAbstractType, + PsPointerType, + PsIntegerType, + PsUnsignedIntegerType, + PsSignedIntegerType, +) from .typed_expressions import PsTypedVariable, ExprOrConstant, PsTypedConstant @@ -129,7 +135,7 @@ class PsLinearizedArray: def strides(self) -> tuple[PsArrayStrideVar | PsTypedConstant, ...]: """The array's strides, expressed using `PsTypedConstant` and `PsArrayStrideVar`""" return self._strides - + @property def strides_spec(self) -> tuple[EllipsisType | int, ...]: """The array's strides, expressed using `int` and `...`""" @@ -207,12 +213,13 @@ class PsArrayBasePointer(PsArrayAssocVar): def __getinitargs__(self): return self.name, self.array - + class TypeErasedBasePointer(PsArrayBasePointer): """Base pointer for arrays whose element type has been erased. - + Used primarily for arrays of anonymous structs.""" + def __init__(self, name: str, array: PsLinearizedArray): dtype = PsPointerType(PsUnsignedIntegerType(8)) super(PsArrayBasePointer, self).__init__(name, dtype, array) diff --git a/src/pystencils/backend/ast/collectors.py b/src/pystencils/backend/ast/collectors.py index c94d0b18b..e64efa1a2 100644 --- a/src/pystencils/backend/ast/collectors.py +++ b/src/pystencils/backend/ast/collectors.py @@ -61,12 +61,7 @@ class UndefinedVariablesCollector: return undefined_vars case PsLoop(ctr, start, stop, step, body): - undefined_vars = ( - self(start) - | self(stop) - | self(step) - | self(body) - ) + undefined_vars = self(start) | self(stop) | self(step) | self(body) undefined_vars.remove(ctr.symbol) return undefined_vars diff --git a/src/pystencils/backend/ast/kernelfunction.py b/src/pystencils/backend/ast/kernelfunction.py index d2e66de87..05ec77bde 100644 --- a/src/pystencils/backend/ast/kernelfunction.py +++ b/src/pystencils/backend/ast/kernelfunction.py @@ -50,7 +50,7 @@ class PsKernelParametersSpec: elif var in self.params: continue - + raise PsInternalCompilerError( "Constrained parameter was neither contained in kernel parameter list " "nor associated with a kernel array.\n" @@ -68,7 +68,9 @@ class PsKernelFunction(PsAstNode): __match_args__ = ("body",) - def __init__(self, body: PsBlock, target: Target, name: str = "kernel", jit: JitBase = no_jit): + def __init__( + self, body: PsBlock, target: Target, name: str = "kernel", jit: JitBase = no_jit + ): self._body: PsBlock = body self._target = target self._name = name @@ -137,7 +139,8 @@ class PsKernelFunction(PsAstNode): def get_required_headers(self) -> set[str]: # To Do: Headers from target/instruction set/... from .collectors import collect_required_headers + return collect_required_headers(self) - + def compile(self) -> Callable[..., None]: return self._jit.compile(self) diff --git a/src/pystencils/backend/ast/nodes.py b/src/pystencils/backend/ast/nodes.py index d11be448d..ef73d8b16 100644 --- a/src/pystencils/backend/ast/nodes.py +++ b/src/pystencils/backend/ast/nodes.py @@ -40,9 +40,9 @@ class PsAstNode(ABC): def __eq__(self, other: object) -> bool: if not isinstance(other, PsAstNode): return False - + return type(self) is type(other) and self.children == other.children - + def __hash__(self) -> int: return hash((type(self), self.children)) @@ -98,12 +98,12 @@ class PsExpression(PsLeafNode): def __repr__(self) -> str: return f"Expr({repr(self._expr)})" - + def __eq__(self, other: object) -> bool: if not isinstance(other, PsExpression): return False return type(self) is type(other) and self._expr == other._expr - + def __hash__(self) -> int: return hash((type(self), self._expr)) @@ -361,7 +361,7 @@ class PsComment(PsLeafNode): @property def text(self) -> str: return self._text - + @property def lines(self) -> tuple[str, ...]: return self._lines diff --git a/src/pystencils/backend/ast/transformations.py b/src/pystencils/backend/ast/transformations.py index 4260e18dc..dc438e52e 100644 --- a/src/pystencils/backend/ast/transformations.py +++ b/src/pystencils/backend/ast/transformations.py @@ -34,14 +34,18 @@ class PsVariablesSubstitutor(PsAstTransformer): def assignment(self, asm: PsAssignment): lhs_expr = asm.lhs.expression if isinstance(lhs_expr, PsTypedVariable) and lhs_expr in self._subs_dict: - raise ValueError(f"Cannot substitute symbol {lhs_expr} that occurs on a left-hand side of an assignment.") + raise ValueError( + f"Cannot substitute symbol {lhs_expr} that occurs on a left-hand side of an assignment." + ) self.transform_children(asm) return asm @visit.case(PsLoop) def loop(self, loop: PsLoop): if loop.counter.expression in self._subs_dict: - raise ValueError(f"Cannot substitute symbol {loop.counter.expression} that is defined as a loop counter.") + raise ValueError( + f"Cannot substitute symbol {loop.counter.expression} that is defined as a loop counter." + ) self.transform_children(loop) return loop diff --git a/src/pystencils/backend/ast/tree_iteration.py b/src/pystencils/backend/ast/tree_iteration.py index 2019e7d02..1549d7580 100644 --- a/src/pystencils/backend/ast/tree_iteration.py +++ b/src/pystencils/backend/ast/tree_iteration.py @@ -4,8 +4,7 @@ from .nodes import PsAstNode def dfs_preorder( - node: PsAstNode, - yield_pred: Callable[[PsAstNode], bool] = lambda _: True + node: PsAstNode, yield_pred: Callable[[PsAstNode], bool] = lambda _: True ) -> Generator[PsAstNode, None, None]: """Pre-Order depth-first traversal of an abstract syntax tree. @@ -21,8 +20,7 @@ def dfs_preorder( def dfs_postorder( - node: PsAstNode, - yield_pred: Callable[[PsAstNode], bool] = lambda _: True + node: PsAstNode, yield_pred: Callable[[PsAstNode], bool] = lambda _: True ) -> Generator[PsAstNode, None, None]: """Post-Order depth-first traversal of an abstract syntax tree. @@ -32,6 +30,6 @@ def dfs_postorder( """ for c in node.children: yield from dfs_postorder(c, yield_pred) - + if yield_pred(node): yield node diff --git a/src/pystencils/backend/constraints.py b/src/pystencils/backend/constraints.py index 0cda3f4dc..0225420b4 100644 --- a/src/pystencils/backend/constraints.py +++ b/src/pystencils/backend/constraints.py @@ -14,9 +14,9 @@ class PsKernelConstraint: def print_c_condition(self): return CCodeMapper()(self.condition) - + def get_variables(self) -> set[PsTypedVariable]: return DependencyMapper(False, False, False, False)(self.condition) - + def __str__(self) -> str: return f"{self.message} [{self.condition}]" diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index 3fe16e9c1..c76d118db 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -11,7 +11,7 @@ from .ast import ( PsAssignment, PsLoop, PsConditional, - PsComment + PsComment, ) from .ast.kernelfunction import PsKernelFunction from .typed_expressions import PsTypedVariable @@ -25,13 +25,12 @@ def emit_code(kernel: PsKernelFunction): class CExpressionsPrinter(CCodeMapper): - def map_deref(self, deref: Deref, enclosing_prec): return "*" - + def map_address_of(self, addrof: AddressOf, enclosing_prec): return "&" - + def map_cast(self, cast: Cast, enclosing_prec): return f"({cast.target_type.c_string()})" @@ -57,7 +56,9 @@ class CAstPrinter: @visit.case(PsKernelFunction) def function(self, func: PsKernelFunction) -> str: params_spec = func.get_parameters() - params_str = ", ".join(f"{p.dtype.c_string()} {p.name}" for p in params_spec.params) + params_str = ", ".join( + f"{p.dtype.c_string()} {p.name}" for p in params_spec.params + ) decl = f"FUNC_PREFIX void {func.name} ({params_str})" body = self.visit(func.body) return f"{decl}\n{body}" diff --git a/src/pystencils/backend/jit/cpu_extension_module.py b/src/pystencils/backend/jit/cpu_extension_module.py index 739aefb27..f58e8cb1b 100644 --- a/src/pystencils/backend/jit/cpu_extension_module.py +++ b/src/pystencils/backend/jit/cpu_extension_module.py @@ -164,9 +164,15 @@ PyInit_{module_name}(void) def create_module_boilerplate_code(module_name, names): - method_definition = '{{"{name}", (PyCFunction){name}, METH_VARARGS | METH_KEYWORDS, ""}},' - method_definitions = "\n".join([method_definition.format(name=name) for name in names]) - return template_module_boilerplate.format(module_name=module_name, method_definitions=method_definitions) + method_definition = ( + '{{"{name}", (PyCFunction){name}, METH_VARARGS | METH_KEYWORDS, ""}},' + ) + method_definitions = "\n".join( + [method_definition.format(name=name) for name in names] + ) + return template_module_boilerplate.format( + module_name=module_name, method_definitions=method_definitions + ) class CallWrapperBuilder: diff --git a/src/pystencils/backend/jit/legacy_cpu.py b/src/pystencils/backend/jit/legacy_cpu.py index 7906bb5e5..df8eab673 100644 --- a/src/pystencils/backend/jit/legacy_cpu.py +++ b/src/pystencils/backend/jit/legacy_cpu.py @@ -114,11 +114,11 @@ def set_config(config): def get_configuration_file_path(): - config_path_in_home = os.path.join(user_config_dir('pystencils'), 'config.json') + config_path_in_home = os.path.join(user_config_dir("pystencils"), "config.json") # 1) Read path from environment variable if found - if 'PYSTENCILS_CONFIG' in os.environ: - return os.environ['PYSTENCILS_CONFIG'], True + if "PYSTENCILS_CONFIG" in os.environ: + return os.environ["PYSTENCILS_CONFIG"], True # 2) Look in current directory for pystencils.json elif os.path.exists("pystencils.json"): return "pystencils.json", True @@ -139,92 +139,122 @@ def create_folder(path, is_file): def read_config(): - if platform.system().lower() == 'linux': - default_compiler_config = OrderedDict([ - ('os', 'linux'), - ('command', 'g++'), - ('flags', '-Ofast -DNDEBUG -fPIC -march=native -fopenmp -std=c++11'), - ('restrict_qualifier', '__restrict__') - ]) - if platform.machine().startswith('ppc64') or platform.machine() == 'arm64': - default_compiler_config['flags'] = default_compiler_config['flags'].replace('-march=native', - '-mcpu=native') - elif platform.system().lower() == 'windows': - default_compiler_config = OrderedDict([ - ('os', 'windows'), - ('msvc_version', 'latest'), - ('arch', 'x64'), - ('flags', '/Ox /fp:fast /OpenMP /arch:avx'), - ('restrict_qualifier', '__restrict') - ]) - if platform.machine() == 'ARM64': - default_compiler_config['arch'] = 'ARM64' - default_compiler_config['flags'] = default_compiler_config['flags'].replace(' /arch:avx', '') - elif platform.system().lower() == 'darwin': - default_compiler_config = OrderedDict([ - ('os', 'darwin'), - ('command', 'clang++'), - ('flags', '-Ofast -DNDEBUG -fPIC -march=native -Xclang -fopenmp -std=c++11'), - ('restrict_qualifier', '__restrict__') - ]) - if platform.machine() == 'arm64': - default_compiler_config['flags'] = default_compiler_config['flags'].replace('-march=native ', '') - for libomp in ['/opt/local/lib/libomp/libomp.dylib', '/usr/local/lib/libomp.dylib', - '/opt/homebrew/lib/libomp.dylib']: + if platform.system().lower() == "linux": + default_compiler_config = OrderedDict( + [ + ("os", "linux"), + ("command", "g++"), + ("flags", "-Ofast -DNDEBUG -fPIC -march=native -fopenmp -std=c++11"), + ("restrict_qualifier", "__restrict__"), + ] + ) + if platform.machine().startswith("ppc64") or platform.machine() == "arm64": + default_compiler_config["flags"] = default_compiler_config["flags"].replace( + "-march=native", "-mcpu=native" + ) + elif platform.system().lower() == "windows": + default_compiler_config = OrderedDict( + [ + ("os", "windows"), + ("msvc_version", "latest"), + ("arch", "x64"), + ("flags", "/Ox /fp:fast /OpenMP /arch:avx"), + ("restrict_qualifier", "__restrict"), + ] + ) + if platform.machine() == "ARM64": + default_compiler_config["arch"] = "ARM64" + default_compiler_config["flags"] = default_compiler_config["flags"].replace( + " /arch:avx", "" + ) + elif platform.system().lower() == "darwin": + default_compiler_config = OrderedDict( + [ + ("os", "darwin"), + ("command", "clang++"), + ( + "flags", + "-Ofast -DNDEBUG -fPIC -march=native -Xclang -fopenmp -std=c++11", + ), + ("restrict_qualifier", "__restrict__"), + ] + ) + if platform.machine() == "arm64": + default_compiler_config["flags"] = default_compiler_config["flags"].replace( + "-march=native ", "" + ) + for libomp in [ + "/opt/local/lib/libomp/libomp.dylib", + "/usr/local/lib/libomp.dylib", + "/opt/homebrew/lib/libomp.dylib", + ]: if os.path.exists(libomp): - default_compiler_config['flags'] += ' ' + libomp + default_compiler_config["flags"] += " " + libomp break else: - raise NotImplementedError('Generation of default compiler flags for %s is not implemented' % - (platform.system(),)) - - default_cache_config = OrderedDict([ - ('object_cache', os.path.join(user_cache_dir('pystencils'), 'objectcache')), - ('clear_cache_on_start', False), - ]) - - default_config = OrderedDict([('compiler', default_compiler_config), - ('cache', default_cache_config)]) + raise NotImplementedError( + "Generation of default compiler flags for %s is not implemented" + % (platform.system(),) + ) + + default_cache_config = OrderedDict( + [ + ("object_cache", os.path.join(user_cache_dir("pystencils"), "objectcache")), + ("clear_cache_on_start", False), + ] + ) + + default_config = OrderedDict( + [("compiler", default_compiler_config), ("cache", default_cache_config)] + ) config_path, config_exists = get_configuration_file_path() config = default_config.copy() if config_exists: - with open(config_path, 'r') as json_config_file: + with open(config_path, "r") as json_config_file: loaded_config = json.load(json_config_file) config = recursive_dict_update(config, loaded_config) else: create_folder(config_path, True) - with open(config_path, 'w') as f: + with open(config_path, "w") as f: json.dump(config, f, indent=4) - if config['cache']['object_cache'] is not False: - config['cache']['object_cache'] = os.path.expanduser(config['cache']['object_cache']).format(pid=os.getpid()) + if config["cache"]["object_cache"] is not False: + config["cache"]["object_cache"] = os.path.expanduser( + config["cache"]["object_cache"] + ).format(pid=os.getpid()) clear_cache_on_start = False - cache_status_file = os.path.join(config['cache']['object_cache'], 'last_config.json') + cache_status_file = os.path.join( + config["cache"]["object_cache"], "last_config.json" + ) if os.path.exists(cache_status_file): # check if compiler config has changed - last_config = json.load(open(cache_status_file, 'r')) - if set(last_config.items()) != set(config['compiler'].items()): + last_config = json.load(open(cache_status_file, "r")) + if set(last_config.items()) != set(config["compiler"].items()): clear_cache_on_start = True else: for key in last_config.keys(): - if last_config[key] != config['compiler'][key]: + if last_config[key] != config["compiler"][key]: clear_cache_on_start = True - if config['cache']['clear_cache_on_start'] or clear_cache_on_start: - shutil.rmtree(config['cache']['object_cache'], ignore_errors=True) + if config["cache"]["clear_cache_on_start"] or clear_cache_on_start: + shutil.rmtree(config["cache"]["object_cache"], ignore_errors=True) - create_folder(config['cache']['object_cache'], False) - with tempfile.NamedTemporaryFile('w', dir=os.path.dirname(cache_status_file), delete=False) as f: - json.dump(config['compiler'], f, indent=4) + create_folder(config["cache"]["object_cache"], False) + with tempfile.NamedTemporaryFile( + "w", dir=os.path.dirname(cache_status_file), delete=False + ) as f: + json.dump(config["compiler"], f, indent=4) os.replace(f.name, cache_status_file) - if config['compiler']['os'] == 'windows': - msvc_env = get_environment(config['compiler']['msvc_version'], config['compiler']['arch']) - if 'env' not in config['compiler']: - config['compiler']['env'] = {} - config['compiler']['env'].update(msvc_env) + if config["compiler"]["os"] == "windows": + msvc_env = get_environment( + config["compiler"]["msvc_version"], config["compiler"]["arch"] + ) + if "env" not in config["compiler"]: + config["compiler"]["env"] = {} + config["compiler"]["env"].update(msvc_env) return config @@ -233,11 +263,11 @@ _config = read_config() def get_compiler_config(): - return _config['compiler'] + return _config["compiler"] def get_cache_config(): - return _config['cache'] + return _config["cache"] def add_or_change_compiler_flags(flags): @@ -246,25 +276,27 @@ def add_or_change_compiler_flags(flags): compiler_config = get_compiler_config() cache_config = get_cache_config() - cache_config['object_cache'] = False # disable cache + cache_config["object_cache"] = False # disable cache for flag in flags: flag = flag.strip() - if '=' in flag: - base = flag.split('=')[0].strip() + if "=" in flag: + base = flag.split("=")[0].strip() else: base = flag - new_flags = [c for c in compiler_config['flags'].split() if not c.startswith(base)] + new_flags = [ + c for c in compiler_config["flags"].split() if not c.startswith(base) + ] new_flags.append(flag) - compiler_config['flags'] = ' '.join(new_flags) + compiler_config["flags"] = " ".join(new_flags) def clear_cache(): cache_config = get_cache_config() - if cache_config['object_cache'] is not False: - shutil.rmtree(cache_config['object_cache'], ignore_errors=True) - create_folder(cache_config['object_cache'], False) + if cache_config["object_cache"] is not False: + shutil.rmtree(cache_config["object_cache"], ignore_errors=True) + create_folder(cache_config["object_cache"], False) def load_kernel_from_file(module_name, function_name, path): @@ -284,15 +316,17 @@ def load_kernel_from_file(module_name, function_name, path): def run_compile_step(command): compiler_config = get_compiler_config() - config_env = compiler_config['env'] if 'env' in compiler_config else {} + config_env = compiler_config["env"] if "env" in compiler_config else {} compile_environment = os.environ.copy() compile_environment.update(config_env) try: - shell = True if compiler_config['os'].lower() == 'windows' else False - subprocess.check_output(command, env=compile_environment, stderr=subprocess.STDOUT, shell=shell) + shell = True if compiler_config["os"].lower() == "windows" else False + subprocess.check_output( + command, env=compile_environment, stderr=subprocess.STDOUT, shell=shell + ) except subprocess.CalledProcessError as e: print(" ".join(command)) - print(e.output.decode('utf8')) + print(e.output.decode("utf8")) raise e @@ -301,15 +335,18 @@ def compile_module(code, code_hash, base_dir, compile_flags=None): compile_flags = [] compiler_config = get_compiler_config() - extra_flags = ['-I' + sysconfig.get_paths()['include'], '-I' + get_pystencils_include_path()] + compile_flags - - if compiler_config['os'].lower() == 'windows': - lib_suffix = '.pyd' - object_suffix = '.obj' + extra_flags = [ + "-I" + sysconfig.get_paths()["include"], + "-I" + get_pystencils_include_path(), + ] + compile_flags + + if compiler_config["os"].lower() == "windows": + lib_suffix = ".pyd" + object_suffix = ".obj" windows = True else: - lib_suffix = '.so' - object_suffix = '.o' + lib_suffix = ".so" + object_suffix = ".o" windows = False src_file = os.path.join(base_dir, code_hash + ".cpp") @@ -318,36 +355,60 @@ def compile_module(code, code_hash, base_dir, compile_flags=None): if not os.path.exists(object_file): try: - with open(src_file, 'x') as f: + with open(src_file, "x") as f: code.write_to_file(f) except FileExistsError: pass if windows: - compile_cmd = ['cl.exe', '/c', '/EHsc'] + compiler_config['flags'].split() - compile_cmd += [*extra_flags, src_file, '/Fo' + object_file] + compile_cmd = ["cl.exe", "/c", "/EHsc"] + compiler_config["flags"].split() + compile_cmd += [*extra_flags, src_file, "/Fo" + object_file] run_compile_step(compile_cmd) else: with atomic_file_write(object_file) as file_name: - compile_cmd = [compiler_config['command'], '-c'] + compiler_config['flags'].split() - compile_cmd += [*extra_flags, '-o', file_name, src_file] + compile_cmd = [compiler_config["command"], "-c"] + compiler_config[ + "flags" + ].split() + compile_cmd += [*extra_flags, "-o", file_name, src_file] run_compile_step(compile_cmd) # Linking if windows: config_vars = sysconfig.get_config_vars() - py_lib = os.path.join(config_vars["installed_base"], "libs", - f"python{config_vars['py_version_nodot']}.lib") - run_compile_step(['link.exe', py_lib, '/DLL', '/out:' + lib_file, object_file]) - elif platform.system().lower() == 'darwin': + py_lib = os.path.join( + config_vars["installed_base"], + "libs", + f"python{config_vars['py_version_nodot']}.lib", + ) + run_compile_step( + ["link.exe", py_lib, "/DLL", "/out:" + lib_file, object_file] + ) + elif platform.system().lower() == "darwin": with atomic_file_write(lib_file) as file_name: - run_compile_step([compiler_config['command'], '-shared', object_file, '-o', file_name, '-undefined', - 'dynamic_lookup'] - + compiler_config['flags'].split()) + run_compile_step( + [ + compiler_config["command"], + "-shared", + object_file, + "-o", + file_name, + "-undefined", + "dynamic_lookup", + ] + + compiler_config["flags"].split() + ) else: with atomic_file_write(lib_file) as file_name: - run_compile_step([compiler_config['command'], '-shared', object_file, '-o', file_name] - + compiler_config['flags'].split()) + run_compile_step( + [ + compiler_config["command"], + "-shared", + object_file, + "-o", + file_name, + ] + + compiler_config["flags"].split() + ) return lib_file @@ -355,26 +416,34 @@ def compile_and_load(ast: PsKernelFunction, custom_backend=None): cache_config = get_cache_config() compiler_config = get_compiler_config() - function_prefix = '__declspec(dllexport)' if compiler_config['os'].lower() == 'windows' else '' + function_prefix = ( + "__declspec(dllexport)" if compiler_config["os"].lower() == "windows" else "" + ) code = PsKernelExtensioNModule() code.add_function(ast, ast.function_name) - code.create_code_string(compiler_config['restrict_qualifier'], function_prefix) + code.create_code_string(compiler_config["restrict_qualifier"], function_prefix) code_hash_str = code.get_hash_of_code() compile_flags = [] - if ast.instruction_set and 'compile_flags' in ast.instruction_set: - compile_flags = ast.instruction_set['compile_flags'] + if ast.instruction_set and "compile_flags" in ast.instruction_set: + compile_flags = ast.instruction_set["compile_flags"] - if cache_config['object_cache'] is False: + if cache_config["object_cache"] is False: with tempfile.TemporaryDirectory() as base_dir: - lib_file = compile_module(code, code_hash_str, base_dir, compile_flags=compile_flags) + lib_file = compile_module( + code, code_hash_str, base_dir, compile_flags=compile_flags + ) result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file) else: - lib_file = compile_module(code, code_hash_str, base_dir=cache_config['object_cache'], - compile_flags=compile_flags) + lib_file = compile_module( + code, + code_hash_str, + base_dir=cache_config["object_cache"], + compile_flags=compile_flags, + ) result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file) return KernelWrapper(result, ast.get_parameters(), ast) diff --git a/src/pystencils/backend/jit/msvc_detection.py b/src/pystencils/backend/jit/msvc_detection.py index 9cc1fc5ad..a7724973c 100644 --- a/src/pystencils/backend/jit/msvc_detection.py +++ b/src/pystencils/backend/jit/msvc_detection.py @@ -2,7 +2,7 @@ import os import subprocess -def get_environment(version_specifier, arch='x64'): +def get_environment(version_specifier, arch="x64"): """Returns an environment dictionary, for activating the Visual Studio compiler. Args: @@ -10,32 +10,39 @@ def get_environment(version_specifier, arch='x64'): installed version or 'setuptools' for setuptools-based detection arch: x86 or x64 """ - if version_specifier == 'setuptools': + if version_specifier == "setuptools": return get_environment_from_setup_tools(arch) - elif '\\' in version_specifier: + elif "\\" in version_specifier: vc_vars_path = find_vc_vars_all_via_filesystem_search(version_specifier) return get_environment_from_vc_vars_file(vc_vars_path, arch) else: try: - if version_specifier in ('auto', 'latest'): + if version_specifier in ("auto", "latest"): version_nr = find_latest_msvc_version_using_environment_variables() else: version_nr = normalize_msvc_version(version_specifier) vc_vars_path = get_vc_vars_path_via_environment_variable(version_nr) except ValueError: - vc_vars_path = find_vc_vars_all_via_filesystem_search("C:\\Program Files (x86)\\Microsoft Visual Studio") + vc_vars_path = find_vc_vars_all_via_filesystem_search( + "C:\\Program Files (x86)\\Microsoft Visual Studio" + ) if vc_vars_path is None: - vc_vars_path = find_vc_vars_all_via_filesystem_search("C:\\Program Files\\Microsoft Visual Studio") + vc_vars_path = find_vc_vars_all_via_filesystem_search( + "C:\\Program Files\\Microsoft Visual Studio" + ) if vc_vars_path is None: - raise ValueError("Visual Studio not found. Write path to VS folder in pystencils config") + raise ValueError( + "Visual Studio not found. Write path to VS folder in pystencils config" + ) return get_environment_from_vc_vars_file(vc_vars_path, arch) def find_latest_msvc_version_using_environment_variables(): import re + # noinspection SpellCheckingInspection - regex = re.compile(r'VS(\d\d)\dCOMNTOOLS') + regex = re.compile(r"VS(\d\d)\dCOMNTOOLS") versions = [] for key, value in os.environ.items(): match = regex.match(key) @@ -54,15 +61,11 @@ def normalize_msvc_version(version): - version numbers with or without dot i.e. 11.0 or 11 :return: integer version number """ - if isinstance(version, str) and '.' in version: - version = version.split('.')[0] + if isinstance(version, str) and "." in version: + version = version.split(".")[0] version = int(version) - mapping = { - 2015: 14, - 2013: 12, - 2012: 11 - } + mapping = {2015: 14, 2013: 12, 2012: 11} if version in mapping: return mapping[version] else: @@ -73,22 +76,27 @@ def get_environment_from_vc_vars_file(vc_vars_file, arch): out = subprocess.check_output( f'cmd /u /c "{vc_vars_file}" {arch} && set', stderr=subprocess.STDOUT, - ).decode('utf-16le', errors='replace') + ).decode("utf-16le", errors="replace") - env = {key.upper(): value for key, _, value in (line.partition('=') for line in out.splitlines()) if key and value} + env = { + key.upper(): value + for key, _, value in (line.partition("=") for line in out.splitlines()) + if key and value + } return env def get_vc_vars_path_via_environment_variable(version_nr): # noinspection SpellCheckingInspection - environment_var_name = 'VS%d0COMNTOOLS' % (version_nr,) + environment_var_name = "VS%d0COMNTOOLS" % (version_nr,) vc_path = os.environ[environment_var_name] - path = os.path.join(vc_path, '..', '..', 'VC', 'vcvarsall.bat') + path = os.path.join(vc_path, "..", "..", "VC", "vcvarsall.bat") return os.path.abspath(path) def get_environment_from_setup_tools(arch): from setuptools.msvc import msvc14_get_vc_env + msvc_env = msvc14_get_vc_env(arch) return {k.upper(): v for k, v in msvc_env.items()} @@ -97,7 +105,7 @@ def find_vc_vars_all_via_filesystem_search(base_path): matches = [] for root, dir_names, file_names in os.walk(base_path): for filename in file_names: - if filename == 'vcvarsall.bat': + if filename == "vcvarsall.bat": matches.append(os.path.join(root, filename)) matches.sort(reverse=True) diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 68fdfd9c5..115a50b9a 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -56,7 +56,7 @@ class KernelCreationContext: self._field_arrays: dict[Field, PsLinearizedArray] = dict() self._fields_collection = FieldsInKernel() - + self._ispace: IterationSpace | None = None @property diff --git a/src/pystencils/backend/kernelcreation/defaults.py b/src/pystencils/backend/kernelcreation/defaults.py index c52f6c254..fe0e8ed4a 100644 --- a/src/pystencils/backend/kernelcreation/defaults.py +++ b/src/pystencils/backend/kernelcreation/defaults.py @@ -29,7 +29,7 @@ class PsDefaults(Generic[SymbolT]): def __init__(self, symcreate: Callable[[str, PsAbstractType], SymbolT]): self.numeric_dtype = PsIeeeFloatType(64) """Default data type for numerical computations""" - + self.index_dtype = PsSignedIntegerType(64) """Default data type for indices.""" diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index ba4b7ac82..fb18cc978 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -54,6 +54,9 @@ class FreezeExpressions(SympyToPymbolicMapper): else: raise PsInputError(f"Don't know how to freeze {obj}") + def freeze_expression(self, expr: sp.Basic) -> pb.Expression: + return self.rec(expr) + def map_Assignment(self, expr: Assignment): # noqa lhs = self.rec(expr.lhs) rhs = self.rec(expr.rhs) diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index dd9151160..63593c548 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -5,6 +5,8 @@ from dataclasses import dataclass from functools import reduce from operator import mul +import sympy as sp + from ...sympyextensions import AssignmentCollection from ...field import Field, FieldType @@ -18,7 +20,7 @@ from ..arrays import PsLinearizedArray from ..ast.util import failing_cast from ..types import PsStructType, constify from .defaults import Pymbolic as Defaults -from ..exceptions import PsInputError, PsInternalCompilerError, KernelConstraintsError +from ..exceptions import PsInputError, KernelConstraintsError if TYPE_CHECKING: from .context import KernelCreationContext @@ -64,9 +66,9 @@ class FullIterationSpace(IterationSpace): @dataclass class Dimension: - start: VarOrConstant - stop: VarOrConstant - step: VarOrConstant + start: ExprOrConstant + stop: ExprOrConstant + step: ExprOrConstant counter: PsTypedVariable @staticmethod @@ -79,7 +81,7 @@ class FullIterationSpace(IterationSpace): archetype_array = ctx.get_array(archetype_field) dim = archetype_field.spatial_dimensions - + counters = [ PsTypedVariable(name, ctx.index_dtype) for name in Defaults.spatial_counter_names[:dim] @@ -119,6 +121,65 @@ class FullIterationSpace(IterationSpace): return FullIterationSpace(ctx, dimensions) + @staticmethod + def create_from_slice( + ctx: KernelCreationContext, + archetype_field: Field, + iteration_slice: Sequence[slice], + ): + archetype_array = ctx.get_array(archetype_field) + dim = archetype_field.spatial_dimensions + + if len(iteration_slice) != dim: + raise ValueError( + f"Number of dimensions in slice ({len(iteration_slice)}) " + f" did not equal iteration space dimensionality ({dim})" + ) + + counters = [ + PsTypedVariable(name, ctx.index_dtype) + for name in Defaults.spatial_counter_names[:dim] + ] + + from .freeze import FreezeExpressions + from .typification import Typifier + + freeze = FreezeExpressions(ctx) + typifier = Typifier(ctx) + + def to_pb(expr): + if isinstance(expr, int): + return PsTypedConstant(expr, ctx.index_dtype) + elif isinstance(expr, sp.Expr): + return typifier.typify_expression( + freeze.freeze_expression(expr), ctx.index_dtype + ) + else: + raise ValueError(f"Invalid entry in slice: {expr}") + + def to_dim(slic: slice, size: VarOrConstant, ctr: PsTypedVariable): + start = to_pb(slic.start if slic.start is not None else 0) + stop = to_pb(slic.stop) if slic.stop is not None else size + step = to_pb(slic.step if slic.step is not None else 1) + + if isinstance(slic.stop, int) and slic.stop < 0: + stop = size + stop + + return FullIterationSpace.Dimension(start, stop, step, ctr) + + dimensions = [ + to_dim(slic, size, ctr) + for slic, size, ctr in zip( + iteration_slice, archetype_array.shape[:dim], counters, strict=True + ) + ] + + # Determine loop order by permuting dimensions + loop_order = archetype_field.layout + dimensions = [dimensions[coordinate] for coordinate in loop_order] + + return FullIterationSpace(ctx, dimensions) + def __init__(self, ctx: KernelCreationContext, dimensions: Sequence[Dimension]): super().__init__(tuple(dim.counter for dim in dimensions)) @@ -150,13 +211,15 @@ class FullIterationSpace(IterationSpace): dim = self.dimensions[dimension] one = PsTypedConstant(1, self._ctx.index_dtype) return one + (dim.stop - dim.start - one) / dim.step - + def compressed_counter(self) -> ExprOrConstant: """Expression counting the actual number of items processed at the iteration defined by the counter tuple. - + Used primarily for indexing buffers.""" actual_iters = [self.actual_iterations(d) for d in range(self.dim)] - compressed_counters = [(dim.counter - dim.start) / dim.step for dim in self.dimensions] + compressed_counters = [ + (dim.counter - dim.start) / dim.step for dim in self.dimensions + ] compressed_idx = compressed_counters[0] for ctr, iters in zip(compressed_counters[1:], actual_iters[1:]): compressed_idx = compressed_idx * iters + ctr @@ -179,7 +242,7 @@ class SparseIterationSpace(IterationSpace): @property def index_list(self) -> PsLinearizedArray: return self._index_list - + @property def coordinate_members(self) -> tuple[PsStructType.Member, ...]: return self._coord_members @@ -224,7 +287,9 @@ def get_archetype_field( def create_sparse_iteration_space( - ctx: KernelCreationContext, assignments: AssignmentCollection, index_field: Field | None = None + ctx: KernelCreationContext, + assignments: AssignmentCollection, + index_field: Field | None = None, ) -> IterationSpace: # All domain and custom fields must have the same spatial dimensions # TODO: Must all domain fields have the same shape? @@ -264,19 +329,23 @@ def create_sparse_iteration_space( sparse_counter = PsTypedVariable(Defaults.sparse_counter_name, ctx.index_dtype) - return SparseIterationSpace(spatial_counters, idx_arr, coord_members, sparse_counter) + return SparseIterationSpace( + spatial_counters, idx_arr, coord_members, sparse_counter + ) def create_full_iteration_space( ctx: KernelCreationContext, assignments: AssignmentCollection, ghost_layers: None | int | Sequence[int | tuple[int, int]] = None, - iteration_slice: None | tuple[slice, ...] = None + iteration_slice: None | Sequence[slice] = None, ) -> IterationSpace: assert not ctx.fields.index_fields if (ghost_layers is not None) and (iteration_slice is not None): - raise ValueError("At most one of `ghost_layers` and `iteration_slice` may be specified.") + raise ValueError( + "At most one of `ghost_layers` and `iteration_slice` may be specified." + ) # Collect all relative accesses into domain fields def access_filter(acc: Field.Access): @@ -315,7 +384,9 @@ def create_full_iteration_space( ctx, archetype_field, ghost_layers ) elif iteration_slice is not None: - raise PsInternalCompilerError("Iteration slices not supported yet") + return FullIterationSpace.create_from_slice( + ctx, archetype_field, iteration_slice + ) else: return FullIterationSpace.create_with_ghost_layers( ctx, archetype_field, inferred_gls diff --git a/src/pystencils/backend/kernelcreation/transformations.py b/src/pystencils/backend/kernelcreation/transformations.py index c73fbce73..07e6a5b37 100644 --- a/src/pystencils/backend/kernelcreation/transformations.py +++ b/src/pystencils/backend/kernelcreation/transformations.py @@ -55,7 +55,9 @@ class EraseAnonymousStructTypes(IdentityMapper): bp = aggr.base_ptr type_erased_bp = TypeErasedBasePointer(bp.name, arr) - base_index = aggr.index_tuple[0] * PsTypedConstant(struct_size, self._ctx.index_dtype) + base_index = aggr.index_tuple[0] * PsTypedConstant( + struct_size, self._ctx.index_dtype + ) member_name = lookup.name member = struct_type.get_member(member_name) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 7c953c5c4..893f3dc03 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -147,7 +147,7 @@ class Typifier(Mapper): def typify_expression( self, expr: Any, target_type: PsNumericType | None = None - ) -> tuple[ExprOrConstant, PsNumericType]: + ) -> ExprOrConstant: tc = TypeContext(target_type) return self.rec(expr, tc) @@ -174,9 +174,7 @@ class Typifier(Mapper): def map_array_access(self, access: PsArrayAccess, tc: TypeContext) -> PsArrayAccess: self._apply_target_type(access, access.dtype, tc) - index = self.rec( - access.index_tuple[0], TypeContext(self._ctx.index_dtype) - ) + index = self.rec(access.index_tuple[0], TypeContext(self._ctx.index_dtype)) return PsArrayAccess(access.base_ptr, index) def map_lookup(self, lookup: pb.Lookup, tc: TypeContext) -> pb.Lookup: @@ -204,6 +202,9 @@ class Typifier(Mapper): def map_product(self, expr: pb.Product, tc: TypeContext) -> pb.Product: return pb.Product(tuple(self.rec(c, tc) for c in expr.children)) + def map_quotient(self, expr: pb.Quotient, tc: TypeContext) -> pb.Quotient: + return pb.Quotient(self.rec(expr.num, tc), self.rec(expr.den, tc)) + # Functions def map_call(self, expr: pb.Call, tc: TypeContext) -> pb.Call: diff --git a/src/pystencils/backend/platforms/__init__.py b/src/pystencils/backend/platforms/__init__.py index 20e2c0aae..72eb2b762 100644 --- a/src/pystencils/backend/platforms/__init__.py +++ b/src/pystencils/backend/platforms/__init__.py @@ -1,5 +1,3 @@ from .basic_cpu import BasicCpu -__all__ = [ - 'BasicCpu' -] +__all__ = ["BasicCpu"] diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py index 3abeebbe6..8013837f4 100644 --- a/src/pystencils/backend/platforms/platform.py +++ b/src/pystencils/backend/platforms/platform.py @@ -8,17 +8,20 @@ from ..kernelcreation.iteration_space import IterationSpace class Platform(ABC): """Abstract base class for all supported platforms. - + The platform performs all target-dependent tasks during code generation: - + - Translation of the iteration space to an index source (loop nest, GPU indexing, ...) - Platform-specific optimizations (e.g. vectorization, OpenMP) """ + def __init__(self, ctx: KernelCreationContext) -> None: self._ctx = ctx @abstractmethod - def materialize_iteration_space(self, block: PsBlock, ispace: IterationSpace) -> PsBlock: + def materialize_iteration_space( + self, block: PsBlock, ispace: IterationSpace + ) -> PsBlock: ... @abstractmethod diff --git a/src/pystencils/backend/typed_expressions.py b/src/pystencils/backend/typed_expressions.py index a6f74fb82..15e4278fc 100644 --- a/src/pystencils/backend/typed_expressions.py +++ b/src/pystencils/backend/typed_expressions.py @@ -221,7 +221,7 @@ class PsTypedConstant: return PsTypedConstant(rem, self._dtype) def __neg__(self): - return PsTypedConstant(- self._value, self._dtype) + return PsTypedConstant(-self._value, self._dtype) def __bool__(self): return bool(self._value) diff --git a/src/pystencils/backend/types/basic_types.py b/src/pystencils/backend/types/basic_types.py index 9df6858e9..49d15968e 100644 --- a/src/pystencils/backend/types/basic_types.py +++ b/src/pystencils/backend/types/basic_types.py @@ -211,7 +211,7 @@ class PsStructType(PsAbstractType): def numpy_dtype(self) -> np.dtype: members = [(m.name, m.dtype.numpy_dtype) for m in self._members] return np.dtype(members) - + @property def itemsize(self) -> int: return self.numpy_dtype.itemsize @@ -222,7 +222,7 @@ class PsStructType(PsAbstractType): "Cannot retrieve C string for anonymous struct type" ) return self._name - + def __str__(self) -> str: if self._name is None: return "<anonymous>" diff --git a/src/pystencils/backend/types/exception.py b/src/pystencils/backend/types/exception.py index 7c0cb97af..9cf7db5af 100644 --- a/src/pystencils/backend/types/exception.py +++ b/src/pystencils/backend/types/exception.py @@ -1,4 +1,2 @@ - - class PsTypeError(Exception): """Indicates a type error in the pystencils AST.""" diff --git a/src/pystencils/config.py b/src/pystencils/config.py index 3c6b40c66..ad5795a87 100644 --- a/src/pystencils/config.py +++ b/src/pystencils/config.py @@ -10,7 +10,6 @@ from .backend.types import PsIntegerType, PsNumericType, PsIeeeFloatType from .backend.kernelcreation.defaults import Sympy as SpDefaults -from .enums import Target @dataclass class CreateKernelConfig: @@ -46,7 +45,7 @@ class CreateKernelConfig: If `ghost_layers=None` is specified, the iteration region may otherwise be set using the `iteration_slice` option. """ - iteration_slice: None | tuple[slice, ...] = None + iteration_slice: None | Sequence[slice] = None """Specifies the kernel's iteration slice. `iteration_slice` may only be set if `ghost_layers = None`. diff --git a/tests/nbackend/kernelcreation/test_iteration_space.py b/tests/nbackend/kernelcreation/test_iteration_space.py index 6fe905f73..477684ff3 100644 --- a/tests/nbackend/kernelcreation/test_iteration_space.py +++ b/tests/nbackend/kernelcreation/test_iteration_space.py @@ -1,12 +1,19 @@ +import pytest + +import pymbolic.primitives as pb from pystencils.field import Field +from pystencils.sympyextensions.typed_sympy import TypedSymbol, create_type from pystencils.backend.kernelcreation import ( KernelCreationContext, FullIterationSpace ) +from pystencils.backend.kernelcreation.typification import TypificationError from pystencils.backend.kernelcreation.defaults import Pymbolic as PbDefaults +from pystencils.backend.typed_expressions import PsTypedConstant + def test_loop_order(): ctx = KernelCreationContext() @@ -46,3 +53,42 @@ def test_loop_order(): for dim, ctr in zip(ispace.dimensions, [ctr_symbols[2], ctr_symbols[0], ctr_symbols[1]]): assert dim.counter == ctr + + +def test_slices(): + ctx = KernelCreationContext() + ctr_symbols = PbDefaults.spatial_counters + + archetype_field = Field.create_generic("f", spatial_dimensions=3, layout='fzyx') + ctx.add_field(archetype_field) + + islice = (slice(1, -1, 1), slice(3, -3, 3), slice(0, None, -1)) + ispace = FullIterationSpace.create_from_slice(ctx, archetype_field, islice) + + archetype_arr = ctx.get_array(archetype_field) + + dims = ispace.dimensions[::-1] + + for sl, size, dim in zip(islice, archetype_arr.shape, dims): + assert isinstance(dim.start, PsTypedConstant) and dim.start.value == sl.start + assert isinstance(dim.step, PsTypedConstant) and dim.step.value == sl.step + + assert isinstance(dims[0].stop, pb.Sum) and archetype_arr.shape[0] in dims[0].stop.children + assert isinstance(dims[1].stop, pb.Sum) and archetype_arr.shape[1] in dims[1].stop.children + assert dims[2].stop == archetype_arr.shape[2] + + +def test_invalid_slices(): + ctx = KernelCreationContext() + ctr_symbols = PbDefaults.spatial_counters + + archetype_field = Field.create_generic("f", spatial_dimensions=1, layout='fzyx') + ctx.add_field(archetype_field) + + islice = (slice(1, -1, 0.5),) + with pytest.raises(ValueError): + FullIterationSpace.create_from_slice(ctx, archetype_field, islice) + + islice = (slice(1, -1, TypedSymbol("w", dtype=create_type("double"))),) + with pytest.raises(TypificationError): + FullIterationSpace.create_from_slice(ctx, archetype_field, islice) diff --git a/tests/test_buffer.py b/tests/test_buffer.py index 32d2eade1..0620e0540 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -1,4 +1,3 @@ -#%% """Tests (un)packing (from)to buffers.""" import pytest @@ -194,7 +193,6 @@ def test_field_layouts(): unpack_kernel = unpack_code.compile() unpack_kernel(buffer=bufferArr, dst_field=dst_arr) -@pytest.mark.xfail(reason="iteration slices not implemented yet") def test_iteration_slices(): num_cell_values = 19 dt = np.uint64 @@ -205,7 +203,7 @@ def test_iteration_slices(): # dst_field = Field.create_from_numpy_array("dst_field", dst_arr, index_dimensions=1) src_field = Field.create_generic("src_field", spatial_dimensions, index_shape=(num_cell_values,), dtype=dt) dst_field = Field.create_generic("dst_field", spatial_dimensions, index_shape=(num_cell_values,), dtype=dt) - buffer = Field.create_generic("buffer", spatial_dimensions=1, index_dimensions=1, + buffer = Field.create_generic("buffer", spatial_dimensions=1, index_shape=(num_cell_values,), field_type=FieldType.BUFFER, dtype=src_arr.dtype) pack_eqs = [] @@ -218,7 +216,7 @@ def test_iteration_slices(): dim = src_field.spatial_dimensions # Pack only the leftmost slice, only every second cell - pack_slice = (slice(None, None, 2),) * (dim - 1) + (0,) + pack_slice = (slice(None, None, 2),) * (dim - 1) + (slice(0, 1, None),) # Fill the entire array with data src_arr[(slice(None, None, 1),) * dim] = np.arange(num_cell_values) @@ -247,5 +245,4 @@ def test_iteration_slices(): np.testing.assert_equal(dst_arr[(slice(1, None, 2),) * (dim - 1) + (0,)], 0) np.testing.assert_equal(dst_arr[(slice(None, None, 1),) * (dim - 1) + (slice(1, None),)], 0) -#%% # test_all_cell_values() \ No newline at end of file -- GitLab