diff --git a/plot2d.py b/plot2d.py index 58a37d68a14770fe67a4a6a72af86789476642fd..43ba8af8b95316f0eacece43cc9a1ae238d0be30 100644 --- a/plot2d.py +++ b/plot2d.py @@ -44,6 +44,30 @@ def scalarField(field, **kwargs): return res +def scalarFieldAlphaValue(field, color, clip=False, **kwargs): + import numpy as np + import matplotlib + field = np.swapaxes(field, 0, 1) + color = matplotlib.colors.to_rgba(color) + + fieldToPlot = np.empty(field.shape + (4,)) + for i in range(3): + fieldToPlot[:, :, i] = color[i] + + if clip: + normalizedField = field.copy() + normalizedField[normalizedField<0] = 0 + normalizedField[normalizedField>1] = 1 + else: + min, max = np.min(field), np.max(field) + normalizedField = (field - min) / (max - min) + fieldToPlot[:, :, 3] = normalizedField + + res = imshow(fieldToPlot, origin='lower', **kwargs) + axis('equal') + return res + + def scalarFieldContour(field, **kwargs): field = np.swapaxes(field, 0, 1) res = contour(field, **kwargs)