Skip to content

Commit

Permalink
Merge pull request #151 from cta-observatory/formatting
Browse files Browse the repository at this point in the history
just some formatting fixes with the help of sourcery
  • Loading branch information
vuillaut authored May 7, 2021
2 parents dd59f45 + cf17a86 commit f8a7a4f
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 48 deletions.
19 changes: 5 additions & 14 deletions ctaplot/ana/ana.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def bias(true, reco):
-------
float
"""
if not len(true) == len(reco):
if len(true) != len(reco):
raise ValueError("both arrays should have the same size")
if len(true) == 0:
return 0
Expand Down Expand Up @@ -517,11 +517,7 @@ def resolution_per_bin(x, y_true, y_reco,
)
)

if isinstance(res[0], u.Quantity):
res = u.Quantity(res)
else:
res = np.array(res)

res = u.Quantity(res) if isinstance(res[0], u.Quantity) else np.array(res)
return x_bins, res


Expand Down Expand Up @@ -987,9 +983,8 @@ def logbin_mean(x_bin):
"""
if not isinstance(x_bin, u.Quantity):
return 10 ** ((np.log10(x_bin[:-1]) + np.log10(x_bin[1:])) / 2.)
else:
unit = x_bin.unit
return (10 ** ((np.log10(x_bin[:-1].to_value(unit)) + np.log10(x_bin[1:].to_value(unit))) / 2.)) * unit
unit = x_bin.unit
return (10 ** ((np.log10(x_bin[:-1].to_value(unit)) + np.log10(x_bin[1:].to_value(unit))) / 2.)) * unit


@u.quantity_input(true_x=u.m, reco_x=u.m, true_y=u.m, reco_y=u.m)
Expand Down Expand Up @@ -1288,11 +1283,7 @@ def bias_per_bin(true, reco, x, relative_scaling_method=None, bins=10):
mask = bin_index == ii
b.append(relative_bias(true[mask], reco[mask], relative_scaling_method=relative_scaling_method))

if isinstance(b[0], u.Quantity):
b = u.Quantity(b)
else:
b = np.array(b)

b = u.Quantity(b) if isinstance(b[0], u.Quantity) else np.array(b)
return x_bins, b


Expand Down
55 changes: 22 additions & 33 deletions ctaplot/plots/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,7 @@ def plot_multiplicity_hist(multiplicity, ax=None, outfile=None, quartils=False,
xmin = multiplicity.min()
xmax = multiplicity.max()

if 'label' not in kwargs:
kwargs['label'] = 'Telescope multiplicity'
kwargs.setdefault('label', 'Telescope multiplicity')

n, bins, patches = ax.hist(multiplicity, bins=(xmax - xmin), range=(xmin, xmax), rwidth=0.7, align='left', **kwargs)
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
Expand Down Expand Up @@ -446,8 +445,7 @@ def plot_effective_area_cta_requirement(cta_site, ax=None, **kwargs):
cta_req = ana.cta_requirement(cta_site)
e_cta, ef_cta = cta_req.get_effective_area()

if not 'label' in kwargs:
kwargs['label'] = "CTA requirement {}".format(cta_site)
kwargs.setdefault('label', "CTA requirement {}".format(cta_site))

with quantity_support():
ax.plot(e_cta, ef_cta, **kwargs)
Expand Down Expand Up @@ -514,8 +512,7 @@ def plot_sensitivity_cta_requirement(cta_site, ax=None, **kwargs):
cta_req = ana.cta_requirement(cta_site)
e_cta, ef_cta = cta_req.get_sensitivity()

if not 'label' in kwargs:
kwargs['label'] = "CTA requirement {}".format(cta_site)
kwargs.setdefault('label', "CTA requirement {}".format(cta_site))

with quantity_support():
ax.plot(e_cta, ef_cta, **kwargs)
Expand Down Expand Up @@ -550,8 +547,7 @@ def plot_sensitivity_cta_performance(cta_site, ax=None, **kwargs):
e_cta, ef_cta = cta_perf.get_sensitivity()
e_bin = cta_perf.energy_bins

if not 'label' in kwargs:
kwargs['label'] = "CTA performance {}".format(cta_site)
kwargs.setdefault('label', "CTA performance {}".format(cta_site))

with quantity_support():
ax.errorbar(e_cta, ef_cta, xerr=u.Quantity([e_cta - e_bin[:-1], e_bin[1:] - e_cta]), **kwargs)
Expand Down Expand Up @@ -722,8 +718,7 @@ def plot_angular_resolution_cta_requirement(cta_site, ax=None, **kwargs):
cta_req = ana.cta_requirement(cta_site)
e_cta, ar_cta = cta_req.get_angular_resolution()

if not 'label' in kwargs:
kwargs['label'] = "CTA requirement {}".format(cta_site)
kwargs.setdefault('label', "CTA requirement {}".format(cta_site))

with quantity_support():
ax.plot(e_cta, ar_cta, **kwargs)
Expand Down Expand Up @@ -758,8 +753,7 @@ def plot_angular_resolution_cta_performance(cta_site, ax=None, **kwargs):
cta_req = ana.cta_performance(cta_site)
e_cta, ar_cta = cta_req.get_angular_resolution()

if not 'label' in kwargs:
kwargs['label'] = "CTA performance {}".format(cta_site)
kwargs.setdefault('label', "CTA performance {}".format(cta_site))

ax.plot(e_cta, ar_cta, **kwargs)
ax.set_xscale('log')
Expand Down Expand Up @@ -804,8 +798,8 @@ def plot_impact_parameter_resolution_per_energy(true_x, reco_x, true_y, reco_y,
def plot_impact_map(impact_x, impact_y, tel_x, tel_y, tel_types=None,
ax=None,
outfile=None,
hist_kwargs={},
scatter_kwargs={},
hist_kwargs=None,
scatter_kwargs=None,
):
"""
Map of the site with telescopes positions and impact points heatmap
Expand All @@ -824,29 +818,29 @@ def plot_impact_map(impact_x, impact_y, tel_x, tel_y, tel_types=None,
"""
ax = plt.gca() if ax is None else ax

hist_kwargs = {} if hist_kwargs is None else hist_kwargs
scatter_kwargs = {} if scatter_kwargs is None else scatter_kwargs

hist_kwargs.setdefault('bins', 40)
unit = impact_x.value
ax.hist2d(impact_x.to_value(unit), impact_y.to_value(unit), **hist_kwargs)
pcm = ax.get_children()[0]
plt.colorbar(pcm, ax=ax)

if not len(tel_x) == len(tel_y):
if len(tel_x) != len(tel_y):
raise ValueError("tel_x and tel_y should have the same length")

scatter_kwargs.setdefault('s', 50)

if tel_types and 'color' not in scatter_kwargs and 'c' not in scatter_kwargs:
scatter_kwargs['color'] = tel_types
assert (len(tel_types) == len(tel_x)), "tel_types and tel_x should have the same length"
with quantity_support():
ax.scatter(tel_x, tel_y, **scatter_kwargs)
else:
if 'color' not in scatter_kwargs and 'c' not in scatter_kwargs:
scatter_kwargs['color'] = 'black'
scatter_kwargs['marker'] = '+' if 'marker' not in scatter_kwargs else scatter_kwargs['marker']
with quantity_support():
ax.scatter(tel_x, tel_y, **scatter_kwargs)

with quantity_support():
ax.scatter(tel_x, tel_y, **scatter_kwargs)
ax.axis('equal')
if outfile is not None:
plt.savefig(outfile, bbox_inches="tight", format='png', dpi=200)
Expand All @@ -871,7 +865,7 @@ def plot_energy_bias(true_energy, reco_energy, ax=None, bins=None, **kwargs):
-------
ax: `matplotlib.pyplot.axes`
"""
if not len(true_energy) == len(reco_energy):
if len(true_energy) != len(reco_energy):
raise ValueError("simulated and reconstructured true_energy arrrays should have the same length")

ax = plt.gca() if ax is None else ax
Expand Down Expand Up @@ -971,8 +965,7 @@ def plot_energy_resolution_cta_requirement(cta_site, ax=None, **kwargs):
cta_req = ana.cta_requirement(cta_site)
e_cta, ar_cta = cta_req.get_energy_resolution()

if not 'label' in kwargs:
kwargs['label'] = "CTA requirement {}".format(cta_site)
kwargs.setdefault('label', "CTA requirement {}".format(cta_site))

ax.set_ylabel(r"$(\Delta energy/energy)_{68}$")
ax.set_xlabel(rf'$E_R$ [{e_cta.unit.to_string("latex")}]')
Expand Down Expand Up @@ -1005,8 +998,7 @@ def plot_energy_resolution_cta_performance(cta_site, ax=None, **kwargs):
cta_req = ana.cta_performance(cta_site)
e_cta, ar_cta = cta_req.get_energy_resolution()

if not 'label' in kwargs:
kwargs['label'] = "CTA performance {}".format(cta_site)
kwargs.setdefault('label', "CTA performance {}".format(cta_site))

ax.set_ylabel(r"$(\Delta energy/energy)_{68}$")
ax.set_xlabel(rf'$E_R$ [{e_cta.unit.to_string("latex")}]')
Expand Down Expand Up @@ -1127,10 +1119,9 @@ def plot_migration_matrix(x, y, ax=None, colorbar=False, xy_line=False, hist2d_a
>>> plot_migration_matrix(x, y, colorbar=True, hist2d_args=dict(norm=matplotlib.colors.LogNorm()))
In this example, the colorbar will be log normed
"""
if hist2d_args is None:
hist2d_args = {}
if line_args is None:
line_args = {}

hist2d_args = {} if hist2d_args is None else hist2d_args
line_args = {} if line_args is None else line_args

if 'bins_x' not in hist2d_args:
hist2d_args['bins'] = 50
Expand Down Expand Up @@ -1559,8 +1550,7 @@ def plot_roc_curve(true_type, reco_proba,
if auc_score < 0.5:
auc_score = 1 - auc_score

if 'label' not in kwargs:
kwargs['label'] = "auc score = {:.3f}".format(auc_score)
kwargs.setdefault('label', "auc score = {:.3f}".format(auc_score))

fpr, tpr, thresholds = metrics.roc_curve(true_type,
reco_proba,
Expand Down Expand Up @@ -1813,8 +1803,7 @@ def plot_any_resource(filename, columns_xy=None, ax=None, **kwargs):

data = load_any_resource(filename)

if 'label' not in kwargs:
kwargs['label'] = filename
kwargs.setdefault('label', filename)
ax.plot(data[columns_xy[0]], data[columns_xy[1]], **kwargs)

return ax
Expand Down
2 changes: 1 addition & 1 deletion ctaplot/plots/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def check_latex():
-------
bool: True if a LaTeX distribution could be found
"""
return not find_executable('latex') is None
return find_executable('latex') is not None


@contextmanager
Expand Down

0 comments on commit f8a7a4f

Please sign in to comment.