diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index ad339a4bdc2a2a5d696c964d602fbc6d01fcfaa9..dd75f42d52ee467166fe111f85e868bb33e934ba 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -1,3 +1,4 @@ +import re from collections import namedtuple from typing import Set @@ -24,6 +25,9 @@ except ImportError: __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] + +HEADER_REGEX = re.compile(r'^[<"].*[">]$') + KERNCRAFT_NO_TERNARY_MODE = False @@ -112,6 +116,9 @@ def get_headers(ast_node: Node) -> Set[str]: if isinstance(g, Node): headers.update(get_headers(g)) + for h in headers: + assert HEADER_REGEX.match(h), f'header /{h}/ does not follow the pattern /"..."/ or /<...>/' + return sorted(headers) diff --git a/pystencils_tests/test_helpful_errors.py b/pystencils_tests/test_helpful_errors.py new file mode 100644 index 0000000000000000000000000000000000000000..a13178afebd424c0f3e5e3355f8b0ecf97eab870 --- /dev/null +++ b/pystencils_tests/test_helpful_errors.py @@ -0,0 +1,37 @@ +""" + +""" + +import pytest + +from pystencils.astnodes import Block +from pystencils.backends.cbackend import CustomCodeNode, get_headers + + +def test_headers_have_quotes_or_brackets(): + class ErrorNode1(CustomCodeNode): + + def __init__(self): + super().__init__("", [], []) + self.headers = ["iostream"] + + class ErrorNode2(CustomCodeNode): + headers = ["<iostream>", "foo"] + + def __init__(self): + super().__init__("", [], []) + self.headers = ["<iostream>", "foo"] + + class OkNode3(CustomCodeNode): + + def __init__(self): + super().__init__("", [], []) + self.headers = ["<iostream>", '"foo"'] + + with pytest.raises(AssertionError, match='.* does not follow the pattern .*'): + get_headers(Block([ErrorNode1()])) + + with pytest.raises(AssertionError, match='.* does not follow the pattern .*'): + get_headers(ErrorNode2()) + + get_headers(OkNode3())