Skip to content

Commit

Permalink
refactor: ready for user test
Browse files Browse the repository at this point in the history
  • Loading branch information
EdenWuyifan committed Sep 5, 2024
1 parent 35640fe commit fb29b98
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 44 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.11.6 AS bdi-jupyter
FROM --platform=linux/amd64 python:3.11.6 AS bdi-jupyter

# Install JupyterHub and dependencies
RUN pip3 --disable-pip-version-check install --no-cache-dir \
Expand Down
170 changes: 127 additions & 43 deletions bdikit/visualization/schema_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
# Panel configurations
self.panel_floatpanel_config = {"headerControls": {"close": "remove"}}
self.ai_assistant_status = "minimized"
self.log_status = "minimized"

def _generate_top_k_matches(self) -> List[Dict]:
if isinstance(self.target, pd.DataFrame):
Expand Down Expand Up @@ -460,7 +461,9 @@ def _generate_all_value_matches(self):
if pd.api.types.is_numeric_dtype(self.source[source_column]):
continue

source_values = list(self.source[source_column].dropna().unique())[:20]
source_values = list(
self.source[source_column].dropna().unique().astype(str)
)[:20]

elif self.additional_sources and source_df in self.additional_sources:
if pd.api.types.is_numeric_dtype(
Expand All @@ -469,7 +472,10 @@ def _generate_all_value_matches(self):
continue

source_values = list(
self.additional_sources[source_df][source_column].dropna().unique()
self.additional_sources[source_df][source_column]
.dropna()
.unique()
.astype(str)
)[:20]

else:
Expand All @@ -481,7 +487,6 @@ def _generate_all_value_matches(self):

for _, row in self.candidates_dfs[source_column].iterrows():
target_values = row["Values (sample)"].split(", ")

value_matcher.match(source_values, target_values)
match_results = value_matcher.get_matches()

Expand Down Expand Up @@ -517,7 +522,6 @@ def _accept_match(self) -> None:
self._record_log("accept", candidate_name, top_k_name)

self._write_json(recommendations)
self._get_heatmap()
return

def _reject_match(self) -> None:
Expand Down Expand Up @@ -563,35 +567,68 @@ def _discard_column(self, select_column: Optional[str]) -> None:
self._write_json(recommendations)
self._record_user_action("discard", d)
self._record_log("discard", candidate_name, "")
self._get_heatmap()
return

def _plot_heatmap_base(self, heatmap_rec_list: pd.DataFrame) -> pn.pane.Vega:
single = alt.selection_point(name="single")

tooltip = [
alt.Tooltip("Column", title="Column"),
alt.Tooltip("Recommendation", title="Recommendation"),
alt.Tooltip("Value", title="Similarity"),
alt.Tooltip("Column", title="Source Column"),
alt.Tooltip("Recommendation", title="Matching Candidate"),
alt.Tooltip("Value", title="Similarity Score", format=".4f"),
]
# facet = alt.Facet(alt.Undefined)

if self.additional_sources:
source_transformation = alt.datum["DataFrame"] + ">" + alt.datum["Column"]
else:
source_transformation = alt.datum["Column"]

size_expr = alt.expr(f"datam.value == {single.name} ? 20 : 10")
weight_expr = alt.expr(f"datam.value == {single.name} ? 800 : 300")

search_input = alt.param(
value="",
bind=alt.binding(
input="search",
placeholder="Candidate search",
name="Search ",
),
)

base = (
alt.Chart(heatmap_rec_list)
.transform_calculate(
Column=alt.datum["DataFrame"] + ">" + alt.datum["Column"]
)
.transform_calculate(Column=source_transformation)
.encode(
y=alt.Y("Column:O", sort=None),
x=alt.X("Recommendation:O", sort=None).axis(labelAngle=-45),
y=alt.Y("Column:O", sort=None).axis(
labelFontSize=12,
titleFontSize=14,
title="Source Columns",
),
x=alt.X(
"Recommendation:O",
sort=None,
).axis(
labelAngle=-45,
labelFontSize=12,
titleFontSize=14,
title="Target Schemas",
),
color=alt.condition(
single,
alt.Color("Value:Q").scale(domainMax=1, domainMin=0),
alt.value("lightgray"),
), # type: ignore
opacity=alt.condition(
alt.expr.test(
alt.expr.regexp(search_input, "i"), alt.datum.Recommendation
),
alt.value(1),
alt.value(0.5),
), # type: ignore
tooltip=tooltip,
)
.add_params(single)
.add_params(single, search_input)
)
background = base.mark_rect(size=100)

Expand All @@ -602,9 +639,7 @@ def _plot_heatmap_base(self, heatmap_rec_list: pd.DataFrame) -> pn.pane.Vega:
]
box_source_base = (
alt.Chart(y_source)
.transform_calculate(
Column=alt.datum["DataFrame"] + ">" + alt.datum["Column"]
)
.transform_calculate(Column=source_transformation)
.encode(
text=alt.condition(
alt.datum["DataFrame"] == self.source_prefix,
Expand All @@ -627,9 +662,7 @@ def _plot_heatmap_base(self, heatmap_rec_list: pd.DataFrame) -> pn.pane.Vega:
y_source = heatmap_rec_list[heatmap_rec_list["DataFrame"] == name]
box_source_base = (
alt.Chart(y_source)
.transform_calculate(
Column=alt.datum["DataFrame"] + ">" + alt.datum["Column"]
)
.transform_calculate(Column=source_transformation)
.encode(
text="DataFrame:O",
y="Column:O",
Expand All @@ -642,9 +675,9 @@ def _plot_heatmap_base(self, heatmap_rec_list: pd.DataFrame) -> pn.pane.Vega:
box_sources.append(box_source)
box_sources.append(box_source_text)

# rule1 = background.mark_rect(color="", stroke="orange", strokeWidth=2).transform_filter(alt.FieldEqualPredicate(field='DataFrame', equal='source'))
# rule2 = background.mark_rect(color="", stroke="yellow", strokeWidth=2).transform_filter(alt.FieldOneOfPredicate(field='DataFrame', oneOf=list(self.additional_sources.keys())))
return pn.pane.Vega(alt.layer(background, *box_sources))
return pn.pane.Vega(alt.layer(background, *box_sources))
else:
return pn.pane.Vega(background)

def _update_column_selection(
self, heatmap_rec_list: pd.DataFrame, selection: List[int]
Expand Down Expand Up @@ -828,7 +861,7 @@ def _plot_target_histogram(

def _plot_value_comparisons(
self, source_column: str, heatmap_rec_list: pd.DataFrame, selection: List[int]
) -> "pn.widgets.Tabulator | pn.pane.Markdown":
) -> "pn.Column | pn.pane.Markdown":
if not selection:
column = source_column
rec = None
Expand All @@ -851,7 +884,7 @@ def _plot_value_comparisons(
if rec:
frozen_columns.append(rec)

return pn.widgets.Tabulator(
tabulator = pn.widgets.Tabulator(
pd.DataFrame(
dict([(k, pd.Series(v)) for k, v in value_comparisons.items()])
).fillna(""),
Expand All @@ -861,6 +894,23 @@ def _plot_value_comparisons(
height=200,
)

value_filter = pn.widgets.TextInput(name="Value filter", value="")

def _filter_values(df: pd.DataFrame, pattern: str):
if not pattern or pattern == "":
return df
col_list = list(df.columns[:1])
for col in df.columns[1:]:
for value in df[col].values:
if pattern.lower() in str(value).lower():
col_list.append(col)
continue
print(col_list)
return df[col_list]

tabulator.add_filter(pn.bind(_filter_values, pattern=value_filter))
return pn.Column(value_filter, tabulator)

def _plot_pane(
self,
select_column: Optional[str] = None,
Expand All @@ -873,6 +923,7 @@ def _plot_pane(
discard_click: int = 0,
undo_click: int = 0,
redo_click: int = 0,
log_click: int = 0,
) -> pn.Column:
if self.rec_list_df is None:
raise ValueError("Heatmap rec_list_df not generated.")
Expand All @@ -889,6 +940,12 @@ def _plot_pane(
heatmap_rec_list["Column"].isin(clustered_cols)
]

sort_order = {
k: v for k, v in zip(clustered_cols, range(len(clustered_cols)))
}
sorted_indices = (heatmap_rec_list["Column"].map(lambda x: sort_order[x]) + (1 - heatmap_rec_list["Value"])).sort_values().index # type: ignore
heatmap_rec_list = heatmap_rec_list.loc[sorted_indices, :]

candidates_df = self.candidates_dfs[select_column]

def _filter_datatype(heatmap_rec: pd.Series) -> bool:
Expand Down Expand Up @@ -954,6 +1011,7 @@ def _filter_datatype(heatmap_rec: pd.Series) -> bool:
align="end",
theme="secondary",
config=self.panel_floatpanel_config,
status=self.log_status,
),
pn.Row(
heatmap_pane,
Expand Down Expand Up @@ -1062,22 +1120,27 @@ def callback(contents: str, user: Any, instance: Any) -> Optional[str]:

def plot_heatmap(self) -> pn.Column:
select_column = pn.widgets.Select(
name="Column",
name="Source Column",
options=list(self.source.columns),
width=120,
)

select_candidate_type = pn.widgets.Select(
name="Candidate type",
name="Candidate Type",
options=["All", "enum", "number", "string", "boolean"],
width=120,
)

n_similar_slider = pn.widgets.IntSlider(
name="N Similar", start=0, end=5, value=0, width=100
name="Similar Sources", start=0, end=5, value=0, width=150
)
thresh_slider = pn.widgets.FloatSlider(
name="Threshold", start=0, end=1.0, step=0.01, value=0.1, width=100
name="Candidate Threshold",
start=0,
end=1.0,
step=0.01,
value=0.1,
width=150,
)

acc_button = pn.widgets.Button(name="Accept Match", button_type="success")
Expand All @@ -1093,14 +1156,13 @@ def plot_heatmap(self) -> pn.Column:
name="Redo", button_style="outline", button_type="primary"
)

if self.ai_assistant_status == "minimized":
ai_assistant_button = pn.widgets.Button(
name="Show AI Assistant", button_type="primary"
)
else:
ai_assistant_button = pn.widgets.Button(
name="Hide AI Assistant", button_type="primary"
)
ai_assistant_button = pn.widgets.Button(
name="Show/Hide AI Assistant", button_type="primary"
)

log_button = pn.widgets.Button(
name="Show/Hide Operation Log", button_type="primary"
)

# Subschemas
if not isinstance(self.target, pd.DataFrame) and self.target == "gdc":
Expand All @@ -1113,12 +1175,26 @@ def plot_heatmap(self) -> pn.Column:

def on_click_accept_match(event: Any) -> None:
self._accept_match()
if (
select_column.value
and self.selected_row is not None
and n_similar_slider.value == 0
):
value_idx = select_column.options.index(select_column.value)
if value_idx < len(select_column.options) - 1:
select_column.value = select_column.options[value_idx + 1]
self._get_heatmap()

def on_click_reject_match(event: Any) -> None:
self._reject_match()

def on_click_discard_column(event: Any) -> None:
self._discard_column(select_column.value)
if select_column.value and n_similar_slider.value == 0:
value_idx = select_column.options.index(select_column.value)
if value_idx < len(select_column.options) - 1:
select_column.value = select_column.options[value_idx + 1]
self._get_heatmap()

def on_click_undo(event: Any) -> None:
self._undo_user_action()
Expand All @@ -1132,12 +1208,19 @@ def on_click_ai_assistant(event: Any) -> None:
else:
self.ai_assistant_status = "minimized"

def on_click_log(event: Any) -> None:
if self.log_status == "minimized":
self.log_status = "normalized"
else:
self.log_status = "minimized"

acc_button.on_click(on_click_accept_match)
rej_button.on_click(on_click_reject_match)
discard_button.on_click(on_click_discard_column)
undo_button.on_click(on_click_undo)
redo_button.on_click(on_click_redo)
ai_assistant_button.on_click(on_click_ai_assistant)
log_button.on_click(on_click_log)

heatmap_bind = pn.bind(
self._plot_pane,
Expand All @@ -1155,20 +1238,21 @@ def on_click_ai_assistant(event: Any) -> None:
discard_button.param.clicks,
undo_button.param.clicks,
redo_button.param.clicks,
log_button.param.clicks,
)

buttons_down = pn.Column(acc_button, rej_button, discard_button)
buttons_redo_undo = pn.Column(undo_button, redo_button)
buttons_floatpanel = pn.Column(ai_assistant_button)
buttons_floatpanel = pn.Column(ai_assistant_button, log_button)

column_top = pn.Row(
select_column,
select_candidate_type,
(
subschema_col # type: ignore
if (not isinstance(self.target, pd.DataFrame) and self.target == "gdc")
else None
),
# (
# subschema_col # type: ignore
# if (not isinstance(self.target, pd.DataFrame) and self.target == "gdc")
# else None
# ),
n_similar_slider,
thresh_slider,
buttons_down,
Expand Down

0 comments on commit fb29b98

Please sign in to comment.