from pystencils.transformations import makeLoopOverDomain, typingFromSympyInspection, \
    typeAllEquations, moveConstantsBeforeLoop, getOptimalLoopOrdering
import pystencils.ast as ast
from pystencils.backends.cbackend import CBackend, CustomSympyPrinter
from pystencils import TypedSymbol


def createKerncraftCode(listOfEquations, typeForSymbol=None, ghostLayers=None):
    """
    Creates an abstract syntax tree for a kernel function, by taking a list of update rules.

    Loops are created according to the field accesses in the equations.

    :param listOfEquations: list of sympy equations, containing accesses to :class:`pystencils.field.Field`.
           Defining the update rules of the kernel
    :param typeForSymbol: a map from symbol name to a C type specifier. If not specified all symbols are assumed to
           be of type 'double' except symbols which occur on the left hand side of equations where the
           right hand side is a sympy Boolean which are assumed to be 'bool' .
    :param ghostLayers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
                        if None, the number of ghost layers is determined automatically and assumed to be equal for a
                        all dimensions

    :return: :class:`pystencils.ast.KernelFunction` node
    """
    if not typeForSymbol:
        typeForSymbol = typingFromSympyInspection(listOfEquations, "double")

    fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
    allFields = fieldsRead.union(fieldsWritten)

    optimalLoopOrder = getOptimalLoopOrdering(allFields)
    cstyleLoopOrder = list(range(len(optimalLoopOrder)))

    body = ast.Block(assignments)
    code = makeLoopOverDomain(body, "kerncraft", ghostLayers=ghostLayers, loopOrder=cstyleLoopOrder)
    moveConstantsBeforeLoop(code)
    loopBody = code.body

    printer = CBackend(sympyPrinter=ArraySympyPrinter())

    FIXED_SIZES = ("XS", "YS", "ZS", "E1S", "E2S")

    result = ""
    for field in allFields:
        sizesPermutation = [FIXED_SIZES[i] for i in field.layout]
        suffix = "".join("[%s]" % (size,) for size in sizesPermutation)
        result += "%s%s;\n" % (field.name, suffix)

    # add parameter definitions
    for s in loopBody.undefinedSymbols:
        if isinstance(s, TypedSymbol):
            result += "%s %s;\n" % (s.dtype, s.name)

    for element in loopBody.args:
        result += printer(element)
        result += "\n"
    return result


class ArraySympyPrinter(CustomSympyPrinter):

    def _print_Access(self, fieldAccess):
        """"""
        Loop = ast.LoopOverCoordinate
        coordinateValues = [Loop.getLoopCounterSymbol(i) + offset for i, offset in enumerate(fieldAccess.offsets)]
        coordinateValues += list(fieldAccess.index)
        permutedCoordinates = [coordinateValues[i] for i in fieldAccess.field.layout]

        suffix = "".join("[%s]" % (self._print(a)) for a in permutedCoordinates)
        return "%s%s" % (self._print(fieldAccess.field.name), suffix)