Skip to content
Snippets Groups Projects
Commit 5c849686 authored by Christoph Alt's avatar Christoph Alt
Browse files

moved sqlite helper to extra file

parent 332bbae9
No related merge requests found
......@@ -2,31 +2,11 @@ import sqlite3
import logging
from contextlib import contextmanager
import os
import cbutil.postprocessing.sqlite_helper as sh
logger = logging.getLogger(__file__)
def select_stmt(arg="*"):
return f"SELECT {arg}"
def from_stmt(table_name):
return f"FROM {table_name}"
def where_stmt(lhs, rhs, op="="):
return f"WHERE {lhs}{op}{rhs}"
def join_stmt(lhs, rhs, key):
return f"{lhs} inner join {rhs} on {lhs}.{key} = {rhs}.{key}"
def table_name_query():
where = where_stmt('type', "'table'")
return f"{select_stmt('name')} {from_stmt('sqlite_master')} {where}"
@contextmanager
def sqlite_context(path):
if not os.path.exists(path):
......@@ -43,15 +23,15 @@ def sqlite_context(path):
def get_all_table_names(connection):
return [t[0] for t in connection.execute(table_name_query()).fetchall()]
return [t[0] for t in connection.execute(sh.table_name_query()).fetchall()]
def query_complete_table(connection, table_name: str):
return connection.execute(f"{select_stmt()} {from_stmt(table_name)}")
return connection.execute(f"{sh.select_stmt()} {sh.from_stmt(table_name)}")
def query_join(connection, lhs, rhs, key):
return connection.execute(f"SELECT * FROM {lhs} inner join {rhs} on {lhs}.{key} = {rhs}.{key}")
return connection.execute(f"{sh.select_stmt()} {sh.from_stmt(lhs)} {sh.join_stmt(lhs, rhs, key)}")
def tables2dict(table_iterator):
......
def strip_n_check(arg):
if not isinstance(arg, str):
return str(arg)
stripped = arg.strip()
if not stripped:
raise ValueError("Empty arg")
return stripped
def select_stmt(arg: str = "*") -> str:
return f"SELECT {strip_n_check(arg)}"
def from_stmt(table_name: str) -> str:
return f"FROM {strip_n_check(table_name)}"
def where_stmt(lhs: str, rhs: str, op="=") -> str:
return f"WHERE {strip_n_check(lhs)}{strip_n_check(op)}{strip_n_check(rhs)}"
def join_stmt(lhs: str, rhs: str, key: str) -> str:
lhs = strip_n_check(lhs)
rhs = strip_n_check(rhs)
key = strip_n_check(key)
return f"{lhs} inner join {rhs} on {lhs}.{key} = {rhs}.{key}"
def table_name_query() -> str:
return query_builder(select="name",
from_table="sqlite_master",
where_args=['type', "'table'"])
def query_builder(*, select: str = "*", from_table: str, where_args=[]) -> str:
ret = f"{select_stmt(select)} {from_stmt(from_table)}"
if where_args:
ret += f" {where_stmt(*where_args)}"
return ret
from cbutil.postprocessing.sqlite import (select_stmt, from_stmt, where_stmt, join_stmt, table_name_query)
from cbutil.postprocessing.sqlite import get_all_table_names, tables2dict, query_join, iterate_all_tables
import pytest
from cbutil.postprocessing.sqlite_helper import (select_stmt,
from_stmt,
where_stmt,
join_stmt,
table_name_query,
query_builder)
from cbutil.postprocessing.sqlite import (get_all_table_names,
tables2dict,
query_join,
iterate_all_tables)
from cbutil.postprocessing.sqlite import sqlite_context
def test_select():
assert select_stmt().strip() == "SELECT *"
assert select_stmt("test").strip() == "SELECT test"
with pytest.raises(ValueError):
select_stmt("")
with pytest.raises(ValueError):
select_stmt(" ")
def test_from():
assert from_stmt("table").strip() == "FROM table"
with pytest.raises(ValueError):
from_stmt("")
with pytest.raises(ValueError):
from_stmt(" ")
def test_where():
lhs = "lhs"
rhs = "rhs"
assert where_stmt(lhs, rhs).strip() == f"WHERE {lhs}={rhs}"
with pytest.raises(ValueError):
where_stmt("", lhs)
with pytest.raises(ValueError):
where_stmt(" ", lhs)
with pytest.raises(ValueError):
where_stmt(rhs, "")
with pytest.raises(ValueError):
where_stmt(rhs, " ")
with pytest.raises(ValueError):
where_stmt(rhs, lhs, "")
with pytest.raises(ValueError):
where_stmt(rhs, lhs, " ")
def test_join():
......@@ -30,6 +60,11 @@ def test_table_name():
assert table_name_query() == "SELECT name FROM sqlite_master WHERE type='table'"
def test_builder():
assert query_builder(from_table="table") == "SELECT * FROM table"
assert query_builder(select="row", from_table="table") == "SELECT row FROM table"
def test_all_table_names():
with sqlite_context("tests/benchmark.sqlite") as connection:
names = get_all_table_names(connection)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment