"""
Module plotshop.post_processing
---------------------------------
Functions for plot post-processing.
"""
import numpy as np
from matplotlib import pyplot as plt
import matplotlib as mpl
# Public Functions #############################################################
[docs]def merge_two_plots(axes, keep_style=True):
    """ Merges two plots into one.
    Args:
        axes (list): list of axes to merge.
        keep_style (bool): if ``true``, it keeps the lines styles.
    """
    fig = plt.figure()
    new_ax = fig.gca()
    for ax in axes:
        # new_ax.lines += ax.lines
        # new_ax.containers += ax.containers
        lines = get_line_data(ax)
        _plot_collected_data(new_ax, lines, "lines", keep_style)
        ebars = get_errorbar_data(ax)
        _plot_collected_data(new_ax, ebars, "errorbar", keep_style)
    return fig 
[docs]def transpose_legend(leg):
    """ Transposes the legend. Has some problems with the markers. """
    nrow = leg._ncol
    loc = leg._get_loc()
    handles = np.array(leg.legendHandles)
    ncol = int(np.ceil(len(handles) / float(nrow)))
    order = np.array(range(ncol*nrow)).reshape(nrow, ncol).transpose().ravel()[0:len(handles)]
    handles = handles[order]
    new_leg = leg.axes.legend(handles, [h.get_label() for h in handles], ncol=ncol)
    new_leg._set_loc(loc)
    return new_leg 
# Data Extraction ##############################################################
[docs]def get_errorbar_data(ax):
    """ Extract data from all errorbars in axes.
    Args:
        ax: axes handle to axes to extract data from.
    Returns:
        :List of dictionaries of extracted data with keys as follows:
        |    label: Line label
        |    x: x-position data of the points
        |    y: y-position data of the points
        |    xerr: tupel with arrays corresponding to lower and upper x-error values
        |    yerr: tupel with arrays corresponding to lower and upper y-error values
        |    with: Linewidth
        |    style: Linestyle
        |    color: Linecolor
        For ``"_nolegend_"`` entries an index is added to avoid collision.
    See Also:
        ``get_line_data()``
    """
    data = []
    for idx, ebar in enumerate(ax.containers):
        if isinstance(ebar, mpl.container.ErrorbarContainer):
            line_dict = _extract_line_data(ebar[0])
            line_dict.update({
                "label": ebar.get_label(),
                "xerr": _get_ebar_err(ebar, "x"),
                "yerr": _get_ebar_err(ebar, "y"),
            })
            data.append(line_dict)
    return data 
[docs]def get_line_data(ax):
    """ Extract data from all lines in axes.
    Args:
        ax: axes handle to axes to extract data from.
    Returns:
        :List of dictionaries of extracted data with keys as follows:
        |    label: Line label
        |    x: x-position data of the points
        |    y: y-position data of the points
        |    with: Linewidth
        |    style: Linestyle
        |    color: Linecolor
        If a line is found to be an error-bar-line it is ignored!
        For ``"_nolegend_"`` entries an index is added to avoid collision.
    See Also:
        ``get_errorbar_data()``
    """
    ax_errorbars = [ebar[0] for ebar in ax.containers
                    if isinstance(ebar, mpl.container.ErrorbarContainer)]
    data = []
    for idx, line in enumerate(ax.get_lines()):
        if line not in ax_errorbars:
            data.append(_extract_line_data(line))
    return data 
# Private Functions ############################################################
def _get_ebar_err(ebar, plane):
    """ Extract the error information from the errorbar. """
    data = ebar[0].get_xdata() if plane == "x" else ebar[0].get_ydata()
    plane_idx = 0 if plane == "x" else 1
    segments = ebar[2][plane_idx].get_segments()
    lower = np.zeros(len(data))
    upper = np.zeros(len(data))
    for idx, (dat, bar) in enumerate(zip(data, segments)):
        lower[idx] = dat - bar[0][plane_idx]
        upper[idx] = bar[1][plane_idx] - dat
    return lower, upper
def _plot_collected_data(ax, data, data_type, keep_style):
    plot_fun = {
        "errorbar": ax.errorbar,
        "lines": lambda **kwargs: ax.plot(kwargs.pop("x"), kwargs.pop("y"), **kwargs),  # mpl !!
    }[data_type]
    for dat in data:
        if not keep_style:
            [dat.pop(key) for key in ["linewidth", "linestyle", "color", "markersize"]]
        plot_fun(**dat)
def _extract_line_data(line):
        return {
            "label": line.get_label(),
            "x": line.get_xdata(),
            "y": line.get_ydata(),
            "linewidth": line.get_linewidth(),
            "linestyle": line.get_linestyle(),
            "color": line.get_color(),
            "marker": line.get_marker(),
            "markersize": line.get_markersize(),
        }