From 9f0b36ea830e9c4e5afc61f3203b2afea1f20262 Mon Sep 17 00:00:00 2001
From: Christoph Alt <christoph.alt@fau.de>
Date: Mon, 1 Aug 2022 15:03:34 +0200
Subject: [PATCH] added function to build and iterate query

---
 cbutil/postprocessing/__init__.py      |  2 +-
 cbutil/postprocessing/sqlite.py        | 10 ++++++++++
 cbutil/postprocessing/sqlite_helper.py |  2 +-
 tests/test_sqlite.py                   | 15 +++++++++++++--
 4 files changed, 25 insertions(+), 4 deletions(-)

diff --git a/cbutil/postprocessing/__init__.py b/cbutil/postprocessing/__init__.py
index 3af9ad9..5aa1264 100644
--- a/cbutil/postprocessing/__init__.py
+++ b/cbutil/postprocessing/__init__.py
@@ -1,3 +1,3 @@
 from .plain_text import process_linewise
-from .sqlite import sqlite_context, query_complete_table
+from .sqlite import sqlite_context, query_complete_table, build_iterate_query, iterate_all_tables
 from .sqlite_helper import query_builder
diff --git a/cbutil/postprocessing/sqlite.py b/cbutil/postprocessing/sqlite.py
index 228c72a..b4c3ccb 100644
--- a/cbutil/postprocessing/sqlite.py
+++ b/cbutil/postprocessing/sqlite.py
@@ -48,3 +48,13 @@ def iterate_all_tables(path):
 def iterate_join(path, lhs, rhs, key):
     with sqlite_context(path) as connection:
         yield from tables2dict(query_join(connection, lhs, rhs, key))
+
+
+def iterate_query(connection, query):
+    for result in connection.execute(query).fetchall():
+        yield result
+
+
+def build_iterate_query(connection, *args, **kwargs):
+    query = sh.query_builder(*args, **kwargs)
+    yield from iterate_query(connection, query)
diff --git a/cbutil/postprocessing/sqlite_helper.py b/cbutil/postprocessing/sqlite_helper.py
index df9cbf2..5dbee65 100644
--- a/cbutil/postprocessing/sqlite_helper.py
+++ b/cbutil/postprocessing/sqlite_helper.py
@@ -1,7 +1,7 @@
 
 def strip_n_check(arg):
     if not isinstance(arg, str):
-        return str(arg) 
+        return str(arg)
     stripped = arg.strip()
     if not stripped:
         raise ValueError("Empty arg")
diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py
index 01dd08e..da50b47 100644
--- a/tests/test_sqlite.py
+++ b/tests/test_sqlite.py
@@ -4,12 +4,14 @@ from cbutil.postprocessing.sqlite_helper import (select_stmt,
                                                  where_stmt,
                                                  join_stmt,
                                                  table_name_query,
-                                                 query_builder)
+                                                 query_builder,
+                                                 )
 
 from cbutil.postprocessing.sqlite import (get_all_table_names,
                                           tables2dict,
                                           query_join,
-                                          iterate_all_tables)
+                                          iterate_all_tables,
+                                          build_iterate_query)
 
 from cbutil.postprocessing.sqlite import sqlite_context
 
@@ -63,6 +65,8 @@ def test_table_name():
 def test_builder():
     assert query_builder(from_table="table") == "SELECT * FROM table"
     assert query_builder(select="row", from_table="table") == "SELECT row FROM table"
+    query = query_builder(select="row", from_table="table", where_args=["id", "0", "<"])
+    assert query == "SELECT row FROM table WHERE id<0"
 
 
 def test_all_table_names():
@@ -80,3 +84,10 @@ def test_join_query():
 def test_iterate_query():
     dicts = list(tables2dict(iterate_all_tables("tests/cpu_benchmark.sqlite3")))
     assert len(dicts) == 150
+
+
+def test_build_iterate():
+    with sqlite_context("tests/benchmark.sqlite") as connection:
+        for res in build_iterate_query(connection, from_table="timingPool",
+                                       where_args=["runId", 1]):
+            pass
-- 
GitLab