Coverage for pesummary/core/plots/population.py: 96.6%

29 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import numpy as np 

4from pesummary.core.plots.figure import figure 

5from pesummary.utils.utils import logger 

6 

7__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

8 

9 

10def scatter_plot( 

11 parameters, sample_dict, latex_labels, colors=None, xerr=None, yerr=None 

12): 

13 """Produce a plot which shows a population of runs over a certain parameter 

14 space. If errors are given, then plot error bars. 

15 

16 Parameters 

17 ---------- 

18 parameters: list 

19 names of the parameters that you wish to plot 

20 sample_dict: dict 

21 nested dictionary storing the median values for each parameter for each 

22 run. For example: x = {'one': {'m': 10, 'n': 20}} 

23 latex_labels: dictionary 

24 dictionary of latex labels 

25 colors: list 

26 list of colors that you wish to use to distinguish the different runs 

27 xerr: dict 

28 same structure as sample_dict, but dictionary storing error in x 

29 yerr: dict 

30 same structure as sample_dict, but dictionary storing error in y 

31 """ 

32 fig, ax = figure(gca=True) 

33 runs = list(sample_dict.keys()) 

34 

35 xx, yy, xxerr, yyerr = {}, {}, {}, {} 

36 for analysis in runs: 

37 if all(i in sample_dict[analysis].keys() for i in parameters): 

38 xx[analysis] = sample_dict[analysis][parameters[0]] 

39 yy[analysis] = sample_dict[analysis][parameters[1]] 

40 else: 

41 logger.warning( 

42 "'{}' does not include samples for '{}' and/or '{}'. This " 

43 "analysis will not be added to the plot".format( 

44 analysis, parameters[0], parameters[1] 

45 ) 

46 ) 

47 if xerr is not None and parameters[0] in xerr[analysis].keys(): 

48 xxerr[analysis] = xerr[analysis][parameters[0]] 

49 if yerr is not None and parameters[1] in yerr[analysis].keys(): 

50 yyerr[analysis] = yerr[analysis][parameters[1]] 

51 

52 keys = xx.keys() 

53 xdata = [xx[key] for key in keys] 

54 ydata = [yy[key] for key in keys] 

55 xerrdata = np.array([xxerr[key] if key in xxerr.keys() else [0, 0] for key in keys]) 

56 yerrdata = np.array([yyerr[key] if key in yyerr.keys() else [0, 0] for key in keys]) 

57 

58 if xerr is not None or yerr is not None: 

59 ax.errorbar( 

60 xdata, ydata, color=colors, xerr=xerrdata.T, yerr=yerrdata.T, linestyle=" " 

61 ) 

62 else: 

63 ax.scatter(xdata, ydata, color=colors) 

64 ax.set_xlabel(latex_labels[parameters[0]], fontsize=16) 

65 ax.set_ylabel(latex_labels[parameters[1]], fontsize=16) 

66 fig.tight_layout() 

67 return fig