Coverage for pesummary/core/plots/population.py: 17.2%
29 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import numpy as np
4from pesummary.core.plots.figure import figure
5from pesummary.utils.utils import logger
7__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
10def scatter_plot(
11 parameters, sample_dict, latex_labels, colors=None, xerr=None, yerr=None
12):
13 """Produce a plot which shows a population of runs over a certain parameter
14 space. If errors are given, then plot error bars.
16 Parameters
17 ----------
18 parameters: list
19 names of the parameters that you wish to plot
20 sample_dict: dict
21 nested dictionary storing the median values for each parameter for each
22 run. For example: x = {'one': {'m': 10, 'n': 20}}
23 latex_labels: dictionary
24 dictionary of latex labels
25 colors: list
26 list of colors that you wish to use to distinguish the different runs
27 xerr: dict
28 same structure as sample_dict, but dictionary storing error in x
29 yerr: dict
30 same structure as sample_dict, but dictionary storing error in y
31 """
32 fig, ax = figure(gca=True)
33 runs = list(sample_dict.keys())
35 xx, yy, xxerr, yyerr = {}, {}, {}, {}
36 for analysis in runs:
37 if all(i in sample_dict[analysis].keys() for i in parameters):
38 xx[analysis] = sample_dict[analysis][parameters[0]]
39 yy[analysis] = sample_dict[analysis][parameters[1]]
40 else:
41 logger.warning(
42 "'{}' does not include samples for '{}' and/or '{}'. This "
43 "analysis will not be added to the plot".format(
44 analysis, parameters[0], parameters[1]
45 )
46 )
47 if xerr is not None and parameters[0] in xerr[analysis].keys():
48 xxerr[analysis] = xerr[analysis][parameters[0]]
49 if yerr is not None and parameters[1] in yerr[analysis].keys():
50 yyerr[analysis] = yerr[analysis][parameters[1]]
52 keys = xx.keys()
53 xdata = [xx[key] for key in keys]
54 ydata = [yy[key] for key in keys]
55 xerrdata = np.array([xxerr[key] if key in xxerr.keys() else [0, 0] for key in keys])
56 yerrdata = np.array([yyerr[key] if key in yyerr.keys() else [0, 0] for key in keys])
58 if xerr is not None or yerr is not None:
59 ax.errorbar(
60 xdata, ydata, color=colors, xerr=xerrdata.T, yerr=yerrdata.T, linestyle=" "
61 )
62 else:
63 ax.scatter(xdata, ydata, color=colors)
64 ax.set_xlabel(latex_labels[parameters[0]], fontsize=16)
65 ax.set_ylabel(latex_labels[parameters[1]], fontsize=16)
66 fig.tight_layout()
67 return fig