diff --git a/runhelper/db.py b/runhelper/db.py index 8db4d769aab118f7ddd428e434059f1b543def73..13faa23203a6f1053332b5ce522f56eeda75421d 100644 --- a/runhelper/db.py +++ b/runhelper/db.py @@ -1,18 +1,41 @@ import time import socket -from collections import OrderedDict - import blitzdb -import pandas as pd from pystencils.cpu.cpujit import getCompilerConfig def removeConstantColumns(df): - remainingDf = df.loc[:, (df != df.ix[0]).any()] - constants = df.loc[:, (df == df.ix[0]).all()].ix[0] + import pandas as pd + remainingDf = df.loc[:, df.apply(pd.Series.nunique) > 1] + constants = df.loc[:, df.apply(pd.Series.nunique) <= 1].iloc[0] return remainingDf, constants +def removeColumnsByPrefix(df, prefixes, inplace=False): + if not inplace: + df = df.copy() + + for columnName in df.columns: + for prefix in prefixes: + if columnName.startswith(prefix): + del df[columnName] + return df + + +def removePrefixInColumnName(df, inplace=False): + if not inplace: + df = df.copy() + + newColumnNames = [] + for columnName in df.columns: + if '.' in columnName: + newColumnNames.append(columnName[columnName.index('.') + 1:]) + else: + newColumnNames.append(columnName) + df.columns = newColumnNames + return df + + class Database(object): class SimulationResult(blitzdb.Document): pass @@ -43,21 +66,31 @@ class Database(object): def filter(self, *args, **kwargs): return self.backend.filter(Database.SimulationResult, *args, **kwargs) + def filterParams(self, query, *args, **kwargs): + query = {'params.' + k: v for k, v in query.items()} + return self.filter(query, *args, **kwargs) + def alreadySimulated(self, parameters): return len(self.filter({'params': parameters})) > 0 - def toPandas(self, query): - queryResult = self.backend.filter(self.SimulationResult, query) - records = [] - index = set() - for e in queryResult: - record = OrderedDict(e.params.items()) - record.update(e.result) - records.append(record) - index.update(e.params.keys()) + # Columns with these prefixes are not included in pandas result + pandasColumnsToIgnore = ['changedParams.', 'env.'] - df = pd.DataFrame.from_records(records) + def toPandas(self, parameterQuery, removePrefix=True, dropConstantColumns=False): + import pandas as pd - return df + queryResult = self.filterParams(parameterQuery) + if len(queryResult) == 0: + return + df = pd.io.json.json_normalize([e.attributes for e in queryResult]) + df.set_index('pk', inplace=True) + if self.pandasColumnsToIgnore: + removeColumnsByPrefix(df, self.pandasColumnsToIgnore, inplace=True) + if removePrefix: + removePrefixInColumnName(df, inplace=True) + if dropConstantColumns: + df, _ = removeConstantColumns(df) + + return df