Commit 53caa7e0 authored by Martin Bauer's avatar Martin Bauer
Browse files

pystencils: systematic checks for conditions on kernel assignments

- SSA form is checked
- loop independence condition is checked
- bug fix in Field.create_generic when using index_shape
parent 3e7154dc
...@@ -2,7 +2,7 @@ import sympy as sp ...@@ -2,7 +2,7 @@ import sympy as sp
from functools import partial from functools import partial
from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
from pystencils.transformations import resolve_buffer_accesses, resolve_field_accesses, make_loop_over_domain, \ from pystencils.transformations import resolve_buffer_accesses, resolve_field_accesses, make_loop_over_domain, \
type_all_equations, get_optimal_loop_ordering, parse_base_pointer_info, move_constants_before_loop, \ add_types, get_optimal_loop_ordering, parse_base_pointer_info, move_constants_before_loop, \
split_inner_loop, substitute_array_accesses_with_constants split_inner_loop, substitute_array_accesses_with_constants
from pystencils.data_types import TypedSymbol, BasicType, StructType, create_type from pystencils.data_types import TypedSymbol, BasicType, StructType, create_type
from pystencils.field import Field, FieldType from pystencils.field import Field, FieldType
...@@ -15,7 +15,8 @@ AssignmentOrAstNodeList = List[Union[Assignment, ast.Node]] ...@@ -15,7 +15,8 @@ AssignmentOrAstNodeList = List[Union[Assignment, ast.Node]]
def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "kernel", type_info='double', def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "kernel", type_info='double',
split_groups=(), iteration_slice=None, ghost_layers=None) -> KernelFunction: split_groups=(), iteration_slice=None, ghost_layers=None,
skip_independence_check=False) -> KernelFunction:
""" """
Creates an abstract syntax tree for a kernel function, by taking a list of update rules. Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
...@@ -34,6 +35,8 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke ...@@ -34,6 +35,8 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
if None, the number of ghost layers is determined automatically and assumed to be equal for a if None, the number of ghost layers is determined automatically and assumed to be equal for a
all dimensions all dimensions
skip_independence_check: don't check that loop iterations are independent. This is needed e.g. for
periodicity kernel, that access the field outside the iteration bounds. Use with care!
Returns: Returns:
AST node representing a function, that can be printed as C or CUDA code AST node representing a function, that can be printed as C or CUDA code
...@@ -50,7 +53,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke ...@@ -50,7 +53,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
else: else:
raise ValueError("Term has to be field access or symbol") raise ValueError("Term has to be field access or symbol")
fields_read, fields_written, assignments = type_all_equations(assignments, type_info) fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check)
all_fields = fields_read.union(fields_written) all_fields = fields_read.union(fields_written)
read_only_fields = set([f.name for f in fields_read - fields_written]) read_only_fields = set([f.name for f in fields_read - fields_written])
...@@ -108,7 +111,7 @@ def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, fu ...@@ -108,7 +111,7 @@ def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, fu
function_name: see documentation of :func:`create_kernel` function_name: see documentation of :func:`create_kernel`
coordinate_names: name of the coordinate fields in the struct data type coordinate_names: name of the coordinate fields in the struct data type
""" """
fields_read, fields_written, assignments = type_all_equations(assignments, type_info) fields_read, fields_written, assignments = add_types(assignments, type_info, check_independence_condition=False)
all_fields = fields_read.union(fields_written) all_fields = fields_read.union(fields_written)
for index_field in index_fields: for index_field in index_fields:
......
...@@ -165,6 +165,9 @@ class Field: ...@@ -165,6 +165,9 @@ class Field:
that should be iterated over, and BUFFER fields that are used to generate that should be iterated over, and BUFFER fields that are used to generate
communication packing/unpacking kernels communication packing/unpacking kernels
""" """
if index_shape is not None:
assert index_dimensions == 0 or index_dimensions == len(index_shape)
index_dimensions = len(index_shape)
if isinstance(layout, str): if isinstance(layout, str):
layout = spatial_layout_string_to_tuple(layout, dim=spatial_dimensions) layout = spatial_layout_string_to_tuple(layout, dim=spatial_dimensions)
shape_symbol = IndexedBase(TypedSymbol(Field.SHAPE_PREFIX + field_name, Field.SHAPE_DTYPE), shape=(1,)) shape_symbol = IndexedBase(TypedSymbol(Field.SHAPE_PREFIX + field_name, Field.SHAPE_DTYPE), shape=(1,))
...@@ -260,6 +263,7 @@ class Field: ...@@ -260,6 +263,7 @@ class Field:
"""Do not use directly. Use static create* methods""" """Do not use directly. Use static create* methods"""
self._field_name = field_name self._field_name = field_name
assert isinstance(field_type, FieldType) assert isinstance(field_type, FieldType)
assert len(shape) == len(strides)
self.field_type = field_type self.field_type = field_type
self._dtype = create_type(dtype) self._dtype = create_type(dtype)
self._layout = normalize_layout(layout) self._layout = normalize_layout(layout)
......
from functools import partial from functools import partial
from pystencils.gpucuda.indexing import BlockIndexing from pystencils.gpucuda.indexing import BlockIndexing
from pystencils.transformations import resolve_field_accesses, type_all_equations, parse_base_pointer_info, \ from pystencils.transformations import resolve_field_accesses, add_types, parse_base_pointer_info, \
get_common_shape, substitute_array_accesses_with_constants, resolve_buffer_accesses, unify_shape_symbols get_common_shape, substitute_array_accesses_with_constants, resolve_buffer_accesses, unify_shape_symbols
from pystencils.astnodes import Block, KernelFunction, SympyAssignment, LoopOverCoordinate from pystencils.astnodes import Block, KernelFunction, SympyAssignment, LoopOverCoordinate
from pystencils.data_types import TypedSymbol, BasicType, StructType from pystencils.data_types import TypedSymbol, BasicType, StructType
...@@ -10,8 +10,8 @@ from pystencils.gpucuda.cudajit import make_python_function ...@@ -10,8 +10,8 @@ from pystencils.gpucuda.cudajit import make_python_function
def create_cuda_kernel(assignments, function_name="kernel", type_info=None, indexing_creator=BlockIndexing, def create_cuda_kernel(assignments, function_name="kernel", type_info=None, indexing_creator=BlockIndexing,
iteration_slice=None, ghost_layers=None): iteration_slice=None, ghost_layers=None, skip_independence_check=False):
fields_read, fields_written, assignments = type_all_equations(assignments, type_info) fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check)
all_fields = fields_read.union(fields_written) all_fields = fields_read.union(fields_written)
read_only_fields = set([f.name for f in fields_read - fields_written]) read_only_fields = set([f.name for f in fields_read - fields_written])
...@@ -93,7 +93,7 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde ...@@ -93,7 +93,7 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde
def created_indexed_cuda_kernel(assignments, index_fields, function_name="kernel", type_info=None, def created_indexed_cuda_kernel(assignments, index_fields, function_name="kernel", type_info=None,
coordinate_names=('x', 'y', 'z'), indexing_creator=BlockIndexing): coordinate_names=('x', 'y', 'z'), indexing_creator=BlockIndexing):
fields_read, fields_written, assignments = type_all_equations(assignments, type_info) fields_read, fields_written, assignments = add_types(assignments, type_info, check_independence_condition=False)
all_fields = fields_read.union(fields_written) all_fields = fields_read.union(fields_written)
read_only_fields = set([f.name for f in fields_read - fields_written]) read_only_fields = set([f.name for f in fields_read - fields_written])
......
...@@ -23,7 +23,7 @@ def create_copy_kernel(domain_size, from_slice, to_slice, index_dimensions=0, in ...@@ -23,7 +23,7 @@ def create_copy_kernel(domain_size, from_slice, to_slice, index_dimensions=0, in
eq = Assignment(f(i), f[tuple(offset)](i)) eq = Assignment(f(i), f[tuple(offset)](i))
update_eqs.append(eq) update_eqs.append(eq)
ast = create_cuda_kernel(update_eqs, iteration_slice=to_slice) ast = create_cuda_kernel(update_eqs, iteration_slice=to_slice, skip_independence_check=True)
return make_python_function(ast) return make_python_function(ast)
......
...@@ -12,27 +12,28 @@ def create_kernel(assignments, target='cpu', data_type="double", iteration_slice ...@@ -12,27 +12,28 @@ def create_kernel(assignments, target='cpu', data_type="double", iteration_slice
gpu_indexing='block', gpu_indexing_params=MappingProxyType({})): gpu_indexing='block', gpu_indexing_params=MappingProxyType({})):
""" """
Creates abstract syntax tree (AST) of kernel, using a list of update equations. Creates abstract syntax tree (AST) of kernel, using a list of update equations.
:param assignments: either be a plain list of equations or a AssignmentCollection object
:param target: 'cpu', 'llvm' or 'gpu' Args:
:param data_type: data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name assignments: either be a plain list of equations or a AssignmentCollection object
to type target: 'cpu', 'llvm' or 'gpu'
:param iteration_slice: rectangular subset to iterate over, if not specified the complete non-ghost layer \ data_type: data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name
part of the field is iterated over to type
:param ghost_layers: if left to default, the number of necessary ghost layers is determined automatically iteration_slice: rectangular subset to iterate over, if not specified the complete non-ghost layer \
a single integer specifies the ghost layer count at all borders, can also be a sequence of part of the field is iterated over
pairs [(x_lower_gl, x_upper_gl), .... ] ghost_layers: if left to default, the number of necessary ghost layers is determined automatically
a single integer specifies the ghost layer count at all borders, can also be a sequence of
CPU specific Parameters: pairs [(x_lower_gl, x_upper_gl), .... ]
:param cpu_openmp: True or number of threads for OpenMP parallelization, False for no OpenMP
:param cpu_vectorize_info: pair of instruction set name ('sse, 'avx', 'avx512') and data type ('float', 'double') cpu_openmp: True or number of threads for OpenMP parallelization, False for no OpenMP
cpu_vectorize_info: pair of instruction set name ('sse, 'avx', 'avx512') and data type ('float', 'double')
GPU specific Parameters
:param gpu_indexing: either 'block' or 'line' , or custom indexing class (see gpucuda/indexing.py) gpu_indexing: either 'block' or 'line' , or custom indexing class (see gpucuda/indexing.py)
:param gpu_indexing_params: dict with indexing parameters (constructor parameters of indexing class) gpu_indexing_params: dict with indexing parameters (constructor parameters of indexing class)
e.g. for 'block' one can specify {'block_size': (20, 20, 10) } e.g. for 'block' one can specify {'block_size': (20, 20, 10) }
:return: abstract syntax tree object, that can either be printed as source code or can be compiled with Returns:
through its compile() function abstract syntax tree object, that can either be printed as source code with `show_code` or can be compiled with
through its `compile()` member
""" """
# ---- Normalizing parameters # ---- Normalizing parameters
...@@ -124,8 +125,9 @@ def create_staggered_kernel(staggered_field, expressions, subexpressions=(), tar ...@@ -124,8 +125,9 @@ def create_staggered_kernel(staggered_field, expressions, subexpressions=(), tar
subexpressions: optional sequence of Assignments, that define subexpressions used in the main expressions subexpressions: optional sequence of Assignments, that define subexpressions used in the main expressions
target: 'cpu' or 'gpu' target: 'cpu' or 'gpu'
kwargs: passed directly to create_kernel, iteration slice and ghost_layers parameters are not allowed kwargs: passed directly to create_kernel, iteration slice and ghost_layers parameters are not allowed
Returns: Returns:
AST AST, see `create_kernel`
""" """
assert 'iteration_slice' not in kwargs and 'ghost_layers' not in kwargs assert 'iteration_slice' not in kwargs and 'ghost_layers' not in kwargs
assert staggered_field.index_dimensions == 1, 'Staggered field must have exactly one index dimension' assert staggered_field.index_dimensions == 1, 'Staggered field must have exactly one index dimension'
......
import warnings import warnings
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict, namedtuple
from copy import deepcopy from copy import deepcopy
from types import MappingProxyType from types import MappingProxyType
import sympy as sp import sympy as sp
...@@ -139,17 +139,21 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer ...@@ -139,17 +139,21 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr): def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
r""" r"""
Addressing elements in structured arrays are done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]` Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate. where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
The sum can be split up into multiple parts, such that parts of it can be pulled before loops. The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`. This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
Returns a new typed symbol, where the name encodes which coordinates have been resolved. Returns a new typed symbol, where the name encodes which coordinates have been resolved.
:param field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
:param coordinates: mapping of coordinate ids to its value, where stride*value is calculated
:param previous_ptr: the pointer which is de-referenced
:return: tuple with the new pointer symbol and the calculated offset
Example: Args:
field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
coordinates: mapping of coordinate ids to its value, where stride*value is calculated
previous_ptr: the pointer which is de-referenced
Returns
tuple with the new pointer symbol and the calculated offset
Examples:
>>> field = Field.create_generic('myfield', spatial_dimensions=2, index_dimensions=1) >>> field = Field.create_generic('myfield', spatial_dimensions=2, index_dimensions=1)
>>> x, y = sp.symbols("x y") >>> x, y = sp.symbols("x y")
>>> prev_pointer = TypedSymbol("ptr", "double") >>> prev_pointer = TypedSymbol("ptr", "double")
...@@ -193,7 +197,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, field): ...@@ -193,7 +197,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, field):
Specification of how many and which intermediate pointers are created for a field access. Specification of how many and which intermediate pointers are created for a field access.
For example [ (0), (2,3,)] creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate For example [ (0), (2,3,)] creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
zero directly in the field access. These specifications are more sensible defined dependent on the loop ordering. zero directly in the field access. These specifications are defined dependent on the loop ordering.
This function translates more readable version into the specification above. This function translates more readable version into the specification above.
Allowed specifications: Allowed specifications:
...@@ -362,13 +366,16 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -362,13 +366,16 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
""" """
Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing
:param ast_node: the AST root Args:
:param read_only_field_names: set of field names which are considered read-only ast_node: the AST root
:param field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created read_only_field_names: set of field names which are considered read-only
for details see :func:`parse_base_pointer_info` field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created
:param field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop for details see :func:`parse_base_pointer_info`
field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
counters to index the field these symbols are used as coordinates counters to index the field these symbols are used as coordinates
:return: transformed AST
Returns
transformed AST
""" """
field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0])) field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0])) field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
...@@ -393,8 +400,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -393,8 +400,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
if field.name in field_to_fixed_coordinates: if field.name in field_to_fixed_coordinates:
coordinates[e] = field_to_fixed_coordinates[field.name][e] coordinates[e] = field_to_fixed_coordinates[field.name][e]
else: else:
ctr_name = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
coordinates[e] = TypedSymbol("%s_%d" % (ctr_name, e), 'int')
coordinates[e] *= field.dtype.item_size coordinates[e] *= field.dtype.item_size
else: else:
if isinstance(field.dtype, StructType): if isinstance(field.dtype, StructType):
...@@ -418,7 +424,6 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -418,7 +424,6 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
last_pointer = new_ptr last_pointer = new_ptr
coord_dict = create_coordinate_dict(base_pointer_info[0]) coord_dict = create_coordinate_dict(base_pointer_info[0])
_, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer) _, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field, result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field,
field_access.offsets, field_access.index) field_access.offsets, field_access.index)
...@@ -652,68 +657,125 @@ def symbol_name_to_variable_name(symbol_name): ...@@ -652,68 +657,125 @@ def symbol_name_to_variable_name(symbol_name):
return symbol_name.replace("^", "_") return symbol_name.replace("^", "_")
def type_all_equations(eqs, type_for_symbol): class KernelConstraintsCheck:
"""Checks if the input to create_kernel is valid.
Test the following conditions:
- SSA Form for pure symbols:
- Every pure symbol may occur only once as left-hand-side of an assignment
- Every pure symbol that is read, may not be written to later
- Independence / Parallelization condition:
- a field that is written may only be read at exact the same spatial position
(Pure symbols are symbols that are not Field.Accesses)
""" """
Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`. FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
def __init__(self, type_for_symbol, check_independence_condition):
self._type_for_symbol = type_for_symbol
self._defined_pure_symbols = set()
self._accessed_pure_symbols = set()
self._field_writes = defaultdict(set)
self.fields_read = set()
self.check_independence_condition = check_independence_condition
def process_assignment(self, assignment):
# for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
new_rhs = self.process_expression(assignment.rhs)
new_lhs = self._process_lhs(assignment.lhs)
return ast.SympyAssignment(new_lhs, new_rhs)
def process_expression(self, rhs):
self._update_accesses_rhs(rhs)
if isinstance(rhs, Field.Access):
return rhs
elif isinstance(rhs, TypedSymbol):
return rhs
elif isinstance(rhs, sp.Symbol):
return TypedSymbol(symbol_name_to_variable_name(rhs.name), self._type_for_symbol[rhs.name])
else:
new_args = [self.process_expression(arg) for arg in rhs.args]
return rhs.func(*new_args) if new_args else rhs
@property
def fields_written(self):
return set(k.field for k, v in self._field_writes.items() if len(v))
def _process_lhs(self, lhs):
assert isinstance(lhs, sp.Symbol)
self._update_accesses_lhs(lhs)
if not isinstance(lhs, Field.Access) and not isinstance(lhs, TypedSymbol):
return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
else:
return lhs
def _update_accesses_lhs(self, lhs):
if isinstance(lhs, Field.Access):
fai = self.FieldAndIndex(lhs.field, lhs.index)
self._field_writes[fai].add(lhs.offsets)
if len(self._field_writes[fai]) > 1:
raise ValueError(f"Field {lhs.field.name} is written at two different locations")
elif isinstance(lhs, sp.Symbol):
if lhs in self._defined_pure_symbols:
raise ValueError(f"Assignments not in SSA form, multiple assignments to {lhs.name}")
if lhs in self._accessed_pure_symbols:
raise ValueError(f"Symbol {lhs.name} is written, after it has been read")
self._defined_pure_symbols.add(lhs)
def _update_accesses_rhs(self, rhs):
if isinstance(rhs, Field.Access) and self.check_independence_condition:
writes = self._field_writes[self.FieldAndIndex(rhs.field, rhs.index)]
for write_offset in writes:
assert len(writes) == 1
if write_offset != rhs.offsets:
raise ValueError(f"Violation of loop independence condition. "
f"Field {rhs.field} is read at {rhs.offsets} and written at {write_offset}")
self.fields_read.add(rhs.field)
elif isinstance(rhs, sp.Symbol):
self._accessed_pure_symbols.add(rhs)
def add_types(eqs, type_for_symbol, check_independence_condition):
"""Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
Additionally returns sets of all fields which are read/written Additionally returns sets of all fields which are read/written
:param eqs: list of equations Args:
:param type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double' eqs: list of equations
:return: ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields, type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
list of equations where symbols have been replaced by typed symbols check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed
kernels
Returns:
``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
list of equations where symbols have been replaced by typed symbols
""" """
if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'): if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'):
type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol) type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
fields_written = set() check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
fields_read = set()
def process_rhs(term):
"""Replaces Symbols by:
- TypedSymbol if symbol is not a field access
"""
if isinstance(term, Field.Access):
fields_read.add(term.field)
return term
elif isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
return TypedSymbol(symbol_name_to_variable_name(term.name), type_for_symbol[term.name])
else:
new_args = [process_rhs(arg) for arg in term.args]
return term.func(*new_args) if new_args else term
def process_lhs(term):
"""Replaces symbol by TypedSymbol and adds field to fieldsWriten"""
if isinstance(term, Field.Access):
fields_written.add(term.field)
return term
elif isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
return TypedSymbol(term.name, type_for_symbol[term.name])
else:
assert False, "Expected a symbol as left-hand-side"
def visit(obj): def visit(obj):
if isinstance(obj, list) or isinstance(obj, tuple): if isinstance(obj, list) or isinstance(obj, tuple):
return [visit(e) for e in obj] return [visit(e) for e in obj]
if isinstance(obj, sp.Eq) or isinstance(obj, ast.SympyAssignment) or isinstance(obj, Assignment): if isinstance(obj, sp.Eq) or isinstance(obj, ast.SympyAssignment) or isinstance(obj, Assignment):
new_lhs = process_lhs(obj.lhs) return check.process_assignment(obj)
new_rhs = process_rhs(obj.rhs)
return ast.SympyAssignment(new_lhs, new_rhs)
elif isinstance(obj, ast.Conditional): elif isinstance(obj, ast.Conditional):
false_block = None if obj.false_block is None else visit(obj.false_block) false_block = None if obj.false_block is None else visit(obj.false_block)
return ast.Conditional(process_rhs(obj.condition_expr), return ast.Conditional(check.process_expression(obj.condition_expr),
true_block=visit(obj.true_block), false_block=false_block) true_block=visit(obj.true_block), false_block=false_block)
elif isinstance(obj, ast.Block): elif isinstance(obj, ast.Block):
return ast.Block([visit(e) for e in obj.args]) return ast.Block([visit(e) for e in obj.args])
else: elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
return obj return obj
else:
raise ValueError("Invalid object in kernel " + str(type(obj)))
typed_equations = visit(eqs) typed_equations = visit(eqs)
return fields_read, fields_written, typed_equations return check.fields_read, check.fields_written, typed_equations
def insert_casts(node): def insert_casts(node):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment