Source code for pylhc_submitter.sixdesk_tools.post_process_da

"""
Post Process DA
----------------------------------

Tools to process data after sixdb has calculated the
da. Includes functions for extracting data from database
as well as plotting of DA polar plots.
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any

import numpy as np
from generic_parser import DotDict
from matplotlib import lines as mlines
from matplotlib import pyplot as plt
from matplotlib import rcParams
from scipy.interpolate import interp1d
from tfs import TfsDataFrame, write_tfs

from pylhc_submitter.constants.autosix import (
    ALOST1,
    ALOST2,
    AMP,
    ANGLE,
    HEADER_HINT,
    HEADER_INFO,
    HEADER_NTOTAL,
    MAX,
    MEAN,
    MIN,
    SEED,
    STD,
    N,
    get_autosix_results_path,
    get_tfs_da_angle_stats_path,
    get_tfs_da_path,
    get_tfs_da_seed_stats_path,
)
from pylhc_submitter.sixdesk_tools.extract_data_from_db import extract_da_data

if TYPE_CHECKING:
    from collections.abc import Iterable
    from pathlib import Path

    import pandas as pd


LOG = logging.getLogger(__name__)

DA_COLUMNS = (ALOST1, ALOST2)
INFO = (
    "Statistics over the N={n:d} {over:s} per {per:s}. "
    "The N-Columns indicate how many non-zero DA values were used."
)
HINT = "{param:s} {val:} is the respective value calculated over all other {param:s}s."
OVER_WHICH = {SEED: "angles", ANGLE: "seeds"}

# Fixed plot styles ---
COLOR_MEAN = "red"
COLOR_SEED = "grey"
COLOR_LIM = "black"
COLOR_FILL = "blue"
ALPHA_SEED = 0.5
ALPHA_FILL = 0.2
ALPHA_FILL_STD = 0.2


[docs] def post_process_da(jobname: str, basedir: Path): """Post process the DA results into dataframes and DA plots.""" LOG.info("Post-Processing Sixdesk Results.") df_da, df_angle, df_seed = create_da_tfs(jobname, basedir) create_polar_plots(jobname, basedir, df_da, df_angle) LOG.info("Post-Processing finished.")
# Data Analysis ----------------------------------------------------------------
[docs] def create_da_tfs(jobname: str, basedir: Path) -> tuple[TfsDataFrame, TfsDataFrame, TfsDataFrame]: """Extracts data from db into dataframes, and writes and returns them. Args: jobname (str): Name of the Job basedir (Path): SixDesk Basefolder Location """ LOG.info("Gathering DA data into tfs-files.") df_da = extract_da_data(jobname, basedir) df_angle = _create_stats_df(df_da, ANGLE) df_seed = _create_stats_df(df_da, SEED, global_index=0) write_tfs(get_tfs_da_path(jobname, basedir), df_da) write_tfs(get_tfs_da_angle_stats_path(jobname, basedir), df_angle, save_index=ANGLE) write_tfs(get_tfs_da_seed_stats_path(jobname, basedir), df_seed, save_index=SEED) return df_da, df_angle, df_seed
def _create_stats_df(df: pd.DataFrame, parameter: str, global_index: Any = None) -> TfsDataFrame: """Calculates the stats over a given parameter. Note: Could be refactored to use `group_by`. Args: df (DataFrame): DataFrame containing the DA information over all seeds. parameter (str): The parameter over which we want to average, i.e. SEED or ANGLE. global_index (Any): identifier to use as a global index, i.e. the statistics over all entries are stored here. (e.g. '0' for SEEDs) """ operation_map = DotDict({MEAN: np.mean, STD: np.std, MIN: np.min, MAX: np.max}) pre_index = [] if global_index is None else [global_index] index = sorted(set(df[parameter])) n_total = sum(df[parameter] == index[0]) df_stats = TfsDataFrame( index=pre_index + index, columns=[f"{fun}{al}" for al in DA_COLUMNS for fun in list(operation_map.keys()) + [N]], ) df_stats.headers[HEADER_INFO] = INFO.format( over=OVER_WHICH[parameter], per=parameter.lower(), n=n_total ) df_stats.headers[HEADER_NTOTAL] = n_total for col_da in DA_COLUMNS: for idx in index: mask = (df[parameter] == idx) & (df[col_da] != 0) df_stats.loc[idx, f"{N}{col_da}"] = sum(mask) for name, operation in operation_map.items(): df_stats.loc[idx, f"{name}{col_da}"] = operation(df.loc[mask, col_da]) for name, operation in operation_map.get_subdict([MIN, MAX]).items(): df_stats.loc[idx, f"{name}{AMP}"] = operation(df.loc[mask, f"{name}{AMP}"]) if global_index is not None: # Note: could be done over df_stats for MEAN, MIN and MAX, but not STD mask = df[col_da] != 0 df_stats.loc[global_index, f"{N}{col_da}"] = sum(mask) # Global MEAN, MIN, MAX Dynamic Aperture for name, operation in operation_map.get_subdict([MEAN, MIN, MAX, STD]).items(): df_stats.loc[global_index, f"{name}{col_da}"] = operation(df.loc[mask, col_da]) # Global MIN, MAX Amplitudes for name, operation in operation_map.get_subdict([MIN, MAX]).items(): df_stats.loc[global_index, f"{name}{AMP}"] = operation( df.loc[mask, f"{name}{AMP}"] # min(MINA) and max(MAXA) ) df_stats.headers[HEADER_HINT] = HINT.format(param=parameter, val=global_index) return df_stats # Single Plots -----------------------------------------------------------------
[docs] def create_polar_plots(jobname: str, basedir: Path, df_da: TfsDataFrame, df_angles: TfsDataFrame): """Plotting loop over da-methods and wrapper so save plots. Args: jobname (str): Name of the Job basedir (Path): SixDesk Basefolder Location df_da (TfsDataFrame): Full DA analysis result. df_angles (TfsDataFrame): Dataframe with the statistics (min, max, mean) per angle """ LOG.info("Creating Polar Plots.") outdir_path = get_autosix_results_path(jobname, basedir) for da_col in DA_COLUMNS: fig = plot_polar(df_angles, da_col, jobname, df_da) fig.tight_layout(), fig.tight_layout() fig.savefig(outdir_path / fig.canvas.get_default_filename())
# plt.show()
[docs] def plot_polar( df_angles: TfsDataFrame, da_col: str = ALOST2, jobname: str = "", df_da: TfsDataFrame = None, **kwargs, ) -> plt.Figure: """Create Polar Plot for DA analysis data. Keyword arguments are all optional. Args: df_angles (TfsDataFrame): Dataframe with the statistics (min, max, mean) per angle da_col (str): DA-Column name from sixdesk analysis to be used , e.g. ``ALOST2``. (optional, default: ``ALOST2``) jobname (str): Name of the job. Used in window title only (optional). df_da (TfsDataFrame): Full DA analysis result. If given, plots the individual DA results per seed. (optional) Keyword Arguments: interpolated (bool): If true, uses interpolation to plot the lines curved fill (bool): If true, fills the area between min and max with light blue angle_ticks (Iterable[numeric]): Positions in degree of the angle ticks (and lines) amplitude ticks (Iterable[numeric]): Positions of the amplitude ticks. Returns: Figure of the polar plot. """ interpolated: bool = kwargs.pop("interpolated", True) fill: bool = kwargs.pop("fill", df_da is None) angle_ticks: Iterable[np.numeric] = kwargs.pop("angle_ticks", None) amplitude_ticks: Iterable[np.numeric] = kwargs.pop("amplitude_ticks", None) if "lines.marker" not in kwargs: kwargs["lines.marker"] = "None" fig, ax = plt.subplots(nrows=1, ncols=1, subplot_kw={"projection": "polar"}) fig.canvas.manager.set_window_title(f"{jobname} polar plot for {da_col}") angles = np.deg2rad(df_angles.index) da_min, da_mean, da_max, da_std = ( df_angles[f"{name}{da_col}"] for name in (MIN, MEAN, MAX, STD) ) seed_h, seed_l = _plot_seeds(ax, df_da, da_col, interpolated) if interpolated: mean_h, max_h = _plot_interpolated(ax, angles, da_min, da_mean, da_max, da_std, fill) else: mean_h, max_h = _plot_straight(ax, angles, da_min, da_mean, da_max, da_std, fill) ax.set_thetamin(0) ax.set_thetamax(90) ax.set_rlim([0, None]) ax.set_xlabel(r"DA$_{x}~[\sigma_{nominal}]$", labelpad=15) ax.set_ylabel(r"DA$_{y}~[\sigma_{nominal}]$", labelpad=20) if angle_ticks is not None: ax.set_xticks(np.deg2rad(angle_ticks)) if amplitude_ticks is not None: ax.set_yticks(amplitude_ticks) ax.tick_params(labelright=True, labelleft=True) ax.legend( loc="upper right", bbox_to_anchor=(0.9, 0.95), bbox_transform=fig.transFigure, # frameon=False, handles=seed_h + [mean_h, max_h], labels=seed_l + ["Mean DA", "Limits"], ncol=1, ) return fig
def _plot_seeds(ax, df_da: TfsDataFrame, da_col: str, interpolated: bool) -> tuple[list, list]: """Add the Seed lines to the polar plots, if df_da is given. Args: ax: Axes to plot in df_da: DataFrame with DA information da_col: Dynamic Aperture column (ALOST1 or ALOST2) interpolated: If true, the lines will be curved. Returns: Tuple of list of one line handle and a list of a single label """ if df_da is not None: seed_h = None for seed in sorted(set(df_da[SEED])): seed_mask = df_da[SEED] == seed angles = np.deg2rad(df_da.loc[seed_mask, ANGLE]) da_data = df_da.loc[seed_mask, da_col] da_data.loc[da_data == 0] = np.nan if interpolated: seed_h, _, _ = _interpolated_line( ax, angles, da_data, c=COLOR_SEED, ls="-", label=f"Seed {seed:d}", alpha=ALPHA_SEED, ) else: (seed_h,) = ax.plot( angles, da_data, c=COLOR_SEED, ls="-", label=f"Seed {seed:d}", alpha=ALPHA_SEED ) return [seed_h], ["DA per Seed"] return [], [] def _plot_interpolated(ax, angles, da_min, da_mean, da_max, da_std, fill): """Plot interpolated DA lines and areas.""" _, _, ip_min = _interpolated_line(ax, angles, da_min, c=COLOR_LIM, ls="--", label="Minimum DA") max_h, ip_x, ip_max = _interpolated_line( ax, angles, da_max, c=COLOR_LIM, ls="--", label="Maximum DA" ) if fill: ax.fill_between(ip_x, ip_min, ip_max, color=COLOR_FILL, alpha=ALPHA_FILL) _, ip_std_min = _interpolated_coords(angles, da_mean - da_std) _, ip_std_max = _interpolated_coords(angles, da_mean + da_std) ax.fill_between(ip_x, ip_std_min, ip_std_max, color=COLOR_FILL, alpha=ALPHA_FILL_STD) mean_h, _, _ = _interpolated_line(ax, angles, da_mean, c=COLOR_MEAN, ls="-", label="Mean DA") return mean_h, max_h def _plot_straight(ax, angles, da_min, da_mean, da_max, da_std, fill): """Plot straight DA lines and areas.""" (_,) = ax.plot(angles, da_min, c=COLOR_LIM, ls="--", label="Minimum DA") (max_h,) = ax.plot(angles, da_max, c=COLOR_LIM, ls="--", label="Maximum DA") if fill: ax.fill_between( angles, da_min.astype(float), da_max.astype(float), # weird conversion to obj otherwise color=COLOR_FILL, alpha=ALPHA_FILL, ) ax.fill_between( angles, (da_mean - da_std).astype(float), (da_mean + da_std).astype(float), color=COLOR_FILL, alpha=ALPHA_FILL_STD, ) (mean_h,) = ax.plot(angles, da_mean, c=COLOR_MEAN, ls="-", label="Mean DA") return mean_h, max_h def _interpolated_line(ax, x, y, npoints: int = 100, **kwargs): """Plot a line that interpolates linearly between points. Useful for polar plots with sparse points.""" ls = kwargs.pop("linestyle", kwargs.pop("ls", rcParams["lines.linestyle"])) marker = kwargs.pop("marker", rcParams["lines.marker"]) label = kwargs.pop("label") ip_x, ip_y = _interpolated_coords(x, y, npoints) (line_h,) = ax.plot(ip_x, ip_y, marker="None", ls=ls, label=f"_{label}_line", **kwargs) if marker.lower() not in ["none", ""]: ax.plot(x, y, ls="None", marker=marker, label=f"_{label}_markers", **kwargs) # fake handle for legend handle = mlines.Line2D([], [], color=line_h.get_color(), ls=ls, marker=marker, label=label) return handle, ip_x, ip_y def _interpolated_coords(x, y, npoints: int = 100): """Do linear interpolation between points.""" ip_x = np.linspace(min(x), max(x), npoints) ip_y = interp1d(x, y)(ip_x) return ip_x, ip_y