diff --git a/runhelper/parameterstudy.py b/runhelper/parameterstudy.py index bdc1f49ec4a97cdf43d0761cd2eb7b76b868d0e5..5923d6030bc59f811cd3590007407afb6eca1989 100644 --- a/runhelper/parameterstudy.py +++ b/runhelper/parameterstudy.py @@ -2,6 +2,8 @@ import json import datetime import os import socket +import itertools +from copy import deepcopy from collections import namedtuple from time import sleep from pystencils.runhelper import Database @@ -11,6 +13,12 @@ class ParameterStudy(object): Run = namedtuple("Run", ['parameterDict', 'weight']) + class DotDict(dict): + """Normal dict with additional dot access for all keys""" + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + def __init__(self, runFunction, listOfRuns=[], databaseFile='./db'): self.listOfRuns = listOfRuns self.runFunction = runFunction @@ -19,6 +27,22 @@ class ParameterStudy(object): def addRun(self, parameterDict, weight=1): self.listOfRuns.append(self.Run(parameterDict, weight)) + def addCombinations(self, degreesOfFreedom, constantParameters=None, filterFunction=None, weightFunction=None): + parameterNames = [e[0] for e in degreesOfFreedom] + parameterValues = [e[1] for e in degreesOfFreedom] + + defaultParamsDict = {} if constantParameters is None else constantParameters + for valueTuple in itertools.product(*parameterValues): + paramsDict = deepcopy(defaultParamsDict) + paramsDict.update({name: value for name, value in zip(parameterNames, valueTuple)}) + params = self.DotDict(paramsDict) + if filterFunction: + params = filterFunction(params) + if params is None: + continue + weight = 1 if not weightFunction else weightFunction(params) + self.addRun(params, weight) + def filterAlreadySimulated(self, allRuns): return [r for r in allRuns if not self.db.alreadySimulated(r.parameterDict)] @@ -132,11 +156,23 @@ class ParameterStudy(object): print("Cannot connect to server {} retrying in 5 seconds...".format(url)) sleep(5) - def run(self, process, numProcesses): + def run(self, process, numProcesses, parameterUpdate={}): ownRuns = self.distributeRuns(self.listOfRuns, process, numProcesses) for run in ownRuns: - result = self.runFunction(**run.parameterDict) - self.db.save(run.parameterDict, result) + parameterDict = run.parameterDict.copy() + parameterDict.update(parameterUpdate) + result = self.runFunction(**parameterDict) + + self.db.save(run.parameterDict, result, None, changedParams=parameterUpdate) + + def runScenariosNotInDatabase(self, parameterUpdate={}): + filteredRuns = self.filterAlreadySimulated(self.listOfRuns) + for run in filteredRuns: + parameterDict = run.parameterDict.copy() + parameterDict.update(parameterUpdate) + result = self.runFunction(**parameterDict) + + self.db.save(run.parameterDict, result, None, changedParams=parameterUpdate) def runFromCommandLine(self, argv=None): from argparse import ArgumentParser @@ -147,11 +183,24 @@ class ParameterStudy(object): self.runServer(a.host, a.port) def client(a): - print(a.parameterOverride) self.runClient(a.clientName, a.host, a.port, json.loads(a.parameterOverride)) + def local(a): + if a.database: + self.db = Database(a.database) + self.runScenariosNotInDatabase(json.loads(a.parameterOverride)) + parser = ArgumentParser() subparsers = parser.add_subparsers() + + localParser = subparsers.add_parser('local', aliases=['l'], + help="Run scenarios locally which are not yet in database",) + localParser.add_argument("-d", "--database", type=str, default="") + localParser.add_argument("-P", "--parameterOverride", type=str, default="{}", + help="JSON: the parameter dictionary is updated with these parameters. Use this to " + "set host specific options like GPU call parameters. Enclose in \" ") + localParser.set_defaults(func=local) + serverParser = subparsers.add_parser('server', aliases=['serv', 's'], help="Runs server to distribute different scenarios to workers",) serverParser.add_argument("-p", "--port", type=int, default=8342, help="Port to listen on")