Skip to content
Snippets Groups Projects
conditional.py 2.61 KiB
Newer Older
Frederik Hennig's avatar
Frederik Hennig committed
from __future__ import annotations
Frederik Hennig's avatar
Frederik Hennig committed
from typing import TYPE_CHECKING, Optional, cast
Frederik Hennig's avatar
Frederik Hennig committed

from pystencils.typing import TypedSymbol, BasicType

Frederik Hennig's avatar
Frederik Hennig committed
from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf
from ..source_concepts.source_objects import TypedSymbolOrObject

Frederik Hennig's avatar
Frederik Hennig committed
if TYPE_CHECKING:
    from ..context import SfgContext


class SfgCondition(SfgCallTreeLeaf):
    pass

Frederik Hennig's avatar
Frederik Hennig committed
class SfgCustomCondition(SfgCondition):
    def __init__(self, cond_text: str):
        super().__init__()
Frederik Hennig's avatar
Frederik Hennig committed
        self._cond_text = cond_text

    @property
Frederik Hennig's avatar
Frederik Hennig committed
    def required_parameters(self) -> set[TypedSymbolOrObject]:
Frederik Hennig's avatar
Frederik Hennig committed
        return set()

    def get_code(self, ctx: SfgContext) -> str:
        return self._cond_text
Frederik Hennig's avatar
Frederik Hennig committed

class IntEven(SfgCondition):
    def __init__(self, symbol: TypedSymbol):
        super().__init__()
        if not isinstance(symbol.dtype, BasicType) or not symbol.dtype.is_int():
            raise ValueError(f"Symbol {symbol} does not have integer type.")

        self._symbol = symbol

    @property
    def required_parameters(self) -> set[TypedSymbolOrObject]:
        return {self._symbol}

    def get_code(self, ctx: SfgContext) -> str:
        return f"(({self._symbol.name} & 1) ^ 1)"


class IntOdd(SfgCondition):
    def __init__(self, symbol: TypedSymbol):
        super().__init__()
        if not isinstance(symbol.dtype, BasicType) or not symbol.dtype.is_int():
            raise ValueError(f"Symbol {symbol} does not have integer type.")

        self._symbol = symbol

    @property
    def required_parameters(self) -> set[TypedSymbolOrObject]:
        return {self._symbol}

    def get_code(self, ctx: SfgContext) -> str:
        return f"({self._symbol.name} & 1)"
Frederik Hennig's avatar
Frederik Hennig committed


class SfgBranch(SfgCallTreeNode):
Frederik Hennig's avatar
Frederik Hennig committed
    def __init__(self,
                 cond: SfgCondition,
                 branch_true: SfgCallTreeNode,
                 branch_false: Optional[SfgCallTreeNode] = None):
        super().__init__(cond, branch_true, *((branch_false,) if branch_false else ()))

    @property
    def condition(self) -> SfgCondition:
Frederik Hennig's avatar
Frederik Hennig committed
        return cast(SfgCondition, self._children[0])

    @property
    def branch_true(self) -> SfgCallTreeNode:
        return self._children[1]
Frederik Hennig's avatar
Frederik Hennig committed
    @property
    def branch_false(self) -> SfgCallTreeNode:
        return self._children[2]
Frederik Hennig's avatar
Frederik Hennig committed
    def get_code(self, ctx: SfgContext) -> str:
        code = f"if({self.condition.get_code(ctx)}) {{\n"
        code += ctx.codestyle.indent(self.branch_true.get_code(ctx))
Frederik Hennig's avatar
Frederik Hennig committed
        code += "\n}"
        if self.branch_false is not None:
Frederik Hennig's avatar
Frederik Hennig committed
            code += "else {\n"
            code += ctx.codestyle.indent(self.branch_false.get_code(ctx))
Frederik Hennig's avatar
Frederik Hennig committed
            code += "\n}"

        return code