Commit dced5877 authored by Markus Holzer's avatar Markus Holzer
Browse files

Add type conversion for SP types

parent 24dde405
Pipeline #31124 failed with stage
in 16 minutes and 30 seconds
......@@ -12,7 +12,8 @@ from pystencils.gpucuda.indexing import indexing_creator_from_params
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.stencil import direction_string_to_offset, inverse_direction_string
from pystencils.transformations import (
loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel)
loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel,
replace_data_type_of_typed_symbols)
def create_kernel(assignments,
......@@ -88,6 +89,8 @@ def create_kernel(assignments,
split_groups = assignments.simplification_hints['split_groups']
assignments = assignments.all_assignments
assignments = replace_data_type_of_typed_symbols(assignments, data_type)
# ---- Creating ast
if target == 'cpu':
from pystencils.cpu import create_kernel
......
......@@ -1101,6 +1101,15 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, i
move_constants_before_loop(function_node.body)
cleanup_blocks(function_node.body)
def replace_data_type_of_typed_symbols(assignments, data_type):
"""changes the data types of the lhs of assignments which are already specified as TypedSymbol. This is needed
if the Assignments are already typed to double but the kernel is created for single precision"""
for i, assignment in enumerate(assignments):
if type(assignment.lhs) is TypedSymbol and assignment.lhs.dtype != data_type:
assignments[i] = Assignment(TypedSymbol(assignments[i].lhs.name, data_type), assignments[i].rhs)
return assignments
# --------------------------------------- Helper Functions -------------------------------------------------------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment