From b962a099b28fcd106a9b7dc1ee87665e2f03b079 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Fri, 24 Jan 2020 10:29:45 +0100 Subject: [PATCH] Add assertion that headers follow the pattern /"..."/ or /<...>/ --- pystencils/backends/cbackend.py | 7 +++++ pystencils_tests/test_helpful_errors.py | 37 +++++++++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 pystencils_tests/test_helpful_errors.py diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index ad339a4bd..dd75f42d5 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -1,3 +1,4 @@ +import re from collections import namedtuple from typing import Set @@ -24,6 +25,9 @@ except ImportError: __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] + +HEADER_REGEX = re.compile(r'^[<"].*[">]$') + KERNCRAFT_NO_TERNARY_MODE = False @@ -112,6 +116,9 @@ def get_headers(ast_node: Node) -> Set[str]: if isinstance(g, Node): headers.update(get_headers(g)) + for h in headers: + assert HEADER_REGEX.match(h), f'header /{h}/ does not follow the pattern /"..."/ or /<...>/' + return sorted(headers) diff --git a/pystencils_tests/test_helpful_errors.py b/pystencils_tests/test_helpful_errors.py new file mode 100644 index 000000000..a13178afe --- /dev/null +++ b/pystencils_tests/test_helpful_errors.py @@ -0,0 +1,37 @@ +""" + +""" + +import pytest + +from pystencils.astnodes import Block +from pystencils.backends.cbackend import CustomCodeNode, get_headers + + +def test_headers_have_quotes_or_brackets(): + class ErrorNode1(CustomCodeNode): + + def __init__(self): + super().__init__("", [], []) + self.headers = ["iostream"] + + class ErrorNode2(CustomCodeNode): + headers = ["<iostream>", "foo"] + + def __init__(self): + super().__init__("", [], []) + self.headers = ["<iostream>", "foo"] + + class OkNode3(CustomCodeNode): + + def __init__(self): + super().__init__("", [], []) + self.headers = ["<iostream>", '"foo"'] + + with pytest.raises(AssertionError, match='.* does not follow the pattern .*'): + get_headers(Block([ErrorNode1()])) + + with pytest.raises(AssertionError, match='.* does not follow the pattern .*'): + get_headers(ErrorNode2()) + + get_headers(OkNode3()) -- GitLab