Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

just some formatting fixes with the help of sourcery #151

Merged
merged 4 commits into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions ctaplot/ana/ana.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def bias(true, reco):
-------
float
"""
if not len(true) == len(reco):
if len(true) != len(reco):
raise ValueError("both arrays should have the same size")
if len(true) == 0:
return 0
Expand Down Expand Up @@ -517,11 +517,7 @@ def resolution_per_bin(x, y_true, y_reco,
)
)

if isinstance(res[0], u.Quantity):
res = u.Quantity(res)
else:
res = np.array(res)

res = u.Quantity(res) if isinstance(res[0], u.Quantity) else np.array(res)
return x_bins, res


Expand Down Expand Up @@ -987,9 +983,8 @@ def logbin_mean(x_bin):
"""
if not isinstance(x_bin, u.Quantity):
return 10 ** ((np.log10(x_bin[:-1]) + np.log10(x_bin[1:])) / 2.)
else:
unit = x_bin.unit
return (10 ** ((np.log10(x_bin[:-1].to_value(unit)) + np.log10(x_bin[1:].to_value(unit))) / 2.)) * unit
unit = x_bin.unit
return (10 ** ((np.log10(x_bin[:-1].to_value(unit)) + np.log10(x_bin[1:].to_value(unit))) / 2.)) * unit


@u.quantity_input(true_x=u.m, reco_x=u.m, true_y=u.m, reco_y=u.m)
Expand Down Expand Up @@ -1288,11 +1283,7 @@ def bias_per_bin(true, reco, x, relative_scaling_method=None, bins=10):
mask = bin_index == ii
b.append(relative_bias(true[mask], reco[mask], relative_scaling_method=relative_scaling_method))

if isinstance(b[0], u.Quantity):
b = u.Quantity(b)
else:
b = np.array(b)

b = u.Quantity(b) if isinstance(b[0], u.Quantity) else np.array(b)
return x_bins, b


Expand Down
55 changes: 22 additions & 33 deletions ctaplot/plots/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,7 @@ def plot_multiplicity_hist(multiplicity, ax=None, outfile=None, quartils=False,
xmin = multiplicity.min()
xmax = multiplicity.max()

if 'label' not in kwargs:
kwargs['label'] = 'Telescope multiplicity'
kwargs.setdefault('label', 'Telescope multiplicity')

n, bins, patches = ax.hist(multiplicity, bins=(xmax - xmin), range=(xmin, xmax), rwidth=0.7, align='left', **kwargs)
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
Expand Down Expand Up @@ -446,8 +445,7 @@ def plot_effective_area_cta_requirement(cta_site, ax=None, **kwargs):
cta_req = ana.cta_requirement(cta_site)
e_cta, ef_cta = cta_req.get_effective_area()

if not 'label' in kwargs:
kwargs['label'] = "CTA requirement {}".format(cta_site)
kwargs.setdefault('label', "CTA requirement {}".format(cta_site))

with quantity_support():
ax.plot(e_cta, ef_cta, **kwargs)
Expand Down Expand Up @@ -514,8 +512,7 @@ def plot_sensitivity_cta_requirement(cta_site, ax=None, **kwargs):
cta_req = ana.cta_requirement(cta_site)
e_cta, ef_cta = cta_req.get_sensitivity()

if not 'label' in kwargs:
kwargs['label'] = "CTA requirement {}".format(cta_site)
kwargs.setdefault('label', "CTA requirement {}".format(cta_site))

with quantity_support():
ax.plot(e_cta, ef_cta, **kwargs)
Expand Down Expand Up @@ -550,8 +547,7 @@ def plot_sensitivity_cta_performance(cta_site, ax=None, **kwargs):
e_cta, ef_cta = cta_perf.get_sensitivity()
e_bin = cta_perf.energy_bins

if not 'label' in kwargs:
kwargs['label'] = "CTA performance {}".format(cta_site)
kwargs.setdefault('label', "CTA performance {}".format(cta_site))

with quantity_support():
ax.errorbar(e_cta, ef_cta, xerr=u.Quantity([e_cta - e_bin[:-1], e_bin[1:] - e_cta]), **kwargs)
Expand Down Expand Up @@ -722,8 +718,7 @@ def plot_angular_resolution_cta_requirement(cta_site, ax=None, **kwargs):
cta_req = ana.cta_requirement(cta_site)
e_cta, ar_cta = cta_req.get_angular_resolution()

if not 'label' in kwargs:
kwargs['label'] = "CTA requirement {}".format(cta_site)
kwargs.setdefault('label', "CTA requirement {}".format(cta_site))

with quantity_support():
ax.plot(e_cta, ar_cta, **kwargs)
Expand Down Expand Up @@ -758,8 +753,7 @@ def plot_angular_resolution_cta_performance(cta_site, ax=None, **kwargs):
cta_req = ana.cta_performance(cta_site)
e_cta, ar_cta = cta_req.get_angular_resolution()

if not 'label' in kwargs:
kwargs['label'] = "CTA performance {}".format(cta_site)
kwargs.setdefault('label', "CTA performance {}".format(cta_site))

ax.plot(e_cta, ar_cta, **kwargs)
ax.set_xscale('log')
Expand Down Expand Up @@ -804,8 +798,8 @@ def plot_impact_parameter_resolution_per_energy(true_x, reco_x, true_y, reco_y,
def plot_impact_map(impact_x, impact_y, tel_x, tel_y, tel_types=None,
ax=None,
outfile=None,
hist_kwargs={},
scatter_kwargs={},
hist_kwargs=None,
scatter_kwargs=None,
):
"""
Map of the site with telescopes positions and impact points heatmap
Expand All @@ -824,29 +818,29 @@ def plot_impact_map(impact_x, impact_y, tel_x, tel_y, tel_types=None,
"""
ax = plt.gca() if ax is None else ax

hist_kwargs = {} if hist_kwargs is None else hist_kwargs
scatter_kwargs = {} if scatter_kwargs is None else scatter_kwargs

hist_kwargs.setdefault('bins', 40)
unit = impact_x.value
ax.hist2d(impact_x.to_value(unit), impact_y.to_value(unit), **hist_kwargs)
pcm = ax.get_children()[0]
plt.colorbar(pcm, ax=ax)

if not len(tel_x) == len(tel_y):
if len(tel_x) != len(tel_y):
raise ValueError("tel_x and tel_y should have the same length")

scatter_kwargs.setdefault('s', 50)

if tel_types and 'color' not in scatter_kwargs and 'c' not in scatter_kwargs:
scatter_kwargs['color'] = tel_types
assert (len(tel_types) == len(tel_x)), "tel_types and tel_x should have the same length"
with quantity_support():
ax.scatter(tel_x, tel_y, **scatter_kwargs)
else:
if 'color' not in scatter_kwargs and 'c' not in scatter_kwargs:
scatter_kwargs['color'] = 'black'
scatter_kwargs['marker'] = '+' if 'marker' not in scatter_kwargs else scatter_kwargs['marker']
with quantity_support():
ax.scatter(tel_x, tel_y, **scatter_kwargs)

with quantity_support():
ax.scatter(tel_x, tel_y, **scatter_kwargs)
ax.axis('equal')
if outfile is not None:
plt.savefig(outfile, bbox_inches="tight", format='png', dpi=200)
Expand All @@ -871,7 +865,7 @@ def plot_energy_bias(true_energy, reco_energy, ax=None, bins=None, **kwargs):
-------
ax: `matplotlib.pyplot.axes`
"""
if not len(true_energy) == len(reco_energy):
if len(true_energy) != len(reco_energy):
raise ValueError("simulated and reconstructured true_energy arrrays should have the same length")

ax = plt.gca() if ax is None else ax
Expand Down Expand Up @@ -971,8 +965,7 @@ def plot_energy_resolution_cta_requirement(cta_site, ax=None, **kwargs):
cta_req = ana.cta_requirement(cta_site)
e_cta, ar_cta = cta_req.get_energy_resolution()

if not 'label' in kwargs:
kwargs['label'] = "CTA requirement {}".format(cta_site)
kwargs.setdefault('label', "CTA requirement {}".format(cta_site))

ax.set_ylabel(r"$(\Delta energy/energy)_{68}$")
ax.set_xlabel(rf'$E_R$ [{e_cta.unit.to_string("latex")}]')
Expand Down Expand Up @@ -1005,8 +998,7 @@ def plot_energy_resolution_cta_performance(cta_site, ax=None, **kwargs):
cta_req = ana.cta_performance(cta_site)
e_cta, ar_cta = cta_req.get_energy_resolution()

if not 'label' in kwargs:
kwargs['label'] = "CTA performance {}".format(cta_site)
kwargs.setdefault('label', "CTA performance {}".format(cta_site))

ax.set_ylabel(r"$(\Delta energy/energy)_{68}$")
ax.set_xlabel(rf'$E_R$ [{e_cta.unit.to_string("latex")}]')
Expand Down Expand Up @@ -1127,10 +1119,9 @@ def plot_migration_matrix(x, y, ax=None, colorbar=False, xy_line=False, hist2d_a
>>> plot_migration_matrix(x, y, colorbar=True, hist2d_args=dict(norm=matplotlib.colors.LogNorm()))
In this example, the colorbar will be log normed
"""
if hist2d_args is None:
hist2d_args = {}
if line_args is None:
line_args = {}

hist2d_args = {} if hist2d_args is None else hist2d_args
line_args = {} if line_args is None else line_args

if 'bins_x' not in hist2d_args:
hist2d_args['bins'] = 50
Expand Down Expand Up @@ -1559,8 +1550,7 @@ def plot_roc_curve(true_type, reco_proba,
if auc_score < 0.5:
auc_score = 1 - auc_score

if 'label' not in kwargs:
kwargs['label'] = "auc score = {:.3f}".format(auc_score)
kwargs.setdefault('label', "auc score = {:.3f}".format(auc_score))

fpr, tpr, thresholds = metrics.roc_curve(true_type,
reco_proba,
Expand Down Expand Up @@ -1813,8 +1803,7 @@ def plot_any_resource(filename, columns_xy=None, ax=None, **kwargs):

data = load_any_resource(filename)

if 'label' not in kwargs:
kwargs['label'] = filename
kwargs.setdefault('label', filename)
ax.plot(data[columns_xy[0]], data[columns_xy[1]], **kwargs)

return ax
Expand Down
2 changes: 1 addition & 1 deletion ctaplot/plots/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def check_latex():
-------
bool: True if a LaTeX distribution could be found
"""
return not find_executable('latex') is None
return find_executable('latex') is not None


@contextmanager
Expand Down