Skip to content
Snippets Groups Projects
test_sqlite.py 3.29 KiB
Newer Older
from datetime import datetime
import pytest
Christoph Alt's avatar
Christoph Alt committed

from cbutil.postprocessing.sqlite import (build_iterate_query,
                                          get_all_table_names,
                                          iterate_all_tables, query_join,
                                          sqlite_context, tables2dict)
from cbutil.postprocessing.sqlite_helper import (from_stmt, join_stmt,
                                                 query_builder, select_stmt,
                                                 table_name_query, where_stmt)
from cbutil.util import time_conversion
Christoph Alt's avatar
Christoph Alt committed


def test_select():
    assert select_stmt().strip() == "SELECT *"
    assert select_stmt("test").strip() == "SELECT test"
    with pytest.raises(ValueError):
        select_stmt("")
    with pytest.raises(ValueError):
        select_stmt("     ")
Christoph Alt's avatar
Christoph Alt committed


def test_from():
    assert from_stmt("table").strip() == "FROM table"
    with pytest.raises(ValueError):
        from_stmt("")
    with pytest.raises(ValueError):
        from_stmt("    ")
Christoph Alt's avatar
Christoph Alt committed


def test_where():
    lhs = "lhs"
    rhs = "rhs"
    assert where_stmt(lhs, rhs).strip() == f"WHERE {lhs}={rhs}"
    with pytest.raises(ValueError):
        where_stmt("", lhs)
    with pytest.raises(ValueError):
        where_stmt("    ", lhs)
    with pytest.raises(ValueError):
        where_stmt(rhs, "")
    with pytest.raises(ValueError):
        where_stmt(rhs, "    ")
    with pytest.raises(ValueError):
        where_stmt(rhs, lhs, "")
    with pytest.raises(ValueError):
        where_stmt(rhs, lhs, "    ")
Christoph Alt's avatar
Christoph Alt committed


def test_join():
    lhs = "lhs"
    rhs = "rhs"
    key = "key"
    assert join_stmt(lhs, rhs, key).strip() == f"{lhs} inner join {rhs} on {lhs}.{key} = {rhs}.{key}"


def test_table_name():
    assert table_name_query() == "SELECT name FROM sqlite_master WHERE type='table'"


def test_builder():
    assert query_builder(from_table="table") == "SELECT * FROM table"
    assert query_builder(select="row", from_table="table") == "SELECT row FROM table"
    query = query_builder(select="row", from_table="table", where_args=["id", "0", "<"])
    assert query == "SELECT row FROM table WHERE id<0"
Christoph Alt's avatar
Christoph Alt committed
def test_all_table_names():
    with sqlite_context("tests/benchmark.sqlite") as connection:
        names = get_all_table_names(connection)
        assert sorted(["runs", "timingPool"]) == sorted(names)


def test_join_query():
    with sqlite_context("tests/benchmark.sqlite") as connection:
        dicts = list(tables2dict(query_join(connection, "runs", "timingPool", "runId")))
        assert len(dicts) == 9


def test_iterate_query():
    dicts = list(tables2dict(iterate_all_tables("tests/cpu_benchmark.sqlite3")))
    assert len(dicts) == 150


def test_build_iterate():
    with sqlite_context("tests/benchmark.sqlite") as connection:
        for res in build_iterate_query(connection, from_table="timingPool",
                                       where_args=["runId", 1]):
            pass


def test_time_stamp():
    pattern = "%Y-%m-%d %H:%M:%S"
    with sqlite_context("tests/benchmark.sqlite") as connection:
        dicts = list(tables2dict(query_join(connection, "runs", "timingPool", "runId")))
        for run in dicts:
            ts = time_conversion(run["timestamp"], pattern=pattern)
            date = datetime.fromtimestamp(ts).strftime(pattern)
            assert run["timestamp"] == date