Commit 01ab38e8 authored by Martin Bauer's avatar Martin Bauer
Browse files

Removed python3.6 only constructs for python3.5 compatibility

parent cc52a9f3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from sympy.codegen.ast import Assignment import sympy as sp
from sympy.printing.latex import LatexPrinter from sympy.printing.latex import LatexPrinter
try:
from sympy.codegen.ast import Assignment
except ImportError:
Assignment = None
__all__ = ['Assignment'] __all__ = ['Assignment']
...@@ -9,12 +14,37 @@ def print_assignment_latex(printer, expr): ...@@ -9,12 +14,37 @@ def print_assignment_latex(printer, expr):
"""sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer""" """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer"""
printed_lhs = printer.doprint(expr.lhs) printed_lhs = printer.doprint(expr.lhs)
printed_rhs = printer.doprint(expr.rhs) printed_rhs = printer.doprint(expr.rhs)
return f"{printed_lhs} \leftarrow {printed_rhs}" return "{printed_lhs} \leftarrow {printed_rhs}".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs)
def assignment_str(assignment): def assignment_str(assignment):
return f"{assignment.lhs}{assignment.rhs}" return "{lhs} ← {rhs}".format(lhs=assignment.lhs, rhs=assignment.rhs)
if Assignment:
Assignment.__str__ = assignment_str
LatexPrinter._print_Assignment = print_assignment_latex
else:
# back port for older sympy versions that don't have Assignment yet
class Assignment(sp.Rel):
rel_op = ':='
__slots__ = []
def __new__(cls, lhs, rhs=0, **assumptions):
from sympy.matrices.expressions.matexpr import (
MatrixElement, MatrixSymbol)
from sympy.tensor.indexed import Indexed
lhs = sp.sympify(lhs)
rhs = sp.sympify(rhs)
# Tuple of things that can be on the lhs of an assignment
assignable = (sp.Symbol, MatrixSymbol, MatrixElement, Indexed)
if not isinstance(lhs, assignable):
raise TypeError("Cannot assign to lhs of type %s." % type(lhs))
return sp.Rel.__new__(cls, lhs, rhs, **assumptions)
Assignment.__str__ = assignment_str __str__ = assignment_str
LatexPrinter._print_Assignment = print_assignment_latex _print_Assignment = print_assignment_latex
...@@ -494,7 +494,7 @@ class SympyAssignment(Node): ...@@ -494,7 +494,7 @@ class SympyAssignment(Node):
def _repr_html_(self): def _repr_html_(self):
printed_lhs = sp.latex(self.lhs) printed_lhs = sp.latex(self.lhs)
printed_rhs = sp.latex(self.rhs) printed_rhs = sp.latex(self.rhs)
return f"${printed_lhs} \leftarrow {printed_rhs}$" return "${printed_lhs} \leftarrow {printed_rhs}$".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs)
class ResolvedFieldAccess(sp.Indexed): class ResolvedFieldAccess(sp.Indexed):
......
...@@ -51,13 +51,13 @@ class FlagInterface: ...@@ -51,13 +51,13 @@ class FlagInterface:
self._used_flags.add(flag) self._used_flags.add(flag)
assert self._is_power_of_2(flag) assert self._is_power_of_2(flag)
return flag return flag
raise ValueError(f"All available {self.max_bits} flags are reserved") raise ValueError("All available {} flags are reserved".format(self.max_bits))
def reserve_flag(self, flag): def reserve_flag(self, flag):
assert self._is_power_of_2(flag) assert self._is_power_of_2(flag)
flag = self.dtype(flag) flag = self.dtype(flag)
if flag in self._used_flags: if flag in self._used_flags:
raise ValueError(f"The flag {flag} is already reserved") raise ValueError("The flag {flag} is already reserved".format(flag=flag))
self._used_flags.add(flag) self._used_flags.add(flag)
return flag return flag
......
...@@ -43,7 +43,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width=4): ...@@ -43,7 +43,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width=4):
successful = False successful = False
break break
typed_symbol = base.label typed_symbol = base.label
assert type(typed_symbol.dtype) is PointerType, f"Type of access is {typed_symbol.dtype}, {indexed}" assert type(typed_symbol.dtype) is PointerType, \
"Type of access is {}, {}".format(typed_symbol.dtype, indexed)
substitutions[indexed] = cast_func(indexed, VectorType(typed_symbol.dtype.base_type, vector_width)) substitutions[indexed] = cast_func(indexed, VectorType(typed_symbol.dtype.base_type, vector_width))
if not successful: if not successful:
warnings.warn("Could not vectorize loop because of non-consecutive memory access") warnings.warn("Could not vectorize loop because of non-consecutive memory access")
......
...@@ -27,7 +27,7 @@ def highlight_cpp(code: str): ...@@ -27,7 +27,7 @@ def highlight_cpp(code: str):
from pygments.lexers import CppLexer from pygments.lexers import CppLexer
css = HtmlFormatter().get_style_defs('.highlight') css = HtmlFormatter().get_style_defs('.highlight')
css_tag = f"<style>{css}</style>" css_tag = "<style>{css}</style>".format(css=css)
display(HTML(css_tag)) display(HTML(css_tag))
return HTML(highlight(code, CppLexer(), HtmlFormatter())) return HTML(highlight(code, CppLexer(), HtmlFormatter()))
......
...@@ -269,7 +269,7 @@ class Field: ...@@ -269,7 +269,7 @@ class Field:
self._layout = normalize_layout(layout) self._layout = normalize_layout(layout)
self.shape = shape self.shape = shape
self.strides = strides self.strides = strides
self.latex_name: Optional[str] = None self.latex_name = None # type: Optional[str]
def new_field_with_different_name(self, new_name): def new_field_with_different_name(self, new_name):
return Field(new_name, self.field_type, self._dtype, self._layout, self.shape, self.strides) return Field(new_name, self.field_type, self._dtype, self._layout, self.shape, self.strides)
......
...@@ -97,8 +97,8 @@ class BlockIndexing(AbstractIndexing): ...@@ -97,8 +97,8 @@ class BlockIndexing(AbstractIndexing):
_get_end_from_slice(self._iterationSlice, arr_shape))] _get_end_from_slice(self._iterationSlice, arr_shape))]
widths = sp.Matrix(widths).subs(substitution_dict) widths = sp.Matrix(widths).subs(substitution_dict)
grid: Tuple[int, ...] = tuple(sp.ceiling(length / block_size) grid = tuple(sp.ceiling(length / block_size)
for length, block_size in zip(widths, self._blockSize)) for length, block_size in zip(widths, self._blockSize)) # type: : Tuple[int, ...]
extend_bs = (1,) * (3 - len(self._blockSize)) extend_bs = (1,) * (3 - len(self._blockSize))
extend_gr = (1,) * (3 - len(grid)) extend_gr = (1,) * (3 - len(grid))
......
...@@ -36,11 +36,13 @@ def isl_iteration_set(node: ast.Node): ...@@ -36,11 +36,13 @@ def isl_iteration_set(node: ast.Node):
loop_start_str = remove_brackets(str(loop.start)) loop_start_str = remove_brackets(str(loop.start))
loop_stop_str = remove_brackets(str(loop.stop)) loop_stop_str = remove_brackets(str(loop.stop))
ctr_name = loop.loop_counter_name ctr_name = loop.loop_counter_name
conditions.append(remove_brackets(f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}")) set_string_description = "{} >= {} and {} < {}".format(ctr_name, loop_start_str, ctr_name, loop_stop_str)
conditions.append(remove_brackets(set_string_description))
symbol_names = ','.join(degrees_of_freedom) symbol_names = ','.join(degrees_of_freedom)
condition_str = ' and '.join(conditions) condition_str = ' and '.join(conditions)
set_description = f"{{ [{symbol_names}] : {condition_str} }}" set_description = "{{ [{symbol_names}] : {condition_str} }}".format(symbol_names=symbol_names,
condition_str=condition_str)
return degrees_of_freedom, isl.BasicSet(set_description) return degrees_of_freedom, isl.BasicSet(set_description)
...@@ -51,7 +53,8 @@ def simplify_loop_counter_dependent_conditional(conditional): ...@@ -51,7 +53,8 @@ def simplify_loop_counter_dependent_conditional(conditional):
if dofs_in_condition.issubset(dofs_in_loops): if dofs_in_condition.issubset(dofs_in_loops):
symbol_names = ','.join(dofs_in_loops) symbol_names = ','.join(dofs_in_loops)
condition_str = remove_brackets(str(conditional.condition_expr)) condition_str = remove_brackets(str(conditional.condition_expr))
condition_set = isl.BasicSet(f"{{ [{symbol_names}] : {condition_str} }}") condition_set = isl.BasicSet("{{ [{symbol_names}] : {condition_str} }}".format(symbol_names=symbol_names,
condition_str=condition_str))
if condition_set.is_empty(): if condition_set.is_empty():
conditional.replace_by_false_block() conditional.replace_by_false_block()
......
...@@ -326,10 +326,10 @@ class AssignmentCollection: ...@@ -326,10 +326,10 @@ class AssignmentCollection:
def __str__(self): def __str__(self):
result = "Subexpressions:\n" result = "Subexpressions:\n"
for eq in self.subexpressions: for eq in self.subexpressions:
result += f"\t{eq}\n" result += "\t{eq}\n".format(eq=eq)
result += "Main Assignments:\n" result += "Main Assignments:\n"
for eq in self.main_assignments: for eq in self.main_assignments:
result += f"\t{eq}\n" result += "\t{eq}\n".format(eq=eq)
return result return result
...@@ -343,6 +343,6 @@ class SymbolGen: ...@@ -343,6 +343,6 @@ class SymbolGen:
return self return self
def __next__(self): def __next__(self):
name = f"{self._symbol}_{self._ctr}" name = "{}_{}".format(self._symbol, self._ctr)
self._ctr += 1 self._ctr += 1
return sp.Symbol(name) return sp.Symbol(name)
...@@ -713,12 +713,12 @@ class KernelConstraintsCheck: ...@@ -713,12 +713,12 @@ class KernelConstraintsCheck:
fai = self.FieldAndIndex(lhs.field, lhs.index) fai = self.FieldAndIndex(lhs.field, lhs.index)
self._field_writes[fai].add(lhs.offsets) self._field_writes[fai].add(lhs.offsets)
if len(self._field_writes[fai]) > 1: if len(self._field_writes[fai]) > 1:
raise ValueError(f"Field {lhs.field.name} is written at two different locations") raise ValueError("Field {} is written at two different locations".format(lhs.field.name))
elif isinstance(lhs, sp.Symbol): elif isinstance(lhs, sp.Symbol):
if lhs in self._defined_pure_symbols: if lhs in self._defined_pure_symbols:
raise ValueError(f"Assignments not in SSA form, multiple assignments to {lhs.name}") raise ValueError("Assignments not in SSA form, multiple assignments to {}".format(lhs.name))
if lhs in self._accessed_pure_symbols: if lhs in self._accessed_pure_symbols:
raise ValueError(f"Symbol {lhs.name} is written, after it has been read") raise ValueError("Symbol {} is written, after it has been read".format(lhs.name))
self._defined_pure_symbols.add(lhs) self._defined_pure_symbols.add(lhs)
def _update_accesses_rhs(self, rhs): def _update_accesses_rhs(self, rhs):
...@@ -727,8 +727,8 @@ class KernelConstraintsCheck: ...@@ -727,8 +727,8 @@ class KernelConstraintsCheck:
for write_offset in writes: for write_offset in writes:
assert len(writes) == 1 assert len(writes) == 1
if write_offset != rhs.offsets: if write_offset != rhs.offsets:
raise ValueError(f"Violation of loop independence condition. " raise ValueError("Violation of loop independence condition. Field "
f"Field {rhs.field} is read at {rhs.offsets} and written at {write_offset}") "{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset))
self.fields_read.add(rhs.field) self.fields_read.add(rhs.field)
elif isinstance(rhs, sp.Symbol): elif isinstance(rhs, sp.Symbol):
self._accessed_pure_symbols.add(rhs) self._accessed_pure_symbols.add(rhs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment