Skip to content

Commit

Permalink
Merge pull request #489 from qiboteam/fix_classification
Browse files Browse the repository at this point in the history
Fix classification after #447
  • Loading branch information
andrea-pasquale authored Aug 28, 2023
2 parents 91b8ba1 + 7527474 commit 5fb3700
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 34 deletions.
15 changes: 5 additions & 10 deletions src/qibocal/fitting/classifier/qubit_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,12 @@ def fit(self, iq_coordinates: list, states: list):
iq_state0 = iq_coordinates[(states == 0)]
self.iq_mean0 = np.mean(iq_state0, axis=0)
self.iq_mean1 = np.mean(iq_state1, axis=0)
# translate
iq_coordinates_translated = self.translate(iq_coordinates)
iq_state1_trans = self.translate(self.iq_mean1)
self.angle = -1 * atan2(iq_state1_trans[1], iq_state1_trans[0])

vector01 = self.iq_mean1 - self.iq_mean0
self.angle = -1 * atan2(vector01[1], vector01[0])

# rotate
iq_coord_rot = self.rotate(iq_coordinates_translated)
iq_coord_rot = self.rotate(iq_coordinates)

x_values_state0 = np.sort(iq_coord_rot[(states == 0)][:, 0])
x_values_state1 = np.sort(iq_coord_rot[(states == 1)][:, 0])
Expand Down Expand Up @@ -110,15 +109,11 @@ def rotate(self, v):
rot = np.array([[c, -s], [s, c]])
return v @ rot.T

def translate(self, v):
return v - self.iq_mean0

def predict(self, inputs: npt.NDArray):
r"""Classify the `inputs`.
Returns:
List of predictions.
"""
translated = self.translate(inputs)
rotated = self.rotate(translated)
rotated = self.rotate(inputs)
return (rotated[:, 0] > self.threshold).astype(int)
50 changes: 26 additions & 24 deletions src/qibocal/protocols/characterization/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,17 +351,17 @@ def _plot(
fig = make_subplots(
rows=1,
cols=len(models_name),
horizontal_spacing=SPACING * 3 / len(models_name),
horizontal_spacing=SPACING * 3 / len(models_name) * 3,
vertical_spacing=SPACING,
subplot_titles=[run.pretty_name(model) for model in models_name],
column_width=[COLUMNWIDTH] * len(models_name),
)
fig_roc = go.Figure()
fig_roc.add_shape(
type="line", line=dict(dash="dash"), x0=0.0, x1=1.0, y0=0.0, y1=1.0
)

if len(models_name) != 1:
fig_roc = go.Figure()
fig_roc.add_shape(
type="line", line=dict(dash="dash"), x0=0.0, x1=1.0, y0=0.0, y1=1.0
)
fig_benchmarks = make_subplots(
rows=1,
cols=3,
Expand All @@ -381,20 +381,6 @@ def _plot(
predictions = np.round(
np.reshape(fit.models[qubit][i].predict(grid), q_values.shape)
).astype(np.int64)
# Evaluate the ROC curve
fpr, tpr, _ = roc_curve(y_test, y_pred)
auc_score = roc_auc_score(y_test, y_pred)
model = run.pretty_name(model)
name = f"{model} (AUC={auc_score:.2f})"
fig_roc.add_trace(
go.Scatter(
x=fpr,
y=tpr,
name=name,
mode="lines",
marker=dict(size=3, color=get_color_state0(i)),
)
)

max_x = max(grid[:, 0])
max_y = max(grid[:, 1])
Expand Down Expand Up @@ -490,6 +476,20 @@ def _plot(
col=i + 1,
)
if len(models_name) != 1:
# Evaluate the ROC curve
fpr, tpr, _ = roc_curve(y_test, y_pred)
auc_score = roc_auc_score(y_test, y_pred)
model = run.pretty_name(model)
name = f"{model} (AUC={auc_score:.2f})"
fig_roc.add_trace(
go.Scatter(
x=fpr,
y=tpr,
name=name,
mode="lines",
marker=dict(size=3, color=get_color_state0(i)),
)
)
fig_benchmarks.add_trace(
go.Scatter(
x=[model],
Expand Down Expand Up @@ -568,6 +568,13 @@ def _plot(
font=dict(size=LEGEND_FONT_SIZE),
),
)

fitting_report = "No fitting data" if fitting_report == "" else fitting_report

figures.append(fig)
if len(models_name) != 1:
figures.append(fig_roc)
figures.append(fig_benchmarks)
fig_roc.update_layout(
width=ROC_WIDTH,
height=ROC_LENGHT,
Expand All @@ -582,11 +589,6 @@ def _plot(
title_text="True Positive Rate",
range=[0, 1],
)
fitting_report = "No fitting data" if fitting_report == "" else fitting_report
figures.append(fig_roc)
figures.append(fig)
if len(models_name) != 1:
figures.append(fig_benchmarks)
return figures, fitting_report


Expand Down

0 comments on commit 5fb3700

Please sign in to comment.