db.py 7.57 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
2
import time
import socket
3
from typing import Dict, Sequence, Iterator
Martin Bauer's avatar
Martin Bauer committed
4
import blitzdb
Martin Bauer's avatar
Martin Bauer committed
5
from pystencils.cpu.cpujit import get_compiler_config
Martin Bauer's avatar
Martin Bauer committed
6
7


8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Database:
    """NoSQL database for storing simulation results.

    Two backends are supported:
        * `blitzdb`: simple file-based solution similar to sqlite for SQL databases, stores json files
                     no server setup required, but slow for larger collections
        * `mongodb`: mongodb backend via `pymongo`

    A simulation result is stored as an object consisting of
        * parameters: dict with simulation parameters
        * results: dict with results
        * environment: information about the machine, compiler configuration and time

    Args:
        file: database identifier, for blitzdb pass a directory name here. Database folder is created if it doesn't
              exist yet. For larger collections use mongodb. In this case pass a pymongo connection string
              e.g. "mongo://server:9131"

    Example:
        >>> from tempfile import TemporaryDirectory
        >>> with TemporaryDirectory() as tmp_dir:
        ...     db = Database(tmp_dir)  # create database in temporary folder
        ...     params = {'method': 'finite_diff', 'dx': 1.5}  # some hypothetical simulation parameters
        ...     db.save(params, result={'error': 1e-6})  # save simulation parameters together with hypothetical results
        ...     assert db.was_already_simulated(params)  # search for parameters in database
        ...     assert next(db.filter_params(params))['params'] == params # get data set, keys are 'params', 'results'
        ...                                                               # and 'env'
        ...     # get a pandas object with all results matching a query
        ...     db.to_pandas({'dx': 1.5}, remove_prefix=True)  # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
                             dx       method     error
        pk
        ...                 1.5  finite_diff  0.000001
    """
Martin Bauer's avatar
Martin Bauer committed
41

Martin Bauer's avatar
Martin Bauer committed
42
43
44
    class SimulationResult(blitzdb.Document):
        pass

45
    def __init__(self, file: str) -> None:
46
47
        if file.startswith("mongo://"):
            from pymongo import MongoClient
48
            db_name = file[len("mongo://"):]
49
            c = MongoClient()
50
            self.backend = blitzdb.MongoBackend(c[db_name])
51
52
53
        else:
            self.backend = blitzdb.FileBackend(file)

Martin Bauer's avatar
Martin Bauer committed
54
55
        self.backend.autocommit = True

Martin Bauer's avatar
Martin Bauer committed
56
    def save(self, params: Dict, result: Dict, env: Dict = None, **kwargs) -> None:
57
        """Stores a simulation result in the database.
58

59
60
61
62
63
64
65
66
67
        Args:
            params: dict of simulation parameters
            result: dict of simulation results
            env: optional environment - if None a default environment with compiler configuration, machine info and time
                 is used
            **kwargs: the final object is updated with the keyword arguments

        """
        document_dict = {
Martin Bauer's avatar
Martin Bauer committed
68
69
            'params': params,
            'result': result,
70
            'env': env if env else self.get_environment(),
Martin Bauer's avatar
Martin Bauer committed
71
        }
72
73
        document_dict.update(kwargs)
        document = Database.SimulationResult(document_dict, backend=self.backend)
Martin Bauer's avatar
Martin Bauer committed
74
75
76
        document.save()
        self.backend.commit()

77
    def filter_params(self, parameter_query: Dict, *args, **kwargs) -> Iterator['SimulationResult']:
78
79
80
81
82
83
84
85
        """Query using simulation parameters.

        See blitzdb documentation for filter

        Args:
            parameter_query: blitzdb filter dict using only simulation parameters
            *args: arguments passed to blitzdb filter
            **kwargs: arguments passed to blitzdb filter
86

87
88
89
90
        Returns:
            generator of SimulationResult, which is a dict-like object with keys 'params', 'result' and 'env'
        """
        query = {'params.' + k: v for k, v in parameter_query.items()}
Martin Bauer's avatar
Martin Bauer committed
91
92
        return self.filter(query, *args, **kwargs)

93
94
95
96
97
98
99
100
101
102
    def filter(self, *args, **kwargs):
        """blitzdb filter on SimulationResult, not only simulation parameters.

        Can be used to filter for results or environment options.
        The filter dictionary has to have prefixes "params." , "env." or "result."
        """
        return self.backend.filter(Database.SimulationResult, *args, **kwargs)

    def was_already_simulated(self, parameters):
        """Checks if there is at least one simulation result matching the passed parameters."""
103
        return len(self.filter({'params': parameters})) > 0
Martin Bauer's avatar
Martin Bauer committed
104

Martin Bauer's avatar
Martin Bauer committed
105
    # Columns with these prefixes are not included in pandas result
Martin Bauer's avatar
Martin Bauer committed
106
    pandas_columns_to_ignore = ['changedParams.', 'env.']
Martin Bauer's avatar
Martin Bauer committed
107

108
109
    def to_pandas(self, parameter_query, remove_prefix=True, drop_constant_columns=False):
        """Queries for simulations with given parameters and returns them in a pandas data frame.
Martin Bauer's avatar
Martin Bauer committed
110

111
112
113
114
115
116
117
118
119
        Args:
            parameter_query: see filter method
            remove_prefix: if True the name of the pandas columns are not prefixed with "params." or "results."
            drop_constant_columns: if True, all columns are dropped that have the same value is all rows

        Returns:
            pandas data frame
        """
        from pandas.io.json import json_normalize
Martin Bauer's avatar
Martin Bauer committed
120

121
122
123
124
125
        query_result = self.filter_params(parameter_query)
        attributes = [e.attributes for e in query_result]
        if not attributes:
            return
        df = json_normalize(attributes)
Martin Bauer's avatar
Martin Bauer committed
126
        df.set_index('pk', inplace=True)
127

Martin Bauer's avatar
Martin Bauer committed
128
129
        if self.pandas_columns_to_ignore:
            remove_columns_by_prefix(df, self.pandas_columns_to_ignore, inplace=True)
130
131
132
133
        if remove_prefix:
            remove_prefix_in_column_name(df, inplace=True)
        if drop_constant_columns:
            df, _ = remove_constant_columns(df)
Martin Bauer's avatar
Martin Bauer committed
134
135

        return df
136
137
138

    @staticmethod
    def get_environment():
Martin Bauer's avatar
Martin Bauer committed
139
        result = {
140
141
            'timestamp': time.mktime(time.gmtime()),
            'hostname': socket.gethostname(),
Martin Bauer's avatar
Martin Bauer committed
142
            'cpuCompilerConfig': get_compiler_config(),
143
        }
Martin Bauer's avatar
Martin Bauer committed
144
145
146
147
148
149
150
151
        try:
            from git import Repo, InvalidGitRepositoryError
            repo = Repo(search_parent_directories=True)
            result['git_hash'] = str(repo.head.commit)
        except (ImportError, InvalidGitRepositoryError):
            pass

        return result
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

# ----------------------------------------- Helper Functions -----------------------------------------------------------


def remove_constant_columns(df):
    """Removes all columns of a pandas data frame that have the same value in all rows."""
    import pandas as pd
    remaining_df = df.loc[:, df.apply(pd.Series.nunique) > 1]
    constants = df.loc[:, df.apply(pd.Series.nunique) <= 1].iloc[0]
    return remaining_df, constants


def remove_columns_by_prefix(df, prefixes: Sequence[str], inplace: bool = False):
    """Remove all columns from a pandas data frame whose name starts with one of the given prefixes."""
    if not inplace:
        df = df.copy()

Martin Bauer's avatar
Martin Bauer committed
169
    for column_name in df.columns:
170
        for prefix in prefixes:
Martin Bauer's avatar
Martin Bauer committed
171
172
            if column_name.startswith(prefix):
                del df[column_name]
173
174
175
176
177
178
179
180
181
182
183
184
185
    return df


def remove_prefix_in_column_name(df, inplace: bool = False):
    """Removes dotted prefixes from pandas column names.

    A column named 'result.finite_diff.dx' is renamed to 'finite_diff.dx', everything before the first dot is removed.
    If the column name does not contain a dot, the column name is not changed.
    """
    if not inplace:
        df = df.copy()

    new_column_names = []
Martin Bauer's avatar
Martin Bauer committed
186
187
188
    for column_name in df.columns:
        if '.' in column_name:
            new_column_names.append(column_name[column_name.index('.') + 1:])
189
        else:
Martin Bauer's avatar
Martin Bauer committed
190
            new_column_names.append(column_name)
191
192
    df.columns = new_column_names
    return df