Skip to content
Snippets Groups Projects
Commit 9761a600 authored by Martin Bauer's avatar Martin Bauer
Browse files

Updated square channel scenario & parameter study

parent 90c9bb3c
Branches
Tags
No related merge requests found
...@@ -2,6 +2,8 @@ import json ...@@ -2,6 +2,8 @@ import json
import datetime import datetime
import os import os
import socket import socket
import itertools
from copy import deepcopy
from collections import namedtuple from collections import namedtuple
from time import sleep from time import sleep
from pystencils.runhelper import Database from pystencils.runhelper import Database
...@@ -11,6 +13,12 @@ class ParameterStudy(object): ...@@ -11,6 +13,12 @@ class ParameterStudy(object):
Run = namedtuple("Run", ['parameterDict', 'weight']) Run = namedtuple("Run", ['parameterDict', 'weight'])
class DotDict(dict):
"""Normal dict with additional dot access for all keys"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def __init__(self, runFunction, listOfRuns=[], databaseFile='./db'): def __init__(self, runFunction, listOfRuns=[], databaseFile='./db'):
self.listOfRuns = listOfRuns self.listOfRuns = listOfRuns
self.runFunction = runFunction self.runFunction = runFunction
...@@ -19,6 +27,22 @@ class ParameterStudy(object): ...@@ -19,6 +27,22 @@ class ParameterStudy(object):
def addRun(self, parameterDict, weight=1): def addRun(self, parameterDict, weight=1):
self.listOfRuns.append(self.Run(parameterDict, weight)) self.listOfRuns.append(self.Run(parameterDict, weight))
def addCombinations(self, degreesOfFreedom, constantParameters=None, filterFunction=None, weightFunction=None):
parameterNames = [e[0] for e in degreesOfFreedom]
parameterValues = [e[1] for e in degreesOfFreedom]
defaultParamsDict = {} if constantParameters is None else constantParameters
for valueTuple in itertools.product(*parameterValues):
paramsDict = deepcopy(defaultParamsDict)
paramsDict.update({name: value for name, value in zip(parameterNames, valueTuple)})
params = self.DotDict(paramsDict)
if filterFunction:
params = filterFunction(params)
if params is None:
continue
weight = 1 if not weightFunction else weightFunction(params)
self.addRun(params, weight)
def filterAlreadySimulated(self, allRuns): def filterAlreadySimulated(self, allRuns):
return [r for r in allRuns if not self.db.alreadySimulated(r.parameterDict)] return [r for r in allRuns if not self.db.alreadySimulated(r.parameterDict)]
...@@ -132,11 +156,23 @@ class ParameterStudy(object): ...@@ -132,11 +156,23 @@ class ParameterStudy(object):
print("Cannot connect to server {} retrying in 5 seconds...".format(url)) print("Cannot connect to server {} retrying in 5 seconds...".format(url))
sleep(5) sleep(5)
def run(self, process, numProcesses): def run(self, process, numProcesses, parameterUpdate={}):
ownRuns = self.distributeRuns(self.listOfRuns, process, numProcesses) ownRuns = self.distributeRuns(self.listOfRuns, process, numProcesses)
for run in ownRuns: for run in ownRuns:
result = self.runFunction(**run.parameterDict) parameterDict = run.parameterDict.copy()
self.db.save(run.parameterDict, result) parameterDict.update(parameterUpdate)
result = self.runFunction(**parameterDict)
self.db.save(run.parameterDict, result, None, changedParams=parameterUpdate)
def runScenariosNotInDatabase(self, parameterUpdate={}):
filteredRuns = self.filterAlreadySimulated(self.listOfRuns)
for run in filteredRuns:
parameterDict = run.parameterDict.copy()
parameterDict.update(parameterUpdate)
result = self.runFunction(**parameterDict)
self.db.save(run.parameterDict, result, None, changedParams=parameterUpdate)
def runFromCommandLine(self, argv=None): def runFromCommandLine(self, argv=None):
from argparse import ArgumentParser from argparse import ArgumentParser
...@@ -147,11 +183,24 @@ class ParameterStudy(object): ...@@ -147,11 +183,24 @@ class ParameterStudy(object):
self.runServer(a.host, a.port) self.runServer(a.host, a.port)
def client(a): def client(a):
print(a.parameterOverride)
self.runClient(a.clientName, a.host, a.port, json.loads(a.parameterOverride)) self.runClient(a.clientName, a.host, a.port, json.loads(a.parameterOverride))
def local(a):
if a.database:
self.db = Database(a.database)
self.runScenariosNotInDatabase(json.loads(a.parameterOverride))
parser = ArgumentParser() parser = ArgumentParser()
subparsers = parser.add_subparsers() subparsers = parser.add_subparsers()
localParser = subparsers.add_parser('local', aliases=['l'],
help="Run scenarios locally which are not yet in database",)
localParser.add_argument("-d", "--database", type=str, default="")
localParser.add_argument("-P", "--parameterOverride", type=str, default="{}",
help="JSON: the parameter dictionary is updated with these parameters. Use this to "
"set host specific options like GPU call parameters. Enclose in \" ")
localParser.set_defaults(func=local)
serverParser = subparsers.add_parser('server', aliases=['serv', 's'], serverParser = subparsers.add_parser('server', aliases=['serv', 's'],
help="Runs server to distribute different scenarios to workers",) help="Runs server to distribute different scenarios to workers",)
serverParser.add_argument("-p", "--port", type=int, default=8342, help="Port to listen on") serverParser.add_argument("-p", "--port", type=int, default=8342, help="Port to listen on")
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment