diff --git a/conftest.py b/conftest.py index c27883d6f9ce8166692f8afd462da3a4ae0c0aa5..040ddf59505d29d1235a91279006ad34e6bcb8db 100644 --- a/conftest.py +++ b/conftest.py @@ -65,6 +65,7 @@ try: except ImportError: add_path_to_ignore('pystencils/runhelper') collect_ignore += [os.path.join(SCRIPT_FOLDER, "pystencils_tests/test_parameterstudy.py")] + collect_ignore += [os.path.join(SCRIPT_FOLDER, "pystencils_tests/test_json_serializer.py")] try: import islpy diff --git a/pystencils/runhelper/db.py b/pystencils/runhelper/db.py index 21b75c4ba91a53588db8637651bfda090c4044e6..1c8d3aa66c0e47681b942dd06f0a1805e244e930 100644 --- a/pystencils/runhelper/db.py +++ b/pystencils/runhelper/db.py @@ -1,10 +1,60 @@ import socket import time +from types import MappingProxyType from typing import Dict, Iterator, Sequence import blitzdb +import six +from blitzdb.backends.file.backend import serializer_classes +from blitzdb.backends.file.utils import JsonEncoder from pystencils.cpu.cpujit import get_compiler_config +from pystencils import CreateKernelConfig, Target, Backend, Field + +import json +import sympy as sp + +from pystencils.typing import BasicType + + +class PystencilsJsonEncoder(JsonEncoder): + + def default(self, obj): + if isinstance(obj, CreateKernelConfig): + return obj.__dict__ + if isinstance(obj, (sp.Float, sp.Rational)): + return float(obj) + if isinstance(obj, sp.Integer): + return int(obj) + if isinstance(obj, (BasicType, MappingProxyType)): + return str(obj) + if isinstance(obj, (Target, Backend, sp.Symbol)): + return obj.name + if isinstance(obj, Field): + return f"pystencils.Field(name = {obj.name}, field_type = {obj.field_type.name}, " \ + f"dtype = {str(obj.dtype)}, layout = {obj.layout}, shape = {obj.shape}, " \ + f"strides = {obj.strides})" + return JsonEncoder.default(self, obj) + + +class PystencilsJsonSerializer(object): + + @classmethod + def serialize(cls, data): + if six.PY3: + if isinstance(data, bytes): + return json.dumps(data.decode('utf-8'), cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8') + else: + return json.dumps(data, cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8') + else: + return json.dumps(data, cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8') + + @classmethod + def deserialize(cls, data): + if six.PY3: + return json.loads(data.decode('utf-8')) + else: + return json.loads(data.decode('utf-8')) class Database: @@ -46,7 +96,7 @@ class Database: class SimulationResult(blitzdb.Document): pass - def __init__(self, file: str) -> None: + def __init__(self, file: str, serializer_info: tuple = None) -> None: if file.startswith("mongo://"): from pymongo import MongoClient db_name = file[len("mongo://"):] @@ -57,6 +107,10 @@ class Database: self.backend.autocommit = True + if serializer_info: + serializer_classes.update({serializer_info[0]: serializer_info[1]}) + self.backend.load_config({'serializer_class': serializer_info[0]}, True) + def save(self, params: Dict, result: Dict, env: Dict = None, **kwargs) -> None: """Stores a simulation result in the database. @@ -146,10 +200,15 @@ class Database: 'cpuCompilerConfig': get_compiler_config(), } try: - from git import Repo, InvalidGitRepositoryError + from git import Repo + except ImportError: + return result + + try: + from git import InvalidGitRepositoryError repo = Repo(search_parent_directories=True) result['git_hash'] = str(repo.head.commit) - except (ImportError, InvalidGitRepositoryError): + except InvalidGitRepositoryError: pass return result diff --git a/pystencils/runhelper/parameterstudy.py b/pystencils/runhelper/parameterstudy.py index f4d8327d335125f7f57a8d622e1fd37855c3d9dd..243a30e437b05960a986d6134b3226c47f936948 100644 --- a/pystencils/runhelper/parameterstudy.py +++ b/pystencils/runhelper/parameterstudy.py @@ -9,6 +9,7 @@ from time import sleep from typing import Any, Callable, Dict, Optional, Sequence, Tuple from pystencils.runhelper import Database +from pystencils.runhelper.db import PystencilsJsonSerializer from pystencils.utils import DotDict ParameterDict = Dict[str, Any] @@ -54,10 +55,11 @@ class ParameterStudy: Run = namedtuple("Run", ['parameter_dict', 'weight']) def __init__(self, run_function: Callable[..., Dict], runs: Sequence = (), - database_connector: str = './db') -> None: + database_connector: str = './db', + serializer_info: tuple = ('pystencils_serializer', PystencilsJsonSerializer)) -> None: self.runs = list(runs) self.run_function = run_function - self.db = Database(database_connector) + self.db = Database(database_connector, serializer_info) def add_run(self, parameter_dict: ParameterDict, weight: int = 1) -> None: """Schedule a dictionary of parameters to run in this parameter study. diff --git a/pystencils_tests/test_json_serializer.py b/pystencils_tests/test_json_serializer.py new file mode 100644 index 0000000000000000000000000000000000000000..f4600753559ae1f2052aa115cf55a85dec5c381c --- /dev/null +++ b/pystencils_tests/test_json_serializer.py @@ -0,0 +1,28 @@ +""" +Test the pystencils-specific JSON encoder and serializer as used in the Database class. +""" + +import numpy as np +import tempfile + +from pystencils.config import CreateKernelConfig +from pystencils import Target, Field +from pystencils.runhelper.db import Database, PystencilsJsonSerializer + + +def test_json_serializer(): + + dtype = np.float32 + + index_arr = np.zeros((3,), dtype=dtype) + indexed_field = Field.create_from_numpy_array('index', index_arr) + + # create pystencils config + config = CreateKernelConfig(target=Target.CPU, function_name='dummy_config', data_type=dtype, + index_fields=[indexed_field]) + + # create dummy database + temp_dir = tempfile.TemporaryDirectory() + db = Database(file=temp_dir.name, serializer_info=('pystencils_serializer', PystencilsJsonSerializer)) + + db.save(params={'config': config}, result={'test': 'dummy'})