diff --git a/ctaplot/ana/ana.py b/ctaplot/ana/ana.py index 5b12e3f..92eba78 100644 --- a/ctaplot/ana/ana.py +++ b/ctaplot/ana/ana.py @@ -374,7 +374,7 @@ def bias(true, reco): ------- float """ - if not len(true) == len(reco): + if len(true) != len(reco): raise ValueError("both arrays should have the same size") if len(true) == 0: return 0 @@ -517,11 +517,7 @@ def resolution_per_bin(x, y_true, y_reco, ) ) - if isinstance(res[0], u.Quantity): - res = u.Quantity(res) - else: - res = np.array(res) - + res = u.Quantity(res) if isinstance(res[0], u.Quantity) else np.array(res) return x_bins, res @@ -987,9 +983,8 @@ def logbin_mean(x_bin): """ if not isinstance(x_bin, u.Quantity): return 10 ** ((np.log10(x_bin[:-1]) + np.log10(x_bin[1:])) / 2.) - else: - unit = x_bin.unit - return (10 ** ((np.log10(x_bin[:-1].to_value(unit)) + np.log10(x_bin[1:].to_value(unit))) / 2.)) * unit + unit = x_bin.unit + return (10 ** ((np.log10(x_bin[:-1].to_value(unit)) + np.log10(x_bin[1:].to_value(unit))) / 2.)) * unit @u.quantity_input(true_x=u.m, reco_x=u.m, true_y=u.m, reco_y=u.m) @@ -1288,11 +1283,7 @@ def bias_per_bin(true, reco, x, relative_scaling_method=None, bins=10): mask = bin_index == ii b.append(relative_bias(true[mask], reco[mask], relative_scaling_method=relative_scaling_method)) - if isinstance(b[0], u.Quantity): - b = u.Quantity(b) - else: - b = np.array(b) - + b = u.Quantity(b) if isinstance(b[0], u.Quantity) else np.array(b) return x_bins, b diff --git a/ctaplot/plots/plots.py b/ctaplot/plots/plots.py index 48214c7..72fb6ad 100644 --- a/ctaplot/plots/plots.py +++ b/ctaplot/plots/plots.py @@ -315,8 +315,7 @@ def plot_multiplicity_hist(multiplicity, ax=None, outfile=None, quartils=False, xmin = multiplicity.min() xmax = multiplicity.max() - if 'label' not in kwargs: - kwargs['label'] = 'Telescope multiplicity' + kwargs.setdefault('label', 'Telescope multiplicity') n, bins, patches = ax.hist(multiplicity, bins=(xmax - xmin), range=(xmin, xmax), rwidth=0.7, align='left', **kwargs) ax.xaxis.set_major_locator(MaxNLocator(integer=True)) @@ -446,8 +445,7 @@ def plot_effective_area_cta_requirement(cta_site, ax=None, **kwargs): cta_req = ana.cta_requirement(cta_site) e_cta, ef_cta = cta_req.get_effective_area() - if not 'label' in kwargs: - kwargs['label'] = "CTA requirement {}".format(cta_site) + kwargs.setdefault('label', "CTA requirement {}".format(cta_site)) with quantity_support(): ax.plot(e_cta, ef_cta, **kwargs) @@ -514,8 +512,7 @@ def plot_sensitivity_cta_requirement(cta_site, ax=None, **kwargs): cta_req = ana.cta_requirement(cta_site) e_cta, ef_cta = cta_req.get_sensitivity() - if not 'label' in kwargs: - kwargs['label'] = "CTA requirement {}".format(cta_site) + kwargs.setdefault('label', "CTA requirement {}".format(cta_site)) with quantity_support(): ax.plot(e_cta, ef_cta, **kwargs) @@ -550,8 +547,7 @@ def plot_sensitivity_cta_performance(cta_site, ax=None, **kwargs): e_cta, ef_cta = cta_perf.get_sensitivity() e_bin = cta_perf.energy_bins - if not 'label' in kwargs: - kwargs['label'] = "CTA performance {}".format(cta_site) + kwargs.setdefault('label', "CTA performance {}".format(cta_site)) with quantity_support(): ax.errorbar(e_cta, ef_cta, xerr=u.Quantity([e_cta - e_bin[:-1], e_bin[1:] - e_cta]), **kwargs) @@ -722,8 +718,7 @@ def plot_angular_resolution_cta_requirement(cta_site, ax=None, **kwargs): cta_req = ana.cta_requirement(cta_site) e_cta, ar_cta = cta_req.get_angular_resolution() - if not 'label' in kwargs: - kwargs['label'] = "CTA requirement {}".format(cta_site) + kwargs.setdefault('label', "CTA requirement {}".format(cta_site)) with quantity_support(): ax.plot(e_cta, ar_cta, **kwargs) @@ -758,8 +753,7 @@ def plot_angular_resolution_cta_performance(cta_site, ax=None, **kwargs): cta_req = ana.cta_performance(cta_site) e_cta, ar_cta = cta_req.get_angular_resolution() - if not 'label' in kwargs: - kwargs['label'] = "CTA performance {}".format(cta_site) + kwargs.setdefault('label', "CTA performance {}".format(cta_site)) ax.plot(e_cta, ar_cta, **kwargs) ax.set_xscale('log') @@ -804,8 +798,8 @@ def plot_impact_parameter_resolution_per_energy(true_x, reco_x, true_y, reco_y, def plot_impact_map(impact_x, impact_y, tel_x, tel_y, tel_types=None, ax=None, outfile=None, - hist_kwargs={}, - scatter_kwargs={}, + hist_kwargs=None, + scatter_kwargs=None, ): """ Map of the site with telescopes positions and impact points heatmap @@ -824,13 +818,16 @@ def plot_impact_map(impact_x, impact_y, tel_x, tel_y, tel_types=None, """ ax = plt.gca() if ax is None else ax + hist_kwargs = {} if hist_kwargs is None else hist_kwargs + scatter_kwargs = {} if scatter_kwargs is None else scatter_kwargs + hist_kwargs.setdefault('bins', 40) unit = impact_x.value ax.hist2d(impact_x.to_value(unit), impact_y.to_value(unit), **hist_kwargs) pcm = ax.get_children()[0] plt.colorbar(pcm, ax=ax) - if not len(tel_x) == len(tel_y): + if len(tel_x) != len(tel_y): raise ValueError("tel_x and tel_y should have the same length") scatter_kwargs.setdefault('s', 50) @@ -838,15 +835,12 @@ def plot_impact_map(impact_x, impact_y, tel_x, tel_y, tel_types=None, if tel_types and 'color' not in scatter_kwargs and 'c' not in scatter_kwargs: scatter_kwargs['color'] = tel_types assert (len(tel_types) == len(tel_x)), "tel_types and tel_x should have the same length" - with quantity_support(): - ax.scatter(tel_x, tel_y, **scatter_kwargs) else: if 'color' not in scatter_kwargs and 'c' not in scatter_kwargs: scatter_kwargs['color'] = 'black' scatter_kwargs['marker'] = '+' if 'marker' not in scatter_kwargs else scatter_kwargs['marker'] - with quantity_support(): - ax.scatter(tel_x, tel_y, **scatter_kwargs) - + with quantity_support(): + ax.scatter(tel_x, tel_y, **scatter_kwargs) ax.axis('equal') if outfile is not None: plt.savefig(outfile, bbox_inches="tight", format='png', dpi=200) @@ -871,7 +865,7 @@ def plot_energy_bias(true_energy, reco_energy, ax=None, bins=None, **kwargs): ------- ax: `matplotlib.pyplot.axes` """ - if not len(true_energy) == len(reco_energy): + if len(true_energy) != len(reco_energy): raise ValueError("simulated and reconstructured true_energy arrrays should have the same length") ax = plt.gca() if ax is None else ax @@ -971,8 +965,7 @@ def plot_energy_resolution_cta_requirement(cta_site, ax=None, **kwargs): cta_req = ana.cta_requirement(cta_site) e_cta, ar_cta = cta_req.get_energy_resolution() - if not 'label' in kwargs: - kwargs['label'] = "CTA requirement {}".format(cta_site) + kwargs.setdefault('label', "CTA requirement {}".format(cta_site)) ax.set_ylabel(r"$(\Delta energy/energy)_{68}$") ax.set_xlabel(rf'$E_R$ [{e_cta.unit.to_string("latex")}]') @@ -1005,8 +998,7 @@ def plot_energy_resolution_cta_performance(cta_site, ax=None, **kwargs): cta_req = ana.cta_performance(cta_site) e_cta, ar_cta = cta_req.get_energy_resolution() - if not 'label' in kwargs: - kwargs['label'] = "CTA performance {}".format(cta_site) + kwargs.setdefault('label', "CTA performance {}".format(cta_site)) ax.set_ylabel(r"$(\Delta energy/energy)_{68}$") ax.set_xlabel(rf'$E_R$ [{e_cta.unit.to_string("latex")}]') @@ -1127,10 +1119,9 @@ def plot_migration_matrix(x, y, ax=None, colorbar=False, xy_line=False, hist2d_a >>> plot_migration_matrix(x, y, colorbar=True, hist2d_args=dict(norm=matplotlib.colors.LogNorm())) In this example, the colorbar will be log normed """ - if hist2d_args is None: - hist2d_args = {} - if line_args is None: - line_args = {} + + hist2d_args = {} if hist2d_args is None else hist2d_args + line_args = {} if line_args is None else line_args if 'bins_x' not in hist2d_args: hist2d_args['bins'] = 50 @@ -1559,8 +1550,7 @@ def plot_roc_curve(true_type, reco_proba, if auc_score < 0.5: auc_score = 1 - auc_score - if 'label' not in kwargs: - kwargs['label'] = "auc score = {:.3f}".format(auc_score) + kwargs.setdefault('label', "auc score = {:.3f}".format(auc_score)) fpr, tpr, thresholds = metrics.roc_curve(true_type, reco_proba, @@ -1813,8 +1803,7 @@ def plot_any_resource(filename, columns_xy=None, ax=None, **kwargs): data = load_any_resource(filename) - if 'label' not in kwargs: - kwargs['label'] = filename + kwargs.setdefault('label', filename) ax.plot(data[columns_xy[0]], data[columns_xy[1]], **kwargs) return ax diff --git a/ctaplot/plots/style.py b/ctaplot/plots/style.py index 9b7cce4..8dfc56c 100644 --- a/ctaplot/plots/style.py +++ b/ctaplot/plots/style.py @@ -12,7 +12,7 @@ def check_latex(): ------- bool: True if a LaTeX distribution could be found """ - return not find_executable('latex') is None + return find_executable('latex') is not None @contextmanager