Source code for plotshop.post_processing

"""
Module plotshop.post_processing
---------------------------------

Functions for plot post-processing.
"""
import numpy as np
from matplotlib import pyplot as plt
import matplotlib as mpl


# Public Functions #############################################################


[docs]def merge_two_plots(axes, keep_style=True): """ Merges two plots into one. Args: axes (list): list of axes to merge. keep_style (bool): if ``true``, it keeps the lines styles. """ fig = plt.figure() new_ax = fig.gca() for ax in axes: # new_ax.lines += ax.lines # new_ax.containers += ax.containers lines = get_line_data(ax) _plot_collected_data(new_ax, lines, "lines", keep_style) ebars = get_errorbar_data(ax) _plot_collected_data(new_ax, ebars, "errorbar", keep_style) return fig
[docs]def transpose_legend(leg): """ Transposes the legend. Has some problems with the markers. """ nrow = leg._ncol loc = leg._get_loc() handles = np.array(leg.legendHandles) ncol = int(np.ceil(len(handles) / float(nrow))) order = np.array(range(ncol*nrow)).reshape(nrow, ncol).transpose().ravel()[0:len(handles)] handles = handles[order] new_leg = leg.axes.legend(handles, [h.get_label() for h in handles], ncol=ncol) new_leg._set_loc(loc) return new_leg
# Data Extraction ##############################################################
[docs]def get_errorbar_data(ax): """ Extract data from all errorbars in axes. Args: ax: axes handle to axes to extract data from. Returns: :List of dictionaries of extracted data with keys as follows: | label: Line label | x: x-position data of the points | y: y-position data of the points | xerr: tupel with arrays corresponding to lower and upper x-error values | yerr: tupel with arrays corresponding to lower and upper y-error values | with: Linewidth | style: Linestyle | color: Linecolor For ``"_nolegend_"`` entries an index is added to avoid collision. See Also: ``get_line_data()`` """ data = [] for idx, ebar in enumerate(ax.containers): if isinstance(ebar, mpl.container.ErrorbarContainer): line_dict = _extract_line_data(ebar[0]) line_dict.update({ "label": ebar.get_label(), "xerr": _get_ebar_err(ebar, "x"), "yerr": _get_ebar_err(ebar, "y"), }) data.append(line_dict) return data
[docs]def get_line_data(ax): """ Extract data from all lines in axes. Args: ax: axes handle to axes to extract data from. Returns: :List of dictionaries of extracted data with keys as follows: | label: Line label | x: x-position data of the points | y: y-position data of the points | with: Linewidth | style: Linestyle | color: Linecolor If a line is found to be an error-bar-line it is ignored! For ``"_nolegend_"`` entries an index is added to avoid collision. See Also: ``get_errorbar_data()`` """ ax_errorbars = [ebar[0] for ebar in ax.containers if isinstance(ebar, mpl.container.ErrorbarContainer)] data = [] for idx, line in enumerate(ax.get_lines()): if line not in ax_errorbars: data.append(_extract_line_data(line)) return data
# Private Functions ############################################################ def _get_ebar_err(ebar, plane): """ Extract the error information from the errorbar. """ data = ebar[0].get_xdata() if plane == "x" else ebar[0].get_ydata() plane_idx = 0 if plane == "x" else 1 segments = ebar[2][plane_idx].get_segments() lower = np.zeros(len(data)) upper = np.zeros(len(data)) for idx, (dat, bar) in enumerate(zip(data, segments)): lower[idx] = dat - bar[0][plane_idx] upper[idx] = bar[1][plane_idx] - dat return lower, upper def _plot_collected_data(ax, data, data_type, keep_style): plot_fun = { "errorbar": ax.errorbar, "lines": lambda **kwargs: ax.plot(kwargs.pop("x"), kwargs.pop("y"), **kwargs), # mpl !! }[data_type] for dat in data: if not keep_style: [dat.pop(key) for key in ["linewidth", "linestyle", "color", "markersize"]] plot_fun(**dat) def _extract_line_data(line): return { "label": line.get_label(), "x": line.get_xdata(), "y": line.get_ydata(), "linewidth": line.get_linewidth(), "linestyle": line.get_linestyle(), "color": line.get_color(), "marker": line.get_marker(), "markersize": line.get_markersize(), }