Commit 96c04e6b authored by Markus Holzer's avatar Markus Holzer
Browse files

Start with refactoring of the type system

parent ddad5daf
......@@ -29,11 +29,13 @@ def typed_symbols(names, dtype, *args):
def type_all_numbers(expr, dtype):
# TODO: move to pystnecils_walberla
substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)}
return expr.subs(substitutions)
def matrix_symbols(names, dtype, rows, cols):
# TODO: check if needed. (lbmpy, walberla)
if isinstance(names, str):
names = names.replace(' ', '').split(',')
......@@ -46,6 +48,7 @@ def matrix_symbols(names, dtype, rows, cols):
def assumptions_from_dtype(dtype):
# TODO: type hints and if dtype is correct type form Numpy
"""Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype
Args:
......@@ -76,6 +79,9 @@ def assumptions_from_dtype(dtype):
# noinspection PyPep8Naming
class address_of(sp.Function):
# TODO: ask Martin
# TODO: documentation
# TODO: move function to `functions.py`
is_Atom = True
def __new__(cls, arg):
......@@ -103,6 +109,8 @@ class address_of(sp.Function):
# noinspection PyPep8Naming
class cast_func(sp.Function):
# TODO: documentation
# TODO: move function to `functions.py`
is_Atom = True
def __new__(cls, *args, **kwargs):
......@@ -190,22 +198,30 @@ class cast_func(sp.Function):
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
# TODO: documentation
# TODO: move function to `functions.py`
pass
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
# TODO: documentation
# TODO: move function to `functions.py`
# Arguments are: read/write expression, type, aligned, nontemporal, mask (or none), stride
nargs = (6,)
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
# TODO: documentation
# TODO: move function to `functions.py`
pass
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
# TODO: documentation
# TODO: move function to `functions.py`
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
......@@ -272,6 +288,8 @@ class TypedSymbol(sp.Symbol):
def create_type(specification):
# TODO: HERE
# TODO: type hint -> np.type
"""Creates a subclass of Type according to a string or an object of subclass Type.
Args:
......@@ -292,6 +310,7 @@ def create_type(specification):
@memorycache(maxsize=64)
def create_composite_type_from_string(specification):
# TODO: can be removed after llvm removla and fix of kernelparameters
"""Creates a new Type object from a c-like string specification.
Args:
......@@ -338,12 +357,15 @@ def create_composite_type_from_string(specification):
def get_base_type(data_type):
# TODO: WTF is this?? DOCS!!!
# TODO: Can be removed after removal of kerncraft and fix in FieldPointer Symbol
while data_type.base_type is not None:
data_type = data_type.base_type
return data_type
def to_ctypes(data_type):
# TODO: can be removed with llvm
"""
Transforms a given Type into ctypes
:param data_type: Subclass of Type
......@@ -356,7 +378,7 @@ def to_ctypes(data_type):
else:
return to_ctypes.map[data_type.numpy_dtype]
# TODO: can be removed with llvm
to_ctypes.map = {
np.dtype(np.int8): ctypes.c_int8,
np.dtype(np.int16): ctypes.c_int16,
......@@ -374,6 +396,7 @@ to_ctypes.map = {
def ctypes_from_llvm(data_type):
# TODO can be removed with LLVM
if not ir:
raise _ir_importerror
if isinstance(data_type, ir.PointerType):
......@@ -404,6 +427,7 @@ def ctypes_from_llvm(data_type):
def to_llvm_type(data_type, nvvm_target=False):
# TODO: can be removed with LLVM
"""
Transforms a given type into ctypes
:param data_type: Subclass of Type
......@@ -417,6 +441,7 @@ def to_llvm_type(data_type, nvvm_target=False):
return to_llvm_type.map[data_type.numpy_dtype]
# TODO: can be removed with LLVM
if ir:
to_llvm_type.map = {
np.dtype(np.int8): ir.IntType(8),
......@@ -435,16 +460,19 @@ if ir:
def peel_off_type(dtype, type_to_peel_off):
# TODO: WTF is this??? DOCS!!!
# TODO: used only once.... can be a lambda there
while type(dtype) is type_to_peel_off:
dtype = dtype.base_type
return dtype
############################# This is basically our type system ########################################################
def collate_types(types,
forbid_collation_to_complex=False,
forbid_collation_to_float=False,
default_float_type='float64',
default_int_type='int64'):
forbid_collation_to_complex=False, # TODO: type system shouldn't need this!!!
forbid_collation_to_float=False, # TODO: type system shouldn't need this!!!
default_float_type='float64', # TODO: AST leaves should be typed. Expressions should be able to find out correct type
default_int_type='int64'): # TODO: AST leaves should be typed. Expressions should be able to find out correct type
"""
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy.
......@@ -495,9 +523,9 @@ def collate_types(types,
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
default_float_type='double',
default_int_type='int',
symbol_type_dict=None):
default_float_type='double', # TODO: we shouldn't need to have default. AST leaves should have a type
default_int_type='int', # TODO: we shouldn't need to have default. AST leaves should have a type
symbol_type_dict=None): # TODO: we shouldn't need to have default. AST leaves should have a type
from pystencils.astnodes import ResolvedFieldAccess
from pystencils.cpu.vectorization import vec_all, vec_any
......@@ -582,6 +610,7 @@ def get_type_of_expression(expr,
return create_type(default_float_type)
raise NotImplementedError("Could not determine type for", expr, type(expr))
############################# End This is basically our type system ##################################################
sympy_version = sp.__version__.split('.')
......@@ -614,6 +643,8 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
class Type(sp.Atom):
# TODO: why is our type system dependent on sympy???
# TODO: ask Martin
def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls)
......@@ -622,8 +653,15 @@ class Type(sp.Atom):
class BasicType(Type):
# TODO: check if Type inheritance is needed
# TODO: should be a sensible interface to np.dtype
# TODO: read numpy docs (Jan)
@staticmethod
def numpy_name_to_c(name):
# TODO: this should be a free function
# TODO: also check if numpy has this functionality
# TODO: docs!!!
# TODO: is this C?
if name == 'float64':
return 'double'
elif name == 'float32':
......@@ -644,9 +682,10 @@ class BasicType(Type):
raise NotImplementedError(f"Can map numpy to C name for {name}")
def __init__(self, dtype, const=False):
# TODO: type hints
self.const = const
if isinstance(dtype, Type):
self._dtype = dtype.numpy_dtype
self._dtype = dtype.numpy_dtype # TODO: wtf?
else:
self._dtype = np.dtype(dtype)
assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
......@@ -660,7 +699,7 @@ class BasicType(Type):
return (self.numpy_dtype, self.const), {}
@property
def base_type(self):
def base_type(self): # TODO: what is base_type?
return None
@property
......@@ -672,7 +711,7 @@ class BasicType(Type):
return getattr(sympy.codegen.ast, str(self.numpy_dtype))
@property
def item_size(self):
def item_size(self): # TODO: what is this?
return 1
def is_int(self):
......@@ -691,7 +730,7 @@ class BasicType(Type):
return self.numpy_dtype in np.sctypes['others']
@property
def base_name(self):
def base_name(self): # TODO: name of the function is highly confusing
return BasicType.numpy_name_to_c(str(self._dtype))
def __str__(self):
......@@ -714,6 +753,7 @@ class BasicType(Type):
class VectorType(Type):
# TODO: check with rest
instruction_set = None
def __init__(self, base_type, width=4):
......@@ -760,6 +800,7 @@ class VectorType(Type):
class PointerType(Type):
# TODO: rename to FieldType
def __init__(self, base_type, const=False, restrict=True):
self._base_type = base_type
self.const = const
......@@ -805,6 +846,7 @@ class PointerType(Type):
class StructType:
# TODO: Docs. This is a struct. A list of types (with C offsets)
def __init__(self, numpy_type, const=False):
self.const = const
self._dtype = np.dtype(numpy_type)
......@@ -858,6 +900,8 @@ class StructType:
class TypedImaginaryUnit(TypedSymbol):
# TODO: why is this an extra class???
# TODO: remove?
def __new__(cls, *args, **kwds):
obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
return obj
......
......@@ -31,6 +31,7 @@ def create_copy_kernel(domain_size, from_slice, to_slice, index_dimensions=0, in
return ast
# TODO: type flot is dangerous here
def get_periodic_boundary_functor(stencil, domain_size, index_dimensions=0, index_dim_shape=1, ghost_layers=1,
thickness=None, dtype=float, target=Target.GPU, opencl_queue=None, opencl_ctx=None):
assert target in {Target.GPU, Target.OPENCL}
......
......@@ -38,6 +38,8 @@ class CreateKernelConfig:
"""
Name of the generated function - only important if generated code is written out
"""
# TODO: config should check that the datatype is a Numpy type
# TODO: check for the python types and issue warnings
data_type: Union[str, dict] = 'double'
"""
Data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name to type
......@@ -125,6 +127,7 @@ class CreateKernelConfig:
def __post_init__(self):
# ---- Legacy parameters
# TODO adapt here the types
if isinstance(self.target, str):
new_target = Target[self.target.upper()]
warnings.warn(f'Target "{self.target}" as str is deprecated. Use {new_target} instead',
......@@ -249,6 +252,7 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
if config.target == Target.CPU:
if config.backend == Backend.C:
from pystencils.cpu import add_openmp, create_kernel
# TODO: data type keyword should be unified to data_type
ast = create_kernel(assignments, function_name=config.function_name, type_info=config.data_type,
split_groups=split_groups,
iteration_slice=config.iteration_slice, ghost_layers=config.ghost_layers,
......
......@@ -18,6 +18,11 @@ from sympy.core.cache import cacheit
from pystencils.data_types import (
PointerType, TypedSymbol, create_composite_type_from_string, get_base_type)
# TODO: Why do we need extra classes? Why isn't TypedSymbol enough?
# TODO: Replace with a factory function
SHAPE_DTYPE = create_composite_type_from_string("const int64")
STRIDE_DTYPE = create_composite_type_from_string("const int64")
......
import hashlib
import pickle
import warnings
from typing import List, Dict
from collections import OrderedDict, defaultdict, namedtuple
from copy import deepcopy
from types import MappingProxyType
......@@ -424,7 +425,7 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s
return visit_node(ast_node)
def resolve_field_accesses(ast_node, read_only_field_names=set(),
def resolve_field_accesses(ast_node, read_only_field_names=None,
field_to_base_pointer_info=MappingProxyType({}),
field_to_fixed_coordinates=MappingProxyType({})):
"""
......@@ -441,6 +442,8 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
Returns
transformed AST
"""
if read_only_field_names is None:
read_only_field_names = set()
field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
......@@ -936,7 +939,8 @@ class KernelConstraintsCheck:
self.scopes.access_symbol(rhs)
def add_types(eqs, type_for_symbol, check_independence_condition, check_double_write_condition=True):
def add_types(eqs: List[Assignment], type_for_symbol: Dict[sp.Symbol, np.dtype], check_independence_condition: bool,
check_double_write_condition: bool=True):
"""Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
Additionally returns sets of all fields which are read/written
......@@ -956,9 +960,12 @@ def add_types(eqs, type_for_symbol, check_independence_condition, check_double_w
type_for_symbol = adjust_c_single_precision_type(type_for_symbol)
# TODO what does this do????
# TODO: ask Martin
check = KernelConstraintsCheck(type_for_symbol, check_independence_condition,
check_double_write_condition=check_double_write_condition)
# TODO: check if this adds only types to leave nodes of AST, get type info
def visit(obj):
if isinstance(obj, (list, tuple)):
return [visit(e) for e in obj]
......
import sympy as sp
from pystencils import fields, Assignment, AssignmentCollection
from pystencils.simp.subexpression_insertion import *
......
......@@ -180,3 +180,29 @@ def test_ctypes_from_llvm():
assert ctypes_from_llvm(ir.FloatType()) == ctypes.c_float
assert ctypes_from_llvm(ir.DoubleType()) == ctypes.c_double
def test_division():
f = ps.fields('f(10): float32[2D]')
m, tau = sp.symbols("m, tau")
up = [ps.Assignment(tau, 1.0 / (0.5 + (3.0 * m))),
ps.Assignment(f.center, tau)]
ast = ps.create_kernel(up, config=ps.CreateKernelConfig(data_type="float32"))
code = ps.get_code_str(ast)
assert "1.0f" in code
def test_pow():
f = ps.fields('f(10): float32[2D]')
m, tau = sp.symbols("m, tau")
up = [ps.Assignment(tau, m ** 1.5),
ps.Assignment(f.center, tau)]
ast = ps.create_kernel(up, config=ps.CreateKernelConfig(data_type="float32"))
code = ps.get_code_str(ast)
assert "1.5f" in code
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment