Commit dcf2c6f4 authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'interpolation-24.0.9' into 'master'

Interpolation 24.0.9

See merge request pycodegen/pystencils!56
parents 472f6f6c 84d81234
...@@ -12,6 +12,8 @@ from .kernelcreation import create_indexed_kernel, create_kernel, create_stagger ...@@ -12,6 +12,8 @@ from .kernelcreation import create_indexed_kernel, create_kernel, create_stagger
from .simp import AssignmentCollection from .simp import AssignmentCollection
from .slicing import make_slice from .slicing import make_slice
from .sympyextensions import SymbolCreator from .sympyextensions import SymbolCreator
from .spatial_coordinates import (x_, x_staggered, x_staggered_vector, x_vector,
y_, y_staggered, z_, z_staggered)
try: try:
import pystencils_autodiff import pystencils_autodiff
...@@ -30,5 +32,8 @@ __all__ = ['Field', 'FieldType', 'fields', ...@@ -30,5 +32,8 @@ __all__ = ['Field', 'FieldType', 'fields',
'SymbolCreator', 'SymbolCreator',
'create_data_handling', 'create_data_handling',
'kernel', 'kernel',
'x_', 'y_', 'z_',
'x_staggered', 'y_staggered', 'z_staggered',
'x_vector', 'x_staggered_vector',
'fd', 'fd',
'stencil'] 'stencil']
import collections.abc
import itertools
import uuid import uuid
from typing import Any, List, Optional, Sequence, Set, Union from typing import Any, List, Optional, Sequence, Set, Union
...@@ -33,7 +35,7 @@ class Node: ...@@ -33,7 +35,7 @@ class Node:
raise NotImplementedError() raise NotImplementedError()
def subs(self, subs_dict) -> None: def subs(self, subs_dict) -> None:
"""Inplace! substitute, similar to sympy's but modifies the AST inplace.""" """Inplace! Substitute, similar to sympy's but modifies the AST inplace."""
for a in self.args: for a in self.args:
a.subs(subs_dict) a.subs(subs_dict)
...@@ -102,7 +104,8 @@ class Conditional(Node): ...@@ -102,7 +104,8 @@ class Conditional(Node):
result = self.true_block.undefined_symbols result = self.true_block.undefined_symbols
if self.false_block: if self.false_block:
result.update(self.false_block.undefined_symbols) result.update(self.false_block.undefined_symbols)
result.update(self.condition_expr.atoms(sp.Symbol)) if hasattr(self.condition_expr, 'atoms'):
result.update(self.condition_expr.atoms(sp.Symbol))
return result return result
def __str__(self): def __str__(self):
...@@ -212,9 +215,16 @@ class KernelFunction(Node): ...@@ -212,9 +215,16 @@ class KernelFunction(Node):
"""Set of Field instances: fields which are accessed inside this kernel function""" """Set of Field instances: fields which are accessed inside this kernel function"""
return set(o.field for o in self.atoms(ResolvedFieldAccess)) return set(o.field for o in self.atoms(ResolvedFieldAccess))
def fields_written(self): @property
assigments = self.atoms(SympyAssignment) def fields_written(self) -> Set['ResolvedFieldAccess']:
return {a.lhs.field for a in assigments if isinstance(a.lhs, ResolvedFieldAccess)} assignments = self.atoms(SympyAssignment)
return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)}
@property
def fields_read(self) -> Set['ResolvedFieldAccess']:
assignments = self.atoms(SympyAssignment)
return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')]
for a in assignments))
def get_parameters(self) -> Sequence['KernelFunction.Parameter']: def get_parameters(self) -> Sequence['KernelFunction.Parameter']:
"""Returns list of parameters for this function. """Returns list of parameters for this function.
...@@ -283,8 +293,15 @@ class Block(Node): ...@@ -283,8 +293,15 @@ class Block(Node):
a.subs(subs_dict) a.subs(subs_dict)
def insert_front(self, node): def insert_front(self, node):
node.parent = self if isinstance(node, collections.abc.Iterable):
self._nodes.insert(0, node) node = list(node)
for n in node:
n.parent = self
self._nodes = node + self._nodes
else:
node.parent = self
self._nodes.insert(0, node)
def insert_before(self, new_node, insert_before): def insert_before(self, new_node, insert_before):
new_node.parent = self new_node.parent = self
...@@ -485,7 +502,7 @@ class SympyAssignment(Node): ...@@ -485,7 +502,7 @@ class SympyAssignment(Node):
def __init__(self, lhs_symbol, rhs_expr, is_const=True): def __init__(self, lhs_symbol, rhs_expr, is_const=True):
super(SympyAssignment, self).__init__(parent=None) super(SympyAssignment, self).__init__(parent=None)
self._lhs_symbol = lhs_symbol self._lhs_symbol = lhs_symbol
self.rhs = rhs_expr self.rhs = sp.simplify(rhs_expr)
self._is_const = is_const self._is_const = is_const
self._is_declaration = self.__is_declaration() self._is_declaration = self.__is_declaration()
...@@ -678,3 +695,49 @@ def early_out(condition): ...@@ -678,3 +695,49 @@ def early_out(condition):
def get_dummy_symbol(dtype='bool'): def get_dummy_symbol(dtype='bool'):
return TypedSymbol('dummy%s' % uuid.uuid4().hex, create_type(dtype)) return TypedSymbol('dummy%s' % uuid.uuid4().hex, create_type(dtype))
class SourceCodeComment(Node):
def __init__(self, text):
self.text = text
@property
def args(self):
return []
@property
def symbols_defined(self):
return set()
@property
def undefined_symbols(self):
return set()
def __str__(self):
return "/* " + self.text + " */"
def __repr__(self):
return self.__str__()
class EmptyLine(Node):
def __init__(self):
pass
@property
def args(self):
return []
@property
def symbols_defined(self):
return set()
@property
def undefined_symbols(self):
return set()
def __str__(self):
return ""
def __repr__(self):
return self.__str__()
...@@ -102,6 +102,10 @@ def get_headers(ast_node: Node) -> Set[str]: ...@@ -102,6 +102,10 @@ def get_headers(ast_node: Node) -> Set[str]:
if isinstance(a, Node): if isinstance(a, Node):
headers.update(get_headers(a)) headers.update(get_headers(a))
for g in get_global_declarations(ast_node):
if isinstance(g, Node):
headers.update(get_headers(g))
return sorted(headers) return sorted(headers)
...@@ -131,6 +135,12 @@ class CustomCodeNode(Node): ...@@ -131,6 +135,12 @@ class CustomCodeNode(Node):
def undefined_symbols(self): def undefined_symbols(self):
return self._symbols_read - self._symbols_defined return self._symbols_read - self._symbols_defined
def __eq___(self, other):
return self._code == other._code
def __hash__(self):
return hash(self._code)
class PrintNode(CustomCodeNode): class PrintNode(CustomCodeNode):
# noinspection SpellCheckingInspection # noinspection SpellCheckingInspection
...@@ -263,6 +273,12 @@ class CBackend: ...@@ -263,6 +273,12 @@ class CBackend:
def _print_CustomCodeNode(self, node): def _print_CustomCodeNode(self, node):
return node.get_code(self._dialect, self._vector_instruction_set) return node.get_code(self._dialect, self._vector_instruction_set)
def _print_SourceCodeComment(self, node):
return "/* " + node.text + " */"
def _print_EmptyLine(self, node):
return ""
def _print_Conditional(self, node): def _print_Conditional(self, node):
cond_type = get_type_of_expression(node.condition_expr) cond_type = get_type_of_expression(node.condition_expr)
if isinstance(cond_type, VectorType): if isinstance(cond_type, VectorType):
...@@ -409,6 +425,7 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -409,6 +425,7 @@ class CustomSympyPrinter(CCodePrinter):
condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>=' condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>='
) )
return code return code
_print_Max = C89CodePrinter._print_Max _print_Max = C89CodePrinter._print_Max
_print_Min = C89CodePrinter._print_Min _print_Min = C89CodePrinter._print_Min
......
...@@ -3,6 +3,7 @@ from os.path import dirname, join ...@@ -3,6 +3,7 @@ from os.path import dirname, join
from pystencils.astnodes import Node from pystencils.astnodes import Node
from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.interpolation_astnodes import InterpolationMode
with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f: with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
lines = f.readlines() lines = f.readlines()
...@@ -43,11 +44,19 @@ class CudaBackend(CBackend): ...@@ -43,11 +44,19 @@ class CudaBackend(CBackend):
return code return code
def _print_TextureDeclaration(self, node): def _print_TextureDeclaration(self, node):
code = "texture<%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
str(node.texture.field.dtype), if node.texture.field.dtype.numpy_dtype.itemsize > 4:
node.texture.field.spatial_dimensions, code = "texture<fp_tex_%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
node.texture str(node.texture.field.dtype),
) node.texture.field.spatial_dimensions,
node.texture
)
else:
code = "texture<%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
str(node.texture.field.dtype),
node.texture.field.spatial_dimensions,
node.texture
)
return code return code
def _print_SkipIteration(self, _): def _print_SkipIteration(self, _):
...@@ -62,17 +71,23 @@ class CudaSympyPrinter(CustomSympyPrinter): ...@@ -62,17 +71,23 @@ class CudaSympyPrinter(CustomSympyPrinter):
self.known_functions.update(CUDA_KNOWN_FUNCTIONS) self.known_functions.update(CUDA_KNOWN_FUNCTIONS)
def _print_TextureAccess(self, node): def _print_TextureAccess(self, node):
dtype = node.texture.field.dtype.numpy_dtype
if node.texture.cubic_bspline_interpolation: if node.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
template = "cubicTex%iDSimple<%s>(%s, %s)" template = "cubicTex%iDSimple(%s, %s)"
else: else:
template = "tex%iD<%s>(%s, %s)" if dtype.itemsize > 4:
# Use PyCuda hack!
# https://github.com/inducer/pycuda/blob/master/pycuda/cuda/pycuda-helpers.hpp
template = "fp_tex%iD(%s, %s)"
else:
template = "tex%iD(%s, %s)"
code = template % ( code = template % (
node.texture.field.spatial_dimensions, node.texture.field.spatial_dimensions,
str(node.texture.field.dtype),
str(node.texture), str(node.texture),
', '.join(self._print(o) for o in node.offsets) # + 0.5 comes from Nvidia's staggered indexing
', '.join(self._print(o + 0.5) for o in reversed(node.offsets))
) )
return code return code
......
...@@ -45,6 +45,7 @@ tex1D ...@@ -45,6 +45,7 @@ tex1D
tex2D tex2D
tex3D tex3D
sqrtf
rsqrtf rsqrtf
cbrtf cbrtf
rcbrtf rcbrtf
......
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import json
from pystencils.backends.cbackend import CustomSympyPrinter, generate_c
try:
import toml
except Exception:
class toml:
def dumps(self, *args):
raise ImportError('toml not installed')
def dump(self, *args):
raise ImportError('toml not installed')
try:
import yaml
except Exception:
class yaml:
def dumps(self, *args):
raise ImportError('pyyaml not installed')
def dump(self, *args):
raise ImportError('pyyaml not installed')
def expr_to_dict(expr_or_node, with_c_code=True, full_class_names=False):
self = {'str': str(expr_or_node)}
if with_c_code:
try:
self.update({'c': generate_c(expr_or_node)})
except Exception:
try:
self.update({'c': CustomSympyPrinter().doprint(expr_or_node)})
except Exception:
pass
for a in expr_or_node.args:
self.update({str(a.__class__ if full_class_names else a.__class__.__name__): expr_to_dict(a)})
return self
def print_json(expr_or_node):
dict = expr_to_dict(expr_or_node)
return json.dumps(dict, indent=4)
def write_json(filename, expr_or_node):
dict = expr_to_dict(expr_or_node)
with open(filename, 'w') as f:
json.dump(dict, f, indent=4)
def print_toml(expr_or_node):
dict = expr_to_dict(expr_or_node, full_class_names=False)
return toml.dumps(dict)
def write_toml(filename, expr_or_node):
dict = expr_to_dict(expr_or_node)
with open(filename, 'w') as f:
toml.dump(dict, f)
def print_yaml(expr_or_node):
dict = expr_to_dict(expr_or_node, full_class_names=False)
return yaml.dump(dict)
def write_yaml(filename, expr_or_node):
dict = expr_to_dict(expr_or_node)
with open(filename, 'w') as f:
yaml.dump(dict, f)
...@@ -10,8 +10,8 @@ from pystencils.data_types import BasicType, StructType, TypedSymbol, create_typ ...@@ -10,8 +10,8 @@ from pystencils.data_types import BasicType, StructType, TypedSymbol, create_typ
from pystencils.field import Field, FieldType from pystencils.field import Field, FieldType
from pystencils.transformations import ( from pystencils.transformations import (
add_types, filtered_tree_iteration, get_base_buffer_index, get_optimal_loop_ordering, add_types, filtered_tree_iteration, get_base_buffer_index, get_optimal_loop_ordering,
make_loop_over_domain, move_constants_before_loop, parse_base_pointer_info, implement_interpolations, make_loop_over_domain, move_constants_before_loop,
resolve_buffer_accesses, resolve_field_accesses, split_inner_loop) parse_base_pointer_info, resolve_buffer_accesses, resolve_field_accesses, split_inner_loop)
AssignmentOrAstNodeList = List[Union[Assignment, ast.Node]] AssignmentOrAstNodeList = List[Union[Assignment, ast.Node]]
...@@ -67,6 +67,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke ...@@ -67,6 +67,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
ghost_layers=ghost_layers, loop_order=loop_order) ghost_layers=ghost_layers, loop_order=loop_order)
ast_node = KernelFunction(loop_node, 'cpu', 'c', compile_function=make_python_function, ast_node = KernelFunction(loop_node, 'cpu', 'c', compile_function=make_python_function,
ghost_layers=ghost_layer_info, function_name=function_name) ghost_layers=ghost_layer_info, function_name=function_name)
implement_interpolations(body)
if split_groups: if split_groups:
typed_split_groups = [[type_symbol(s) for s in split_group] for split_group in split_groups] typed_split_groups = [[type_symbol(s) for s in split_group] for split_group in split_groups]
...@@ -139,6 +140,8 @@ def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, fu ...@@ -139,6 +140,8 @@ def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, fu
loop_body = Block([]) loop_body = Block([])
loop_node = LoopOverCoordinate(loop_body, coordinate_to_loop_over=0, start=0, stop=index_fields[0].shape[0]) loop_node = LoopOverCoordinate(loop_body, coordinate_to_loop_over=0, start=0, stop=index_fields[0].shape[0])
implement_interpolations(loop_node)
for assignment in assignments: for assignment in assignments:
loop_body.append(assignment) loop_body.append(assignment)
......
import ctypes import ctypes
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
from typing import Tuple
import numpy as np import numpy as np
import sympy as sp import sympy as sp
import sympy.codegen.ast
from sympy.core.cache import cacheit from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean from sympy.logic.boolalg import Boolean
import pystencils
from pystencils.cache import memorycache, memorycache_if_hashable from pystencils.cache import memorycache, memorycache_if_hashable
from pystencils.utils import all_equal from pystencils.utils import all_equal
...@@ -17,6 +20,26 @@ except ImportError as e: ...@@ -17,6 +20,26 @@ except ImportError as e:
_ir_importerror = e _ir_importerror = e
def typed_symbols(names, dtype, *args):
symbols = sp.symbols(names, *args)
if isinstance(symbols, Tuple):
return tuple(TypedSymbol(str(s), dtype) for s in symbols)
else:
return TypedSymbol(str(symbols), dtype)
def matrix_symbols(names, dtype, rows, cols):
if isinstance(names, str):
names = names.replace(' ', '').split(',')
matrices = []
for n in names:
symbols = typed_symbols("%s:%i" % (n, rows * cols), dtype)
matrices.append(sp.Matrix(rows, cols, lambda i, j: symbols[i * cols + j]))
return tuple(matrices)
# noinspection PyPep8Naming # noinspection PyPep8Naming
class address_of(sp.Function): class address_of(sp.Function):
is_Atom = True is_Atom = True
...@@ -86,6 +109,11 @@ class cast_func(sp.Function): ...@@ -86,6 +109,11 @@ class cast_func(sp.Function):
@property @property
def is_integer(self): def is_integer(self):
"""
Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate
For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
"""
if hasattr(self.dtype, 'numpy_dtype'): if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
else: else:
...@@ -93,6 +121,9 @@ class cast_func(sp.Function): ...@@ -93,6 +121,9 @@ class cast_func(sp.Function):
@property @property
def is_negative(self): def is_negative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if hasattr(self.dtype, 'numpy_dtype'): if hasattr(self.dtype, 'numpy_dtype'):
if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger): if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
return False return False
...@@ -101,6 +132,9 @@ class cast_func(sp.Function): ...@@ -101,6 +132,9 @@ class cast_func(sp.Function):
@property @property
def is_nonnegative(self): def is_nonnegative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if self.is_negative is False: if self.is_negative is False:
return True return True
else: else:
...@@ -108,6 +142,9 @@ class cast_func(sp.Function): ...@@ -108,6 +142,9 @@ class cast_func(sp.Function):
@property @property
def is_real(self): def is_real(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if hasattr(self.dtype, 'numpy_dtype'): if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \ return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
np.issubdtype(self.dtype.numpy_dtype, np.floating) or \ np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
...@@ -171,6 +208,11 @@ class TypedSymbol(sp.Symbol): ...@@ -171,6 +208,11 @@ class TypedSymbol(sp.Symbol):
# For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html # For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
@property @property
def is_integer(self): def is_integer(self):
"""
Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate
For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
"""
if hasattr(self.dtype, 'numpy_dtype'): if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
else: else:
...@@ -178,6 +220,9 @@ class TypedSymbol(sp.Symbol): ...@@ -178,6 +220,9 @@ class TypedSymbol(sp.Symbol):
@property @property
def is_negative(self): def is_negative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if hasattr(self.dtype, 'numpy_dtype'): if hasattr(self.dtype, 'numpy_dtype'):
if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger): if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
return False return False
...@@ -186,6 +231,9 @@ class TypedSymbol(sp.Symbol): ...@@ -186,6 +231,9 @@ class TypedSymbol(sp.Symbol):
@property @property
def is_nonnegative(self): def is_nonnegative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if self.