Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 545 additions and 238 deletions
from pystencils.typing import CFunction
def get_argument_string(function_shortcut, first=''): def get_argument_string(function_shortcut, first=''):
args = function_shortcut[function_shortcut.index('[') + 1: -1] args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "(" arg_string = "("
...@@ -16,10 +19,13 @@ def get_argument_string(function_shortcut, first=''): ...@@ -16,10 +19,13 @@ def get_argument_string(function_shortcut, first=''):
def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
if instruction_set != 'neon' and not instruction_set.startswith('sve'): if instruction_set not in ['neon', 'sme'] and not instruction_set.startswith('sve'):
raise NotImplementedError(instruction_set) raise NotImplementedError(instruction_set)
if instruction_set == 'sve': if instruction_set in ['sve', 'sve2', 'sme']:
cmp = 'cmp'
elif instruction_set.startswith('sve2') and instruction_set not in ('sve256', 'sve2048'):
cmp = 'cmp' cmp = 'cmp'
bitwidth = int(instruction_set[4:])
elif instruction_set.startswith('sve'): elif instruction_set.startswith('sve'):
cmp = 'cmp' cmp = 'cmp'
bitwidth = int(instruction_set[3:]) bitwidth = int(instruction_set[3:])
...@@ -35,9 +41,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -35,9 +41,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
'sqrt': 'sqrt[0]', 'sqrt': 'sqrt[0]',
'loadU': 'ld1[0]', 'loadU': 'ld1[0]',
'loadA': 'ld1[0]',
'storeU': 'st1[0, 1]', 'storeU': 'st1[0, 1]',
'storeA': 'st1[0, 1]',
'abs': 'abs[0]', 'abs': 'abs[0]',
'==': f'{cmp}eq[0, 1]', '==': f'{cmp}eq[0, 1]',
...@@ -54,7 +58,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -54,7 +58,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result = dict() result = dict()
if instruction_set == 'sve': if instruction_set in ['sve', 'sve2', 'sme']:
width = 'svcntd()' if data_type == 'double' else 'svcntw()' width = 'svcntd()' if data_type == 'double' else 'svcntw()'
intwidth = 'svcntw()' intwidth = 'svcntw()'
result['bytes'] = 'svcntb()' result['bytes'] = 'svcntb()'
...@@ -62,14 +66,15 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -62,14 +66,15 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
width = bitwidth // bits[data_type] width = bitwidth // bits[data_type]
intwidth = bitwidth // bits['int'] intwidth = bitwidth // bits['int']
result['bytes'] = bitwidth // 8 result['bytes'] = bitwidth // 8
if instruction_set.startswith('sve'): if instruction_set.startswith('sve') or instruction_set == 'sme':
base_names['stream'] = 'stnt1[0, 1]'
prefix = 'sv' prefix = 'sv'
suffix = f'_f{bits[data_type]}' suffix = f'_f{bits[data_type]}'
elif instruction_set == 'neon': elif instruction_set == 'neon':
prefix = 'v' prefix = 'v'
suffix = f'q_f{bits[data_type]}' suffix = f'q_f{bits[data_type]}'
if instruction_set == 'sve': if instruction_set in ['sve', 'sve2', 'sme']:
predicate = f'{prefix}whilelt_b{bits[data_type]}_u64({{loop_counter}}, {{loop_stop}})' predicate = f'{prefix}whilelt_b{bits[data_type]}_u64({{loop_counter}}, {{loop_stop}})'
int_predicate = f'{prefix}whilelt_b{bits["int"]}_u64({{loop_counter}}, {{loop_stop}})' int_predicate = f'{prefix}whilelt_b{bits["int"]}_u64({{loop_counter}}, {{loop_stop}})'
else: else:
...@@ -88,33 +93,36 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -88,33 +93,36 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result[intrinsic_id] = prefix + name + suffix + undef + arg_string result[intrinsic_id] = prefix + name + suffix + undef + arg_string
if instruction_set == 'sve': if instruction_set in ['sve', 'sve2', 'sme']:
from pystencils.backends.cbackend import CFunction
result['width'] = CFunction(width, "int") result['width'] = CFunction(width, "int")
result['intwidth'] = CFunction(intwidth, "int") result['intwidth'] = CFunction(intwidth, "int")
else: else:
result['width'] = width result['width'] = width
result['intwidth'] = intwidth result['intwidth'] = intwidth
if instruction_set.startswith('sve'): if instruction_set.startswith('sve') or instruction_set == 'sme':
result['makeVecConst'] = f'svdup_f{bits[data_type]}' + '({0})' result['makeVecConst'] = f'svdup_f{bits[data_type]}' + '({0})'
result['makeVecConstInt'] = f'svdup_s{bits["int"]}' + '({0})' result['makeVecConstInt'] = f'svdup_s{bits["int"]}' + '({0})'
result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})' result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})'
vindex = f'svindex_u{bits[data_type]}(0, {{0}})' if instruction_set != 'sme':
result['storeS'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \ vindex = f'svindex_u{bits[data_type]}(0, {{0}})'
vindex.format("{2}") + ', {1})' result['storeS'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
result['loadS'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \ vindex.format("{2}") + ', {1})'
vindex.format("{1}") + ')' result['loadS'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{1}") + ')'
if instruction_set.startswith('sve2') and instruction_set not in ('sve256', 'sve2048'):
result['streamS'] = f'svstnt1_scatter_u{bits[data_type]}offset_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format(f"{{2}}*{bits[data_type]//8}") + ', {1})'
result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})" result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})"
result['float'] = f'svfloat{bits["float"]}_{"s" if instruction_set != "sve" else ""}t' result['float'] = f'svfloat{bits["float"]}_{"s" if instruction_set not in ["sve", "sve2", "sme"] else ""}t'
result['double'] = f'svfloat{bits["double"]}_{"s" if instruction_set != "sve" else ""}t' result['double'] = f'svfloat{bits["double"]}_{"s" if instruction_set not in ["sve", "sve2", "sme"] else ""}t'
result['int'] = f'svint{bits["int"]}_{"s" if instruction_set != "sve" else ""}t' result['int'] = f'svint{bits["int"]}_{"s" if instruction_set not in ["sve", "sve2", "sme"] else ""}t'
result['bool'] = f'svbool_{"s" if instruction_set != "sve" else ""}t' result['bool'] = f'svbool_{"s" if instruction_set not in ["sve", "sve2", "sme"] else ""}t'
result['headers'] = ['<arm_sve.h>', '"arm_neon_helpers.h"'] result['headers'] = ['<arm_sve.h>', '<arm_acle.h>', '"arm_neon_helpers.h"']
result['&'] = f'svand_b_z({predicate},' + ' {0}, {1})' result['&'] = f'svand_b_z({predicate},' + ' {0}, {1})'
result['|'] = f'svorr_b_z({predicate},' + ' {0}, {1})' result['|'] = f'svorr_b_z({predicate},' + ' {0}, {1})'
...@@ -123,10 +131,17 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -123,10 +131,17 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['all'] = f'svcntp_b{bits[data_type]}({predicate}, {{0}}) == {width}' result['all'] = f'svcntp_b{bits[data_type]}({predicate}, {{0}}) == {width}'
result['maskStoreU'] = result['storeU'].replace(predicate, '{2}') result['maskStoreU'] = result['storeU'].replace(predicate, '{2}')
result['maskStoreA'] = result['storeA'].replace(predicate, '{2}') result['maskStream'] = result['stream'].replace(predicate, '{2}')
result['maskStoreS'] = result['storeS'].replace(predicate, '{3}') if instruction_set != 'sme':
result['maskStoreS'] = result['storeS'].replace(predicate, '{3}')
if instruction_set.startswith('sve2') and instruction_set not in ('sve256', 'sve2048'):
result['maskStreamS'] = result['streamS'].replace(predicate, '{3}')
if instruction_set != 'sve': result['streamFence'] = '__dmb(15)'
if instruction_set == 'sme':
result['function_prefix'] = '__attribute__((arm_locally_streaming))'
elif instruction_set not in ['sve', 'sve2', 'sme']:
result['compile_flags'] = [f'-msve-vector-bits={bitwidth}'] result['compile_flags'] = [f'-msve-vector-bits={bitwidth}']
else: else:
result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})' result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})'
...@@ -151,9 +166,9 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -151,9 +166,9 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['any'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) > 0' result['any'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) > 0'
result['all'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) == 16*0xff' result['all'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) == 16*0xff'
if instruction_set == 'sve' or bitwidth & (bitwidth - 1) == 0: # SVE has real nontemporal stores, so we only need to zero cachlines on Neon
# only power-of-2 vector sizes will evenly divide a cacheline
result['cachelineSize'] = 'cachelineSize()'
result['cachelineZero'] = 'cachelineZero((void*) {0})' result['cachelineZero'] = 'cachelineZero((void*) {0})'
result['cachelineSize'] = 'cachelineSize()'
return result return result
...@@ -6,7 +6,6 @@ from typing import Set ...@@ -6,7 +6,6 @@ from typing import Set
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from sympy.core import S from sympy.core import S
from sympy.core.cache import cacheit
from sympy.logic.boolalg import BooleanFalse, BooleanTrue from sympy.logic.boolalg import BooleanFalse, BooleanTrue
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
from sympy.functions.elementary.hyperbolic import HyperbolicFunction from sympy.functions.elementary.hyperbolic import HyperbolicFunction
...@@ -15,7 +14,7 @@ from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node ...@@ -15,7 +14,7 @@ from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize
from pystencils.typing import ( from pystencils.typing import (
PointerType, VectorType, CastFunc, create_type, get_type_of_expression, PointerType, VectorType, CastFunc, create_type, get_type_of_expression,
ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol) ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol, CFunction)
from pystencils.enums import Backend from pystencils.enums import Backend
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.functions import DivFunc, AddressOf from pystencils.functions import DivFunc, AddressOf
...@@ -151,8 +150,8 @@ class CustomCodeNode(Node): ...@@ -151,8 +150,8 @@ class CustomCodeNode(Node):
def undefined_symbols(self): def undefined_symbols(self):
return self._symbols_read - self._symbols_defined return self._symbols_read - self._symbols_defined
def __eq___(self, other): def __eq__(self, other):
return self._code == other._code return type(self) is type(other) and self._code == other._code
def __hash__(self): def __hash__(self):
return hash(self._code) return hash(self._code)
...@@ -166,23 +165,6 @@ class PrintNode(CustomCodeNode): ...@@ -166,23 +165,6 @@ class PrintNode(CustomCodeNode):
self.headers.append("<iostream>") self.headers.append("<iostream>")
class CFunction(TypedSymbol):
def __new__(cls, function, dtype):
return CFunction.__xnew_cached_(cls, function, dtype)
def __new_stage2__(cls, function, dtype):
return super(CFunction, cls).__xnew__(cls, function, dtype)
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), {}
# ------------------------------------------- Printer ------------------------------------------------------------------ # ------------------------------------------- Printer ------------------------------------------------------------------
...@@ -248,12 +230,13 @@ class CBackend: ...@@ -248,12 +230,13 @@ class CBackend:
return f"{node.pragma_line}\n{self._print_Block(node)}" return f"{node.pragma_line}\n{self._print_Block(node)}"
def _print_LoopOverCoordinate(self, node): def _print_LoopOverCoordinate(self, node):
counter_symbol = node.loop_counter_name counter_name = node.loop_counter_name
start = f"int64_t {counter_symbol} = {self.sympy_printer.doprint(node.start)}" counter_dtype = node.loop_counter_symbol.dtype.c_name
condition = f"{counter_symbol} < {self.sympy_printer.doprint(node.stop)}" start = f"{counter_dtype} {counter_name} = {self.sympy_printer.doprint(node.start)}"
update = f"{counter_symbol} += {self.sympy_printer.doprint(node.step)}" condition = f"{counter_name} < {self.sympy_printer.doprint(node.stop)}"
update = f"{counter_name} += {self.sympy_printer.doprint(node.step)}"
loop_str = f"for ({start}; {condition}; {update})" loop_str = f"for ({start}; {condition}; {update})"
self._kwargs['loop_counter'] = counter_symbol self._kwargs['loop_counter'] = counter_name
self._kwargs['loop_stop'] = node.stop self._kwargs['loop_stop'] = node.stop
prefix = "\n".join(node.prefix_lines) prefix = "\n".join(node.prefix_lines)
...@@ -262,33 +245,42 @@ class CBackend: ...@@ -262,33 +245,42 @@ class CBackend:
return f"{prefix}{loop_str}\n{self._print(node.body)}" return f"{prefix}{loop_str}\n{self._print(node.body)}"
def _print_SympyAssignment(self, node): def _print_SympyAssignment(self, node):
printed_lhs = self.sympy_printer.doprint(node.lhs)
printed_rhs = self.sympy_printer.doprint(node.rhs)
if node.is_declaration: if node.is_declaration:
if node.use_auto: if node.use_auto:
data_type = 'auto ' data_type = 'auto'
else: else:
data_type = self._print(node.lhs.dtype).replace(' const', '')
if node.is_const: if node.is_const:
prefix = 'const ' data_type = f'const {data_type}'
else: return f"{data_type} {printed_lhs} = {printed_rhs};"
prefix = ''
data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "
return "%s%s = %s;" % (data_type,
self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs))
else: else:
lhs_type = get_type_of_expression(node.lhs) # TOOD: this should have been typed lhs_type = get_type_of_expression(node.lhs) # TOOD: this should have been typed
printed_mask = "" printed_mask = ""
if type(lhs_type) is VectorType and isinstance(node.lhs, CastFunc): if type(lhs_type) is VectorType and isinstance(node.lhs, CastFunc):
arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args
instr = 'storeU' instr = 'storeU'
if aligned: if nontemporal and 'storeA' not in self._vector_instruction_set and \
'stream' in self._vector_instruction_set:
instr = 'stream'
elif aligned:
instr = 'stream' if nontemporal and 'stream' in self._vector_instruction_set else 'storeA' instr = 'stream' if nontemporal and 'stream' in self._vector_instruction_set else 'storeA'
if mask != True: # NOQA if mask != True: # NOQA
instr = 'maskStoreA' if aligned else 'maskStoreU' instr = 'maskStream' if nontemporal and 'maskStream' in self._vector_instruction_set else \
'maskStoreA' if aligned else 'maskStoreU'
if instr not in self._vector_instruction_set: if instr not in self._vector_instruction_set:
self._vector_instruction_set[instr] = self._vector_instruction_set['store' + instr[-1]].format( if instr == 'maskStream' and 'stream' in self._vector_instruction_set:
store, load = 'stream', 'loadA'
elif (instr in ('maskStream', 'maskStoreA')) and 'storeA' in self._vector_instruction_set:
store, load = 'storeA', 'loadA'
else:
store, load = 'storeU', 'loadU'
load = load if load in self._vector_instruction_set else 'loadU'
self._vector_instruction_set[instr] = self._vector_instruction_set[store].format(
'{0}', self._vector_instruction_set['blendv'].format( '{0}', self._vector_instruction_set['blendv'].format(
self._vector_instruction_set['load' + instr[-1]].format('{0}', **self._kwargs), self._vector_instruction_set[load].format('{0}', **self._kwargs),
'{1}', '{2}', **self._kwargs), **self._kwargs) '{1}', '{2}', **self._kwargs), **self._kwargs)
printed_mask = self.sympy_printer.doprint(mask) printed_mask = self.sympy_printer.doprint(mask)
if data_type.base_type.c_name == 'double': if data_type.base_type.c_name == 'double':
...@@ -313,12 +305,14 @@ class CBackend: ...@@ -313,12 +305,14 @@ class CBackend:
ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0]) ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0])
if stride != 1: if stride != 1:
instr = 'maskStoreS' if mask != True else 'storeS' # NOQA instr = ('maskStreamS' if nontemporal and 'maskStreamS' in self._vector_instruction_set else
'maskStoreS') if mask != True else \
('streamS' if nontemporal and 'streamS' in self._vector_instruction_set else 'storeS') # NOQA
return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs), return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
stride, printed_mask, **self._kwargs) + ';' stride, printed_mask, **self._kwargs) + ';'
pre_code = '' pre_code = ''
if nontemporal and 'cachelineZero' in self._vector_instruction_set: if nontemporal and 'cachelineZero' in self._vector_instruction_set and mask == True: # NOQA
first_cond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0" first_cond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0"
offset = sp.Add(*[sp.Symbol(LoopOverCoordinate.get_loop_counter_name(i)) offset = sp.Add(*[sp.Symbol(LoopOverCoordinate.get_loop_counter_name(i))
* node.lhs.args[0].field.spatial_strides[i] for i in * node.lhs.args[0].field.spatial_strides[i] for i in
...@@ -338,17 +332,26 @@ class CBackend: ...@@ -338,17 +332,26 @@ class CBackend:
code2 = self._vector_instruction_set['flushCacheline'].format( code2 = self._vector_instruction_set['flushCacheline'].format(
ptr, self.sympy_printer.doprint(rhs), **self._kwargs) + ';' ptr, self.sympy_printer.doprint(rhs), **self._kwargs) + ';'
code = f"{code}\nif ({flushcond}) {{\n\t{code2}\n}}" code = f"{code}\nif ({flushcond}) {{\n\t{code2}\n}}"
elif nontemporal and 'storeAAndFlushCacheline' in self._vector_instruction_set: elif aligned and nontemporal and 'storeAAndFlushCacheline' in self._vector_instruction_set:
tmpvar = '_tmp_' + hashlib.sha1(self.sympy_printer.doprint(rhs).encode('ascii')).hexdigest()[:8] lhs_hash = hashlib.sha1(self.sympy_printer.doprint(node.lhs).encode('ascii')).hexdigest()[:8]
rhs_hash = hashlib.sha1(self.sympy_printer.doprint(rhs).encode('ascii')).hexdigest()[:8]
tmpvar = f'_tmp_{lhs_hash}_{rhs_hash}'
code = 'const ' + self._print(node.lhs.dtype).replace(' const', '') + ' ' + tmpvar + ' = ' \ code = 'const ' + self._print(node.lhs.dtype).replace(' const', '') + ' ' + tmpvar + ' = ' \
+ self.sympy_printer.doprint(rhs) + ';' + self.sympy_printer.doprint(rhs) + ';'
code1 = self._vector_instruction_set[instr].format(ptr, tmpvar, printed_mask, **self._kwargs) + ';' code1 = self._vector_instruction_set[instr].format(ptr, tmpvar, printed_mask, **self._kwargs) + ';'
code2 = self._vector_instruction_set['storeAAndFlushCacheline'].format(ptr, tmpvar, printed_mask, maskStore, store, load = 'maskStoreAAndFlushCacheline', 'storeAAndFlushCacheline', 'loadA'
**self._kwargs) + ';' instr2 = maskStore if mask != True else store # NOQA
if instr2 not in self._vector_instruction_set:
self._vector_instruction_set[maskStore] = self._vector_instruction_set[store].format(
'{0}', self._vector_instruction_set['blendv'].format(
self._vector_instruction_set[load].format('{0}', **self._kwargs),
'{1}', '{2}', **self._kwargs),
**self._kwargs)
code2 = self._vector_instruction_set[instr2].format(ptr, tmpvar, printed_mask, **self._kwargs) + ';'
code += f"\nif ({flushcond}) {{\n\t{code2}\n}} else {{\n\t{code1}\n}}" code += f"\nif ({flushcond}) {{\n\t{code2}\n}} else {{\n\t{code1}\n}}"
return pre_code + code return pre_code + code
else: else:
return f"{self.sympy_printer.doprint(node.lhs)} = {self.sympy_printer.doprint(node.rhs)};" return f"{printed_lhs} = {printed_rhs};"
def _print_NontemporalFence(self, _): def _print_NontemporalFence(self, _):
if 'streamFence' in self._vector_instruction_set: if 'streamFence' in self._vector_instruction_set:
...@@ -443,10 +446,22 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -443,10 +446,22 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Pow(self, expr): def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication""" """Don't use std::pow function, for small integer exponents, write as multiplication"""
if not expr.free_symbols: # Ideally the printer has as little logic as possible. Therefore,
raise NotImplementedError("This pow should be simplified already?") # powers should be rewritten as `DivFunc`s / unevaluated `Mul`s before
# return self._typed_number(expr.evalf(), get_type_of_expression(expr.base)) # printing. `NodeCollection` offers a convenience function to do just
return super(CustomSympyPrinter, self)._print_Pow(expr) # that. However, `cut_loops` rewrites unevaluated multiplications as
# `Pow`s again. Neither `deepcopy` nor `func(*args)` are suited to
# rebuild unevaluated expressions. Therefore, as long as we stick with
# SymPy, this is the only way to avoid printing `pow`s.
exp = expr.exp.expr if isinstance(expr.exp, CastFunc) else expr.exp
one_type = expr.base.dtype if hasattr(expr.base, "dtype") else get_type_of_expression(expr.base)
if exp.is_integer and exp.is_number and (0 < exp <= 8):
return f"({self._print(sp.Mul(*[expr.base] * exp, evaluate=False))})"
elif exp.is_integer and exp.is_number and (-8 <= exp < 0):
return f"{self._typed_number(1, one_type)} / ({self._print(sp.Mul(*([expr.base] * -exp), evaluate=False))})"
else:
return super(CustomSympyPrinter, self)._print_Pow(expr)
# TODO don't print ones in sp.Mul # TODO don't print ones in sp.Mul
...@@ -485,11 +500,16 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -485,11 +500,16 @@ class CustomSympyPrinter(CCodePrinter):
return expr.to_c(self._print) return expr.to_c(self._print)
if isinstance(expr, ReinterpretCastFunc): if isinstance(expr, ReinterpretCastFunc):
arg, data_type = expr.args arg, data_type = expr.args
return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))" if isinstance(data_type, PointerType):
const_str = "const" if data_type.const else ""
return f"(({const_str} {self._print(data_type.base_type)} *)(& {self._print(arg)}))"
else:
return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))"
elif isinstance(expr, AddressOf): elif isinstance(expr, AddressOf):
assert len(expr.args) == 1, "address_of must only have one argument" assert len(expr.args) == 1, "address_of must only have one argument"
return f"&({self._print(expr.args[0])})" return f"&({self._print(expr.args[0])})"
elif isinstance(expr, CastFunc): elif isinstance(expr, CastFunc):
cast = "(({data_type})({code}))"
arg, data_type = expr.args arg, data_type = expr.args
if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)): if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)):
return self._typed_number(arg, data_type) return self._typed_number(arg, data_type)
...@@ -504,17 +524,20 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -504,17 +524,20 @@ class CustomSympyPrinter(CCodePrinter):
for k in known: for k in known:
if k in code: if k in code:
return code.replace(k, f'{k}f') return code.replace(k, f'{k}f')
# Powers of small integers are printed as divisions/multiplications.
if '/' in code or '*' in code:
return cast.format(data_type=data_type, code=code)
raise ValueError(f"{code} doesn't give {known=} function back.") raise ValueError(f"{code} doesn't give {known=} function back.")
else: else:
return f"(({data_type})({self._print(arg)}))" return cast.format(data_type=data_type, code=self._print(arg))
elif isinstance(expr, fast_division): elif isinstance(expr, fast_division):
return f"({self._print(expr.args[0] / expr.args[1])})" raise ValueError("fast_division is only supported for Taget.GPU")
elif isinstance(expr, fast_sqrt): elif isinstance(expr, fast_sqrt):
return f"({self._print(sp.sqrt(expr.args[0]))})" raise ValueError("fast_sqrt is only supported for Taget.GPU")
elif isinstance(expr, fast_inv_sqrt):
raise ValueError("fast_inv_sqrt is only supported for Taget.GPU")
elif isinstance(expr, vec_any) or isinstance(expr, vec_all): elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
return self._print(expr.args[0]) return self._print(expr.args[0])
elif isinstance(expr, fast_inv_sqrt):
return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
elif isinstance(expr, sp.Abs): elif isinstance(expr, sp.Abs):
return f"abs({self._print(expr.args[0])})" return f"abs({self._print(expr.args[0])})"
elif isinstance(expr, sp.Mod): elif isinstance(expr, sp.Mod):
...@@ -593,7 +616,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -593,7 +616,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
return None return None
def _print_Abs(self, expr): def _print_Abs(self, expr):
if 'abs' in self.instruction_set and isinstance(expr.args[0], VectorMemoryAccess): if isinstance(get_type_of_expression(expr), (VectorType, VectorMemoryAccess)):
return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs) return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs)
return super()._print_Abs(expr) return super()._print_Abs(expr)
...@@ -627,6 +650,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -627,6 +650,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
def _print_CastFunc(self, expr): def _print_CastFunc(self, expr):
arg, data_type = expr.args arg, data_type = expr.args
if type(data_type) is VectorType: if type(data_type) is VectorType:
base_type = data_type.base_type
# vector_memory_access is a cast_func itself so it should't be directly inside a cast_func # vector_memory_access is a cast_func itself so it should't be directly inside a cast_func
assert not isinstance(arg, VectorMemoryAccess) assert not isinstance(arg, VectorMemoryAccess)
if isinstance(arg, sp.Tuple): if isinstance(arg, sp.Tuple):
...@@ -646,19 +670,18 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -646,19 +670,18 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
elif isinstance(arg, TypedSymbol): elif isinstance(arg, TypedSymbol):
return self._typed_vectorized_symbol(arg, data_type) return self._typed_vectorized_symbol(arg, data_type)
elif isinstance(arg, (InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)) \ elif isinstance(arg, (InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)) \
and data_type == BasicType('float32'): and base_type == BasicType('float32'):
raise NotImplementedError('Vectorizer is not tested for trigonometric functions yet') raise NotImplementedError('Vectorizer is not tested for trigonometric functions yet')
# known = self.known_functions[arg.__class__.__name__.lower()] # known = self.known_functions[arg.__class__.__name__.lower()]
# code = self._print(arg) # code = self._print(arg)
# return code.replace(known, f"{known}f") # return code.replace(known, f"{known}f")
elif isinstance(arg, sp.Pow) and data_type == BasicType('float32'): elif isinstance(arg, sp.Pow):
raise NotImplementedError('Vectorizer cannot print casted aka. not double pow') if base_type == BasicType('float32') or base_type == BasicType('float64'):
# known = ['sqrt', 'cbrt', 'pow'] return self._print_Pow(arg)
# code = self._print(arg) else:
# for k in known: raise NotImplementedError('Integer Pow is not implemented')
# if k in code: elif isinstance(arg, sp.UnevaluatedExpr):
# return code.replace(k, f'{k}f') return self._print(arg.args[0])
# raise ValueError(f"{code} doesn't give {known=} function back.")
else: else:
raise NotImplementedError('Vectorizer cannot cast between different datatypes') raise NotImplementedError('Vectorizer cannot cast between different datatypes')
# to_type = self.instruction_set['suffix'][data_type.base_type.c_name] # to_type = self.instruction_set['suffix'][data_type.base_type.c_name]
...@@ -681,21 +704,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -681,21 +704,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self.instruction_set['/'].format(self._print(expr.divisor), self._print(expr.dividend), result = self.instruction_set['/'].format(self._print(expr.divisor), self._print(expr.dividend),
**self._kwargs) **self._kwargs)
return result return result
elif expr.func == fast_division: elif isinstance(expr, fast_division):
result = self._scalarFallback('_print_Function', expr) raise ValueError("fast_division is only supported for Taget.GPU")
if not result: elif isinstance(expr, fast_sqrt):
result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]), raise ValueError("fast_sqrt is only supported for Taget.GPU")
**self._kwargs) elif isinstance(expr, fast_inv_sqrt):
return result raise ValueError("fast_inv_sqrt is only supported for Taget.GPU")
elif expr.func == fast_sqrt:
return f"({self._print(sp.sqrt(expr.args[0]))})"
elif expr.func == fast_inv_sqrt:
result = self._scalarFallback('_print_Function', expr)
if not result:
if 'rsqrt' in self.instruction_set:
return self.instruction_set['rsqrt'].format(self._print(expr.args[0]), **self._kwargs)
else:
return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
elif isinstance(expr, vec_any) or isinstance(expr, vec_all): elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
instr = 'any' if isinstance(expr, vec_any) else 'all' instr = 'any' if isinstance(expr, vec_any) else 'all'
expr_type = get_type_of_expression(expr.args[0]) expr_type = get_type_of_expression(expr.args[0])
...@@ -777,31 +791,31 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -777,31 +791,31 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
return processed return processed
def _print_Pow(self, expr): def _print_Pow(self, expr):
result = self._scalarFallback('_print_Pow', expr) # Due to loop cutting sp.Mul is evaluated again.
try:
result = self._scalarFallback('_print_Pow', expr)
except ValueError:
result = None
if result: if result:
return result return result
one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs) one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs)
root = self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs)
if isinstance(expr.exp, CastFunc) and expr.exp.args[0].is_number: if isinstance(expr.exp, CastFunc) and expr.exp.args[0].is_number:
exp = expr.exp.args[0] exp = expr.exp.args[0]
else: else:
exp = expr.exp exp = expr.exp
# TODO the printer should not have any intelligence like this.
# TODO To remove all of these cases the vectoriser needs to be reworked. See loop cutting
if exp.is_integer and exp.is_number and 0 < exp < 8: if exp.is_integer and exp.is_number and 0 < exp < 8:
return "(" + self._print(sp.Mul(*[expr.base] * exp, evaluate=False)) + ")" return self._print(sp.Mul(*[expr.base] * exp, evaluate=False))
elif exp == -1:
one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs)
return self.instruction_set['/'].format(one, self._print(expr.base), **self._kwargs)
elif exp == 0.5: elif exp == 0.5:
return self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs) return root
elif exp == -0.5: elif exp == -0.5:
root = self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs)
return self.instruction_set['/'].format(one, root, **self._kwargs) return self.instruction_set['/'].format(one, root, **self._kwargs)
elif exp.is_integer and exp.is_number and - 8 < exp < 0:
return self.instruction_set['/'].format(one,
self._print(sp.Mul(*[expr.base] * (-exp), evaluate=False)),
**self._kwargs)
else: else:
raise ValueError("Generic exponential not supported: " + str(expr)) raise ValueError("Generic exponential not supported: " + str(expr))
...@@ -809,7 +823,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -809,7 +823,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
# noinspection PyProtectedMember # noinspection PyProtectedMember
from sympy.core.mul import _keep_coeff from sympy.core.mul import _keep_coeff
result = self._scalarFallback('_print_Mul', expr) if not inside_add:
result = self._scalarFallback('_print_Mul', expr)
else:
result = None
if result: if result:
return result return result
......
from pystencils.typing import CFunction
def get_argument_string(function_shortcut, last=''): def get_argument_string(function_shortcut, last=''):
args = function_shortcut[function_shortcut.index('[') + 1: -1] args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "(" arg_string = "("
...@@ -30,14 +33,11 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'): ...@@ -30,14 +33,11 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'):
'sqrt': 'fsqrt_v[0]', 'sqrt': 'fsqrt_v[0]',
'loadU': f'le{bits[data_type]}_v[0]', 'loadU': f'le{bits[data_type]}_v[0]',
'loadA': f'le{bits[data_type]}_v[0]',
'storeU': f'se{bits[data_type]}_v[0, 1]', 'storeU': f'se{bits[data_type]}_v[0, 1]',
'storeA': f'se{bits[data_type]}_v[0, 1]',
'maskStoreU': f'se{bits[data_type]}_v[2, 0, 1]', 'maskStoreU': f'se{bits[data_type]}_v[2, 0, 1]',
'maskStoreA': f'se{bits[data_type]}_v[2, 0, 1]',
'loadS': f'lse{bits[data_type]}_v[0, 1]', 'loadS': f'lse{bits[data_type]}_v[0, 1]',
'storeS': f'sse{bits[data_type]}_v[0, 2, 1]', 'storeS': f'sse{bits[data_type]}_v[0, 2, 1]',
'maskStoreS': f'sse{bits[data_type]}_v[2, 0, 3, 1]', 'maskStoreS': f'sse{bits[data_type]}_v[3, 0, 2, 1]',
'abs': 'fabs_v[0]', 'abs': 'fabs_v[0]',
'==': 'mfeq_vv[0, 1]', '==': 'mfeq_vv[0, 1]',
...@@ -50,8 +50,8 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'): ...@@ -50,8 +50,8 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'):
'|': 'mor_mm[0, 1]', '|': 'mor_mm[0, 1]',
'blendv': 'merge_vvm[2, 0, 1]', 'blendv': 'merge_vvm[2, 0, 1]',
'any': 'popc_m[0]', 'any': 'cpop_m[0]',
'all': 'popc_m[0]', 'all': 'cpop_m[0]',
} }
result = dict() result = dict()
...@@ -81,7 +81,6 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'): ...@@ -81,7 +81,6 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'):
result[intrinsic_id] = prefix + name + suffix2 + arg_string result[intrinsic_id] = prefix + name + suffix2 + arg_string
from pystencils.backends.cbackend import CFunction
result['width'] = CFunction(width, "int") result['width'] = CFunction(width, "int")
result['intwidth'] = CFunction(intwidth, "int") result['intwidth'] = CFunction(intwidth, "int")
...@@ -92,7 +91,7 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'): ...@@ -92,7 +91,7 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'):
result['storeS'] = result['storeS'].replace('{2}', f'{{2}}*{bits[data_type]//8}') result['storeS'] = result['storeS'].replace('{2}', f'{{2}}*{bits[data_type]//8}')
result['loadS'] = result['loadS'].replace('{1}', f'{{1}}*{bits[data_type]//8}') result['loadS'] = result['loadS'].replace('{1}', f'{{1}}*{bits[data_type]//8}')
result['maskStoreS'] = result['maskStoreS'].replace('{3}', f'{{3}}*{bits[data_type]//8}') result['maskStoreS'] = result['maskStoreS'].replace('{2}', f'{{2}}*{bits[data_type]//8}')
result['+int'] = f"vadd_vv_i{bits['int']}m1({{0}}, {{1}}, {int_vl})" result['+int'] = f"vadd_vv_i{bits['int']}m1({{0}}, {{1}}, {int_vl})"
...@@ -101,9 +100,12 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'): ...@@ -101,9 +100,12 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'):
result['int'] = f'vint{bits["int"]}m1_t' result['int'] = f'vint{bits["int"]}m1_t'
result['bool'] = f'vbool{bits[data_type]}_t' result['bool'] = f'vbool{bits[data_type]}_t'
result['headers'] = ['<riscv_vector.h>'] result['headers'] = ['<riscv_vector.h>', '"riscv_v_helpers.h"']
result['any'] += ' > 0x0' result['any'] += ' > 0x0'
result['all'] += f' == vsetvl_e{bits[data_type]}m1({vl})' result['all'] += f' == vsetvl_e{bits[data_type]}m1({vl})'
result['cachelineSize'] = 'cachelineSize()'
result['cachelineZero'] = 'cachelineZero((void*) {0})'
return result return result
import math
import os import os
import platform import platform
from ctypes import CDLL from ctypes import CDLL, c_int, c_size_t, sizeof, byref
from warnings import warn
import numpy as np
from pystencils.backends.x86_instruction_sets import get_vector_instruction_set_x86 from pystencils.backends.x86_instruction_sets import get_vector_instruction_set_x86
from pystencils.backends.arm_instruction_sets import get_vector_instruction_set_arm from pystencils.backends.arm_instruction_sets import get_vector_instruction_set_arm
from pystencils.backends.ppc_instruction_sets import get_vector_instruction_set_ppc from pystencils.backends.ppc_instruction_sets import get_vector_instruction_set_ppc
from pystencils.backends.riscv_instruction_sets import get_vector_instruction_set_riscv from pystencils.backends.riscv_instruction_sets import get_vector_instruction_set_riscv
from pystencils.cache import memorycache
from pystencils.typing import numpy_name_to_c
def get_vector_instruction_set(data_type='double', instruction_set='avx'): def get_vector_instruction_set(data_type='double', instruction_set='avx'):
if instruction_set in ['neon'] or instruction_set.startswith('sve'): if data_type == 'float':
return get_vector_instruction_set_arm(data_type, instruction_set) warn(f"Ambiguous input for data_type: {data_type}. For single precision please use float32. "
f"For more information please take numpy.dtype as a reference. This input will not be supported in future "
f"releases")
data_type = 'float64'
type_name = numpy_name_to_c(np.dtype(data_type).name)
if instruction_set in ['neon', 'sme'] or instruction_set.startswith('sve'):
return get_vector_instruction_set_arm(type_name, instruction_set)
elif instruction_set in ['vsx']: elif instruction_set in ['vsx']:
return get_vector_instruction_set_ppc(data_type, instruction_set) return get_vector_instruction_set_ppc(type_name, instruction_set)
elif instruction_set in ['rvv']: elif instruction_set in ['rvv']:
return get_vector_instruction_set_riscv(data_type, instruction_set) return get_vector_instruction_set_riscv(type_name, instruction_set)
else: else:
return get_vector_instruction_set_x86(data_type, instruction_set) return get_vector_instruction_set_x86(type_name, instruction_set)
_cache = None
_cachelinesize = None
@memorycache
def get_supported_instruction_sets(): def get_supported_instruction_sets():
"""List of supported instruction sets on current hardware, or None if query failed.""" """List of supported instruction sets on current hardware, or None if query failed."""
global _cache
if _cache is not None:
return _cache.copy()
if 'PYSTENCILS_SIMD' in os.environ: if 'PYSTENCILS_SIMD' in os.environ:
return os.environ['PYSTENCILS_SIMD'].split(',') return os.environ['PYSTENCILS_SIMD'].split(',')
if platform.system() == 'Darwin' and platform.machine() == 'arm64': # not supported by cpuinfo if platform.system() == 'Darwin' and platform.machine() == 'arm64':
result = ['neon']
libc = CDLL('/usr/lib/libc.dylib')
value = c_int(0)
size = c_size_t(sizeof(value))
status = libc.sysctlbyname(b"hw.optional.arm.FEAT_SME", byref(value), byref(size), None, 0)
if status == 0 and value.value == 1:
result.insert(0, "sme")
return result
elif platform.system() == 'Windows' and platform.machine() == 'ARM64':
return ['neon'] return ['neon']
elif platform.system() == 'Linux' and platform.machine().startswith('riscv'): # not supported by cpuinfo elif platform.system() == 'Linux' and platform.machine() == 'aarch64':
result = ['neon'] # Neon is mandatory on 64-bit ARM
libc = CDLL('libc.so.6')
hwcap = libc.getauxval(16) # AT_HWCAP
hwcap2 = libc.getauxval(26) # AT_HWCAP2
if hwcap & (1 << 22): # HWCAP_SVE
if hwcap2 & (1 << 1): # HWCAP2_SVE2
name = 'sve2'
else:
name = 'sve'
length = 8 * libc.prctl(51, 0, 0, 0, 0) # PR_SVE_GET_VL
if length < 0:
raise OSError("SVE length query failed")
while length >= 128:
result.append(f"{name}{length}")
length //= 2
result.append(name)
if hwcap2 & (1 << 23): # HWCAP2_SME
result.insert(0, "sme") # prepend to list so it is not automatically chosen as best instruction set
return result
elif platform.system() == 'Linux' and platform.machine().startswith('riscv'):
libc = CDLL('libc.so.6') libc = CDLL('libc.so.6')
hwcap = libc.getauxval(16) # AT_HWCAP hwcap = libc.getauxval(16) # AT_HWCAP
hwcap_isa_v = 1 << (ord('V') - ord('A')) # COMPAT_HWCAP_ISA_V hwcap_isa_v = 1 << (ord('V') - ord('A')) # COMPAT_HWCAP_ISA_V
return ['rvv'] if hwcap & hwcap_isa_v else [] return ['rvv'] if hwcap & hwcap_isa_v else []
elif platform.machine().startswith('ppc64'): # no flags reported by cpuinfo elif platform.system() == 'Linux' and platform.machine().startswith('ppc64'):
import subprocess libc = CDLL('libc.so.6')
import tempfile hwcap = libc.getauxval(16) # AT_HWCAP
from pystencils.cpu.cpujit import get_compiler_config return ['vsx'] if hwcap & 0x00000080 else [] # PPC_FEATURE_HAS_VSX
f = tempfile.NamedTemporaryFile(suffix='.cpp') elif platform.machine() in ['x86_64', 'x86', 'AMD64', 'i386']:
command = [get_compiler_config()['command'], '-mcpu=native', '-dM', '-E', f.name] try:
macros = subprocess.check_output(command, input='', text=True) from cpuinfo import get_cpu_info
if '#define __VSX__' in macros and '#define __ALTIVEC__' in macros: except ImportError:
_cache = ['vsx'] return None
else:
_cache = []
return _cache.copy()
try:
from cpuinfo import get_cpu_info
except ImportError:
return None
result = [] result = []
required_sse_flags = {'sse', 'sse2', 'ssse3', 'sse4_1', 'sse4_2'} required_sse_flags = {'sse', 'sse2', 'ssse3', 'sse4_1', 'sse4_2'}
required_avx_flags = {'avx', 'avx2'} required_avx_flags = {'avx', 'avx2'}
required_avx512_flags = {'avx512f'} required_avx512_flags = {'avx512f'}
required_neon_flags = {'neon'} possible_avx512vl_flags = {'avx512vl', 'avx10_1'}
required_sve_flags = {'sve'} flags = set(get_cpu_info()['flags'])
flags = set(get_cpu_info()['flags']) if flags.issuperset(required_sse_flags):
if flags.issuperset(required_sse_flags): result.append("sse")
result.append("sse") if flags.issuperset(required_avx_flags):
if flags.issuperset(required_avx_flags): result.append("avx")
result.append("avx") if flags.issuperset(required_avx512_flags):
if flags.issuperset(required_avx512_flags): result.append("avx512")
result.append("avx512") if not flags.isdisjoint(possible_avx512vl_flags):
if flags.issuperset(required_neon_flags): result.append("avx512vl")
result.append("neon") return result
if flags.issuperset(required_sve_flags): else:
if platform.system() == 'Linux': raise NotImplementedError('Instruction set detection for %s on %s is not implemented' %
libc = CDLL('libc.so.6') (platform.system(), platform.machine()))
native_length = 8 * libc.prctl(51, 0, 0, 0, 0) # PR_SVE_GET_VL
if native_length < 0:
raise OSError("SVE length query failed")
pwr2_length = int(2**math.floor(math.log2(native_length)))
if pwr2_length % 256 == 0:
result.append(f"sve{pwr2_length//2}")
if native_length != pwr2_length:
result.append(f"sve{pwr2_length}")
result.append(f"sve{native_length}")
result.append("sve")
return result
@memorycache
def get_cacheline_size(instruction_set): def get_cacheline_size(instruction_set):
"""Get the size (in bytes) of a cache block that can be zeroed without memory access. """Get the size (in bytes) of a cache block that can be zeroed without memory access.
Usually, this is identical to the cache line size.""" Usually, this is identical to the cache line size."""
global _cachelinesize
instruction_sets = get_vector_instruction_set('double', instruction_set) instruction_sets = get_vector_instruction_set('double', instruction_set)
if 'cachelineSize' not in instruction_sets: if 'cachelineSize' not in instruction_sets:
return None return None
if _cachelinesize is not None:
return _cachelinesize
import pystencils as ps import pystencils as ps
from pystencils.astnodes import SympyAssignment from pystencils.astnodes import SympyAssignment
...@@ -108,5 +123,4 @@ def get_cacheline_size(instruction_set): ...@@ -108,5 +123,4 @@ def get_cacheline_size(instruction_set):
ast = ps.create_kernel(ass, cpu_vectorize_info={'instruction_set': instruction_set}) ast = ps.create_kernel(ass, cpu_vectorize_info={'instruction_set': instruction_set})
kernel = ast.compile() kernel = ast.compile()
kernel(**{f.name: arr, CachelineSize.symbol.name: 0}) kernel(**{f.name: arr, CachelineSize.symbol.name: 0})
_cachelinesize = int(arr[0, 0]) return int(arr[0, 0])
return _cachelinesize
...@@ -57,8 +57,8 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): ...@@ -57,8 +57,8 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
'storeU': 'storeu[0,1]', 'storeU': 'storeu[0,1]',
'storeA': 'store[0,1]', 'storeA': 'store[0,1]',
'stream': 'stream[0,1]', 'stream': 'stream[0,1]',
'maskStoreA': 'mask_store[0, 2, 1]' if instruction_set == 'avx512' else 'maskstore[0, 2, 1]', 'maskStoreA': 'mask_store[0, 2, 1]' if instruction_set.startswith('avx512') else 'maskstore[0, 2, 1]',
'maskStoreU': 'mask_storeu[0, 2, 1]' if instruction_set == 'avx512' else 'maskstore[0, 2, 1]', 'maskStoreU': 'mask_storeu[0, 2, 1]' if instruction_set.startswith('avx512') else 'maskstore[0, 2, 1]',
} }
for comparison_op, constant in comparisons.items(): for comparison_op, constant in comparisons.items():
...@@ -66,6 +66,7 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): ...@@ -66,6 +66,7 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
headers = { headers = {
'avx512': ['<immintrin.h>'], 'avx512': ['<immintrin.h>'],
'avx512vl': ['<immintrin.h>'],
'avx': ['<immintrin.h>'], 'avx': ['<immintrin.h>'],
'sse': ['<immintrin.h>', '<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>', 'sse': ['<immintrin.h>', '<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>',
'<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>'] '<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>']
...@@ -79,6 +80,7 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): ...@@ -79,6 +80,7 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
prefix = { prefix = {
'sse': '_mm', 'sse': '_mm',
'avx': '_mm256', 'avx': '_mm256',
'avx512vl': '_mm256',
'avx512': '_mm512', 'avx512': '_mm512',
} }
...@@ -89,6 +91,9 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): ...@@ -89,6 +91,9 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
("double", "avx"): 4, ("double", "avx"): 4,
("float", "avx"): 8, ("float", "avx"): 8,
("int", "avx"): 8, ("int", "avx"): 8,
("double", "avx512vl"): 4,
("float", "avx512vl"): 8,
("int", "avx512vl"): 8,
("double", "avx512"): 8, ("double", "avx512"): 8,
("float", "avx512"): 16, ("float", "avx512"): 16,
("int", "avx512"): 16, ("int", "avx512"): 16,
...@@ -110,7 +115,7 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): ...@@ -110,7 +115,7 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
suf = suffix[data_type] suf = suffix[data_type]
arg_string = get_argument_string(intrinsic_id, result['width'], function_shortcut) arg_string = get_argument_string(intrinsic_id, result['width'], function_shortcut)
mask_suffix = '_mask' if instruction_set == 'avx512' and intrinsic_id in comparisons.keys() else '' mask_suffix = '_mask' if instruction_set.startswith('avx512') and intrinsic_id in comparisons.keys() else ''
result[intrinsic_id] = pre + "_" + name + "_" + suf + mask_suffix + arg_string result[intrinsic_id] = pre + "_" + name + "_" + suf + mask_suffix + arg_string
bit_width = result['width'] * (64 if data_type == 'double' else 32) bit_width = result['width'] * (64 if data_type == 'double' else 32)
...@@ -123,29 +128,45 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): ...@@ -123,29 +128,45 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
result['any'] = f"{pre}_movemask_{suf}({{0}}) > 0" result['any'] = f"{pre}_movemask_{suf}({{0}}) > 0"
result['all'] = f"{pre}_movemask_{suf}({{0}}) == {hex(2**result['width']-1)}" result['all'] = f"{pre}_movemask_{suf}({{0}}) == {hex(2**result['width']-1)}"
if instruction_set == 'avx512': setsuf = "x" if bit_width < 512 and bit_width // result['width'] == 64 else ""
if instruction_set.startswith('avx512'):
size = result['width'] size = result['width']
result['&'] = f'_kand_mask{size}({{0}}, {{1}})' masksize = max(size, 8)
result['|'] = f'_kor_mask{size}({{0}}, {{1}})' result['&'] = f'_kand_mask{masksize}({{0}}, {{1}})'
result['any'] = f'!_ktestz_mask{size}_u8({{0}}, {{0}})' result['|'] = f'_kor_mask{masksize}({{0}}, {{1}})'
result['all'] = f'_kortestc_mask{size}_u8({{0}}, {{0}})' result['any'] = f'!_ktestz_mask{masksize}_u8({{0}}, {{0}})'
result['all'] = f'_kortestc_mask{masksize}_u8({{0}}, {{0}})'
result['blendv'] = f'{pre}_mask_blend_{suf}({{2}}, {{0}}, {{1}})' result['blendv'] = f'{pre}_mask_blend_{suf}({{2}}, {{0}}, {{1}})'
result['rsqrt'] = f"{pre}_rsqrt14_{suf}({{0}})" result['rsqrt'] = f"{pre}_rsqrt14_{suf}({{0}})"
result['abs'] = f"{pre}_abs_{suf}({{0}})" result['bool'] = f"__mmask{masksize}"
result['bool'] = f"__mmask{size}"
params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)]) params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)])
result['makeVecBool'] = f"__mmask8(({params}) )" result['makeVecBool'] = f"__mmask8(({params}) )"
params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)]) params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)])
result['makeVecConstBool'] = f"__mmask8(({params}) )" result['makeVecConstBool'] = f"__mmask8(({params}) )"
vindex = f'{pre}_set_epi{bit_width//size}(' + ', '.join([str(i) for i in range(result['width'])][::-1]) + ')' vindex = f'{pre}_set_epi{bit_width//size}{setsuf}(' + \
vindex = f'{pre}_mullo_epi{bit_width//size}({vindex}, {pre}_set1_epi{bit_width//size}({{0}}))' ', '.join([str(i) for i in range(result['width'])][::-1]) + ')'
vindex = f'{pre}_mullo_epi{bit_width//size}({vindex}, {pre}_set1_epi{bit_width//size}{setsuf}({{0}}))'
scale = bit_width // size // 8
result['storeS'] = f'{pre}_i{bit_width//size}scatter_{suf}({{0}}, ' + vindex.format("{2}") + \ result['storeS'] = f'{pre}_i{bit_width//size}scatter_{suf}({{0}}, ' + vindex.format("{2}") + \
f', {{1}}, {64//size})' f', {{1}}, {scale})'
result['maskStoreS'] = f'{pre}_mask_i{bit_width//size}scatter_{suf}({{0}}, {{3}}, ' + vindex.format("{2}") + \ result['maskStoreS'] = f'{pre}_mask_i{bit_width//size}scatter_{suf}({{0}}, {{3}}, ' + vindex.format("{2}") + \
f', {{1}}, {64//size})' f', {{1}}, {scale})'
result['loadS'] = f'{pre}_i{bit_width//size}gather_{suf}(' + vindex.format("{1}") + f', {{0}}, {64//size})' if bit_width == 512:
result['loadS'] = f'{pre}_i{bit_width//size}gather_{suf}(' + vindex.format("{1}") + f', {{0}}, {scale})'
else:
result['loadS'] = f'{pre}_i{bit_width//size}gather_{suf}({{0}}, ' + vindex.format("{1}") + f', {scale})'
# abs intrinsic exists in 512 bits, but expands to a sequence. We generate that same sequence for 128 and 256 bits
if instruction_set == 'avx512':
result['abs'] = f"{pre}_abs_{suf}({{0}})"
else:
result['abs'] = f"{pre}_castsi{bit_width}_{suf}({pre}_and_si{bit_width}(" + \
f"{pre}_set1_epi{bit_width // result['width']}{setsuf}(0x7" + \
'f' * (bit_width // result['width'] // 4 - 1) + "), " + \
f"{pre}_cast{suf}_si{bit_width}({{0}})))"
if instruction_set == 'avx' and data_type == 'float': if instruction_set == 'avx' and data_type == 'float':
result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})" result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})"
......
File moved
...@@ -76,7 +76,7 @@ class Neumann(Boundary): ...@@ -76,7 +76,7 @@ class Neumann(Boundary):
return hash("Neumann") return hash("Neumann")
def __eq__(self, other): def __eq__(self, other):
return type(other) == Neumann return type(other) is Neumann
class Dirichlet(Boundary): class Dirichlet(Boundary):
......
...@@ -9,7 +9,7 @@ from pystencils.backends.cbackend import CustomCodeNode ...@@ -9,7 +9,7 @@ from pystencils.backends.cbackend import CustomCodeNode
from pystencils.boundaries.createindexlist import ( from pystencils.boundaries.createindexlist import (
create_boundary_index_array, numpy_data_type_for_boundary_object) create_boundary_index_array, numpy_data_type_for_boundary_object)
from pystencils.typing import TypedSymbol, create_type from pystencils.typing import TypedSymbol, create_type
from pystencils.datahandling.pycuda import PyCudaArrayHandler from pystencils.gpu.gpu_array_handler import GPUArrayHandler
from pystencils.field import Field from pystencils.field import Field
from pystencils.typing.typed_sympy import FieldPointerSymbol from pystencils.typing.typed_sympy import FieldPointerSymbol
...@@ -18,6 +18,7 @@ try: ...@@ -18,6 +18,7 @@ try:
import waLBerla as wlb import waLBerla as wlb
if wlb.cpp_available: if wlb.cpp_available:
from pystencils.datahandling.parallel_datahandling import ParallelDataHandling from pystencils.datahandling.parallel_datahandling import ParallelDataHandling
import cupy.cuda.runtime
else: else:
ParallelDataHandling = None ParallelDataHandling = None
except ImportError: except ImportError:
...@@ -34,11 +35,11 @@ class FlagInterface: ...@@ -34,11 +35,11 @@ class FlagInterface:
>>> dh = create_data_handling((4, 5)) >>> dh = create_data_handling((4, 5))
>>> fi = FlagInterface(dh, 'flag_field', np.uint8) >>> fi = FlagInterface(dh, 'flag_field', np.uint8)
>>> assert dh.has_data('flag_field') >>> assert dh.has_data('flag_field')
>>> fi.reserve_next_flag() >>> int(fi.reserve_next_flag())
2 2
>>> fi.reserve_flag(4) >>> int(fi.reserve_flag(4))
4 4
>>> fi.reserve_next_flag() >>> int(fi.reserve_next_flag())
8 8
""" """
...@@ -100,7 +101,7 @@ class BoundaryHandling: ...@@ -100,7 +101,7 @@ class BoundaryHandling:
self.flag_interface = fi if fi is not None else FlagInterface(data_handling, name + "Flags") self.flag_interface = fi if fi is not None else FlagInterface(data_handling, name + "Flags")
if ParallelDataHandling and isinstance(self.data_handling, ParallelDataHandling): if ParallelDataHandling and isinstance(self.data_handling, ParallelDataHandling):
array_handler = PyCudaArrayHandler() array_handler = GPUArrayHandler(cupy.cuda.runtime.getDevice())
else: else:
array_handler = self.data_handling.array_handler array_handler = self.data_handling.array_handler
...@@ -116,7 +117,8 @@ class BoundaryHandling: ...@@ -116,7 +117,8 @@ class BoundaryHandling:
for obj, cpu_arr in cpu_version.items(): for obj, cpu_arr in cpu_version.items():
if obj not in gpu_version or gpu_version[obj].shape != cpu_arr.shape: if obj not in gpu_version or gpu_version[obj].shape != cpu_arr.shape:
gpu_version[obj] = array_handler.to_gpu(cpu_arr) gpu_version[obj] = array_handler.empty(cpu_arr.shape, cpu_arr.dtype)
array_handler.upload(gpu_version[obj], cpu_arr)
else: else:
array_handler.upload(gpu_version[obj], cpu_arr) array_handler.upload(gpu_version[obj], cpu_arr)
...@@ -424,29 +426,30 @@ class BoundaryOffsetInfo(CustomCodeNode): ...@@ -424,29 +426,30 @@ class BoundaryOffsetInfo(CustomCodeNode):
code = "\n" code = "\n"
for i in range(dim): for i in range(dim):
offset_str = ", ".join([str(d[i]) for d in stencil]) offset_str = ", ".join([str(d[i]) for d in stencil])
code += "const int64_t %s [] = { %s };\n" % (offset_sym[i].name, offset_str) code += "const int32_t %s [] = { %s };\n" % (offset_sym[i].name, offset_str)
inv_dirs = [] inv_dirs = []
for direction in stencil: for direction in stencil:
inverse_dir = tuple([-i for i in direction]) inverse_dir = tuple([-i for i in direction])
inv_dirs.append(str(stencil.index(inverse_dir))) inv_dirs.append(str(stencil.index(inverse_dir)))
code += "const int64_t %s [] = { %s };\n" % (self.INV_DIR_SYMBOL.name, ", ".join(inv_dirs)) code += "const int32_t %s [] = { %s };\n" % (self.INV_DIR_SYMBOL.name, ", ".join(inv_dirs))
offset_symbols = BoundaryOffsetInfo._offset_symbols(dim) offset_symbols = BoundaryOffsetInfo._offset_symbols(dim)
super(BoundaryOffsetInfo, self).__init__(code, symbols_read=set(), super(BoundaryOffsetInfo, self).__init__(code, symbols_read=set(),
symbols_defined=set(offset_symbols + [self.INV_DIR_SYMBOL])) symbols_defined=set(offset_symbols + [self.INV_DIR_SYMBOL]))
@staticmethod @staticmethod
def _offset_symbols(dim): def _offset_symbols(dim):
return [TypedSymbol(f"c{d}", create_type(np.int64)) for d in ['x', 'y', 'z'][:dim]] return [TypedSymbol(f"c{d}", create_type(np.int32)) for d in ['x', 'y', 'z'][:dim]]
INV_DIR_SYMBOL = TypedSymbol("invdir", np.int64) INV_DIR_SYMBOL = TypedSymbol("invdir", np.int32)
def create_boundary_kernel(field, index_field, stencil, boundary_functor, target=Target.CPU, **kernel_creation_args): def create_boundary_kernel(field, index_field, stencil, boundary_functor, target=Target.CPU, **kernel_creation_args):
elements = [BoundaryOffsetInfo(stencil)] elements = [BoundaryOffsetInfo(stencil)]
dir_symbol = TypedSymbol("dir", np.int64) dir_symbol = TypedSymbol("dir", np.int32)
elements += [SympyAssignment(dir_symbol, index_field[0]('dir'))] elements += [SympyAssignment(dir_symbol, index_field[0]('dir'))]
elements += boundary_functor(field, direction_symbol=dir_symbol, index_field=index_field) elements += boundary_functor(field, direction_symbol=dir_symbol, index_field=index_field)
config = CreateKernelConfig(index_fields=[index_field], target=target, **kernel_creation_args) config = CreateKernelConfig(index_fields=[index_field], target=target, skip_independence_check=True,
**kernel_creation_args)
return create_kernel(elements, config=config) return create_kernel(elements, config=config)
...@@ -2,66 +2,83 @@ import warnings ...@@ -2,66 +2,83 @@ import warnings
import numpy as np import numpy as np
try: try:
# Try to import right away - assume compiled code is available import pyximport
# compile with: python setup.py build_ext --inplace --use-cython
from pystencils.boundaries.createindexlistcython import create_boundary_neighbor_index_list_2d, \
create_boundary_neighbor_index_list_3d, create_boundary_cell_index_list_2d, create_boundary_cell_index_list_3d
pyximport.install(language_level=3)
cython_funcs_available = True cython_funcs_available = True
except ImportError: except ImportError:
try: cython_funcs_available = False
# If not, try development mode and import via pyximport
import pyximport if cython_funcs_available:
from pystencils.boundaries.createindexlistcython import (
pyximport.install(language_level=3) create_boundary_neighbor_index_list_2d,
cython_funcs_available = True create_boundary_neighbor_index_list_3d,
except ImportError: create_boundary_cell_index_list_2d,
cython_funcs_available = False create_boundary_cell_index_list_3d,
if cython_funcs_available: )
from pystencils.boundaries.createindexlistcython import create_boundary_neighbor_index_list_2d, \
create_boundary_neighbor_index_list_3d, create_boundary_cell_index_list_2d, \
create_boundary_cell_index_list_3d
boundary_index_array_coordinate_names = ["x", "y", "z"] boundary_index_array_coordinate_names = ["x", "y", "z"]
direction_member_name = "dir" direction_member_name = "dir"
default_index_array_dtype = np.int32
def numpy_data_type_for_boundary_object(boundary_object, dim): def numpy_data_type_for_boundary_object(boundary_object, dim):
coordinate_names = boundary_index_array_coordinate_names[:dim] coordinate_names = boundary_index_array_coordinate_names[:dim]
return np.dtype([(name, np.int32) for name in coordinate_names] return np.dtype(
+ [(direction_member_name, np.int32)] [(name, default_index_array_dtype) for name in coordinate_names]
+ [(i[0], i[1].numpy_dtype) for i in boundary_object.additional_data], align=True) + [(direction_member_name, default_index_array_dtype)]
+ [(i[0], i[1].numpy_dtype) for i in boundary_object.additional_data],
align=True,
def _create_index_list_python(flag_field_arr, boundary_mask, )
fluid_mask, stencil, single_link, inner_or_boundary=False, nr_of_ghost_layers=None):
def _create_index_list_python(
flag_field_arr,
boundary_mask,
fluid_mask,
stencil,
single_link,
inner_or_boundary=False,
nr_of_ghost_layers=None,
):
if inner_or_boundary and nr_of_ghost_layers is None: if inner_or_boundary and nr_of_ghost_layers is None:
raise ValueError("If inner_or_boundary is set True the number of ghost layers " raise ValueError(
"around the inner domain has to be specified") "If inner_or_boundary is set True the number of ghost layers "
"around the inner domain has to be specified"
)
if nr_of_ghost_layers is None: if nr_of_ghost_layers is None:
nr_of_ghost_layers = 0 nr_of_ghost_layers = 0
coordinate_names = boundary_index_array_coordinate_names[:len(flag_field_arr.shape)] coordinate_names = boundary_index_array_coordinate_names[
index_arr_dtype = np.dtype([(name, np.int32) for name in coordinate_names] + [(direction_member_name, np.int32)]) : len(flag_field_arr.shape)
]
index_arr_dtype = np.dtype(
[(name, default_index_array_dtype) for name in coordinate_names]
+ [(direction_member_name, default_index_array_dtype)]
)
# boundary cells are extracted via np.where. To ensure continous memory access in the compute kernel these cells # boundary cells are extracted via np.where. To ensure continous memory access in the compute kernel these cells
# have to be sorted. # have to be sorted.
boundary_cells = np.transpose(np.nonzero(flag_field_arr == boundary_mask)) boundary_cells = np.transpose(np.nonzero(flag_field_arr == boundary_mask))
for i in range(len(flag_field_arr.shape)): for i in range(len(flag_field_arr.shape)):
boundary_cells = boundary_cells[boundary_cells[:, i].argsort(kind='mergesort')] boundary_cells = boundary_cells[boundary_cells[:, i].argsort(kind="mergesort")]
# First a set is created to save all fluid cells which are near boundary # First a set is created to save all fluid cells which are near boundary
fluid_cells = set() fluid_cells = set()
for cell in boundary_cells: for cell in boundary_cells:
cell = tuple(cell) cell = tuple(cell)
for dir_idx, direction in enumerate(stencil): for dir_idx, direction in enumerate(stencil):
neighbor_cell = tuple([cell_i + dir_i for cell_i, dir_i in zip(cell, direction)]) neighbor_cell = tuple(
[cell_i + dir_i for cell_i, dir_i in zip(cell, direction)]
)
# prevent out ouf bounds access. If boundary cell is at the border, some stencil directions would be out. # prevent out ouf bounds access. If boundary cell is at the border, some stencil directions would be out.
if any(not 0 + nr_of_ghost_layers <= e < upper - nr_of_ghost_layers if any(
for e, upper in zip(neighbor_cell, flag_field_arr.shape)): not 0 + nr_of_ghost_layers <= e < upper - nr_of_ghost_layers
for e, upper in zip(neighbor_cell, flag_field_arr.shape)
):
continue continue
if flag_field_arr[neighbor_cell] & fluid_mask: if flag_field_arr[neighbor_cell] & fluid_mask:
fluid_cells.add(neighbor_cell) fluid_cells.add(neighbor_cell)
...@@ -81,9 +98,14 @@ def _create_index_list_python(flag_field_arr, boundary_mask, ...@@ -81,9 +98,14 @@ def _create_index_list_python(flag_field_arr, boundary_mask,
cell = tuple(cell) cell = tuple(cell)
sum_cells = np.zeros(len(cell)) sum_cells = np.zeros(len(cell))
for dir_idx, direction in enumerate(stencil): for dir_idx, direction in enumerate(stencil):
neighbor_cell = tuple([cell_i + dir_i for cell_i, dir_i in zip(cell, direction)]) neighbor_cell = tuple(
[cell_i + dir_i for cell_i, dir_i in zip(cell, direction)]
)
# prevent out ouf bounds access. If boundary cell is at the border, some stencil directions would be out. # prevent out ouf bounds access. If boundary cell is at the border, some stencil directions would be out.
if any(not 0 <= e < upper for e, upper in zip(neighbor_cell, flag_field_arr.shape)): if any(
not 0 <= e < upper
for e, upper in zip(neighbor_cell, flag_field_arr.shape)
):
continue continue
if flag_field_arr[neighbor_cell] & checkmask: if flag_field_arr[neighbor_cell] & checkmask:
if single_link: if single_link:
...@@ -99,8 +121,15 @@ def _create_index_list_python(flag_field_arr, boundary_mask, ...@@ -99,8 +121,15 @@ def _create_index_list_python(flag_field_arr, boundary_mask,
return np.array(result, dtype=index_arr_dtype) return np.array(result, dtype=index_arr_dtype)
def create_boundary_index_list(flag_field, stencil, boundary_mask, fluid_mask, def create_boundary_index_list(
nr_of_ghost_layers=1, inner_or_boundary=True, single_link=False): flag_field,
stencil,
boundary_mask,
fluid_mask,
nr_of_ghost_layers=1,
inner_or_boundary=True,
single_link=False,
):
"""Creates a numpy array storing links (connections) between domain cells and boundary cells. """Creates a numpy array storing links (connections) between domain cells and boundary cells.
Args: Args:
...@@ -117,10 +146,20 @@ def create_boundary_index_list(flag_field, stencil, boundary_mask, fluid_mask, ...@@ -117,10 +146,20 @@ def create_boundary_index_list(flag_field, stencil, boundary_mask, fluid_mask,
""" """
dim = len(flag_field.shape) dim = len(flag_field.shape)
coordinate_names = boundary_index_array_coordinate_names[:dim] coordinate_names = boundary_index_array_coordinate_names[:dim]
index_arr_dtype = np.dtype([(name, np.int32) for name in coordinate_names] + [(direction_member_name, np.int32)]) index_arr_dtype = np.dtype(
[(name, default_index_array_dtype) for name in coordinate_names]
stencil = np.array(stencil, dtype=np.int32) + [(direction_member_name, default_index_array_dtype)]
args = (flag_field, nr_of_ghost_layers, boundary_mask, fluid_mask, stencil, single_link) )
stencil = np.array(stencil, dtype=default_index_array_dtype)
args = (
flag_field,
nr_of_ghost_layers,
boundary_mask,
fluid_mask,
stencil,
single_link,
)
args_no_gl = (flag_field, boundary_mask, fluid_mask, stencil, single_link) args_no_gl = (flag_field, boundary_mask, fluid_mask, stencil, single_link)
if cython_funcs_available: if cython_funcs_available:
...@@ -139,22 +178,42 @@ def create_boundary_index_list(flag_field, stencil, boundary_mask, fluid_mask, ...@@ -139,22 +178,42 @@ def create_boundary_index_list(flag_field, stencil, boundary_mask, fluid_mask,
return np.array(idx_list, dtype=index_arr_dtype) return np.array(idx_list, dtype=index_arr_dtype)
else: else:
if flag_field.size > 1e6: if flag_field.size > 1e6:
warnings.warn("Boundary setup may take very long! Consider installing cython to speed it up") warnings.warn(
return _create_index_list_python(*args_no_gl, inner_or_boundary=inner_or_boundary, "Boundary setup may take very long! Consider installing cython to speed it up"
nr_of_ghost_layers=nr_of_ghost_layers) )
return _create_index_list_python(
*args_no_gl,
def create_boundary_index_array(flag_field, stencil, boundary_mask, fluid_mask, boundary_object, inner_or_boundary=inner_or_boundary,
nr_of_ghost_layers=1, inner_or_boundary=True, single_link=False): nr_of_ghost_layers=nr_of_ghost_layers,
idx_array = create_boundary_index_list(flag_field, stencil, boundary_mask, fluid_mask, )
nr_of_ghost_layers, inner_or_boundary, single_link)
def create_boundary_index_array(
flag_field,
stencil,
boundary_mask,
fluid_mask,
boundary_object,
nr_of_ghost_layers=1,
inner_or_boundary=True,
single_link=False,
):
idx_array = create_boundary_index_list(
flag_field,
stencil,
boundary_mask,
fluid_mask,
nr_of_ghost_layers,
inner_or_boundary,
single_link,
)
dim = len(flag_field.shape) dim = len(flag_field.shape)
if boundary_object.additional_data: if boundary_object.additional_data:
coordinate_names = boundary_index_array_coordinate_names[:dim] coordinate_names = boundary_index_array_coordinate_names[:dim]
index_arr_dtype = numpy_data_type_for_boundary_object(boundary_object, dim) index_arr_dtype = numpy_data_type_for_boundary_object(boundary_object, dim)
extended_idx_field = np.empty(len(idx_array), dtype=index_arr_dtype) extended_idx_field = np.empty(len(idx_array), dtype=index_arr_dtype)
for prop in coordinate_names + ['dir']: for prop in coordinate_names + ["dir"]:
extended_idx_field[prop] = idx_array[prop] extended_idx_field[prop] = idx_array[prop]
idx_array = extended_idx_field idx_array = extended_idx_field
......
# distutils: language=c # cython: language_level=3str
# Workaround for cython bug
# see https://stackoverflow.com/questions/8024805/cython-compiled-c-extension-importerror-dynamic-module-does-not-define-init-fu
WORKAROUND = "Something"
import cython import cython
......
import os import os
from collections.abc import Hashable from collections.abc import Hashable
from functools import partial, wraps, lru_cache from functools import partial, wraps
from itertools import chain from itertools import chain
from functools import lru_cache as memorycache
from joblib import Memory from joblib import Memory
from appdirs import user_cache_dir from appdirs import user_cache_dir
...@@ -23,7 +25,7 @@ def _wrapper(wrapped_func, cached_func, *args, **kwargs): ...@@ -23,7 +25,7 @@ def _wrapper(wrapped_func, cached_func, *args, **kwargs):
def memorycache_if_hashable(maxsize=128, typed=False): def memorycache_if_hashable(maxsize=128, typed=False):
def wrapper(func): def wrapper(func):
return partial(_wrapper, func, lru_cache(maxsize, typed)(func)) return partial(_wrapper, func, memorycache(maxsize, typed)(func))
return wrapper return wrapper
...@@ -57,6 +59,14 @@ def sharedmethodcache(cache_id: str): ...@@ -57,6 +59,14 @@ def sharedmethodcache(cache_id: str):
return _decorator return _decorator
def clear_cache():
"""
Clears the pystencils cache created by joblib.
"""
memory = Memory(cache_dir, verbose=0)
memory.clear(warn=False)
# Disable memory cache: # Disable memory cache:
# disk_cache = lambda o: o # disk_cache = lambda o: o
# disk_cache_no_fallback = lambda o: o # disk_cache_no_fallback = lambda o: o
import warnings
from copy import copy from copy import copy
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from types import MappingProxyType from types import MappingProxyType
from typing import Union, Tuple, List, Dict, Callable, Any from typing import Union, Tuple, List, Dict, Callable, Any, DefaultDict, Iterable
from pystencils import Target, Backend, Field from pystencils import Target, Backend, Field
from pystencils.typing.typed_sympy import BasicType from pystencils.typing.typed_sympy import BasicType
from pystencils.typing.utilities import collate_types
import numpy as np import numpy as np
# TODO: There exists DTypeLike in NumPy which would be better than type for type hinting, to new at the moment
# from numpy.typing import DTypeLike
# TODO: CreateKernelConfig is bloated think of more classes better usage, factory whatever ... # TODO: CreateKernelConfig is bloated think of more classes better usage, factory whatever ...
# Proposition: CreateKernelConfigs Classes for different targets? # Proposition: CreateKernelConfigs Classes for different targets?
...@@ -30,17 +33,19 @@ class CreateKernelConfig: ...@@ -30,17 +33,19 @@ class CreateKernelConfig:
""" """
Name of the generated function - only important if generated code is written out Name of the generated function - only important if generated code is written out
""" """
# TODO Sane defaults: config should check that the datatype is a Numpy type data_type: Union[type, str, DefaultDict[str, BasicType], Dict[str, BasicType]] = np.float64
# TODO Sane defaults: QoL default_number_float and default_number_int should be data_type if they are not specified
data_type: Union[str, Dict[str, BasicType]] = 'float64'
""" """
Data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name to type Data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name to type.
If specified as a dict ideally a defaultdict is used to define a default value for symbols not listed in the
dict. If a plain dict is provided it will be transformed into a defaultdict internally. The default value
will then be specified via type collation then.
""" """
default_number_float: Union[str, np.dtype, BasicType] = 'float64' default_number_float: Union[type, str, BasicType] = None
""" """
Data type used for all untyped floating point numbers (i.e. 0.5) Data type used for all untyped floating point numbers (i.e. 0.5). By default the value of data_type is used.
If data_type is given as a defaultdict its default_factory is used.
""" """
default_number_int: Union[str, np.dtype, BasicType] = 'int64' default_number_int: Union[type, str, BasicType] = np.int64
""" """
Data type used for all untyped integer numbers (i.e. 1) Data type used for all untyped integer numbers (i.e. 1)
""" """
...@@ -73,11 +78,20 @@ class CreateKernelConfig: ...@@ -73,11 +78,20 @@ class CreateKernelConfig:
""" """
If OpenMP is active: whether multiple outer loops are permitted If OpenMP is active: whether multiple outer loops are permitted
""" """
base_pointer_specification: Union[List[Iterable[str]], List[Iterable[int]]] = None
"""
Specification of how many and which intermediate pointers are created for a field access.
For example [ (0), (2,3,)] creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
zero directly in the field access. These specifications are defined dependent on the loop ordering.
This function translates more readable version into the specification above.
For more information see: `pystencils.transformations.create_intermediate_base_pointer`
"""
gpu_indexing: str = 'block' gpu_indexing: str = 'block'
""" """
Either 'block' or 'line' , or custom indexing class, see `pystencils.gpucuda.AbstractIndexing` Either 'block' or 'line' , or custom indexing class, see `pystencils.gpu.AbstractIndexing`
""" """
gpu_indexing_params: MappingProxyType = field(default=MappingProxyType({})) gpu_indexing_params: MappingProxyType = field(default_factory=lambda: MappingProxyType({}))
""" """
Dict with indexing parameters (constructor parameters of indexing class) Dict with indexing parameters (constructor parameters of indexing class)
e.g. for 'block' one can specify '{'block_size': (20, 20, 10) }'. e.g. for 'block' one can specify '{'block_size': (20, 20, 10) }'.
...@@ -116,23 +130,43 @@ class CreateKernelConfig: ...@@ -116,23 +130,43 @@ class CreateKernelConfig:
allow_double_writes: bool = False allow_double_writes: bool = False
""" """
If True, don't check if every field is only written at a single location. This is required If True, don't check if every field is only written at a single location. This is required
for example for kernels that are compiled with loop step sizes > 1, that handle multiple for example for kernels that are compiled with loop step sizes > 1, that handle multiple
cells at once. Use with care! cells at once. Use with care!
""" """
skip_independence_check: bool = False skip_independence_check: bool = False
""" """
Don't check that loop iterations are independent. This is needed e.g. for By default the assignment list is checked for read/write independence. This means fields are only written at
periodicity kernel, that access the field outside the iteration bounds. Use with care! locations where they are read. Doing so guarantees thread safety. In some cases e.g. for
periodicity kernel, this can not be assured and does the check needs to be deactivated. Use with care!
""" """
class DataTypeFactory:
"""Because of pickle, we need to have a nested class, instead of a lambda in __post_init__"""
def __init__(self, dt):
self.dt = dt
def __call__(self):
return BasicType(self.dt)
def _check_type(self, dtype_to_check):
if isinstance(dtype_to_check, str) and (dtype_to_check == 'float' or dtype_to_check == 'int'):
self._typing_error()
if isinstance(dtype_to_check, type) and not hasattr(dtype_to_check, "dtype"):
# NumPy-types are also of type 'type'. However, they have more properties
self._typing_error()
@staticmethod
def _typing_error():
raise ValueError("It is not possible to use python types (float, int) for datatypes because these "
"types are ambiguous. For example float will map to double. "
"Also the string version like 'float' is not allowed, e.g. use 'float64' instead")
def __post_init__(self): def __post_init__(self):
# ---- Legacy parameters # ---- Legacy parameters
# TODO Sane defaults: Check for abmigous types like "float", python float, which are dangerous for users if not isinstance(self.target, Target):
if isinstance(self.target, str): raise ValueError("target must be provided by the 'Target' enum")
new_target = Target[self.target.upper()]
warnings.warn(f'Target "{self.target}" as str is deprecated. Use {new_target} instead',
category=DeprecationWarning)
self.target = new_target
# ---- Auto Backend # ---- Auto Backend
if not self.backend: if not self.backend:
if self.target == Target.CPU: if self.target == Target.CPU:
...@@ -142,10 +176,33 @@ class CreateKernelConfig: ...@@ -142,10 +176,33 @@ class CreateKernelConfig:
else: else:
raise NotImplementedError(f'Target {self.target} has no default backend') raise NotImplementedError(f'Target {self.target} has no default backend')
# Normalise data types if not isinstance(self.backend, Backend):
raise ValueError("backend must be provided by the 'Backend' enum")
# Normalise data types
for dtype in [self.data_type, self.default_number_float, self.default_number_int]:
self._check_type(dtype)
if not isinstance(self.data_type, dict): if not isinstance(self.data_type, dict):
dt = copy(self.data_type) # The copy is necessary because BasicType has sympy shinanigans dt = copy(self.data_type) # The copy is necessary because BasicType has sympy shinanigans
self.data_type = defaultdict(lambda: BasicType(dt)) self.data_type = defaultdict(self.DataTypeFactory(dt))
if isinstance(self.data_type, dict) and not isinstance(self.data_type, defaultdict):
for dtype in self.data_type.values():
self._check_type(dtype)
dt = collate_types([BasicType(dtype) for dtype in self.data_type.values()])
dtype_dict = self.data_type
self.data_type = defaultdict(self.DataTypeFactory(dt), dtype_dict)
assert isinstance(self.data_type, defaultdict), "At this point data_type must be a defaultdict!"
for dtype in self.data_type.values():
self._check_type(dtype)
self._check_type(self.data_type.default_factory())
if self.default_number_float is None:
self.default_number_float = self.data_type.default_factory()
if not isinstance(self.default_number_float, BasicType): if not isinstance(self.default_number_float, BasicType):
self.default_number_float = BasicType(self.default_number_float) self.default_number_float = BasicType(self.default_number_float)
if not isinstance(self.default_number_int, BasicType): if not isinstance(self.default_number_int, BasicType):
......
from pystencils.cpu.cpujit import make_python_function from pystencils.cpu.cpujit import make_python_function
from pystencils.cpu.kernelcreation import add_openmp, create_indexed_kernel, create_kernel from pystencils.cpu.kernelcreation import add_openmp, create_indexed_kernel, create_kernel, add_pragmas
__all__ = ['create_kernel', 'create_indexed_kernel', 'add_openmp', 'make_python_function'] __all__ = ['create_kernel', 'create_indexed_kernel', 'add_openmp', 'add_pragmas', 'make_python_function']