diff --git a/cbutil/postprocessing/__init__.py b/cbutil/postprocessing/__init__.py index 3af9ad9259bca9c1ad865e754795b9d46584c8c8..5aa1264f49329d6e92eb50bc8de62f82b8271011 100644 --- a/cbutil/postprocessing/__init__.py +++ b/cbutil/postprocessing/__init__.py @@ -1,3 +1,3 @@ from .plain_text import process_linewise -from .sqlite import sqlite_context, query_complete_table +from .sqlite import sqlite_context, query_complete_table, build_iterate_query, iterate_all_tables from .sqlite_helper import query_builder diff --git a/cbutil/postprocessing/sqlite.py b/cbutil/postprocessing/sqlite.py index 228c72a16ecbf0f7695686931435e1bca8aabf17..b4c3ccb887aac7aafd11ada7d609690a3797569e 100644 --- a/cbutil/postprocessing/sqlite.py +++ b/cbutil/postprocessing/sqlite.py @@ -48,3 +48,13 @@ def iterate_all_tables(path): def iterate_join(path, lhs, rhs, key): with sqlite_context(path) as connection: yield from tables2dict(query_join(connection, lhs, rhs, key)) + + +def iterate_query(connection, query): + for result in connection.execute(query).fetchall(): + yield result + + +def build_iterate_query(connection, *args, **kwargs): + query = sh.query_builder(*args, **kwargs) + yield from iterate_query(connection, query) diff --git a/cbutil/postprocessing/sqlite_helper.py b/cbutil/postprocessing/sqlite_helper.py index df9cbf2d557aad85fafdee2b24e6320924c11187..5dbee65506d28e57ebdecff70d38333637b31495 100644 --- a/cbutil/postprocessing/sqlite_helper.py +++ b/cbutil/postprocessing/sqlite_helper.py @@ -1,7 +1,7 @@ def strip_n_check(arg): if not isinstance(arg, str): - return str(arg) + return str(arg) stripped = arg.strip() if not stripped: raise ValueError("Empty arg") diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py index 01dd08ec68ea76cb744ecad39b75bf2212938bcf..da50b47b290ff76beaf038b7d0c2bd84b23f23dd 100644 --- a/tests/test_sqlite.py +++ b/tests/test_sqlite.py @@ -4,12 +4,14 @@ from cbutil.postprocessing.sqlite_helper import (select_stmt, where_stmt, join_stmt, table_name_query, - query_builder) + query_builder, + ) from cbutil.postprocessing.sqlite import (get_all_table_names, tables2dict, query_join, - iterate_all_tables) + iterate_all_tables, + build_iterate_query) from cbutil.postprocessing.sqlite import sqlite_context @@ -63,6 +65,8 @@ def test_table_name(): def test_builder(): assert query_builder(from_table="table") == "SELECT * FROM table" assert query_builder(select="row", from_table="table") == "SELECT row FROM table" + query = query_builder(select="row", from_table="table", where_args=["id", "0", "<"]) + assert query == "SELECT row FROM table WHERE id<0" def test_all_table_names(): @@ -80,3 +84,10 @@ def test_join_query(): def test_iterate_query(): dicts = list(tables2dict(iterate_all_tables("tests/cpu_benchmark.sqlite3"))) assert len(dicts) == 150 + + +def test_build_iterate(): + with sqlite_context("tests/benchmark.sqlite") as connection: + for res in build_iterate_query(connection, from_table="timingPool", + where_args=["runId", 1]): + pass