Source code for pymchelper.writers.plots

import logging
import os
from enum import IntEnum

import numpy as np

logger = logging.getLogger(__name__)


[docs]class PlotAxis(IntEnum): x = 1 y = 2 z = 3
[docs]class PlotDataWriter: def __init__(self, filename, options): self.filename = filename if not self.filename.endswith(".dat"): self.filename += ".dat"
[docs] def write(self, detector): logger.info("Writing: " + self.filename) data_raw = detector.data_raw error_raw = detector.error_raw # change units for LET from MeV/cm to keV/um if necessary # a copy of data table is made here from pymchelper.shieldhit.detector.detector_type import SHDetType if detector.dettyp in (SHDetType.dlet, SHDetType.dletg, SHDetType.tlet, SHDetType.tletg): data_raw = data_raw * np.float64(0.1) # 1 MeV / cm = 0.1 keV / um if not np.all(np.isnan(error_raw)) and np.any(error_raw): error_raw = error_raw * np.float64(0.1) # 1 MeV / cm = 0.1 keV / um # special case for 0-dim data if detector.dimension == 0: # save two numbers to the file if not np.all(np.isnan(error_raw)) and np.any(error_raw): np.savetxt(self.filename, [[detector.data_raw, detector.error_raw]], fmt="%g %g", delimiter=' ') else: # save one number to the file np.savetxt(self.filename, [detector.data_raw], fmt="%g", delimiter=' ') else: # each axis may have different number of points, this is what we store here: axis_data_columns_1d = [detector.plot_axis(i).data for i in range(detector.dimension)] # now we calculate running index for each axis axis_data_columns_long = [np.meshgrid(*axis_data_columns_1d, indexing='ij')[i].ravel() for i in range(len(axis_data_columns_1d))] fmt = "%g" + " %g" * detector.dimension data_to_save = axis_data_columns_long + [data_raw] # if error information is present save it as additional column if not np.all(np.isnan(error_raw)) and np.any(error_raw): fmt += " %g" data_to_save += [error_raw] # transpose from rows to columns data_columns = np.transpose(data_to_save) # save space-delimited text file np.savetxt(self.filename, data_columns, fmt=fmt, delimiter=' ')
[docs]class GnuplotDataWriter: def __init__(self, filename, options): self.data_filename = filename self.script_filename = filename self.plot_filename = filename if not self.plot_filename.endswith(".png"): self.plot_filename += ".png" if not self.script_filename.endswith(".plot"): self.script_filename += ".plot" if not self.data_filename.endswith(".dat"): self.data_filename += ".dat" dirname = os.path.split(self.script_filename)[0] self.awk_script_filename = os.path.join(dirname, "addblanks.awk") _awk_2d_script_content = """/^[[:blank:]]*#/ {next} # ignore comments (lines starting with #) NF < 3 {next} # ignore lines which don't have at least 3 columns $2 != prev {printf \"\\n\"; prev=$2} # print blank line {print} # print the line """ _header = """set term png set output \"{plot_filename}\" set title \"{title}\" set xlabel \"{xlabel}\" set ylabel \"{ylabel}\" """ _error_plot_command = "'./{data_filename}' u 1:(max($2-$3,0.0)):($2+$3) w filledcurves " \ "fs transparent solid 0.2 lc 3 title '1-sigma confidence', " _plotting_command = { 1: """max(x,y) = (x > y) ? x : y plot {error_plot} './{data_filename}' u 1:2 w l lt 1 lw 2 lc -1 title 'mean value' """, 2: """set view map splot \"<awk -f addblanks.awk '{data_filename}'\" u 1:2:3 with pm3d """ }
[docs] def write(self, detector): # skip plotting 0-D and 3-D data if detector.dimension in {0, 3}: return # set labels plot_x_axis = detector.plot_axis(0) xlabel = ImageWriter._make_label(plot_x_axis.unit, plot_x_axis.name) if detector.dimension == 1: ylabel = ImageWriter._make_label(detector.unit, detector.name) elif detector.dimension == 2: plot_y_axis = detector.plot_axis(1) ylabel = ImageWriter._make_label(plot_y_axis.unit, plot_y_axis.name) # for 2-D plots write additional awk script to convert data # as described in gnuplot faq: http://www.gnuplot.info/faq/faq.html#x1-320003.9 with open(self.awk_script_filename, 'w') as script_file: logger.info("Writing: " + self.awk_script_filename) script_file.write(self._awk_2d_script_content) # save gnuplot script with open(self.script_filename, 'w') as script_file: logger.info("Writing: " + self.script_filename) script_file.write(self._header.format(plot_filename=self.plot_filename, xlabel=xlabel, ylabel=ylabel, title=detector.name)) plt_cmd = self._plotting_command[detector.dimension] # add error plot if error data present err_cmd = "" if np.any(detector.error): err_cmd = self._error_plot_command.format(data_filename=self.data_filename) script_file.write(plt_cmd.format(data_filename=self.data_filename, error_plot=err_cmd))
[docs]class ImageWriter: def __init__(self, filename, options): self.plot_filename = filename if not self.plot_filename.endswith(".png"): self.plot_filename += ".png" self.colormap = options.colormap self.axis_with_logscale = {PlotAxis[name] for name in options.log} default_colormap = 'gnuplot2' @staticmethod def _make_label(unit, name): return name + " " + "[" + unit + "]" def _save_2d_error_plot(self, detector, xlist, ylist, elist, x_axis_label, y_axis_label, z_axis_label): import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from matplotlib import colors # configure logscale on X and Y axis (both for positive and negative numbers) fig, ax = plt.subplots(1, 1) if PlotAxis.x in self.axis_with_logscale: plt.xscale('symlog') if PlotAxis.y in self.axis_with_logscale: plt.yscale('symlog') if PlotAxis.z in self.axis_with_logscale: norm = colors.LogNorm(vmin=elist[elist > 0].min(), vmax=elist.max()) else: norm = colors.Normalize(vmin=elist.min(), vmax=elist.max()) plt.xlabel(x_axis_label) plt.ylabel(y_axis_label) mesh = plt.pcolormesh(xlist, ylist, elist.clip(0.0), cmap=self.colormap, norm=norm) cbar = fig.colorbar(mesh) cbar.set_label(label=z_axis_label, rotation=270, verticalalignment='bottom') base_name, _ = os.path.splitext(self.plot_filename) plt.savefig(base_name + "_error.png") plt.close()
[docs] def write(self, detector): try: import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from matplotlib import colors except ImportError: logger.error("Matplotlib not installed, output won't be generated") return # skip plotting 0-D and 3-D data if detector.dimension in (0, 3): return data_raw = detector.data_raw error_raw = detector.error_raw # change units for LET from MeV/cm to keV/um if necessary # a copy of datatable is made here from pymchelper.shieldhit.detector.detector_type import SHDetType if detector.dettyp in (SHDetType.dlet, SHDetType.dletg, SHDetType.tlet, SHDetType.tletg): data_raw = data_raw * np.float64(0.1) # 1 MeV / cm = 0.1 keV / um if not np.all(np.isnan(error_raw)) and np.any(error_raw): error_raw = error_raw * np.float64(0.1) # 1 MeV / cm = 0.1 keV / um logger.info("Writing: " + self.plot_filename) plot_x_axis = detector.plot_axis(0) plt.xlabel(self._make_label(plot_x_axis.unit, plot_x_axis.name)) # configure logscale on X and Y axis (both for positive and negative numbers) if PlotAxis.x in self.axis_with_logscale: plt.xscale('symlog') if PlotAxis.y in self.axis_with_logscale: plt.yscale('symlog') # 1-D plotting if detector.dimension == 1: # add optional error area if np.any(detector.error): plt.fill_between(plot_x_axis.data, (data_raw - error_raw).clip(0.0), (data_raw + error_raw).clip(0.0, 1.05 * data_raw.max()), alpha=0.2, edgecolor='#CC4F1B', facecolor='#FF9848', antialiased=True) plt.ylabel(self._make_label(detector.unit, detector.name)) plt.plot(plot_x_axis.data, data_raw) elif detector.dimension == 2: plot_y_axis = detector.plot_axis(1) xlist, ylist = np.meshgrid(plot_x_axis.data, plot_y_axis.data) x_axis_label = self._make_label(plot_x_axis.unit, plot_x_axis.name) y_axis_label = self._make_label(plot_y_axis.unit, plot_y_axis.name) z_axis_label = self._make_label(detector.unit, detector.name) # configure logscale on Z axis if PlotAxis.z in self.axis_with_logscale: norm = colors.LogNorm(vmin=data_raw[data_raw > 0].min(), vmax=data_raw.max()) else: norm = colors.Normalize(vmin=data_raw.min(), vmax=data_raw.max()) # in case differential scorer was used there is a case when axis has to be swapped # this happens when X-constant, Y-differential, Z-scored if hasattr(detector, 'dif_axis') and detector.dif_axis == 1: x_axis_label, y_axis_label = y_axis_label, x_axis_label xlist, ylist = ylist, xlist plt.xlabel(x_axis_label) plt.ylabel(y_axis_label) shape_tuple = (plot_y_axis.n, plot_x_axis.n) zlist = data_raw.reshape(shape_tuple) plt.pcolormesh(xlist, ylist, zlist, cmap=self.colormap, norm=norm) cbar = plt.colorbar() cbar.set_label(z_axis_label, rotation=270, verticalalignment='bottom') plt.savefig(self.plot_filename) plt.close() # add 2-D error plot if error data present if detector.dimension == 2 and not np.all(np.isnan(error_raw)) and np.any(error_raw): elist = error_raw.reshape(shape_tuple) self._save_2d_error_plot(detector, xlist, ylist, elist, x_axis_label, y_axis_label, z_axis_label)