diff --git a/cbutil/postprocessing/sqlite.py b/cbutil/postprocessing/sqlite.py index 54522da32c9b4f9449c71ccdf2a50f196bf7657c..228c72a16ecbf0f7695686931435e1bca8aabf17 100644 --- a/cbutil/postprocessing/sqlite.py +++ b/cbutil/postprocessing/sqlite.py @@ -2,31 +2,11 @@ import sqlite3 import logging from contextlib import contextmanager import os +import cbutil.postprocessing.sqlite_helper as sh logger = logging.getLogger(__file__) -def select_stmt(arg="*"): - return f"SELECT {arg}" - - -def from_stmt(table_name): - return f"FROM {table_name}" - - -def where_stmt(lhs, rhs, op="="): - return f"WHERE {lhs}{op}{rhs}" - - -def join_stmt(lhs, rhs, key): - return f"{lhs} inner join {rhs} on {lhs}.{key} = {rhs}.{key}" - - -def table_name_query(): - where = where_stmt('type', "'table'") - return f"{select_stmt('name')} {from_stmt('sqlite_master')} {where}" - - @contextmanager def sqlite_context(path): if not os.path.exists(path): @@ -43,15 +23,15 @@ def sqlite_context(path): def get_all_table_names(connection): - return [t[0] for t in connection.execute(table_name_query()).fetchall()] + return [t[0] for t in connection.execute(sh.table_name_query()).fetchall()] def query_complete_table(connection, table_name: str): - return connection.execute(f"{select_stmt()} {from_stmt(table_name)}") + return connection.execute(f"{sh.select_stmt()} {sh.from_stmt(table_name)}") def query_join(connection, lhs, rhs, key): - return connection.execute(f"SELECT * FROM {lhs} inner join {rhs} on {lhs}.{key} = {rhs}.{key}") + return connection.execute(f"{sh.select_stmt()} {sh.from_stmt(lhs)} {sh.join_stmt(lhs, rhs, key)}") def tables2dict(table_iterator): diff --git a/cbutil/postprocessing/sqlite_helper.py b/cbutil/postprocessing/sqlite_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..df9cbf2d557aad85fafdee2b24e6320924c11187 --- /dev/null +++ b/cbutil/postprocessing/sqlite_helper.py @@ -0,0 +1,40 @@ + +def strip_n_check(arg): + if not isinstance(arg, str): + return str(arg) + stripped = arg.strip() + if not stripped: + raise ValueError("Empty arg") + return stripped + + +def select_stmt(arg: str = "*") -> str: + return f"SELECT {strip_n_check(arg)}" + + +def from_stmt(table_name: str) -> str: + return f"FROM {strip_n_check(table_name)}" + + +def where_stmt(lhs: str, rhs: str, op="=") -> str: + return f"WHERE {strip_n_check(lhs)}{strip_n_check(op)}{strip_n_check(rhs)}" + + +def join_stmt(lhs: str, rhs: str, key: str) -> str: + lhs = strip_n_check(lhs) + rhs = strip_n_check(rhs) + key = strip_n_check(key) + return f"{lhs} inner join {rhs} on {lhs}.{key} = {rhs}.{key}" + + +def table_name_query() -> str: + return query_builder(select="name", + from_table="sqlite_master", + where_args=['type', "'table'"]) + + +def query_builder(*, select: str = "*", from_table: str, where_args=[]) -> str: + ret = f"{select_stmt(select)} {from_stmt(from_table)}" + if where_args: + ret += f" {where_stmt(*where_args)}" + return ret diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py index 135d8e075635b4ed656d4b38c5f65dbab888bc58..01dd08ec68ea76cb744ecad39b75bf2212938bcf 100644 --- a/tests/test_sqlite.py +++ b/tests/test_sqlite.py @@ -1,22 +1,52 @@ -from cbutil.postprocessing.sqlite import (select_stmt, from_stmt, where_stmt, join_stmt, table_name_query) - -from cbutil.postprocessing.sqlite import get_all_table_names, tables2dict, query_join, iterate_all_tables +import pytest +from cbutil.postprocessing.sqlite_helper import (select_stmt, + from_stmt, + where_stmt, + join_stmt, + table_name_query, + query_builder) + +from cbutil.postprocessing.sqlite import (get_all_table_names, + tables2dict, + query_join, + iterate_all_tables) from cbutil.postprocessing.sqlite import sqlite_context def test_select(): assert select_stmt().strip() == "SELECT *" + assert select_stmt("test").strip() == "SELECT test" + with pytest.raises(ValueError): + select_stmt("") + with pytest.raises(ValueError): + select_stmt(" ") def test_from(): assert from_stmt("table").strip() == "FROM table" + with pytest.raises(ValueError): + from_stmt("") + with pytest.raises(ValueError): + from_stmt(" ") def test_where(): lhs = "lhs" rhs = "rhs" assert where_stmt(lhs, rhs).strip() == f"WHERE {lhs}={rhs}" + with pytest.raises(ValueError): + where_stmt("", lhs) + with pytest.raises(ValueError): + where_stmt(" ", lhs) + with pytest.raises(ValueError): + where_stmt(rhs, "") + with pytest.raises(ValueError): + where_stmt(rhs, " ") + with pytest.raises(ValueError): + where_stmt(rhs, lhs, "") + with pytest.raises(ValueError): + where_stmt(rhs, lhs, " ") def test_join(): @@ -30,6 +60,11 @@ def test_table_name(): assert table_name_query() == "SELECT name FROM sqlite_master WHERE type='table'" +def test_builder(): + assert query_builder(from_table="table") == "SELECT * FROM table" + assert query_builder(select="row", from_table="table") == "SELECT row FROM table" + + def test_all_table_names(): with sqlite_context("tests/benchmark.sqlite") as connection: names = get_all_table_names(connection)