"""
Module plotshop.plot_style
----------------------------
Helper functions to make the most awesome* plots out there.
* please feel free to add more stuff
"""
import colorsys
import re
from itertools import cycle
import matplotlib
import matplotlib.colors as mc
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
from distutils.version import LooseVersion
import pandas as pd
import numpy as np
from tfs_files import tfs_pandas as tfs
[docs]class ArgumentError(Exception):
pass
# commented values are part of matplotlib 2.1 but not 1.5
_PRESENTATION_PARAMS = {
# u'axes.autolimit_mode': u'data',
u'backend': u'pdf',
u'axes.edgecolor': u'k',
u'axes.facecolor': u'w',
u'axes.grid': True,
u'axes.grid.axis': u'both',
u'axes.grid.which': u'major',
u'axes.labelcolor': u'k',
# u'axes.labelpad': 4.0,
u'axes.labelsize': 22,
u'axes.labelweight': u'normal',
u'axes.linewidth': 1.8,
# u'axes.titlepad': 16.0,
u'axes.titlesize': u'xx-large',
u'axes.titleweight': u'bold',
u'figure.edgecolor': u'w',
u'figure.facecolor': u'w',
u'figure.figsize': [10.24, 7.68],
u'figure.frameon': True,
u'figure.titlesize': u'xx-large',
u'figure.titleweight': u'normal',
u'font.size': 20.0,
u'font.stretch': u'normal',
u'font.weight': u'normal',
u'font.family': 'sans-serif',
u'font.serif': ['Computer Modern Roman'],
u'grid.alpha': .6,
u'grid.color': u'#b0b0b0',
u'grid.linestyle': u'--',
u'grid.linewidth': 1,
u'legend.edgecolor': u'0.8',
u'legend.facecolor': u'inherit',
u'legend.fancybox': True,
u'legend.fontsize': 20.0,
u'legend.framealpha': 0.9,
u'legend.frameon': False,
u'legend.handleheight': 0.7,
u'legend.handlelength': 2.0,
u'legend.handletextpad': 0.8,
u'legend.labelspacing': 0.5,
u'legend.loc': u'best',
u'legend.markerscale': 1.2,
u'legend.numpoints': 1,
u'legend.scatterpoints': 1,
u'legend.shadow': False,
u'lines.antialiased': True,
# u'lines.color': u'C0',
u'lines.linestyle': u'-',
u'lines.linewidth': 2,
u'lines.marker': u'o',
u'lines.markeredgewidth': 2,
u'lines.markersize': 14.0,
u'lines.solid_capstyle': u'projecting',
u'lines.solid_joinstyle': u'round',
u'markers.fillstyle': u'none',
u'text.antialiased': True,
u'text.color': u'k',
# u'xtick.alignment': u'center',
# u'xtick.bottom': True,
# u'xtick.color': u'k',
# u'xtick.direction': u'out',
# u'xtick.labelsize': u'medium',
# u'xtick.major.bottom': True,
# u'xtick.major.pad': 3.5,
# u'xtick.major.size': 3.5,
# u'xtick.major.top': True,
# u'xtick.major.width': 1.2,
# u'xtick.minor.bottom': True,
# u'xtick.minor.pad': 3.4,
# u'xtick.minor.size': 2.0,
# u'xtick.minor.top': True,
# u'xtick.minor.visible': False,
# u'xtick.minor.width': 1,
# u'xtick.top': False,
# u'ytick.alignment': u'center_baseline',
# u'ytick.color': u'k',
# u'ytick.direction': u'out',
# u'ytick.labelsize': u'medium',
# u'ytick.left': True,
# u'ytick.major.left': True,
# u'ytick.major.pad': 3.5,
# u'ytick.major.right': True,
# u'ytick.major.size': 3.5,
# u'ytick.major.width': 1.2,
# u'ytick.minor.left': True,
# u'ytick.minor.pad': 3.4,
# u'ytick.minor.right': True,
# u'ytick.minor.size': 2.0,
# u'ytick.minor.visible': False,
# u'ytick.minor.width': 1,
# # u'ytick.right': False
}
_STANDARD_PARAMS = {
# u'axes.autolimit_mode': u'data',
u'axes.edgecolor': u'k',
u'axes.facecolor': u'w',
u'axes.grid': True,
u'axes.grid.axis': u'both',
u'axes.grid.which': u'major',
u'axes.labelcolor': u'k',
# u'axes.labelpad': 4.0,
u'axes.labelsize': u'medium',
u'axes.labelweight': u'normal',
u'axes.linewidth': 1.5,
# u'axes.titlepad': 6.0,
u'axes.titlesize': u'x-large',
u'axes.titleweight': u'bold',
u'figure.edgecolor': u'w',
u'figure.facecolor': u'w',
u'figure.figsize': [10.24, 7.68],
u'figure.frameon': True,
u'figure.titlesize': u'large',
u'figure.titleweight': u'normal',
u'font.size': 15.0,
u'font.stretch': u'normal',
u'font.weight': u'normal',
u'font.family': 'sans-serif',
u'font.serif': ['Computer Modern Roman'],
u'font.sans-serif': ['Computer Modern Sans serif'],
u'grid.alpha': .6,
u'grid.color': u'#b0b0b0',
u'grid.linestyle': u'--',
u'grid.linewidth': 0.6,
u'legend.edgecolor': u'0.8',
u'legend.facecolor': u'inherit',
u'legend.fancybox': True,
u'legend.fontsize': 16.,
u'legend.framealpha': 0.8,
u'legend.frameon': False,
u'legend.handleheight': 0.7,
u'legend.handlelength': 2.0,
u'legend.handletextpad': 0.8,
u'legend.labelspacing': 0.5,
u'legend.loc': u'best',
u'legend.markerscale': .8,
u'legend.numpoints': 1,
u'legend.scatterpoints': 1,
u'legend.shadow': False,
u'lines.antialiased': True,
# u'lines.color': u'C0',
u'lines.linestyle': u'-',
u'lines.linewidth': 1.5,
u'lines.marker': u'o',
u'lines.markeredgewidth': 1.0,
u'lines.markersize': 8.0,
u'lines.solid_capstyle': u'projecting',
u'lines.solid_joinstyle': u'round',
u'markers.fillstyle': u'none',
u'text.antialiased': True,
u'text.color': u'k',
# u'xtick.alignment': u'center',
# u'xtick.bottom': True,
# u'xtick.color': u'k',
# u'xtick.direction': u'out',
# u'xtick.labelsize': u'medium',
# u'xtick.major.bottom': True,
# u'xtick.major.pad': 3.5,
# u'xtick.major.size': 3.5,
# u'xtick.major.top': True,
# u'xtick.major.width': 0.8,
# u'xtick.minor.bottom': True,
# u'xtick.minor.pad': 3.4,
# u'xtick.minor.size': 2.0,
# u'xtick.minor.top': True,
# u'xtick.minor.visible': False,
# u'xtick.minor.width': 0.6,
# u'xtick.top': False,
# u'ytick.alignment': u'center_baseline',
# u'ytick.color': u'k',
# u'ytick.direction': u'out',
# u'ytick.labelsize': u'medium',
# u'ytick.left': True,
# u'ytick.major.left': True,
# u'ytick.major.pad': 3.5,
# u'ytick.major.right': True,
# u'ytick.major.size': 3.5,
# u'ytick.major.width': 0.8,
# u'ytick.minor.left': True,
# u'ytick.minor.pad': 3.4,
# u'ytick.minor.right': True,
# u'ytick.minor.size': 2.0,
# u'ytick.minor.visible': False,
# u'ytick.minor.width': 0.6,
# u'ytick.right': False
}
# Style ######################################################################
[docs]def set_style(style='standard', manual=None):
"""Sets the style for all following plots.
Args:
style: Choose Style, either 'standard' or 'presentation'
manual: Dict of manual parameters to update. Convention: "REMOVE_ENTRY" removes entry
"""
if style == 'standard':
params = _STANDARD_PARAMS.copy()
elif style == 'presentation':
params = _PRESENTATION_PARAMS.copy()
else:
raise ArgumentError("Style '" + style + "' not found.")
if manual:
for key in manual.keys():
if manual[key] == "REMOVE_ENTRY":
params.pop(key)
else:
params[key] = manual[key]
matplotlib.rcParams.update(params)
# Tools ######################################################################
[docs]def sync2d(axs, ax_str='xy', ax_lim=()):
"""
Synchronizes the limits for the given axes
Args:
axs: list of axes or figures, or figure with multiple axes
ax_lim: predefined limits (list or list of lists)
ax_str: string 'x','y' or 'xy' defining the axes to sync
"""
if isinstance(axs, (list, np.ndarray)):
if isinstance(axs[0], matplotlib.figure.Figure):
# axs is list of figures: get all axes and call sync2D
sync2d([ax for fig in axs for ax in fig.axes], ax_str)
elif isinstance(axs[0], matplotlib.axes.Axes) and len(axs) > 1:
# axs is list of axes
if 'x' in ax_str:
if len(ax_lim) == 0:
# find x limits
x_min = min([ax.get_xlim()[0] for ax in axs])
x_max = max([ax.get_xlim()[1] for ax in axs])
x_lim = [x_min, x_max]
else:
# use defined limits
if len(ax_lim[0]) == 1:
x_lim = ax_lim
else:
x_lim = ax_lim[0]
for ax in axs:
ax.set_xlim(x_lim)
if 'y' in ax_str:
if len(ax_lim) == 0:
# find x limits
y_min = min([ax.get_ylim()[0] for ax in axs])
y_max = max([ax.get_ylim()[1] for ax in axs])
y_lim = [y_min, y_max]
else:
# use defined limits
if len(ax_lim[0]) == 1:
y_lim = ax_lim
else:
y_lim = ax_lim[1]
for ax in axs:
ax.set_ylim(y_lim)
elif isinstance(axs, matplotlib.figure.Figure):
# axs is one figure: call sync2D with axes from figure
sync2d(axs.axes, ax_str)
else:
raise TypeError(__file__[:-3] + '.sync2d input is of unknown type (' + str(type(axs)) + ')')
[docs]def set_xLimits(accel, ax=None):
"""
Sets the x-limits to the regularly used ones
Args:
accel: Name of the Accelerator
ax: Axes to put the label on (default: gca())
"""
if not ax:
ax = plt.gca()
if accel.startswith("LHCB"):
ax.set_xlim(-200, 27000)
ax.xaxis.set_minor_locator(mtick.MultipleLocator(base=1000.0))
ax.xaxis.set_major_locator(mtick.MultipleLocator(base=5000.0))
elif accel.startswith("ESRF"):
ax.set_xlim(-5, 850)
ax.xaxis.set_minor_locator(mtick.MultipleLocator(base=50.0))
ax.xaxis.set_major_locator(mtick.MultipleLocator(base=100.0))
else:
raise ArgumentError("Accelerator '" + accel + "' unknown.")
[docs]class MarkerList(object):
""" Create a list of predefined markers """
# markers = ["s", "o", ">", "D", "v", "*", "h", "^", "p", "X", "<", "P"] # matplotlib 2.++
markers = ["s", "o", ">", "D", "v", "*", "h", "^", "p", "<"]
def __init__(self):
self.idx = 0
[docs] @classmethod
def get_marker(cls, marker_num):
""" Return marker of index marker_num
Args:
marker_num (int): return maker at this position in list (mod len(list))
"""
return cls.markers[marker_num % len(cls.markers)]
[docs] def get_next_marker(self):
""" Return the next marker in the list (circularly wrapped) """
marker = self.get_marker(self.idx)
self.idx += 1
return marker
# Colors #####################################################################
def get_mpl_color(idx=None):
c = [
'#1f77b4', # muted blue
'#ff7f0e', # safety orange
'#2ca02c', # cooked asparagus green
'#d62728', # brick red
'#9467bd', # muted purple
'#8c564b', # chestnut brown
'#e377c2', # raspberry yogurt pink
'#7f7f7f', # middle gray
'#bcbd22', # curry yellow-green
'#17becf', # blue-teal
]
if idx is None:
return cycle(c)
return c[idx % len(c)]
def rgb_plotly_to_mpl(rgb_string):
if rgb_string.startswith('#'):
return rgb_string
rgb_string = rgb_string.replace("rgba", "").replace("rgb", "")
rgb = eval(rgb_string)
rgb_norm = [c/255. for c in rgb]
return rgb_norm
[docs]def change_color_brightness(color, amount=0.5):
"""
Lightens the given color by multiplying (1-luminosity) by the given amount.
Input can be matplotlib color string, hex string, or RGB tuple.
An amount of 1 equals to no change. 0 is very bright (white) and 2 is very dark.
By Ian Hincks
Source: https://stackoverflow.com/questions/37765197/darken-or-lighten-a-color-in-matplotlib
"""
if not (0<=amount<=2):
raise ValueError("The brightness change has to be between 0 and 2."
" Instead it was {}".format(amount))
try:
c = mc.cnames[color]
except KeyError:
c = color
try:
c = colorsys.rgb_to_hls(*mc.ColorConverter().to_rgb(c)) # matplotlib 1.5
except AttributeError:
c = colorsys.rgb_to_hls(*mc.to_rgb(c)) # matplotlib > 2
return colorsys.hls_to_rgb(c[0], 1-amount * (1-c[1]), c[2])
[docs]def change_ebar_alpha_for_line(ebar, alpha):
""" loop through bars (ebar[1]) and caps (ebar[2]) and set the alpha value """
for bars_or_caps in ebar[1:]:
for bar_or_cap in bars_or_caps:
bar_or_cap.set_alpha(alpha)
[docs]def change_ebar_alpha_for_axes(ax, alpha):
""" Wrapper for change_ebar_alpha_for_line """
for ebar in ax.containers:
if isinstance(ebar, matplotlib.container.ErrorbarContainer):
change_ebar_alpha_for_line(ebar, alpha)
# Labels #####################################################################
# List of common y-labels. Sorry for the ugly.
_ylabels = {
"beta": r'$\beta_{{{0}}} \quad [m]$',
"betabeat": r'$\Delta \beta_{{{0}}} / \beta_{{{0}}}$',
"betabeat_permile": r'$\Delta \beta_{{{0}}} / \beta_{{{0}}} [$'u'\u2030'r'$]$',
"dbeta": r"$\beta'_{{{0}}} \quad [m]$",
"dbetabeat": r'$1/\beta_{{{0}}} \cdot \partial\beta_{{{0}}} / \partial\delta_{{{0}}}$',
"norm_dispersion": r'$\frac{{D_{{{0}}}}}{{\sqrt{{\beta_{{{0}}}}}}} \quad [\sqrt{{m}}]$',
"norm_dispersion_mu": r'$\frac{{D_{{{0}}}}}{{\sqrt{{\beta_{{{0}}}}}}} \quad [\mu \sqrt{{m}}]$',
"phase": r'$\phi_{{{0}}} \quad [2\pi]$',
"phasetot": r'$\phi_{{{0}}} \quad [2\pi]$',
"phase_milli": r'$\phi_{{{0}}} \quad [2\pi\cdot10^{{-3}}]$',
"dispersion": r'$D_{{{0}}} \quad [m]$',
"dispersion_mm": r'$D_{{{0}}} \quad [mm]$',
"co": r'${0} \quad [mm]$',
"tune": r'$Q_{{{0}}} \quad [Hz]$',
"nattune": r'$Nat Q_{{{0}}} \quad [Hz]$',
"chromamp": r'$W_{{{0}}}$',
"real": r'$re({0})$',
"imag": r'$im({0})$',
"absolute": r'$|{0}|$',
}
[docs]def set_yaxis_label(param, plane, ax=None, delta=False, chromcoup=False): # plot x and plot y
""" Set y-axis labels.
Args:
param: One of the ylabels above
plane: Usually x or y, but can be any string actually to be placed into the label ({0})
ax: Axes to put the label on (default: gca())
delta: If True adds a Delta before the label (default: False)
"""
if not ax:
ax = plt.gca()
try:
label = _ylabels[param].format(plane)
except KeyError:
raise ArgumentError("Label '" + param + "' not found.")
if delta:
if param.startswith("beta") or param.startswith("norm"):
label = r'$\Delta(' + label[1:-1] + ")$"
else:
label = r'$\Delta ' + label[1:]
if chromcoup:
label = label[:-1] + r'/\Delta\delta$'
ax.set_ylabel(label)
[docs]def set_xaxis_label(ax=None):
""" Sets the standard x-axis label
Args:
ax: Axes to put the label on (default: gca())
"""
if not ax:
ax = plt.gca()
ax.set_xlabel(r'Longitudinal location [m]')
[docs]def show_ir(ip_dict, ax=None, mode='inside'):
""" Plots the interaction regions into the background of the plot.
Args:
ip_dict: dict, dataframe or series containing "IPLABEL" : IP_POSITION
ax: Axes to put the irs on (default: gca())
mode: 'inside', 'outside' + 'nolines' or just 'lines'
"""
if not ax:
ax = plt.gca()
xlim = ax.get_xlim()
ylim = ax.get_ylim()
lines = 'nolines' not in mode
inside = 'inside' in mode
lines_only = 'inside' not in mode and 'outside' not in mode and 'lines' in mode
if isinstance(ip_dict, (pd.DataFrame, pd.Series)):
if isinstance(ip_dict, pd.DataFrame):
ip_dict = ip_dict.iloc[:, 0]
d = {}
for ip in ip_dict.index:
d[ip] = ip_dict.loc[ip]
ip_dict = d
for ip in ip_dict.keys():
if xlim[0] <= ip_dict[ip] <= xlim[1]:
xpos = ip_dict[ip]
if lines:
ax.axvline(xpos, linestyle=':', color='grey', marker='', zorder=0)
if not lines_only:
ypos = ylim[not inside] + (ylim[1] - ylim[0]) * 0.01
c = 'grey' if inside else matplotlib.rcParams["text.color"]
ax.text(xpos, ypos, ip, color=c, ha='center', va='bottom')
ax.set_xlim(xlim)
ax.set_ylim(ylim)
[docs]def move_ip_labels(ax, value):
""" Moves IP labels according to max y * value."""
y_max = ax.get_ylim()[1]
for t in ax.texts:
if re.match(r"^IP\s*\d$", t.get_text()):
x = t.get_position()[0]
t.set_position((x, y_max * value))
[docs]def get_ip_positions(path):
""" Returns a dict of IP positions from tfs-file of path.
Args:
path (str): Path to the tfs-file containing IP-positions
"""
df = tfs.read_tfs(path).set_index('NAME')
ip_names = ["IP" + str(i) for i in range(1, 9)]
ip_pos = df.loc[ip_names, 'S'].values
return dict(zip(ip_names, ip_pos))
[docs]def set_name(name, fig_or_ax=None):
""" Sets the name of the figure or axes
Args:
name (str): Sting to set as name.
fig_or_ax: Figure or Axes to to use.
If 'None' takes current figure. (Default: None)
"""
if not fig_or_ax:
fig_or_ax = plt.gcf()
try:
fig_or_ax.figure.canvas.set_window_title(name)
except AttributeError:
fig_or_ax.canvas.set_window_title(name)
[docs]def get_name(fig_or_ax=None):
""" Returns the name of the figure or axes
Args:
fig_or_ax: Figure or Axes to to use.
If 'None' takes current figure. (Default: None)
"""
if not fig_or_ax:
fig_or_ax = plt.gcf()
try:
return fig_or_ax.figure.canvas.get_window_title()
except AttributeError:
return fig_or_ax.canvas.get_window_title()
[docs]def set_annotation(text, ax=None):
""" Writes an annotation on the top right of the axes
Args:
text: The annotation
ax: Axes to set annotation on. If 'None' takes current Axes. (Default: None)
"""
if not ax:
ax = plt.gca()
annotation = get_annotation(ax, by_reference=True)
if annotation is None:
ax.text(1.0, 1.0, text,
verticalalignment='bottom',
horizontalalignment='right',
transform=ax.transAxes,
label='plot_style_annotation')
else:
annotation.set_text(text)
[docs]def get_annotation(ax=None, by_reference=False):
""" Returns the annotation set by set_annotation()
Args:
ax: Axes to get annotation from. If 'None' takes current Axes. (Default: None)
by_reference (bool): If true returns the reference to the annotation,
otherwise the text as string. (Default: False)
"""
if not ax:
ax = plt.gca()
for c in ax.get_children():
if c.get_label() == 'plot_style_annotation':
if by_reference:
return c
else:
return c.get_text()
return None
[docs]def small_title(ax=None):
""" Alternative to annotation, which lets you use the title-functions
Args:
ax: Axes to use. If 'None' takes current Axes. (Default: None)
"""
if not ax:
ax = plt.gca()
# could not get set_title() to work properly, so one parameter at a time
ax.title.set_position([1.0, 1.02])
ax.title.set_transform(ax.transAxes)
ax.title.set_fontsize(matplotlib.rcParams['font.size'])
ax.title.set_fontweight(matplotlib.rcParams['font.weight'])
ax.title.set_verticalalignment('bottom')
ax.title.set_horizontalalignment('right')
[docs]def get_legend_ncols(labels, max_length=78):
""" Calulate the number of columns in legend dynamically """
return max([max_length/max([len(l) for l in labels]), 1])
[docs]def make_top_legend(ax, ncol, frame=False, handles=None, labels=None):
""" Create a legend on top of the plot. """
leg = ax.legend(handles=handles, labels=labels, loc='lower right',
bbox_to_anchor=(1.0, 1.02),
fancybox=frame, shadow=frame, frameon=frame, ncol=ncol)
if LooseVersion(matplotlib.__version__) <= LooseVersion("2.2.0"):
legend_height = leg.get_window_extent().inverse_transformed(leg.axes.transAxes).height
ax.figure.tight_layout(rect=[0, 0, 1, 1-legend_height])
leg.axes.figure.canvas.draw()
legend_width = leg.get_window_extent().inverse_transformed(leg.axes.transAxes).width
if legend_width > 1:
x_shift = (legend_width - 1) / 2.
ax.legend(handles=handles, labels=labels, loc='lower right',
bbox_to_anchor=(1.0 + x_shift, 1.02),
fancybox=frame, shadow=frame, frameon=frame, ncol=ncol)
if LooseVersion(matplotlib.__version__) >= LooseVersion("2.2.0"):
ax.figure.tight_layout()
return leg
# Tick Format #################################################################
[docs]def set_sci_magnitude(ax, axis="both", order=0, fformat="%1.1f", offset=True, math_text=True):
""" Uses the OMMFormatter to set the scientific limits on axes.
Args:
ax: Plotting axes
axis (str): "x", "y" or "both"
order (int): Magnitude Order
fformat (str): Format to use
offset (bool): Formatter offset
math_text (bool): Whether to use mathText
"""
oomf = OOMFormatter(order=order, fformat=fformat, offset=offset, mathText=math_text)
if axis == "x" or axis == "both":
ax.xaxis.set_major_formatter(oomf)
if axis == "y" or axis == "both":
ax.yaxis.set_major_formatter(oomf)
ax.ticklabel_format(axis=axis, style="sci", scilimits=(order, order), useMathText=math_text)
return ax