Coverage for pesummary/core/plots/seaborn/violin.py: 57.3%

438 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3from seaborn.categorical import _ViolinPlotter 

4import matplotlib as mpl 

5import colorsys 

6import numpy as np 

7import math 

8import pandas as pd 

9import matplotlib.pyplot as plt 

10 

11from seaborn import utils 

12from seaborn.utils import remove_na 

13from seaborn.palettes import color_palette, husl_palette, light_palette, dark_palette 

14from scipy.stats import gaussian_kde 

15 

16__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>", "Seaborn authors"] 

17 

18 

19class ViolinPlotter(_ViolinPlotter): 

20 """A class to extend the _ViolinPlotter class provided by Seaborn 

21 """ 

22 def __init__(self, x=None, y=None, hue=None, data=None, order=None, hue_order=None, 

23 bw="scott", cut=2, scale="area", scale_hue=True, gridsize=100, 

24 width=.8, inner="box", split=False, dodge=True, orient=None, 

25 linewidth=None, color=None, palette=None, saturation=.75, 

26 ax=None, outer=None, inj=None, kde=gaussian_kde, kde_kwargs={}, 

27 weights=None, **kwargs): 

28 self.multi_color = False 

29 self.kde = kde 

30 self.kde_kwargs = kde_kwargs 

31 self.establish_variables( 

32 x, y, hue, data, orient, order, hue_order, weights=weights 

33 ) 

34 self.establish_colors(color, palette, saturation) 

35 self.estimate_densities(bw, cut, scale, scale_hue, gridsize) 

36 

37 self.gridsize = gridsize 

38 self.width = width 

39 self.dodge = dodge 

40 self.inj = inj 

41 

42 if inner is not None: 

43 if not any([inner.startswith("quart"), 

44 inner.startswith("box"), 

45 inner.startswith("stick"), 

46 inner.startswith("point"), 

47 inner.startswith("line")]): 

48 err = "Inner style '{}' not recognized".format(inner) 

49 raise ValueError(err) 

50 self.inner = inner 

51 

52 if outer is not None: 

53 if isinstance(outer, dict): 

54 for i in outer.keys(): 

55 if not any([i.startswith("percent"), 

56 i.startswith("inject")]): 

57 err = "Outer style '{}' not recognized".format(outer) 

58 raise ValueError(err) 

59 else: 

60 if not any([outer.startswith("percent"), 

61 outer.startswith("injection")]): 

62 err = "Outer style '{}' not recognized".format(outer) 

63 raise ValueError(err) 

64 self.outer = outer 

65 

66 if split and self.hue_names is not None and len(self.hue_names) != 2: 

67 msg = "There must be exactly two hue levels to use `split`.'" 

68 raise ValueError(msg) 

69 self.split = split 

70 

71 if linewidth is None: 

72 linewidth = mpl.rcParams["lines.linewidth"] 

73 self.linewidth = linewidth 

74 

75 def establish_variables(self, x, y, hue, data, orient, order, hue_order, 

76 weights=None, **kwargs): 

77 """Convert input specification into a common representation.""" 

78 super(ViolinPlotter, self).establish_variables( 

79 x, y, hue, data, orient, order, hue_order, **kwargs 

80 ) 

81 if weights is None: 

82 weights_data = [] 

83 if isinstance(data, pd.DataFrame): 

84 colname = None 

85 if "weights" in data.columns: 

86 colname = "weights" 

87 elif "weight" in data.columns: 

88 colname = "weight" 

89 if colname is None: 

90 colname = "weights" 

91 data[colname] = np.ones(len(data)) 

92 for _data in self.plot_data: 

93 weights_data.append(data[colname][_data.index]) 

94 else: 

95 for _data in self.plot_data: 

96 weights_data.append(np.ones_like(_data)) 

97 else: 

98 if hasattr(weights, "shape"): 

99 if len(data.shape) != len(weights.shape): 

100 raise ValueError("weights shape must equal data shape") 

101 if len(weights.shape) == 1: 

102 if np.isscalar(weights[0]): 

103 weights_data = [weights] 

104 else: 

105 weights_data = list(weights) 

106 elif len(weights.shape) == 2: 

107 nr, nc = weights.shape 

108 if nr == 1 or nc == 1: 

109 weights_data = [weights.ravel()] 

110 else: 

111 weights_data = [weights[:, i] for i in range(nc)] 

112 else: 

113 error = "weights can have no more than 2 dimensions" 

114 raise ValueError(error) 

115 elif np.isscalar(weights[0]): 

116 weights_data = [weights] 

117 else: 

118 weights_data = weights 

119 weights_data = [np.asarray(d, float) for d in weights_data] 

120 self.weights_data = weights_data 

121 

122 def establish_colors(self, color, palette, saturation): 

123 """Get a list of colors for the main component of the plots.""" 

124 if self.hue_names is None: 

125 n_colors = len(self.plot_data) 

126 

127 else: 

128 n_colors = len(self.hue_names) 

129 if color is None and palette is None: 

130 # Determine whether the current palette will have enough values 

131 # If not, we'll default to the husl palette so each is distinct 

132 current_palette = utils.get_color_cycle() 

133 if n_colors <= len(current_palette): 

134 colors = color_palette(n_colors=n_colors) 

135 else: 

136 colors = husl_palette(n_colors, l=.7) 

137 elif palette is None: 

138 if self.hue_names: 

139 if self.default_palette == "light": 

140 colors = light_palette(color, n_colors) 

141 elif self.default_palette == "dark": 

142 colors = dark_palette(color, n_colors) 

143 else: 

144 raise RuntimeError("No default palette specified") 

145 else: 

146 colors = [color] * n_colors 

147 else: 

148 colors = self.colors_from_palette(palette) 

149 rgb_colors = color_palette(colors) 

150 

151 light_vals = [colorsys.rgb_to_hls(*c)[1] for c in rgb_colors] 

152 lum = min(light_vals) * .6 

153 gray = mpl.colors.rgb2hex((lum, lum, lum)) 

154 

155 # Assign object attributes 

156 self.colors = rgb_colors 

157 self.gray = gray 

158 

159 def colors_from_palette(self, palette): 

160 """grab the colors from the chosen palette""" 

161 if self.hue_names is None: 

162 n_colors = len(self.plot_data) 

163 else: 

164 n_colors = len(self.hue_names) 

165 

166 if isinstance(palette, dict): 

167 keys = list(palette.keys()) 

168 n_colors = len(self.plot_data) 

169 

170 if "left" in keys and "right" in keys or all( 

171 j in keys for j in self.hue_names): 

172 self.multi_color = True 

173 colors = [self._palette_or_color(palette[i], n_colors) for i in 

174 keys] 

175 colors = [[colors[0][i], colors[1][i]] for i in range(n_colors)] 

176 colors = [y for x in colors for y in x] 

177 

178 return colors 

179 else: 

180 colors = self._palette_or_color(palette, n_colors) 

181 return colors 

182 

183 def _palette_or_color(self, palette_entry, n_colors): 

184 """Determine if the palette is a block color or a palette 

185 """ 

186 if isinstance(palette_entry, list): 

187 while len(palette_entry) < n_colors: 

188 palette_entry += palette_entry 

189 

190 return palette_entry 

191 

192 elif "color:" in palette_entry: 

193 color = palette_entry.split("color:")[1] 

194 color = self._flatten_string(color) 

195 

196 return [color] * n_colors 

197 

198 else: 

199 return color_palette(palette_entry, n_colors) 

200 

201 @staticmethod 

202 def _flatten_string(string): 

203 """Remove the trailing white space from a string""" 

204 return string.lstrip(" ") 

205 

206 def estimate_densities(self, bw, cut, scale, scale_hue, gridsize): 

207 """Find the support and density for all of the data.""" 

208 # Initialize data structures to keep track of plotting data 

209 if self.hue_names is None: 

210 support = [] 

211 density = [] 

212 counts = np.zeros(len(self.plot_data)) 

213 max_density = np.zeros(len(self.plot_data)) 

214 else: 

215 support = [[] for _ in self.plot_data] 

216 density = [[] for _ in self.plot_data] 

217 size = len(self.group_names), len(self.hue_names) 

218 counts = np.zeros(size) 

219 max_density = np.zeros(size) 

220 

221 for i, group_data in enumerate(self.plot_data): 

222 

223 # Option 1: we have a single level of grouping 

224 # -------------------------------------------- 

225 

226 if self.plot_hues is None: 

227 

228 # Strip missing datapoints 

229 kde_data = remove_na(group_data) 

230 

231 # Handle special case of no data at this level 

232 if kde_data.size == 0: 

233 support.append(np.array([])) 

234 density.append(np.array([1.])) 

235 counts[i] = 0 

236 max_density[i] = 0 

237 continue 

238 

239 # Handle special case of a single unique datapoint 

240 elif np.unique(kde_data).size == 1: 

241 support.append(np.unique(kde_data)) 

242 density.append(np.array([1.])) 

243 counts[i] = 1 

244 max_density[i] = 0 

245 continue 

246 

247 # Fit the KDE and get the used bandwidth size 

248 kde, bw_used = self.fit_kde( 

249 kde_data, bw, weights=self.weights_data[i] 

250 ) 

251 

252 # Determine the support grid and get the density over it 

253 support_i = self.kde_support(kde_data, bw_used, cut, gridsize) 

254 density_i = kde(support_i) 

255 if np.array(density_i).ndim == 2: 

256 support_i, density_i = density_i 

257 

258 # Update the data structures with these results 

259 support.append(support_i) 

260 density.append(density_i) 

261 counts[i] = kde_data.size 

262 max_density[i] = density_i.max() 

263 

264 # Option 2: we have nested grouping by a hue variable 

265 # --------------------------------------------------- 

266 

267 else: 

268 for j, hue_level in enumerate(self.hue_names): 

269 

270 # Handle special case of no data at this category level 

271 if not group_data.size: 

272 support[i].append(np.array([])) 

273 density[i].append(np.array([1.])) 

274 counts[i, j] = 0 

275 max_density[i, j] = 0 

276 continue 

277 

278 # Select out the observations for this hue level 

279 hue_mask = self.plot_hues[i] == hue_level 

280 

281 # Strip missing datapoints 

282 kde_data = remove_na(group_data[hue_mask]) 

283 

284 # Handle special case of no data at this level 

285 if kde_data.size == 0: 

286 support[i].append(np.array([])) 

287 density[i].append(np.array([1.])) 

288 counts[i, j] = 0 

289 max_density[i, j] = 0 

290 continue 

291 

292 # Handle special case of a single unique datapoint 

293 elif np.unique(kde_data).size == 1: 

294 support[i].append(np.unique(kde_data)) 

295 density[i].append(np.array([1.])) 

296 counts[i, j] = 1 

297 max_density[i, j] = 0 

298 continue 

299 

300 # Fit the KDE and get the used bandwidth size 

301 kde, bw_used = self.fit_kde( 

302 kde_data, bw, weights=self.weights_data[i][hue_mask] 

303 ) 

304 # Determine the support grid and get the density over it 

305 support_ij = self.kde_support(kde_data, bw_used, 

306 cut, gridsize) 

307 density_ij = kde(support_ij) 

308 if np.array(density_ij).ndim == 2: 

309 support_ij, density_ij = density_ij 

310 

311 # Update the data structures with these results 

312 support[i].append(support_ij) 

313 density[i].append(density_ij) 

314 counts[i, j] = kde_data.size 

315 max_density[i, j] = density_ij.max() 

316 

317 # Scale the height of the density curve. 

318 # For a violinplot the density is non-quantitative. 

319 # The objective here is to scale the curves relative to 1 so that 

320 # they can be multiplied by the width parameter during plotting. 

321 

322 if scale == "area": 

323 self.scale_area(density, max_density, scale_hue) 

324 

325 elif scale == "width": 

326 self.scale_width(density) 

327 

328 elif scale == "count": 

329 self.scale_count(density, counts, scale_hue) 

330 

331 else: 

332 raise ValueError("scale method '{}' not recognized".format(scale)) 

333 

334 # Set object attributes that will be used while plotting 

335 self.support = support 

336 self.density = density 

337 

338 def draw_violins(self, ax): 

339 """Draw the violins onto `ax`.""" 

340 fill_func = ax.fill_betweenx if self.orient == "v" else ax.fill_between 

341 checkpoint = 0 

342 for i, group_data in enumerate(self.plot_data): 

343 

344 kws = dict(edgecolor=self.gray, linewidth=self.linewidth) 

345 

346 # Option 1: we have a single level of grouping 

347 # -------------------------------------------- 

348 if self.plot_hues is None: 

349 

350 support, density = self.support[i], self.density[i] 

351 

352 # Handle special case of no observations in this bin 

353 if support.size == 0: 

354 continue 

355 

356 # Handle special case of a single observation 

357 elif support.size == 1: 

358 val = np.asscalar(support) 

359 d = np.asscalar(density) 

360 self.draw_single_observation(ax, i, val, d) 

361 continue 

362 

363 # Draw the violin for this group 

364 grid = np.ones(self.gridsize) * i 

365 fill_func(support, 

366 grid - density * self.dwidth, 

367 grid + density * self.dwidth, 

368 facecolor=self.colors[i], 

369 **kws) 

370 

371 # Draw the interior representation of the data 

372 if self.inner is None: 

373 continue 

374 

375 # Get a nan-free vector of datapoints 

376 violin_data = remove_na(group_data) 

377 

378 # Draw box and whisker information 

379 if self.inner.startswith("box"): 

380 self.draw_box_lines(ax, violin_data, support, density, i) 

381 

382 # Draw quartile lines 

383 elif self.inner.startswith("quart"): 

384 self.draw_quartiles(ax, violin_data, support, density, i) 

385 

386 # Draw stick observations 

387 elif self.inner.startswith("stick"): 

388 self.draw_stick_lines(ax, violin_data, support, density, i) 

389 

390 # Draw point observations 

391 elif self.inner.startswith("point"): 

392 self.draw_points(ax, violin_data, i) 

393 

394 # Draw single line 

395 elif self.inner.startswith("line"): 

396 self.draw_single_line(ax, violin_data, i) 

397 

398 if self.outer is None: 

399 continue 

400 

401 else: 

402 self.draw_external_range(ax, violin_data, support, density, i) 

403 

404 if self.inj is None: 

405 continue 

406 

407 else: 

408 self.draw_injected_line( 

409 ax, self.inj[i], violin_data, support, density, i 

410 ) 

411 

412 # Option 2: we have nested grouping by a hue variable 

413 # --------------------------------------------------- 

414 

415 else: 

416 offsets = self.hue_offsets 

417 for j, hue_level in enumerate(self.hue_names): 

418 support, density = self.support[i][j], self.density[i][j] 

419 kws["facecolor"] = self.colors[j] 

420 if self.multi_color: 

421 kws["facecolor"] = self.colors[checkpoint] 

422 checkpoint += 1 

423 

424 # Add legend data, but just for one set of violins 

425 if not i and not self.multi_color: 

426 self.add_legend_data(ax, self.colors[j], hue_level) 

427 

428 # Handle the special case where we have no observations 

429 if support.size == 0: 

430 continue 

431 

432 # Handle the special case where we have one observation 

433 elif support.size == 1: 

434 val = np.asscalar(support) 

435 d = np.asscalar(density) 

436 if self.split: 

437 d = d / 2 

438 at_group = i + offsets[j] 

439 self.draw_single_observation(ax, at_group, val, d) 

440 continue 

441 

442 # Option 2a: we are drawing a single split violin 

443 # ----------------------------------------------- 

444 

445 if self.split: 

446 

447 grid = np.ones(self.gridsize) * i 

448 if j: 

449 fill_func(support, 

450 grid, 

451 grid + density * self.dwidth, 

452 **kws) 

453 else: 

454 fill_func(support, 

455 grid - density * self.dwidth, 

456 grid, 

457 **kws) 

458 

459 # Draw the interior representation of the data 

460 if self.inner is None: 

461 continue 

462 

463 # Get a nan-free vector of datapoints 

464 hue_mask = self.plot_hues[i] == hue_level 

465 violin_data = remove_na(group_data[hue_mask]) 

466 

467 # Draw quartile lines 

468 if self.inner.startswith("quart"): 

469 self.draw_quartiles(ax, violin_data, 

470 support, density, i, 

471 ["left", "right"][j]) 

472 

473 # Draw stick observations 

474 elif self.inner.startswith("stick"): 

475 self.draw_stick_lines(ax, violin_data, 

476 support, density, i, 

477 ["left", "right"][j]) 

478 

479 if self.outer is None: 

480 continue 

481 

482 else: 

483 self.draw_external_range(ax, violin_data, 

484 support, density, i, 

485 ["left", "right"][j], 

486 weights=self.weights_data[i][hue_mask]) 

487 

488 if self.inj is None: 

489 continue 

490 

491 else: 

492 self.draw_injected_line( 

493 ax, self.inj[i], violin_data, support, density, i, 

494 ["left", "right"][j] 

495 ) 

496 

497 # The box and point interior plots are drawn for 

498 # all data at the group level, so we just do that once 

499 if not j: 

500 continue 

501 

502 # Get the whole vector for this group level 

503 violin_data = remove_na(group_data) 

504 

505 # Draw box and whisker information 

506 if self.inner.startswith("box"): 

507 self.draw_box_lines(ax, violin_data, 

508 support, density, i) 

509 

510 # Draw point observations 

511 elif self.inner.startswith("point"): 

512 self.draw_points(ax, violin_data, i) 

513 

514 elif self.inner.startswith("line"): 

515 self.draw_single_line(ax, violin_data, i) 

516 

517 # Option 2b: we are drawing full nested violins 

518 # ----------------------------------------------- 

519 

520 else: 

521 grid = np.ones(self.gridsize) * (i + offsets[j]) 

522 fill_func(support, 

523 grid - density * self.dwidth, 

524 grid + density * self.dwidth, 

525 **kws) 

526 

527 # Draw the interior representation 

528 if self.inner is None: 

529 continue 

530 

531 # Get a nan-free vector of datapoints 

532 hue_mask = self.plot_hues[i] == hue_level 

533 violin_data = remove_na(group_data[hue_mask]) 

534 

535 # Draw box and whisker information 

536 if self.inner.startswith("box"): 

537 self.draw_box_lines(ax, violin_data, 

538 support, density, 

539 i + offsets[j]) 

540 

541 # Draw quartile lines 

542 elif self.inner.startswith("quart"): 

543 self.draw_quartiles(ax, violin_data, 

544 support, density, 

545 i + offsets[j]) 

546 

547 # Draw stick observations 

548 elif self.inner.startswith("stick"): 

549 self.draw_stick_lines(ax, violin_data, 

550 support, density, 

551 i + offsets[j]) 

552 

553 # Draw point observations 

554 elif self.inner.startswith("point"): 

555 self.draw_points(ax, violin_data, i + offsets[j]) 

556 

557 def fit_kde(self, x, bw, weights=None): 

558 """Estimate a KDE for a vector of data with flexible bandwidth.""" 

559 kde = self.kde(x, bw_method=bw, weights=weights, **self.kde_kwargs) 

560 # Extract the numeric bandwidth from the KDE object 

561 bw_used = kde.factor 

562 

563 # At this point, bw will be a numeric scale factor. 

564 # To get the actual bandwidth of the kernel, we multiple by the 

565 # unbiased standard deviation of the data, which we will use 

566 # elsewhere to compute the range of the support. 

567 bw_used = bw_used * x.std(ddof=1) 

568 

569 return kde, bw_used 

570 

571 def annotate_axes(self, ax): 

572 """Add descriptive labels to an Axes object.""" 

573 if self.orient == "v": 

574 xlabel, ylabel = self.group_label, self.value_label 

575 else: 

576 xlabel, ylabel = self.value_label, self.group_label 

577 

578 if xlabel is not None: 

579 ax.set_xlabel(xlabel) 

580 if ylabel is not None: 

581 ax.set_ylabel(ylabel) 

582 

583 if self.orient == "v": 

584 ax.set_xticks(np.arange(len(self.plot_data))) 

585 ax.set_xticklabels(self.group_names) 

586 else: 

587 ax.set_yticks(np.arange(len(self.plot_data))) 

588 ax.set_yticklabels(self.group_names) 

589 

590 if self.orient == "v": 

591 ax.xaxis.grid(False) 

592 ax.set_xlim(-.5, len(self.plot_data) - .5, auto=None) 

593 else: 

594 ax.yaxis.grid(False) 

595 ax.set_ylim(-.5, len(self.plot_data) - .5, auto=None) 

596 

597 if self.hue_names is not None: 

598 if not self.multi_color: 

599 leg = ax.legend(loc="best") 

600 if self.hue_title is not None: 

601 leg.set_title(self.hue_title) 

602 

603 # Set the title size a roundabout way to maintain 

604 # compatibility with matplotlib 1.1 

605 # TODO no longer needed 

606 try: 

607 title_size = mpl.rcParams["axes.labelsize"] * .85 

608 except TypeError: # labelsize is something like "large" 

609 title_size = mpl.rcParams["axes.labelsize"] 

610 prop = mpl.font_manager.FontProperties(size=title_size) 

611 leg._legend_title_box._text.set_font_properties(prop) 

612 

613 def draw_single_line(self, ax, data, center): 

614 """Draw a single line through the middle of the violin""" 

615 kws = dict(color=self.gray, edgecolor=self.gray) 

616 upper = np.max(data) 

617 lower = np.min(data) 

618 

619 ax.plot([center, center], [lower, upper], 

620 linewidth=self.linewidth, 

621 color=self.gray) 

622 

623 def _plot_single_line(self, ax, center, y, density, split=None, color=None): 

624 """Plot a single line on a violin plot""" 

625 width = self.dwidth * np.max(density) * 1.1 

626 color = self.gray if color is None else color 

627 

628 if split == "left": 

629 ax.plot([center - width, center], [y, y], 

630 linewidth=self.linewidth, 

631 color=color) 

632 elif split == "right": 

633 ax.plot([center, center + width], [y, y], 

634 linewidth=self.linewidth, 

635 color=color) 

636 else: 

637 ax.plot([center - width, center + width], [y, y], 

638 linewidth=self.linewidth, 

639 color=color) 

640 

641 def draw_external_range(self, ax, data, support, density, 

642 center, split=None, weights=None): 

643 """Draw lines extending outside of the violin showing given range""" 

644 width = self.dwidth * np.max(density) * 1.1 

645 

646 if isinstance(self.outer, dict): 

647 if "percentage" in list(self.outer.keys()): 

648 percent = float(self.outer["percentage"]) 

649 if weights is None: 

650 lower, upper = np.percentile(data, [100 - percent, percent]) 

651 else: 

652 from pesummary.utils.array import Array 

653 

654 _data = Array(data, weights=weights) 

655 lower, upper = _data.credible_interval( 

656 [100 - percent, percent] 

657 ) 

658 h1 = np.min(data[data >= (upper)]) 

659 h2 = np.max(data[data <= (lower)]) 

660 

661 self._plot_single_line(ax, center, h1, density, split=split) 

662 self._plot_single_line(ax, center, h2, density, split=split) 

663 if any("inject" in i for i in list(self.outer.keys())): 

664 key = [i for i in list(self.outer.keys()) if "inject" in i] 

665 if any("injection:" in i for i in list(self.outer.keys())): 

666 split = key[0].split("injection:")[1] 

667 split = self._flatten_string(split) 

668 

669 injection = self.outer[key[0]] 

670 if isinstance(injection, list): 

671 self._plot_single_line( 

672 ax, center, injection[center], density, split=split, 

673 color="r" 

674 ) 

675 else: 

676 self._plot_single_line( 

677 ax, center, injection, density, split=split, color="r", 

678 ) 

679 elif isinstance(self.outer, str): 

680 if "percent" in self.outer: 

681 percent = self.outer.split("percent:")[1] 

682 percent = float(self._flatten_string(percent)) 

683 percent += (100 - percent) / 2. 

684 

685 if weights is None: 

686 lower, upper = np.percentile(data, [100 - percent, percent]) 

687 else: 

688 from pesummary.utils.array import Array 

689 

690 _data = Array(data, weights=weights) 

691 lower, upper = _data.credible_interval( 

692 [100 - percent, percent] 

693 ) 

694 h1 = np.min(data[data >= (upper)]) 

695 h2 = np.max(data[data <= (lower)]) 

696 

697 self._plot_single_line(ax, center, h1, density, split=split) 

698 self._plot_single_line(ax, center, h2, density, split=split) 

699 if "inject" in self.outer: 

700 if "injection:" in self.outer: 

701 split = self.outer.split("injection:")[1] 

702 split = self._flatten_string(split) 

703 

704 injection = self.outer.split("injection:")[1] 

705 

706 self._plot_single_line( 

707 ax, center, injection, density, split=split, color="r" 

708 ) 

709 

710 def draw_injected_line(self, ax, inj, data, support, density, 

711 center, split=None): 

712 """Mark the injected value on the violin""" 

713 width = self.dwidth * np.max(density) * 1.1 

714 if math.isnan(inj): 

715 return 

716 self._plot_single_line(ax, center, inj, density, split=split, color='r') 

717 

718 def plot(self, ax): 

719 """Make the violin plot.""" 

720 self.draw_violins(ax) 

721 self.annotate_axes(ax) 

722 if self.orient == "h": 

723 ax.invert_yaxis() 

724 

725 

726def violinplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, 

727 bw="scott", cut=2, scale="area", scale_hue=True, gridsize=100, 

728 width=.8, inner="box", split=False, dodge=True, orient=None, 

729 linewidth=None, color=None, palette=None, saturation=.75, 

730 ax=None, outer=None, inj=None, kde=gaussian_kde, kde_kwargs={}, 

731 weights=None, **kwargs): 

732 

733 plotter = ViolinPlotter(x, y, hue, data, order, hue_order, 

734 bw, cut, scale, scale_hue, gridsize, 

735 width, inner, split, dodge, orient, linewidth, 

736 color, palette, saturation, outer=outer, 

737 inj=inj, kde=kde, kde_kwargs=kde_kwargs, weights=weights) 

738 

739 if ax is None: 

740 ax = plt.gca() 

741 

742 plotter.plot(ax) 

743 return ax 

744 

745 

746def split_dataframe( 

747 left, right, labels, left_label="left", right_label="right", 

748 weights_left=None, weights_right=None 

749): 

750 """Generate a pandas DataFrame containing two sets of distributions -- one 

751 set for the left hand side of the violins, and one set for the right hand 

752 side of the violins 

753 

754 Parameters 

755 ---------- 

756 left: np.ndarray 

757 array of samples representing the left hand side of the violins 

758 right: np.ndarray 

759 array of samples representing the right hand side of the violins 

760 labels: np.array 

761 array containing the label associated with each violin 

762 """ 

763 import pandas 

764 

765 nviolin = len(left) 

766 if len(left) != len(right) != len(labels): 

767 raise ValueError("Please ensure that 'left' == 'right' == 'labels'") 

768 _left_label = np.array([[left_label] * len(sample) for sample in left]) 

769 _right_label = np.array([[right_label] * len(sample) for sample in right]) 

770 _labels = [ 

771 [label] * (len(left[num]) + len(right[num])) for num, label in 

772 enumerate(labels) 

773 ] 

774 labels = [x for y in _labels for x in y] 

775 dataframe = [ 

776 x for y in [[i, j] for i, j in zip(left, right)] for x in y 

777 ] 

778 dataframe = [x for y in dataframe for x in y] 

779 sides = [ 

780 x for y in [[i, j] for i, j in zip(_left_label, _right_label)] for x in 

781 y 

782 ] 

783 sides = [x for y in sides for x in y] 

784 df = pandas.DataFrame( 

785 data={"data": dataframe, "side": sides, "label": labels} 

786 ) 

787 if all(kwarg is None for kwarg in [weights_left, weights_right]): 

788 return df 

789 

790 left_inds = df["side"][df["side"] == left_label].index 

791 right_inds = df["side"][df["side"] == right_label].index 

792 if weights_left is not None and weights_right is None: 

793 weights_right = [np.ones(len(right[num])) for num in range(nviolin)] 

794 elif weights_left is None and weights_right is not None: 

795 weights_left = [np.ones(len(left[num])) for num in range(nviolin)] 

796 if any(len(kwarg) != nviolin for kwarg in [weights_left, weights_right]): 

797 raise ValueError("help") 

798 

799 weights = [ 

800 x for y in [[i, j] for i, j in zip(weights_left, weights_right)] for x in y 

801 ] 

802 weights = [x for y in weights for x in y] 

803 df["weights"] = weights 

804 return df