From 9f0b36ea830e9c4e5afc61f3203b2afea1f20262 Mon Sep 17 00:00:00 2001 From: Christoph Alt <christoph.alt@fau.de> Date: Mon, 1 Aug 2022 15:03:34 +0200 Subject: [PATCH] added function to build and iterate query --- cbutil/postprocessing/__init__.py | 2 +- cbutil/postprocessing/sqlite.py | 10 ++++++++++ cbutil/postprocessing/sqlite_helper.py | 2 +- tests/test_sqlite.py | 15 +++++++++++++-- 4 files changed, 25 insertions(+), 4 deletions(-) diff --git a/cbutil/postprocessing/__init__.py b/cbutil/postprocessing/__init__.py index 3af9ad9..5aa1264 100644 --- a/cbutil/postprocessing/__init__.py +++ b/cbutil/postprocessing/__init__.py @@ -1,3 +1,3 @@ from .plain_text import process_linewise -from .sqlite import sqlite_context, query_complete_table +from .sqlite import sqlite_context, query_complete_table, build_iterate_query, iterate_all_tables from .sqlite_helper import query_builder diff --git a/cbutil/postprocessing/sqlite.py b/cbutil/postprocessing/sqlite.py index 228c72a..b4c3ccb 100644 --- a/cbutil/postprocessing/sqlite.py +++ b/cbutil/postprocessing/sqlite.py @@ -48,3 +48,13 @@ def iterate_all_tables(path): def iterate_join(path, lhs, rhs, key): with sqlite_context(path) as connection: yield from tables2dict(query_join(connection, lhs, rhs, key)) + + +def iterate_query(connection, query): + for result in connection.execute(query).fetchall(): + yield result + + +def build_iterate_query(connection, *args, **kwargs): + query = sh.query_builder(*args, **kwargs) + yield from iterate_query(connection, query) diff --git a/cbutil/postprocessing/sqlite_helper.py b/cbutil/postprocessing/sqlite_helper.py index df9cbf2..5dbee65 100644 --- a/cbutil/postprocessing/sqlite_helper.py +++ b/cbutil/postprocessing/sqlite_helper.py @@ -1,7 +1,7 @@ def strip_n_check(arg): if not isinstance(arg, str): - return str(arg) + return str(arg) stripped = arg.strip() if not stripped: raise ValueError("Empty arg") diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py index 01dd08e..da50b47 100644 --- a/tests/test_sqlite.py +++ b/tests/test_sqlite.py @@ -4,12 +4,14 @@ from cbutil.postprocessing.sqlite_helper import (select_stmt, where_stmt, join_stmt, table_name_query, - query_builder) + query_builder, + ) from cbutil.postprocessing.sqlite import (get_all_table_names, tables2dict, query_join, - iterate_all_tables) + iterate_all_tables, + build_iterate_query) from cbutil.postprocessing.sqlite import sqlite_context @@ -63,6 +65,8 @@ def test_table_name(): def test_builder(): assert query_builder(from_table="table") == "SELECT * FROM table" assert query_builder(select="row", from_table="table") == "SELECT row FROM table" + query = query_builder(select="row", from_table="table", where_args=["id", "0", "<"]) + assert query == "SELECT row FROM table WHERE id<0" def test_all_table_names(): @@ -80,3 +84,10 @@ def test_join_query(): def test_iterate_query(): dicts = list(tables2dict(iterate_all_tables("tests/cpu_benchmark.sqlite3"))) assert len(dicts) == 150 + + +def test_build_iterate(): + with sqlite_context("tests/benchmark.sqlite") as connection: + for res in build_iterate_query(connection, from_table="timingPool", + where_args=["runId", 1]): + pass -- GitLab