diff --git a/pystencils/__init__.py b/pystencils/__init__.py index f9d64fa006d6afdfa881971db037f9bb21721fcf..bfe5b3da4a27d3770afa067385381e5d8f3008dd 100644 --- a/pystencils/__init__.py +++ b/pystencils/__init__.py @@ -9,8 +9,7 @@ from .display_utils import get_code_obj, get_code_str, show_code, to_dot from .field import Field, FieldType, fields from .config import CreateKernelConfig from .kernel_decorator import kernel, kernel_config -from .kernelcreation import ( - create_domain_kernel, create_indexed_kernel, create_kernel, create_staggered_kernel) +from .kernelcreation import create_kernel from .simp import AssignmentCollection from .slicing import make_slice from .spatial_coordinates import x_, x_staggered, x_staggered_vector, x_vector, y_, y_staggered, z_, z_staggered @@ -20,7 +19,7 @@ __all__ = ['Field', 'FieldType', 'fields', 'TypedSymbol', 'make_slice', 'CreateKernelConfig', - 'create_kernel', 'create_domain_kernel', 'create_indexed_kernel', 'create_staggered_kernel', + 'create_kernel', 'Target', 'Backend', 'show_code', 'to_dot', 'get_code_obj', 'get_code_str', 'AssignmentCollection', diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index b3799f730520279e8edc288c0e85fc5cf092128e..4f7d82b6eb9d4d96e15a0be1ccbfbfa36de0cf68 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -12,11 +12,11 @@ from sympy.logic.boolalg import BooleanFalse, BooleanTrue from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize from pystencils.typing import ( - PointerType, VectorType, address_of, CastFunc, create_type, get_type_of_expression, + PointerType, VectorType, CastFunc, create_type, get_type_of_expression, ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol) from pystencils.enums import Backend from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt -from pystencils.functions import DivFunc +from pystencils.functions import DivFunc, AddressOf from pystencils.integer_functions import ( bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil) @@ -31,8 +31,6 @@ __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSy HEADER_REGEX = re.compile(r'^[<"].*[">]$') -KERNCRAFT_NO_TERNARY_MODE = False - def generate_c(ast_node: Node, signature_only: bool = False, @@ -275,7 +273,7 @@ class CBackend: self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs)) else: - lhs_type = get_type_of_expression(node.lhs) + lhs_type = get_type_of_expression(node.lhs) # TOOD: this should have been typed printed_mask = "" if type(lhs_type) is VectorType and isinstance(node.lhs, CastFunc): arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args @@ -301,7 +299,7 @@ class CBackend: elif self._vector_instruction_set['float'] == '__m128': printed_mask = f"_mm_castps_si128({printed_mask})" - rhs_type = get_type_of_expression(node.rhs) + rhs_type = get_type_of_expression(node.rhs) # TOOD: vector only??? if type(rhs_type) is not VectorType: rhs = CastFunc(node.rhs, VectorType(rhs_type)) else: @@ -417,7 +415,7 @@ class CBackend: return self._print_Block(node.true_block) elif type(node.condition_expr) is BooleanFalse: return self._print_Block(node.false_block) - cond_type = get_type_of_expression(node.condition_expr) + cond_type = get_type_of_expression(node.condition_expr) # TODO: Could be vector or bool? if isinstance(cond_type, VectorType): raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all") condition_expr = self.sympy_printer.doprint(node.condition_expr) @@ -441,7 +439,8 @@ class CustomSympyPrinter(CCodePrinter): def _print_Pow(self, expr): """Don't use std::pow function, for small integer exponents, write as multiplication""" if not expr.free_symbols: - return self._typed_number(expr.evalf(), get_type_of_expression(expr.base)) + raise NotImplementedError("This pow should be simplified already?") + # return self._typed_number(expr.evalf(), get_type_of_expression(expr.base)) return super(CustomSympyPrinter, self)._print_Pow(expr) # TODO don't print ones in sp.Mul @@ -482,12 +481,12 @@ class CustomSympyPrinter(CCodePrinter): if isinstance(expr, ReinterpretCastFunc): arg, data_type = expr.args return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))" - elif isinstance(expr, address_of): + elif isinstance(expr, AddressOf): assert len(expr.args) == 1, "address_of must only have one argument" return f"&({self._print(expr.args[0])})" elif isinstance(expr, CastFunc): arg, data_type = expr.args - if arg.is_Number: + if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)): return self._typed_number(arg, data_type) else: return f"(({data_type})({self._print(arg)}))" @@ -810,11 +809,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): result = self._print(expr.args[-1][0]) for true_expr, condition in reversed(expr.args[:-1]): if isinstance(condition, CastFunc) and get_type_of_expression(condition.args[0]) == create_type("bool"): - if not KERNCRAFT_NO_TERNARY_MODE: - result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr), - result, **self._kwargs) - else: - print("Warning - skipping ternary op") + result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr), + result, **self._kwargs) else: # noinspection SpellCheckingInspection result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition), diff --git a/pystencils/config.py b/pystencils/config.py index 936a92cf231749e571de771c17f3559da623c7de..22b0a5ffc0138cce1d715b5183d6bc56cde1c35c 100644 --- a/pystencils/config.py +++ b/pystencils/config.py @@ -125,7 +125,7 @@ class CreateKernelConfig: def __post_init__(self): # ---- Legacy parameters - # TODO adapt here the types + # TODO adapt here the types for example "float", python float, everything ambiguous should not be allowed if isinstance(self.target, str): new_target = Target[self.target.upper()] warnings.warn(f'Target "{self.target}" as str is deprecated. Use {new_target} instead', diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py index 2861d671fa8eb1232bc8750dfc7261aab5007301..ca4f267944de80d45e1754ca6d32cf7741b7000a 100644 --- a/pystencils/cpu/cpujit.py +++ b/pystencils/cpu/cpujit.py @@ -265,6 +265,7 @@ def clear_cache(): create_folder(cache_config['object_cache'], False) +# TODO don't hardcode C type. [1] of tuple output type_mapping = { np.float32: ('PyFloat_AsDouble', 'float'), np.float64: ('PyFloat_AsDouble', 'double'), @@ -274,8 +275,6 @@ type_mapping = { np.uint16: ('PyLong_AsUnsignedLong', 'uint16_t'), np.uint32: ('PyLong_AsUnsignedLong', 'uint32_t'), np.uint64: ('PyLong_AsUnsignedLong', 'uint64_t'), - np.complex64: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexFloat'), - np.complex128: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexDouble'), } template_extract_scalar = """ @@ -285,14 +284,6 @@ if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument ' if( PyErr_Occurred() ) {{ return NULL; }} """ -template_extract_complex = """ -PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}"); -if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }}; -{target_type} {name}{{ ({real_type}) {extract_function_real}( obj_{name} ), - ({real_type}) {extract_function_imag}( obj_{name} ) }}; -if( PyErr_Occurred() ) {{ return NULL; }} -""" - template_extract_array = """ PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}"); if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }}; @@ -453,17 +444,9 @@ def create_function_boilerplate_code(parameter_info, name, ast_node, insert_chec continue else: extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type] - if np.issubdtype(param.symbol.dtype.numpy_dtype, np.complexfloating): - pre_call_code += template_extract_complex.format(extract_function_real=extract_function[0], - extract_function_imag=extract_function[1], - target_type=target_type, - real_type="float" if target_type == "ComplexFloat" - else "double", - name=param.symbol.name) - else: - pre_call_code += template_extract_scalar.format(extract_function=extract_function, - target_type=target_type, - name=param.symbol.name) + pre_call_code += template_extract_scalar.format(extract_function=extract_function, + target_type=target_type, + name=param.symbol.name) parameters.append(param.symbol.name) diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py index 4e0573c89e9b66c92f81a97a97256460e75f711e..a937d3cf274290b7fe345c841878ccd2a396f9a8 100644 --- a/pystencils/cpu/kernelcreation.py +++ b/pystencils/cpu/kernelcreation.py @@ -12,13 +12,15 @@ from pystencils.cpu.cpujit import make_python_function from pystencils.typing import StructType, TypedSymbol, create_type from pystencils.typing.transformations import add_types from pystencils.field import Field, FieldType +from pystencils.node_collection import NodeCollection from pystencils.transformations import ( filtered_tree_iteration, get_base_buffer_index, get_optimal_loop_ordering, make_loop_over_domain, move_constants_before_loop, parse_base_pointer_info, resolve_buffer_accesses, resolve_field_accesses, split_inner_loop) -def create_kernel(assignments: AssignmentCollection, config: CreateKernelConfig) -> KernelFunction: +def create_kernel(assignments: Union[AssignmentCollection, NodeCollection], + config: CreateKernelConfig) -> KernelFunction: """Creates an abstract syntax tree for a kernel function, by taking a list of update rules. Loops are created according to the field accesses in the equations. @@ -36,7 +38,7 @@ def create_kernel(assignments: AssignmentCollection, config: CreateKernelConfig) iteration_slice = config.iteration_slice ghost_layers = config.ghost_layers fields_written = assignments.bound_fields - fields_read = assignments.free_fields + fields_read = assignments.rhs_fields split_groups = () if 'split_groups' in assignments.simplification_hints: @@ -88,6 +90,7 @@ def create_kernel(assignments: AssignmentCollection, config: CreateKernelConfig) if any(FieldType.is_buffer(f) for f in all_fields): resolve_buffer_accesses(ast_node, get_base_buffer_index(ast_node), read_only_fields) + # TODO think about typing resolve_field_accesses(ast_node, read_only_fields, field_to_base_pointer_info=base_pointer_info) move_constants_before_loop(ast_node) return ast_node diff --git a/pystencils/functions.py b/pystencils/functions.py index b1f349622208f801054990eded2d7b51b77e8cab..c499d04ed6fcb1274639eb9791955db69c9619bf 100644 --- a/pystencils/functions.py +++ b/pystencils/functions.py @@ -1,4 +1,5 @@ import sympy as sp +from pystencils.typing import PointerType class DivFunc(sp.Function): @@ -24,3 +25,30 @@ class DivFunc(sp.Function): def dividend(self): return self.args[1] + +class AddressOf(sp.Function): + # TODO: docstring + # this is '&' in C + is_Atom = True + + def __new__(cls, arg): + obj = sp.Function.__new__(cls, arg) + return obj + + @property + def canonical(self): + if hasattr(self.args[0], 'canonical'): + return self.args[0].canonical + else: + raise NotImplementedError() + + @property + def is_commutative(self): + return self.args[0].is_commutative + + @property + def dtype(self): + if hasattr(self.args[0], 'dtype'): + return PointerType(self.args[0].dtype, restrict=True) + else: + raise ValueError(f'pystencils supports only non void pointers. Current address_of type: {self.args[0]}') diff --git a/pystencils/kernel_contrains_check.py b/pystencils/kernel_contrains_check.py index 42204822a29321d611189ae9f9afd021344a5305..7456e3f618269a8e8da5fca5a151e84810c700f6 100644 --- a/pystencils/kernel_contrains_check.py +++ b/pystencils/kernel_contrains_check.py @@ -7,6 +7,7 @@ from sympy.codegen import Assignment from pystencils.simp import AssignmentCollection from pystencils import astnodes as ast, TypedSymbol from pystencils.field import Field +from pystencils.node_collection import NodeCollection from pystencils.transformations import NestedScopes @@ -42,7 +43,7 @@ class KernelConstraintsCheck: self.check_double_write_condition = check_double_write_condition def visit(self, obj): - if isinstance(obj, AssignmentCollection): + if isinstance(obj, (AssignmentCollection, NodeCollection)): [self.visit(e) for e in obj.all_assignments] elif isinstance(obj, list) or isinstance(obj, tuple): [self.visit(e) for e in obj] diff --git a/pystencils/kernel_decorator.py b/pystencils/kernel_decorator.py index 56453c5d4c4d18a9b5345edbd6e78d015ae322b8..a8db7cb979008b76cde1ea64d4856df93436b43e 100644 --- a/pystencils/kernel_decorator.py +++ b/pystencils/kernel_decorator.py @@ -89,7 +89,8 @@ def kernel_config(config: CreateKernelConfig, **kwargs) -> Callable[..., Dict]: decorator with config Examples: -import pystencils.kernel_creation_config >>> import pystencils as ps + >>> import pystencils.kernel_creation_config + >>> import pystencils as ps >>> config = pystencils.kernel_creation_config.CreateKernelConfig() >>> @kernel_config(config) ... def my_kernel(s): diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py index 2635eb035e7b3ff7ef84fec071bdbd8ef96019d8..8bfee1b21a70e2c660a869f44ba876fcc2aae5d0 100644 --- a/pystencils/kernelcreation.py +++ b/pystencils/kernelcreation.py @@ -1,4 +1,5 @@ import itertools +import logging import warnings from typing import Union, List @@ -6,10 +7,11 @@ import sympy as sp from pystencils.config import CreateKernelConfig from pystencils.assignment import Assignment -from pystencils.astnodes import Block, Conditional, LoopOverCoordinate, SympyAssignment +from pystencils.astnodes import Node, Block, Conditional, LoopOverCoordinate, SympyAssignment from pystencils.cpu.vectorization import vectorize from pystencils.enums import Target, Backend from pystencils.field import Field, FieldType +from pystencils.node_collection import NodeCollection from pystencils.gpucuda.indexing import indexing_creator_from_params from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.kernel_contrains_check import KernelConstraintsCheck @@ -19,7 +21,7 @@ from pystencils.transformations import ( loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel) -def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCollection, List[Conditional]], *, +def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCollection, List[Node]], *, config: CreateKernelConfig = None, **kwargs): """ Creates abstract syntax tree (AST) of kernel, using a list of update equations. @@ -63,7 +65,14 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol assignments = [assignments] assert assignments, "Assignments must not be empty!" if isinstance(assignments, list): - assignments = AssignmentCollection(assignments) + if all((isinstance(a, Assignment) for a in assignments)): + assignments = AssignmentCollection(assignments) + elif all((isinstance(n, Node) for n in assignments)): + assignments = NodeCollection(assignments) + logging.warning('Using Nodes is experimental and not fully tested. Double check your generated code!') + else: + raise ValueError(f'The list "{assignments}" is mixed. Pass either a list of "pystencils.Assignments" ' + f'or a list of "pystencils.astnodes.Node') if config.index_fields: return create_indexed_kernel(assignments, config=config) @@ -71,7 +80,7 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol return create_domain_kernel(assignments, config=config) -def create_domain_kernel(assignments: AssignmentCollection, *, config: CreateKernelConfig): +def create_domain_kernel(assignments: Union[AssignmentCollection, NodeCollection], *, config: CreateKernelConfig): """ Creates abstract syntax tree (AST) of kernel, using a list of update equations. @@ -84,13 +93,13 @@ def create_domain_kernel(assignments: AssignmentCollection, *, config: CreateKer can be compiled with through its 'compile()' member Example: - # TODO change to assignment collection >>> import pystencils as ps >>> import numpy as np + >>> from pystencils.kernelcreation import create_domain_kernel >>> s, d = ps.fields('s, d: [2D]') >>> assignment = ps.Assignment(d[0,0], s[0, 1] + s[0, -1] + s[1, 0] + s[-1, 0]) >>> kernel_config = ps.CreateKernelConfig(cpu_openmp=True) - >>> kernel_ast = ps.kernelcreation.create_domain_kernel([assignment], config=kernel_config) + >>> kernel_ast = create_domain_kernel(ps.AssignmentCollection([assignment]), config=kernel_config) >>> kernel = kernel_ast.compile() >>> d_arr = np.zeros([5, 5]) >>> kernel(d=d_arr, s=np.ones([5, 5])) @@ -103,15 +112,16 @@ def create_domain_kernel(assignments: AssignmentCollection, *, config: CreateKer """ # --- applying first default simplifications - try: - if config.default_assignment_simplifications and isinstance(assignments, AssignmentCollection): - simplification = create_simplification_strategy() - assignments = simplification(assignments) - except Exception as e: - warnings.warn(f"It was not possible to apply the default pystencils optimisations to the " - f"AssignmentCollection due to the following problem :{e}") + if isinstance(assignments, AssignmentCollection): + try: + if config.default_assignment_simplifications and isinstance(assignments, AssignmentCollection): + simplification = create_simplification_strategy() + assignments = simplification(assignments) + except Exception as e: + warnings.warn(f"It was not possible to apply the default pystencils optimisations to the " + f"AssignmentCollection due to the following problem :{e}") - assignments.evaluate_terms() + assignments.evaluate_terms() # --- eval # TODO split apply_sympy_optimisations and do the eval here @@ -121,8 +131,13 @@ def create_domain_kernel(assignments: AssignmentCollection, *, config: CreateKer check = KernelConstraintsCheck(check_independence_condition=config.skip_independence_check, check_double_write_condition=config.allow_double_writes) check.visit(assignments) - assert assignments.bound_fields == check.fields_written, f'WTF' - assert assignments.rhs_fields == check.fields_read, f'WTF' + + if isinstance(assignments, AssignmentCollection): + assert assignments.bound_fields == check.fields_written, f'WTF' + assert assignments.rhs_fields == check.fields_read, f'WTF' + else: + assignments.bound_fields = check.fields_written + assignments.rhs_fields = check.fields_read # ---- Creating ast ast = None @@ -191,8 +206,10 @@ def create_indexed_kernel(assignments: AssignmentCollection, *, config: CreateKe can be compiled with through its 'compile()' member Example: -import pystencils.kernel_creation_config >>> import pystencils as ps + >>> import pystencils.kernel_creation_config + >>> import pystencils as ps >>> import numpy as np + >>> from pystencils.kernelcreation import create_indexed_kernel >>> >>> # Index field stores the indices of the cell to visit together with optional values >>> index_arr_dtype = np.dtype([('x', np.int32), ('y', np.int32), ('val', np.double)]) @@ -202,17 +219,17 @@ import pystencils.kernel_creation_config >>> import pystencils as ps >>> # Additional values stored in index field can be accessed in the kernel as well >>> s, d = ps.fields('s, d: [2D]') >>> assignment = ps.Assignment(d[0, 0], 2 * s[0, 1] + 2 * s[1, 0] + idx_field('val')) - >>> kernel_config = pystencils.kernel_creation_config.CreateKernelConfig(index_fields=[idx_field], coordinate_names=('x', 'y')) - >>> kernel_ast = ps.create_indexed_kernel([assignment], config=kernel_config) + >>> kernel_config = ps.CreateKernelConfig(index_fields=[idx_field], coordinate_names=('x', 'y')) + >>> kernel_ast = create_indexed_kernel(ps.AssignmentCollection([assignment]), config=kernel_config) >>> kernel = kernel_ast.compile() >>> d_arr = np.zeros([5, 5]) >>> kernel(s=np.ones([5, 5]), d=d_arr, idx=index_arr) >>> d_arr - array([[0. , 0. , 0. , 0. , 0. ], - [0. , 4.1, 0. , 0. , 0. ], - [0. , 0. , 4.2, 0. , 0. ], - [0. , 0. , 0. , 4.3, 0. ], - [0. , 0. , 0. , 0. , 0. ]]) + array([[0., 0., 0., 0., 0.], + [0., 4.1, 0., 0., 0.], + [0., 0., 4.2, 0., 0.], + [0., 0., 0., 4.3, 0.], + [0., 0., 0., 0., 0.]]) """ # TODO do this in backends assignments = assignments.all_assignments @@ -260,6 +277,7 @@ def create_staggered_kernel(assignments, target: Target = Target.CPU, gpu_exclus Returns: AST, see `create_kernel` """ + # TODO: Add doku like in the other kernels if 'ghost_layers' in kwargs: assert kwargs['ghost_layers'] is None del kwargs['ghost_layers'] diff --git a/pystencils/node_collection.py b/pystencils/node_collection.py new file mode 100644 index 0000000000000000000000000000000000000000..545f3014930f1538f20900ff346f6277a6d58a23 --- /dev/null +++ b/pystencils/node_collection.py @@ -0,0 +1,15 @@ +from typing import List +from pystencils.astnodes import Node + + +# TODO ABC for NodeCollection and AssignmentCollection +class NodeCollection: + def __init__(self, nodes: List[Node]): + self.nodes = nodes + self.bound_fields = None + self.rhs_fields = None + self.simplification_hints = () + + @property + def all_assignments(self): + return self.nodes diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 8fd4dfbc020e0dcd62a24e5eadf3b40f1a3a5718..7f864f9af193c7ebe75f830d7adfaa595e10572b 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -353,8 +353,12 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): assert len(loops) == len(parents_of_innermost_loop) assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop)) - actual_sizes = [int_div((l.stop - l.start), l.step) for l in loops] - actual_steps = [int_div((l.loop_counter_symbol - l.start), l.step) for l in loops] + actual_sizes = [int_div((loop.stop - loop.start), loop.step) + if loop.step != 1 else loop.stop - loop.start for loop in loops] + + actual_steps = [int_div((loop.loop_counter_symbol - loop.start), loop.step) + if loop.step != 1 else loop.loop_counter_symbol - loop.start for loop in loops] + else: actual_sizes = loop_iterations actual_steps = loop_counters diff --git a/pystencils/typing/cast_functions.py b/pystencils/typing/cast_functions.py index 7d3f3d892797b798f6c980e7808f4be3b2a0bc7f..b70c8bb483a3dfe7a25bb7f45191bebb871cf18c 100644 --- a/pystencils/typing/cast_functions.py +++ b/pystencils/typing/cast_functions.py @@ -45,9 +45,6 @@ class CastFunc(sp.Function): def is_commutative(self): return self.args[0].is_commutative - def _eval_evalf(self, *args, **kwargs): - return self.args[0].evalf() - @property def dtype(self): return self.args[1] @@ -120,6 +117,7 @@ class ReinterpretCastFunc(CastFunc): class PointerArithmeticFunc(sp.Function, Boolean): # TODO: documentation + # TODO wtf is this???? @property def canonical(self): if hasattr(self.args[0], 'canonical'): diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index 04bfacbf4886ee9684ed02c0ed1ad71ba44d1bd4..8b47fb1137c1381335c912c4df09ee469a790387 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -1,5 +1,5 @@ -from collections import namedtuple, defaultdict -from typing import Union, Tuple, Any +from collections import namedtuple +from typing import Union, Tuple, Any, DefaultDict import logging import numpy as np @@ -13,9 +13,9 @@ from sympy.logic.boolalg import BooleanFunction from sympy.logic.boolalg import BooleanAtom from pystencils import astnodes as ast -from pystencils.functions import DivFunc +from pystencils.functions import DivFunc, AddressOf from pystencils.field import Field -from pystencils.typing.types import BasicType, create_type +from pystencils.typing.types import BasicType, create_type, PointerType from pystencils.typing.utilities import get_type_of_expression, collate_types from pystencils.typing.cast_functions import CastFunc, BooleanCastFunc from pystencils.typing.typed_sympy import TypedSymbol @@ -40,7 +40,7 @@ class TypeAdder: """ FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index']) - def __init__(self, type_for_symbol: defaultdict[str, BasicType], default_number_float: BasicType, + def __init__(self, type_for_symbol: DefaultDict[str, BasicType], default_number_float: BasicType, default_number_int: BasicType): self.type_for_symbol = type_for_symbol self.default_number_float = ContextVar(default_number_float) @@ -48,7 +48,6 @@ class TypeAdder: # TODO: check if this adds only types to leave nodes of AST, get type info def visit(self, obj): - if isinstance(obj, (list, tuple)): return [self.visit(e) for e in obj] if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)): @@ -105,7 +104,7 @@ class TypeAdder: # - Mixture in expression with int and float # - Mixture in expression with uint64 and sint64 # TODO: Lowest log level should log all casts ----> cast factory, make cast should contain logging - def figure_out_type(self, expr) -> Tuple[Any, BasicType]: # TODO or abstract type? vector type? + def figure_out_type(self, expr) -> Tuple[Any, Union[BasicType, PointerType]]: # Trivial cases from pystencils.field import Field import pystencils.integer_functions @@ -117,10 +116,12 @@ class TypeAdder: elif isinstance(expr, TypedSymbol): return expr, expr.dtype elif isinstance(expr, sp.Symbol): - t = TypedSymbol(expr.name, self.type_for_symbol[expr.name]) # TODO with or without name + t = TypedSymbol(expr.name, self.type_for_symbol[expr.name]) return t, t.dtype elif isinstance(expr, np.generic): assert False, f'Why do we have a np.generic in rhs???? {expr}' + elif isinstance(expr, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)): + return expr, BasicType('float32') # see https://en.cppreference.com/w/cpp/numeric/math/INFINITY elif isinstance(expr, sp.Number): if expr.is_Integer: data_type = self.default_number_int.get() @@ -129,6 +130,11 @@ class TypeAdder: else: assert False, f'{sp.Number} is neither Float nor Integer' return CastFunc(expr, data_type), data_type + elif isinstance(expr, AddressOf): + of = expr.args[0] + # TODO Basically this should do address_of already + assert isinstance(of, (Field.Access, TypedSymbol, Field)) + return expr, PointerType(of.dtype) elif isinstance(expr, BooleanAtom): return expr, bool_type elif isinstance(expr, Relational): @@ -197,13 +203,18 @@ class TypeAdder: else: new_args.append(a) return expr.func(*new_args) if new_args else expr, collated_type - elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc)): + elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)): args_types = [self.figure_out_type(arg) for arg in expr.args] collated_type = collate_types([t for _, t in args_types]) + if isinstance(collated_type, PointerType): + if isinstance(expr, sp.Add): + return expr.func(*[a for a, _ in args_types]), collated_type + else: + raise NotImplementedError(f'Pointer Arithmetic is implemented only for Add, not {expr}') new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types] return expr.func(*new_args) if new_args else expr, collated_type else: - raise NotImplementedError(f'expr {expr} unknown to typing') + raise NotImplementedError(f'expr {type(expr)}: {expr} unknown to typing') def process_expression(self, rhs, type_constants=True): # TODO DELETE import pystencils.integer_functions diff --git a/pystencils/typing/utilities.py b/pystencils/typing/utilities.py index 7779017ce28dcff93c82d37d5d8314ce156b463a..821a5d2271cd949cf33428b46cd90edf593b1771 100644 --- a/pystencils/typing/utilities.py +++ b/pystencils/typing/utilities.py @@ -23,36 +23,6 @@ def typed_symbols(names, dtype, *args): return TypedSymbol(str(symbols), dtype) -# noinspection PyPep8Naming -class address_of(sp.Function): - # DONE: ask Martin - # TODO: docstring - # this is '&' in C - is_Atom = True - - def __new__(cls, arg): - obj = sp.Function.__new__(cls, arg) - return obj - - @property - def canonical(self): - if hasattr(self.args[0], 'canonical'): - return self.args[0].canonical - else: - raise NotImplementedError() - - @property - def is_commutative(self): - return self.args[0].is_commutative - - @property - def dtype(self): - if hasattr(self.args[0], 'dtype'): - return PointerType(self.args[0].dtype, restrict=True) - else: - return PointerType('void', restrict=True) # TODO this shouldn't work??? FIX: Allow BasicType to be Void and use that. Or raise exception - - def get_base_type(data_type): # TODO: WTF is this?? DOCS!!! # TODO: This is unsafe. @@ -96,21 +66,21 @@ def collate_types(types: Sequence[Union[BasicType, VectorType]]): Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double Uses the collation rules from numpy. """ - # # Pointer arithmetic case i.e. pointer + integer is allowed - # if any(type(t) is PointerType for t in types): - # pointer_type = None - # for t in types: - # if type(t) is PointerType: - # if pointer_type is not None: - # raise ValueError("Cannot collate the combination of two pointer types") - # pointer_type = t - # elif type(t) is BasicType: - # if not (t.is_int() or t.is_uint()): - # raise ValueError("Invalid pointer arithmetic") - # else: - # raise ValueError("Invalid pointer arithmetic") - # return pointer_type - # + # Pointer arithmetic case i.e. pointer + [int, uint] is allowed + if any(isinstance(t, PointerType) for t in types): + pointer_type = None + for t in types: + if isinstance(t, PointerType): + if pointer_type is not None: + raise ValueError(f'Cannot collate the combination of two pointer types "{pointer_type}" and "{t}"') + pointer_type = t + elif isinstance(t, BasicType): + if not (t.is_int() or t.is_uint()): + raise ValueError("Invalid pointer arithmetic") + else: + raise ValueError("Invalid pointer arithmetic") + return pointer_type + # # peel of vector types, if at least one vector type occurred the result will also be the vector type vector_type = [t for t in types if isinstance(t, VectorType)] # if not all_equal(t.width for t in vector_type): diff --git a/pystencils_tests/test_address_of.py b/pystencils_tests/test_address_of.py index 1cb9c8ed167cffbe184b41ad4f47c804470d4cec..c0a75e540237aa2e0ff46af37859c19cc069ce59 100644 --- a/pystencils_tests/test_address_of.py +++ b/pystencils_tests/test_address_of.py @@ -1,48 +1,50 @@ """ Test of pystencils.data_types.address_of """ -import sympy as sp +import pytest import pystencils -from pystencils.typing import PointerType, address_of, CastFunc, create_type +from pystencils.typing import PointerType, CastFunc, BasicType +from pystencils.functions import AddressOf from pystencils.simp.simplifications import sympy_cse +import sympy as sp + def test_address_of(): - x, y = pystencils.fields('x,y: int64[2d]') - s = pystencils.TypedSymbol('s', PointerType(create_type('int64'))) + x, y = pystencils.fields('x, y: int64[2d]') + s = pystencils.TypedSymbol('s', PointerType(BasicType('int64'))) - assert address_of(x[0, 0]).canonical() == x[0, 0] - assert address_of(x[0, 0]).dtype == PointerType(x[0, 0].dtype, restrict=True) - assert address_of(sp.Symbol("a")).dtype == PointerType('void', restrict=True) + assert AddressOf(x[0, 0]).canonical() == x[0, 0] + assert AddressOf(x[0, 0]).dtype == PointerType(x[0, 0].dtype, restrict=True) + with pytest.raises(ValueError): + assert AddressOf(sp.Symbol("a")).dtype assignments = pystencils.AssignmentCollection({ - s: address_of(x[0, 0]), - y[0, 0]: CastFunc(s, create_type('int64')) - }, {}) + s: AddressOf(x[0, 0]), + y[0, 0]: CastFunc(s, BasicType('int64')) + }) - ast = pystencils.create_kernel(assignments) - pystencils.show_code(ast) + kernel = pystencils.create_kernel(assignments).compile() + # pystencils.show_code(kernel.ast) assignments = pystencils.AssignmentCollection({ - y[0, 0]: CastFunc(address_of(x[0, 0]), create_type('int64')) - }, {}) + y[0, 0]: CastFunc(AddressOf(x[0, 0]), BasicType('int64')) + }) - ast = pystencils.create_kernel(assignments) - pystencils.show_code(ast) + kernel = pystencils.create_kernel(assignments).compile() + # pystencils.show_code(kernel.ast) def test_address_of_with_cse(): - x, y = pystencils.fields('x,y: int64[2d]') - s = pystencils.TypedSymbol('s', PointerType(create_type('int64'))) + x, y = pystencils.fields('x, y: int64[2d]') assignments = pystencils.AssignmentCollection({ - y[0, 0]: CastFunc(address_of(x[0, 0]), create_type('int64')) + s, - x[0, 0]: CastFunc(address_of(x[0, 0]), create_type('int64')) + 1 - }, {}) + x[0, 0]: CastFunc(AddressOf(x[0, 0]), BasicType('int64')) + 1 + }) - ast = pystencils.create_kernel(assignments) - pystencils.show_code(ast) + kernel = pystencils.create_kernel(assignments).compile() + # pystencils.show_code(kernel.ast) assignments_cse = sympy_cse(assignments) - ast = pystencils.create_kernel(assignments_cse) - pystencils.show_code(ast) + kernel = pystencils.create_kernel(assignments_cse).compile() + # pystencils.show_code(kernel.ast) diff --git a/pystencils_tests/test_blocking.py b/pystencils_tests/test_blocking.py index 3d6436a74e45f82f299bde4bf3a911f8811cb222..5ab66cd4e3a69c23d90d6c1d62005d7ca3d9da1f 100644 --- a/pystencils_tests/test_blocking.py +++ b/pystencils_tests/test_blocking.py @@ -77,4 +77,4 @@ def test_jacobi3d_fixed_field_size(): print("Fixed Field Size: Smaller than block sizes") arr = np.empty([3, 5, 6]) - check_equivalence(jacobi(dst, src), arr) \ No newline at end of file + check_equivalence(jacobi(dst, src), arr) diff --git a/pystencils_tests/test_buffer.py b/pystencils_tests/test_buffer.py index 3da6f04155d7d67e9d4573d062c28ad411a4cdd2..28665a294ae410c9dfe58038bfbd43380fca6c3d 100644 --- a/pystencils_tests/test_buffer.py +++ b/pystencils_tests/test_buffer.py @@ -2,7 +2,8 @@ import numpy as np -from pystencils import Assignment, Field, FieldType, create_kernel, make_slice +import pystencils as ps +from pystencils import Assignment, Field, FieldType, create_kernel from pystencils.field import create_numpy_array_with_layout, layout_string_to_tuple from pystencils.slicing import ( add_ghost_layers, get_ghost_region_slice, get_slice_before_ghost_layer) @@ -41,6 +42,8 @@ def test_full_scalar_field(): pack_eqs = [Assignment(buffer.center(), src_field.center())] pack_code = create_kernel(pack_eqs, data_type={'src_field': src_arr.dtype, 'buffer': buffer.dtype}) + code = ps.get_code_str(pack_code) + ps.show_code(pack_code) pack_kernel = pack_code.compile() pack_kernel(buffer=buffer_arr, src_field=src_arr) diff --git a/pystencils_tests/test_cuda_known_functions.py b/pystencils_tests/test_cuda_known_functions.py index 7e465da9f48f8c58abf9c0f01c3e0363ab14df06..7828c99f884450fd884084d6a2383729a30fcc66 100644 --- a/pystencils_tests/test_cuda_known_functions.py +++ b/pystencils_tests/test_cuda_known_functions.py @@ -5,7 +5,7 @@ import pytest import pystencils from pystencils.astnodes import get_dummy_symbol from pystencils.backends.cuda_backend import CudaSympyPrinter -from pystencils.typing import address_of +from pystencils.functions import address_of from pystencils.enums import Target diff --git a/pystencils_tests/test_custom_backends.py b/pystencils_tests/test_custom_backends.py index 696d1be2772a82de387c5da18376458e35000349..e5942fcbf7f102017f2d240f89f0f841b22a596c 100644 --- a/pystencils_tests/test_custom_backends.py +++ b/pystencils_tests/test_custom_backends.py @@ -25,10 +25,10 @@ class ScreamingGpuBackend(CudaBackend): def test_custom_backends_cpu(): - z, x, y = pystencils.fields("z, y, x: [2d]") + z, y, x = pystencils.fields("z, y, x: [2d]") normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( - z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) + z[0, 0], x[0, 0] * x[0, 0] * y[0, 0])], []) ast = pystencils.create_kernel(normal_assignments, target=Target.CPU) pystencils.show_code(ast, ScreamingBackend()) @@ -44,7 +44,7 @@ def test_custom_backends_gpu(): z, x, y = pystencils.fields("z, y, x: [2d]") normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( - z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) + z[0, 0], x[0, 0] * x[0, 0] * y[0, 0])], []) ast = pystencils.create_kernel(normal_assignments, target=Target.GPU) pystencils.show_code(ast, ScreamingGpuBackend()) diff --git a/pystencils_tests/test_dot_printer.ipynb b/pystencils_tests/test_dot_printer.ipynb index 67c0e14a947167b13ba012cc71fa6d46841f9aba..35ff1cecb5ec1d5e983dfc2141fa429ae0a8fba1 100644 --- a/pystencils_tests/test_dot_printer.ipynb +++ b/pystencils_tests/test_dot_printer.ipynb @@ -1,15 +1,5 @@ { "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import pytest\n", - "pytest.importorskip('graphviz')" - ] - }, { "cell_type": "code", "execution_count": 1, @@ -17,7 +7,7 @@ "outputs": [], "source": [ "from pystencils.session import *\n", - "from pystencils.astnodes import Block, Conditional" + "from pystencils.astnodes import Block, Conditional, SympyAssignment" ] }, { @@ -28,10 +18,10 @@ "source": [ "src, dst = ps.fields(\"src, dst: double[2D]\", layout='c')\n", "\n", - "true_block = Block([ps.Assignment(dst[0, 0], src[-1, 0])])\n", - "false_block = Block([ps.Assignment(dst[0, 0], src[1, 0])])\n", + "true_block = Block([SympyAssignment(dst[0, 0], src[-1, 0])])\n", + "false_block = Block([SympyAssignment(dst[0, 0], src[1, 0])])\n", "ur = [true_block, Conditional(dst.center() > 0.0, true_block, false_block)]\n", - " \n", + "\n", "ast = ps.create_kernel(ur)" ] }, @@ -44,265 +34,167 @@ "outputs": [ { "data": { - "image/svg+xml": [ - "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", - "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", - " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", - "<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n", - " -->\n", - "<!-- Title: %3 Pages: 1 -->\n", - "<svg width=\"684pt\" height=\"290pt\"\n", - " viewBox=\"0.00 0.00 684.00 289.51\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", - "<g id=\"graph0\" class=\"graph\" transform=\"scale(.4128 .4128) rotate(0) translate(4 697.3797)\">\n", - "<title>%3</title>\n", - "<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-697.3797 1653.0784,-697.3797 1653.0784,4 -4,4\"/>\n", - "<!-- 140060050351120 -->\n", - "<g id=\"node1\" class=\"node\">\n", - "<title>140060050351120</title>\n", - "<ellipse fill=\"#a056db\" stroke=\"#000000\" cx=\"243.1436\" cy=\"-675.3797\" rx=\"111.5806\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"243.1436\" y=\"-671.6797\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Func: kernel (dst,src)</text>\n", - "</g>\n", - "<!-- 140060034299536 -->\n", - "<g id=\"node19\" class=\"node\">\n", - "<title>140060034299536</title>\n", - "<ellipse fill=\"#dbc256\" stroke=\"#000000\" cx=\"243.1436\" cy=\"-603.3797\" rx=\"37.0935\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"243.1436\" y=\"-599.6797\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Block</text>\n", - "</g>\n", - "<!-- 140060050351120->140060034299536 -->\n", - "<g id=\"edge18\" class=\"edge\">\n", - "<title>140060050351120->140060034299536</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M243.1436,-657.2111C243.1436,-649.5107 243.1436,-640.3541 243.1436,-631.7964\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"246.6437,-631.793 243.1436,-621.793 239.6437,-631.793 246.6437,-631.793\"/>\n", - "</g>\n", - "<!-- 140060034299984 -->\n", - "<g id=\"node2\" class=\"node\">\n", - "<title>140060034299984</title>\n", - "<ellipse fill=\"#3498db\" stroke=\"#000000\" cx=\"243.1436\" cy=\"-531.3797\" rx=\"86.3847\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"243.1436\" y=\"-527.6797\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Loop over dim 0</text>\n", - "</g>\n", - "<!-- 140060034299664 -->\n", - "<g id=\"node18\" class=\"node\">\n", - "<title>140060034299664</title>\n", - "<ellipse fill=\"#dbc256\" stroke=\"#000000\" cx=\"243.1436\" cy=\"-459.3797\" rx=\"37.0935\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"243.1436\" y=\"-455.6797\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Block</text>\n", - "</g>\n", - "<!-- 140060034299984->140060034299664 -->\n", - "<g id=\"edge16\" class=\"edge\">\n", - "<title>140060034299984->140060034299664</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M243.1436,-513.2111C243.1436,-505.5107 243.1436,-496.3541 243.1436,-487.7964\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"246.6437,-487.793 243.1436,-477.793 239.6437,-487.793 246.6437,-487.793\"/>\n", - "</g>\n", - "<!-- 140060034380240 -->\n", - "<g id=\"node3\" class=\"node\">\n", - "<title>140060034380240</title>\n", - "<ellipse fill=\"#56db7f\" stroke=\"#000000\" cx=\"72.1436\" cy=\"-387.3797\" rx=\"72.2875\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"72.1436\" y=\"-383.6797\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">_data_dst_00</text>\n", - "</g>\n", - "<!-- 140060034381584 -->\n", - "<g id=\"node4\" class=\"node\">\n", - "<title>140060034381584</title>\n", - "<ellipse fill=\"#56db7f\" stroke=\"#000000\" cx=\"243.1436\" cy=\"-387.3797\" rx=\"81.4863\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"243.1436\" y=\"-383.6797\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">_data_src_0m1</text>\n", - "</g>\n", - "<!-- 140060034300688 -->\n", - "<g id=\"node5\" class=\"node\">\n", - "<title>140060034300688</title>\n", - "<ellipse fill=\"#3498db\" stroke=\"#000000\" cx=\"429.1436\" cy=\"-387.3797\" rx=\"86.3847\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"429.1436\" y=\"-383.6797\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Loop over dim 1</text>\n", - "</g>\n", - "<!-- 140060034298960 -->\n", - "<g id=\"node17\" class=\"node\">\n", - "<title>140060034298960</title>\n", - "<ellipse fill=\"#dbc256\" stroke=\"#000000\" cx=\"429.1436\" cy=\"-315.3797\" rx=\"37.0935\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"429.1436\" y=\"-311.6797\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Block</text>\n", - "</g>\n", - "<!-- 140060034300688->140060034298960 -->\n", - "<g id=\"edge12\" class=\"edge\">\n", - "<title>140060034300688->140060034298960</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M429.1436,-369.2111C429.1436,-361.5107 429.1436,-352.3541 429.1436,-343.7964\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"432.6437,-343.793 429.1436,-333.793 425.6437,-343.793 432.6437,-343.793\"/>\n", - "</g>\n", - "<!-- 140060034298192 -->\n", - "<g id=\"node6\" class=\"node\">\n", - "<title>140060034298192</title>\n", - "<ellipse fill=\"#56db7f\" stroke=\"#000000\" cx=\"203.1436\" cy=\"-202.6899\" rx=\"170.8697\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"203.1436\" y=\"-198.9899\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">_data_dst_00[_stride_dst_1*ctr_1]</text>\n", - "</g>\n", - "<!-- 140060165603728 -->\n", - "<g id=\"node7\" class=\"node\">\n", - "<title>140060165603728</title>\n", - "<ellipse fill=\"#dbc256\" stroke=\"#000000\" cx=\"429.1436\" cy=\"-202.6899\" rx=\"37.0935\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"429.1436\" y=\"-198.9899\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Block</text>\n", - "</g>\n", - "<!-- 140060034299472 -->\n", - "<g id=\"node8\" class=\"node\">\n", - "<title>140060034299472</title>\n", - "<ellipse fill=\"#56bd7f\" stroke=\"#000000\" cx=\"857.1436\" cy=\"-202.6899\" rx=\"372.7906\" ry=\"58.8803\"/>\n", - "<text text-anchor=\"middle\" x=\"857.1436\" y=\"-228.9899\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">else: </text>\n", - "<text text-anchor=\"middle\" x=\"857.1436\" y=\"-213.9899\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">\tBlock _data_dst_00 ↠_data_dst + _stride_dst_0*ctr_0</text>\n", - "<text text-anchor=\"middle\" x=\"857.1436\" y=\"-198.9899\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">_data_src_01 ↠_data_src + _stride_src_0*ctr_0 + _stride_src_0</text>\n", - "<text text-anchor=\"middle\" x=\"857.1436\" y=\"-183.9899\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">_data_dst_00[_stride_dst_1*ctr_1] ↠_data_src_01[_stride_src_1*ctr_1]</text>\n", - "<text text-anchor=\"middle\" x=\"857.1436\" y=\"-168.9899\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\"> </text>\n", - "</g>\n", - "<!-- 140060037556304 -->\n", - "<g id=\"node12\" class=\"node\">\n", - "<title>140060037556304</title>\n", - "<ellipse fill=\"#dbc256\" stroke=\"#000000\" cx=\"659.1436\" cy=\"-90\" rx=\"37.0935\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"659.1436\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Block</text>\n", - "</g>\n", - "<!-- 140060034299472->140060037556304 -->\n", - "<g id=\"edge4\" class=\"edge\">\n", - "<title>140060034299472->140060037556304</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M757.5635,-146.0148C733.9458,-132.573 710.3721,-119.1562 692.2192,-108.8247\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"693.6713,-105.624 683.249,-103.7194 690.2088,-111.7077 693.6713,-105.624\"/>\n", - "</g>\n", - "<!-- 140060034298640 -->\n", - "<g id=\"node16\" class=\"node\">\n", - "<title>140060034298640</title>\n", - "<ellipse fill=\"#dbc256\" stroke=\"#000000\" cx=\"1136.1436\" cy=\"-90\" rx=\"37.0935\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"1136.1436\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Block</text>\n", - "</g>\n", - "<!-- 140060034299472->140060034298640 -->\n", - "<g id=\"edge8\" class=\"edge\">\n", - "<title>140060034299472->140060034298640</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M992.5572,-147.9955C1031.2605,-132.3629 1070.3836,-116.5609 1097.9961,-105.408\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1099.6649,-108.5088 1107.6263,-101.5183 1097.0432,-102.0182 1099.6649,-108.5088\"/>\n", - "</g>\n", - "<!-- 140060034382224 -->\n", - "<g id=\"node9\" class=\"node\">\n", - "<title>140060034382224</title>\n", - "<ellipse fill=\"#56db7f\" stroke=\"#000000\" cx=\"353.1436\" cy=\"-18\" rx=\"72.2875\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"353.1436\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">_data_dst_00</text>\n", - "</g>\n", - "<!-- 140060044051536 -->\n", - "<g id=\"node10\" class=\"node\">\n", - "<title>140060044051536</title>\n", - "<ellipse fill=\"#56db7f\" stroke=\"#000000\" cx=\"524.1436\" cy=\"-18\" rx=\"81.4863\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"524.1436\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">_data_src_0m1</text>\n", - "</g>\n", - "<!-- 140060034298704 -->\n", - "<g id=\"node11\" class=\"node\">\n", - "<title>140060034298704</title>\n", - "<ellipse fill=\"#56db7f\" stroke=\"#000000\" cx=\"794.1436\" cy=\"-18\" rx=\"170.8697\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"794.1436\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">_data_dst_00[_stride_dst_1*ctr_1]</text>\n", - "</g>\n", - "<!-- 140060037556304->140060034382224 -->\n", - "<g id=\"edge1\" class=\"edge\">\n", - "<title>140060037556304->140060034382224</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M625.5209,-82.0888C575.1201,-70.2298 479.132,-47.6443 415.6277,-32.7021\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"416.2623,-29.256 405.7265,-30.3724 414.659,-36.0699 416.2623,-29.256\"/>\n", - "</g>\n", - "<!-- 140060037556304->140060044051536 -->\n", - "<g id=\"edge2\" class=\"edge\">\n", - "<title>140060037556304->140060044051536</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M634.0675,-76.6261C614.6322,-66.2606 587.3057,-51.6865 564.7614,-39.6628\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"566.2191,-36.4736 555.7485,-34.8559 562.9249,-42.6501 566.2191,-36.4736\"/>\n", - "</g>\n", - "<!-- 140060037556304->140060034298704 -->\n", - "<g id=\"edge3\" class=\"edge\">\n", - "<title>140060037556304->140060034298704</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M684.2197,-76.6261C703.1859,-66.5108 729.6668,-52.3876 751.8851,-40.5378\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"753.6187,-43.58 760.7951,-35.7858 750.3245,-37.4035 753.6187,-43.58\"/>\n", - "</g>\n", - "<!-- 140060034383312 -->\n", - "<g id=\"node13\" class=\"node\">\n", - "<title>140060034383312</title>\n", - "<ellipse fill=\"#56db7f\" stroke=\"#000000\" cx=\"1055.1436\" cy=\"-18\" rx=\"72.2875\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"1055.1436\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">_data_dst_00</text>\n", - "</g>\n", - "<!-- 140060034383184 -->\n", - "<g id=\"node14\" class=\"node\">\n", - "<title>140060034383184</title>\n", - "<ellipse fill=\"#56db7f\" stroke=\"#000000\" cx=\"1217.1436\" cy=\"-18\" rx=\"72.2875\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"1217.1436\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">_data_src_01</text>\n", - "</g>\n", - "<!-- 140060034776592 -->\n", - "<g id=\"node15\" class=\"node\">\n", - "<title>140060034776592</title>\n", - "<ellipse fill=\"#56db7f\" stroke=\"#000000\" cx=\"1478.1436\" cy=\"-18\" rx=\"170.8697\" ry=\"18\"/>\n", - "<text text-anchor=\"middle\" x=\"1478.1436\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">_data_dst_00[_stride_dst_1*ctr_1]</text>\n", - "</g>\n", - "<!-- 140060034298640->140060034383312 -->\n", - "<g id=\"edge5\" class=\"edge\">\n", - "<title>140060034298640->140060034383312</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M1118.1671,-74.0209C1107.6147,-64.641 1094.0712,-52.6024 1082.2454,-42.0905\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1084.5702,-39.4741 1074.7708,-35.4464 1079.9196,-44.706 1084.5702,-39.4741\"/>\n", - "</g>\n", - "<!-- 140060034298640->140060034383184 -->\n", - "<g id=\"edge6\" class=\"edge\">\n", - "<title>140060034298640->140060034383184</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M1154.1201,-74.0209C1164.6724,-64.641 1178.216,-52.6024 1190.0418,-42.0905\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1192.3675,-44.706 1197.5164,-35.4464 1187.717,-39.4741 1192.3675,-44.706\"/>\n", - "</g>\n", - "<!-- 140060034298640->140060034776592 -->\n", - "<g id=\"edge7\" class=\"edge\">\n", - "<title>140060034298640->140060034776592</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M1170.1862,-82.8331C1221.3389,-72.0641 1319.6786,-51.3611 1391.5128,-36.2381\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1392.4009,-39.6279 1401.4653,-34.1428 1390.9588,-32.778 1392.4009,-39.6279\"/>\n", - "</g>\n", - "<!-- 140060034298960->140060034298192 -->\n", - "<g id=\"edge9\" class=\"edge\">\n", - "<title>140060034298960->140060034298192</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M403.2185,-302.4528C365.3624,-283.5767 294.424,-248.2048 247.9919,-225.0525\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"249.2741,-221.7809 238.7631,-220.4507 246.1504,-228.0453 249.2741,-221.7809\"/>\n", - "</g>\n", - "<!-- 140060034298960->140060165603728 -->\n", - "<g id=\"edge10\" class=\"edge\">\n", - "<title>140060034298960->140060165603728</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M429.1436,-297.2741C429.1436,-279.3665 429.1436,-251.7016 429.1436,-230.9091\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"432.6437,-230.7505 429.1436,-220.7505 425.6437,-230.7505 432.6437,-230.7505\"/>\n", - "</g>\n", - "<!-- 140060034298960->140060034299472 -->\n", - "<g id=\"edge11\" class=\"edge\">\n", - "<title>140060034298960->140060034299472</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M462.0174,-306.7243C504.2144,-295.614 580.9235,-275.417 655.6024,-255.7545\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"656.8664,-259.041 665.6456,-253.1101 655.084,-252.2717 656.8664,-259.041\"/>\n", - "</g>\n", - "<!-- 140060034299664->140060034380240 -->\n", - "<g id=\"edge13\" class=\"edge\">\n", - "<title>140060034299664->140060034380240</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M214.9064,-447.4904C188.6863,-436.4503 149.1689,-419.8114 118.3532,-406.8364\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"119.668,-403.5924 109.0934,-402.9375 116.9515,-410.0439 119.668,-403.5924\"/>\n", - "</g>\n", - "<!-- 140060034299664->140060034381584 -->\n", - "<g id=\"edge14\" class=\"edge\">\n", - "<title>140060034299664->140060034381584</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M243.1436,-441.2111C243.1436,-433.5107 243.1436,-424.3541 243.1436,-415.7964\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"246.6437,-415.793 243.1436,-405.793 239.6437,-415.793 246.6437,-415.793\"/>\n", - "</g>\n", - "<!-- 140060034299664->140060034300688 -->\n", - "<g id=\"edge15\" class=\"edge\">\n", - "<title>140060034299664->140060034300688</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M272.2061,-448.1297C300.5974,-437.1396 344.3973,-420.1847 378.5179,-406.9768\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"379.802,-410.2329 387.8642,-403.3589 377.275,-403.7049 379.802,-410.2329\"/>\n", - "</g>\n", - "<!-- 140060034299536->140060034299984 -->\n", - "<g id=\"edge17\" class=\"edge\">\n", - "<title>140060034299536->140060034299984</title>\n", - "<path fill=\"none\" stroke=\"#000000\" d=\"M243.1436,-585.2111C243.1436,-577.5107 243.1436,-568.3541 243.1436,-559.7964\"/>\n", - "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"246.6437,-559.793 243.1436,-549.793 239.6437,-559.793 246.6437,-559.793\"/>\n", - "</g>\n", - "</g>\n", - "</svg>\n" + "text/html": [ + "<style>pre { line-height: 125%; }\n", + "td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n", + "span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n", + "td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n", + "span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n", + ".highlight .hll { background-color: #ffffcc }\n", + ".highlight { background: #f8f8f8; }\n", + ".highlight .c { color: #408080; font-style: italic } /* Comment */\n", + ".highlight .err { border: 1px solid #FF0000 } /* Error */\n", + ".highlight .k { color: #008000; font-weight: bold } /* Keyword */\n", + ".highlight .o { color: #666666 } /* Operator */\n", + ".highlight .ch { color: #408080; font-style: italic } /* Comment.Hashbang */\n", + ".highlight .cm { color: #408080; font-style: italic } /* Comment.Multiline */\n", + ".highlight .cp { color: #BC7A00 } /* Comment.Preproc */\n", + ".highlight .cpf { color: #408080; font-style: italic } /* Comment.PreprocFile */\n", + ".highlight .c1 { color: #408080; font-style: italic } /* Comment.Single */\n", + ".highlight .cs { color: #408080; font-style: italic } /* Comment.Special */\n", + ".highlight .gd { color: #A00000 } /* Generic.Deleted */\n", + ".highlight .ge { font-style: italic } /* Generic.Emph */\n", + ".highlight .gr { color: #FF0000 } /* Generic.Error */\n", + ".highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */\n", + ".highlight .gi { color: #00A000 } /* Generic.Inserted */\n", + ".highlight .go { color: #888888 } /* Generic.Output */\n", + ".highlight .gp { color: #000080; font-weight: bold } /* Generic.Prompt */\n", + ".highlight .gs { font-weight: bold } /* Generic.Strong */\n", + ".highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */\n", + ".highlight .gt { color: #0044DD } /* Generic.Traceback */\n", + ".highlight .kc { color: #008000; font-weight: bold } /* Keyword.Constant */\n", + ".highlight .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */\n", + ".highlight .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */\n", + ".highlight .kp { color: #008000 } /* Keyword.Pseudo */\n", + ".highlight .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */\n", + ".highlight .kt { color: #B00040 } /* Keyword.Type */\n", + ".highlight .m { color: #666666 } /* Literal.Number */\n", + ".highlight .s { color: #BA2121 } /* Literal.String */\n", + ".highlight .na { color: #7D9029 } /* Name.Attribute */\n", + ".highlight .nb { color: #008000 } /* Name.Builtin */\n", + ".highlight .nc { color: #0000FF; font-weight: bold } /* Name.Class */\n", + ".highlight .no { color: #880000 } /* Name.Constant */\n", + ".highlight .nd { color: #AA22FF } /* Name.Decorator */\n", + ".highlight .ni { color: #999999; font-weight: bold } /* Name.Entity */\n", + ".highlight .ne { color: #D2413A; font-weight: bold } /* Name.Exception */\n", + ".highlight .nf { color: #0000FF } /* Name.Function */\n", + ".highlight .nl { color: #A0A000 } /* Name.Label */\n", + ".highlight .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */\n", + ".highlight .nt { color: #008000; font-weight: bold } /* Name.Tag */\n", + ".highlight .nv { color: #19177C } /* Name.Variable */\n", + ".highlight .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */\n", + ".highlight .w { color: #bbbbbb } /* Text.Whitespace */\n", + ".highlight .mb { color: #666666 } /* Literal.Number.Bin */\n", + ".highlight .mf { color: #666666 } /* Literal.Number.Float */\n", + ".highlight .mh { color: #666666 } /* Literal.Number.Hex */\n", + ".highlight .mi { color: #666666 } /* Literal.Number.Integer */\n", + ".highlight .mo { color: #666666 } /* Literal.Number.Oct */\n", + ".highlight .sa { color: #BA2121 } /* Literal.String.Affix */\n", + ".highlight .sb { color: #BA2121 } /* Literal.String.Backtick */\n", + ".highlight .sc { color: #BA2121 } /* Literal.String.Char */\n", + ".highlight .dl { color: #BA2121 } /* Literal.String.Delimiter */\n", + ".highlight .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */\n", + ".highlight .s2 { color: #BA2121 } /* Literal.String.Double */\n", + ".highlight .se { color: #BB6622; font-weight: bold } /* Literal.String.Escape */\n", + ".highlight .sh { color: #BA2121 } /* Literal.String.Heredoc */\n", + ".highlight .si { color: #BB6688; font-weight: bold } /* Literal.String.Interpol */\n", + ".highlight .sx { color: #008000 } /* Literal.String.Other */\n", + ".highlight .sr { color: #BB6688 } /* Literal.String.Regex */\n", + ".highlight .s1 { color: #BA2121 } /* Literal.String.Single */\n", + ".highlight .ss { color: #19177C } /* Literal.String.Symbol */\n", + ".highlight .bp { color: #008000 } /* Name.Builtin.Pseudo */\n", + ".highlight .fm { color: #0000FF } /* Name.Function.Magic */\n", + ".highlight .vc { color: #19177C } /* Name.Variable.Class */\n", + ".highlight .vg { color: #19177C } /* Name.Variable.Global */\n", + ".highlight .vi { color: #19177C } /* Name.Variable.Instance */\n", + ".highlight .vm { color: #19177C } /* Name.Variable.Magic */\n", + ".highlight .il { color: #666666 } /* Literal.Number.Integer.Long */</style>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<div class=\"highlight\"><pre><span></span><span class=\"n\">FUNC_PREFIX</span><span class=\"w\"> </span><span class=\"kt\">void</span><span class=\"w\"> </span><span class=\"n\">kernel</span><span class=\"p\">(</span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_dst</span><span class=\"p\">,</span><span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"k\">const</span><span class=\"w\"> </span><span class=\"n\">_data_src</span><span class=\"p\">,</span><span class=\"w\"> </span><span class=\"kt\">int64_t</span><span class=\"w\"> </span><span class=\"k\">const</span><span class=\"w\"> </span><span class=\"n\">_size_dst_0</span><span class=\"p\">,</span><span class=\"w\"> </span><span class=\"kt\">int64_t</span><span class=\"w\"> </span><span class=\"k\">const</span><span class=\"w\"> </span><span class=\"n\">_size_dst_1</span><span class=\"p\">,</span><span class=\"w\"> </span><span class=\"kt\">int64_t</span><span class=\"w\"> </span><span class=\"k\">const</span><span class=\"w\"> </span><span class=\"n\">_stride_dst_0</span><span class=\"p\">,</span><span class=\"w\"> </span><span class=\"kt\">int64_t</span><span class=\"w\"> </span><span class=\"k\">const</span><span class=\"w\"> </span><span class=\"n\">_stride_dst_1</span><span class=\"p\">,</span><span class=\"w\"> </span><span class=\"kt\">int64_t</span><span class=\"w\"> </span><span class=\"k\">const</span><span class=\"w\"> </span><span class=\"n\">_stride_src_0</span><span class=\"p\">,</span><span class=\"w\"> </span><span class=\"kt\">int64_t</span><span class=\"w\"> </span><span class=\"k\">const</span><span class=\"w\"> </span><span class=\"n\">_stride_src_1</span><span class=\"p\">)</span><span class=\"w\"></span>\n", + "<span class=\"p\">{</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"k\">for</span><span class=\"w\"> </span><span class=\"p\">(</span><span class=\"kt\">int64_t</span><span class=\"w\"> </span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"mi\">1</span><span class=\"p\">;</span><span class=\"w\"> </span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\"><</span><span class=\"w\"> </span><span class=\"n\">_size_dst_0</span><span class=\"w\"> </span><span class=\"o\">-</span><span class=\"w\"> </span><span class=\"mi\">1</span><span class=\"p\">;</span><span class=\"w\"> </span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\">+=</span><span class=\"w\"> </span><span class=\"mi\">1</span><span class=\"p\">)</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"p\">{</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_dst_00</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_dst</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"n\">_stride_dst_0</span><span class=\"o\">*</span><span class=\"n\">ctr_0</span><span class=\"p\">;</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_src_0m1</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_src</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"n\">_stride_src_0</span><span class=\"o\">*</span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\">-</span><span class=\"w\"> </span><span class=\"n\">_stride_src_0</span><span class=\"p\">;</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"k\">for</span><span class=\"w\"> </span><span class=\"p\">(</span><span class=\"kt\">int64_t</span><span class=\"w\"> </span><span class=\"n\">ctr_1</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"mi\">1</span><span class=\"p\">;</span><span class=\"w\"> </span><span class=\"n\">ctr_1</span><span class=\"w\"> </span><span class=\"o\"><</span><span class=\"w\"> </span><span class=\"n\">_size_dst_1</span><span class=\"w\"> </span><span class=\"o\">-</span><span class=\"w\"> </span><span class=\"mi\">1</span><span class=\"p\">;</span><span class=\"w\"> </span><span class=\"n\">ctr_1</span><span class=\"w\"> </span><span class=\"o\">+=</span><span class=\"w\"> </span><span class=\"mi\">1</span><span class=\"p\">)</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"p\">{</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"n\">_data_dst_00</span><span class=\"p\">[</span><span class=\"n\">_stride_dst_1</span><span class=\"o\">*</span><span class=\"n\">ctr_1</span><span class=\"p\">]</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_src_0m1</span><span class=\"p\">[</span><span class=\"n\">_stride_src_1</span><span class=\"o\">*</span><span class=\"n\">ctr_1</span><span class=\"p\">];</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"p\">{</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span>\n", + "<span class=\"w\"> </span><span class=\"p\">}</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"k\">if</span><span class=\"w\"> </span><span class=\"p\">(</span><span class=\"n\">_data_dst_00</span><span class=\"p\">[</span><span class=\"n\">_stride_dst_1</span><span class=\"o\">*</span><span class=\"n\">ctr_1</span><span class=\"p\">]</span><span class=\"w\"> </span><span class=\"o\">></span><span class=\"w\"> </span><span class=\"mf\">0.0</span><span class=\"p\">)</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"p\">{</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_dst_00</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_dst</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"n\">_stride_dst_0</span><span class=\"o\">*</span><span class=\"n\">ctr_0</span><span class=\"p\">;</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_src_0m1</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_src</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"n\">_stride_src_0</span><span class=\"o\">*</span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\">-</span><span class=\"w\"> </span><span class=\"n\">_stride_src_0</span><span class=\"p\">;</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"n\">_data_dst_00</span><span class=\"p\">[</span><span class=\"n\">_stride_dst_1</span><span class=\"o\">*</span><span class=\"n\">ctr_1</span><span class=\"p\">]</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_src_0m1</span><span class=\"p\">[</span><span class=\"n\">_stride_src_1</span><span class=\"o\">*</span><span class=\"n\">ctr_1</span><span class=\"p\">];</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"p\">}</span><span class=\"w\"> </span><span class=\"k\">else</span><span class=\"w\"> </span><span class=\"p\">{</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_dst_00</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_dst</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"n\">_stride_dst_0</span><span class=\"o\">*</span><span class=\"n\">ctr_0</span><span class=\"p\">;</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_src_01</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_src</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"n\">_stride_src_0</span><span class=\"o\">*</span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"n\">_stride_src_0</span><span class=\"p\">;</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"n\">_data_dst_00</span><span class=\"p\">[</span><span class=\"n\">_stride_dst_1</span><span class=\"o\">*</span><span class=\"n\">ctr_1</span><span class=\"p\">]</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_src_01</span><span class=\"p\">[</span><span class=\"n\">_stride_src_1</span><span class=\"o\">*</span><span class=\"n\">ctr_1</span><span class=\"p\">];</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"p\">}</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"p\">}</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"p\">}</span><span class=\"w\"></span>\n", + "<span class=\"p\">}</span><span class=\"w\"></span>\n", + "</pre></div>\n" ], "text/plain": [ - "<graphviz.files.Source at 0x7f62452c4110>" + "FUNC_PREFIX void kernel(double * RESTRICT _data_dst, double * RESTRICT const _data_src, int64_t const _size_dst_0, int64_t const _size_dst_1, int64_t const _stride_dst_0, int64_t const _stride_dst_1, int64_t const _stride_src_0, int64_t const _stride_src_1)\n", + "{\n", + " for (int64_t ctr_0 = 1; ctr_0 < _size_dst_0 - 1; ctr_0 += 1)\n", + " {\n", + " double * RESTRICT _data_dst_00 = _data_dst + _stride_dst_0*ctr_0;\n", + " double * RESTRICT _data_src_0m1 = _data_src + _stride_src_0*ctr_0 - _stride_src_0;\n", + " for (int64_t ctr_1 = 1; ctr_1 < _size_dst_1 - 1; ctr_1 += 1)\n", + " {\n", + " _data_dst_00[_stride_dst_1*ctr_1] = _data_src_0m1[_stride_src_1*ctr_1];\n", + " {\n", + " \n", + " }\n", + " if (_data_dst_00[_stride_dst_1*ctr_1] > 0.0)\n", + " {\n", + " double * RESTRICT _data_dst_00 = _data_dst + _stride_dst_0*ctr_0;\n", + " double * RESTRICT _data_src_0m1 = _data_src + _stride_src_0*ctr_0 - _stride_src_0;\n", + " _data_dst_00[_stride_dst_1*ctr_1] = _data_src_0m1[_stride_src_1*ctr_1];\n", + " } else {\n", + " double * RESTRICT _data_dst_00 = _data_dst + _stride_dst_0*ctr_0;\n", + " double * RESTRICT _data_src_01 = _data_src + _stride_src_0*ctr_0 + _stride_src_0;\n", + " _data_dst_00[_stride_dst_1*ctr_1] = _data_src_01[_stride_src_1*ctr_1];\n", + " }\n", + " }\n", + " }\n", + "}" ] }, - "execution_count": 3, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "ps.to_dot(ast, graph_style={'size': \"9.5,12.5\"})" + "ps.show_code(ast)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -316,7 +208,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.9.9" } }, "nbformat": 4, diff --git a/pystencils_tests/test_dot_printer.py b/pystencils_tests/test_dot_printer.py new file mode 100644 index 0000000000000000000000000000000000000000..a9d362c4fc2be1b46e969fe943c8179c532fdb36 --- /dev/null +++ b/pystencils_tests/test_dot_printer.py @@ -0,0 +1,13 @@ +import pystencils as ps + +from pystencils.astnodes import Block, Conditional, SympyAssignment + + +def test_dot_print(): + src, dst = ps.fields("src, dst: double[2D]", layout='c') + + true_block = Block([SympyAssignment(dst[0, 0], src[-1, 0])]) + false_block = Block([SympyAssignment(dst[0, 0], src[1, 0])]) + ur = [true_block, Conditional(dst.center() > 0.0, true_block, false_block)] + + ast = ps.create_kernel(ur) diff --git a/pystencils_tests/test_field_equality.ipynb b/pystencils_tests/test_field_equality.ipynb index 8de31e83b5e496cd57e7e1f7d91a1847588108d8..95959038ec0b3322289a2a6016d3ff43676c1288 100644 --- a/pystencils_tests/test_field_equality.ipynb +++ b/pystencils_tests/test_field_equality.ipynb @@ -6,8 +6,7 @@ "metadata": {}, "outputs": [], "source": [ - "from pystencils.session import *\n", - "from pystencils.data_types import cast_func" + "from pystencils.session import *" ] }, { @@ -164,13 +163,13 @@ "output_type": "stream", "text": [ "Field Accesses:\n", - " - f[0], hash -3276894289571194847, offsets (0,), index (), (('f_C', ('commutative', True)), ((0,), (_size_f_0,), (_stride_f_0,), 3146377891102027609, <FieldType.GENERIC: 0>, 'f', None), 0)\n", - " - f[0], hash -1516451775709390846, offsets (0,), index (), (('f_C', ('commutative', True)), ((0,), (_size_f_0,), (_stride_f_0,), -1421177580377734245, <FieldType.GENERIC: 0>, 'f', None), 0)\n", + " - f[0], hash -8859424145258271267, offsets (0,), index (), ((('f_C', ('commutative', True), ('complex', True), ('extended_real', True), ('finite', True), ('hermitian', True), ('imaginary', False), ('infinite', False), ('real', True)), 2305067722319023373), ((0,), (_size_f_0,), (_stride_f_0,), <FieldType.GENERIC: 0>, 'f', None, double), 0)\n", + " - f[0], hash -6454673863007224785, offsets (0,), index (), ((('f_C', ('commutative', True), ('complex', True), ('extended_real', True), ('finite', True), ('hermitian', True), ('imaginary', False), ('infinite', False), ('real', True)), 4093629613697528859), ((0,), (_size_f_0,), (_stride_f_0,), <FieldType.GENERIC: 0>, 'f', None, float), 0)\n", "\n", " -> 0,1 f[0] == f[0]: False\n", "Fields\n", - " - f, 140548694371968, shape (_size_f_0,), strides (_stride_f_0,), double, FieldType.GENERIC, layout (0,)\n", - " - f, 140548693963104, shape (_size_f_0,), strides (_stride_f_0,), float, FieldType.GENERIC, layout (0,)\n", + " - f, 4881406800, shape (_size_f_0,), strides (_stride_f_0,), double, FieldType.GENERIC, layout (0,)\n", + " - f, 4881445024, shape (_size_f_0,), strides (_stride_f_0,), float, FieldType.GENERIC, layout (0,)\n", "\n", " - f == f: False, ids equal False, hash equal False\n" ] @@ -183,7 +182,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -197,9 +196,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.8" + "version": "3.9.9" } }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/pystencils_tests/test_global_definitions.py b/pystencils_tests/test_global_definitions.py index fa51ccc9d1a6f44672c4ddd4f57217ad8b280638..8b6ee1b5bfb030cddfed2d7e0e70f91d8ccdfc04 100644 --- a/pystencils_tests/test_global_definitions.py +++ b/pystencils_tests/test_global_definitions.py @@ -95,7 +95,7 @@ def test_global_definitions_with_global_symbol(): z, x, y = pystencils.fields("z, y, x: [2d]") normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( - z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) + z[0, 0], x[0, 0] * x[0, 0] * y[0, 0])], []) ast = pystencils.create_kernel(normal_assignments) print(pystencils.show_code(ast)) @@ -115,7 +115,7 @@ def test_global_definitions_without_global_symbol(): z, x, y = pystencils.fields("z, y, x: [2d]") normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( - z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) + z[0, 0], x[0, 0] * x[0, 0] * y[0, 0])], []) ast = pystencils.create_kernel(normal_assignments) print(pystencils.show_code(ast)) diff --git a/pystencils_tests/test_kernel_data_type.py b/pystencils_tests/test_kernel_data_type.py deleted file mode 100644 index 25ca56c2b2468623b42c8122a6a3ba02f51500ff..0000000000000000000000000000000000000000 --- a/pystencils_tests/test_kernel_data_type.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections import defaultdict - -import numpy as np -import pytest -from sympy.abc import x, y - -from pystencils import Assignment, create_kernel, fields, CreateKernelConfig -from pystencils.typing import adjust_c_single_precision_type - - -@pytest.mark.parametrize("data_type", ("float", "double")) -def test_single_precision(data_type): - dtype = f"float{64 if data_type == 'double' else 32}" - s = fields(f"s: {dtype}[1D]") - assignments = [Assignment(x, y), Assignment(s[0], x)] - ast = create_kernel(assignments, config=CreateKernelConfig(data_type=data_type)) - assert ast.body.args[0].lhs.dtype.numpy_dtype == np.dtype(dtype) - assert ast.body.args[0].rhs.dtype.numpy_dtype == np.dtype(dtype) - assert ast.body.args[1].body.args[0].rhs.dtype.numpy_dtype == np.dtype(dtype) - - -def test_adjustment_dict(): - d = dict({"x": "float", "y": "double"}) - adjust_c_single_precision_type(d) - assert np.dtype(d["x"]) == np.dtype("float32") - assert np.dtype(d["y"]) == np.dtype("float64") - - -def test_adjustement_default_dict(): - dd = defaultdict(lambda: "float") - dd["x"] - adjust_c_single_precision_type(dd) - dd["y"] - assert np.dtype(dd["x"]) == np.dtype("float32") - assert np.dtype(dd["y"]) == np.dtype("float32") - assert np.dtype(dd["z"]) == np.dtype("float32") diff --git a/pystencils_tests/test_match_subs_for_assignment_collection.py b/pystencils_tests/test_match_subs_for_assignment_collection.py index 7bb0ec5095fe72b1c25905b6af6c2b90584fea5e..ec305fa52d7c4f1651368f95f9d9c412ad1f5236 100644 --- a/pystencils_tests/test_match_subs_for_assignment_collection.py +++ b/pystencils_tests/test_match_subs_for_assignment_collection.py @@ -11,12 +11,12 @@ import sympy as sp import pystencils -from pystencils.typing import create_type +from pystencils.typing import TypedSymbol, BasicType def test_wild_typed_symbol(): x = pystencils.fields('x: float32[3d]') - typed_symbol = pystencils.typing.data_types.TypedSymbol('a', create_type('float64')) + typed_symbol = TypedSymbol('a', BasicType('float64')) assert x.center().match(sp.Wild('w1')) assert typed_symbol.match(sp.Wild('w1')) diff --git a/pystencils_tests/test_type_interference.py b/pystencils_tests/test_type_interference.py index 179fa2836a34ef2b313c822c4f097df96f2f469b..d240cebcd5b2efe651dd116d67b5d56fdfe0b182 100644 --- a/pystencils_tests/test_type_interference.py +++ b/pystencils_tests/test_type_interference.py @@ -1,4 +1,4 @@ -from sympy.abc import a, b, c, d, e, f +from sympy.abc import a, b, c, d, e, f, g import pystencils from pystencils.typing import CastFunc, create_type @@ -13,13 +13,19 @@ def test_type_interference(): c: b, f: c + b, d: c + b + x.center + e, - x.center: c + b + x.center + x.center: c + b + x.center, + g: a + b + d }) ast = pystencils.create_kernel(assignments) + code = pystencils.get_code_str(ast) + # print(code) - code = str(pystencils.get_code_str(ast)) - assert 'double a' in code - assert 'uint16_t b' in code - assert 'uint16_t f' in code - assert 'int64_t e' in code + assert 'const double a' in code + assert 'const uint16_t b' in code + assert 'const uint16_t f' in code + assert 'const int64_t e' in code + + assert 'const float d = ((float)(b)) + ((float)(c)) + ((float)(e)) + _data_x_00_10[_stride_x_2*ctr_2];' in code + assert '_data_x_00_10[_stride_x_2*ctr_2] = ((float)(b)) + ((float)(c)) + _data_x_00_10[_stride_x_2*ctr_2];' in code + assert 'const double g = a + ((double)(b)) + ((double)(d));' in code