Skip to content
Snippets Groups Projects
Commit 32b97ae5 authored by Martin Bauer's avatar Martin Bauer
Browse files

Documentation cleanup part1

parent 956c89a0
No related merge requests found
...@@ -7,12 +7,12 @@ from pystencils.gpucuda.indexing import indexing_creator_from_params ...@@ -7,12 +7,12 @@ from pystencils.gpucuda.indexing import indexing_creator_from_params
from pystencils.transformations import remove_conditionals_in_staggered_kernel from pystencils.transformations import remove_conditionals_in_staggered_kernel
def create_kernel(equations, target='cpu', data_type="double", iteration_slice=None, ghost_layers=None, def create_kernel(assignments, target='cpu', data_type="double", iteration_slice=None, ghost_layers=None,
cpu_openmp=False, cpu_vectorize_info=None, cpu_openmp=False, cpu_vectorize_info=None,
gpu_indexing='block', gpu_indexing_params=MappingProxyType({})): gpu_indexing='block', gpu_indexing_params=MappingProxyType({})):
""" """
Creates abstract syntax tree (AST) of kernel, using a list of update equations. Creates abstract syntax tree (AST) of kernel, using a list of update equations.
:param equations: either be a plain list of equations or a AssignmentCollection object :param assignments: either be a plain list of equations or a AssignmentCollection object
:param target: 'cpu', 'llvm' or 'gpu' :param target: 'cpu', 'llvm' or 'gpu'
:param data_type: data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name :param data_type: data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name
to type to type
...@@ -37,16 +37,18 @@ def create_kernel(equations, target='cpu', data_type="double", iteration_slice=N ...@@ -37,16 +37,18 @@ def create_kernel(equations, target='cpu', data_type="double", iteration_slice=N
# ---- Normalizing parameters # ---- Normalizing parameters
split_groups = () split_groups = ()
if isinstance(equations, AssignmentCollection): if isinstance(assignments, AssignmentCollection):
if 'split_groups' in equations.simplification_hints: if 'split_groups' in assignments.simplification_hints:
split_groups = equations.simplification_hints['split_groups'] split_groups = assignments.simplification_hints['split_groups']
equations = equations.all_assignments assignments = assignments.all_assignments
if isinstance(assignments, Assignment):
assignments = [assignments]
# ---- Creating ast # ---- Creating ast
if target == 'cpu': if target == 'cpu':
from pystencils.cpu import create_kernel from pystencils.cpu import create_kernel
from pystencils.cpu import add_openmp from pystencils.cpu import add_openmp
ast = create_kernel(equations, type_info=data_type, split_groups=split_groups, ast = create_kernel(assignments, type_info=data_type, split_groups=split_groups,
iteration_slice=iteration_slice, ghost_layers=ghost_layers) iteration_slice=iteration_slice, ghost_layers=ghost_layers)
if cpu_openmp: if cpu_openmp:
add_openmp(ast, num_threads=cpu_openmp) add_openmp(ast, num_threads=cpu_openmp)
...@@ -60,12 +62,12 @@ def create_kernel(equations, target='cpu', data_type="double", iteration_slice=N ...@@ -60,12 +62,12 @@ def create_kernel(equations, target='cpu', data_type="double", iteration_slice=N
return ast return ast
elif target == 'llvm': elif target == 'llvm':
from pystencils.llvm import create_kernel from pystencils.llvm import create_kernel
ast = create_kernel(equations, type_info=data_type, split_groups=split_groups, ast = create_kernel(assignments, type_info=data_type, split_groups=split_groups,
iteration_slice=iteration_slice, ghost_layers=ghost_layers) iteration_slice=iteration_slice, ghost_layers=ghost_layers)
return ast return ast
elif target == 'gpu': elif target == 'gpu':
from pystencils.gpucuda import create_cuda_kernel from pystencils.gpucuda import create_cuda_kernel
ast = create_cuda_kernel(equations, type_info=data_type, ast = create_cuda_kernel(assignments, type_info=data_type,
indexing_creator=indexing_creator_from_params(gpu_indexing, gpu_indexing_params), indexing_creator=indexing_creator_from_params(gpu_indexing, gpu_indexing_params),
iteration_slice=iteration_slice, ghost_layers=ghost_layers) iteration_slice=iteration_slice, ghost_layers=ghost_layers)
return ast return ast
......
from pystencils.sympy_gmpy_bug_workaround import *
from pystencils import *
import sympy as sp
import numpy as np
import pystencils.plot2d as plt
from pystencils.jupytersetup import *
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment