diff --git a/Dockerfile b/Dockerfile index 69f44c4..41e88b3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.11.6 AS bdi-jupyter +FROM --platform=linux/amd64 python:3.11.6 AS bdi-jupyter # Install JupyterHub and dependencies RUN pip3 --disable-pip-version-check install --no-cache-dir \ diff --git a/bdikit/visualization/schema_matching.py b/bdikit/visualization/schema_matching.py index 3a105f7..65d72b7 100644 --- a/bdikit/visualization/schema_matching.py +++ b/bdikit/visualization/schema_matching.py @@ -150,6 +150,7 @@ def __init__( # Panel configurations self.panel_floatpanel_config = {"headerControls": {"close": "remove"}} self.ai_assistant_status = "minimized" + self.log_status = "minimized" def _generate_top_k_matches(self) -> List[Dict]: if isinstance(self.target, pd.DataFrame): @@ -460,7 +461,9 @@ def _generate_all_value_matches(self): if pd.api.types.is_numeric_dtype(self.source[source_column]): continue - source_values = list(self.source[source_column].dropna().unique())[:20] + source_values = list( + self.source[source_column].dropna().unique().astype(str) + )[:20] elif self.additional_sources and source_df in self.additional_sources: if pd.api.types.is_numeric_dtype( @@ -469,7 +472,10 @@ def _generate_all_value_matches(self): continue source_values = list( - self.additional_sources[source_df][source_column].dropna().unique() + self.additional_sources[source_df][source_column] + .dropna() + .unique() + .astype(str) )[:20] else: @@ -481,7 +487,6 @@ def _generate_all_value_matches(self): for _, row in self.candidates_dfs[source_column].iterrows(): target_values = row["Values (sample)"].split(", ") - value_matcher.match(source_values, target_values) match_results = value_matcher.get_matches() @@ -517,7 +522,6 @@ def _accept_match(self) -> None: self._record_log("accept", candidate_name, top_k_name) self._write_json(recommendations) - self._get_heatmap() return def _reject_match(self) -> None: @@ -563,35 +567,68 @@ def _discard_column(self, select_column: Optional[str]) -> None: self._write_json(recommendations) self._record_user_action("discard", d) self._record_log("discard", candidate_name, "") - self._get_heatmap() return def _plot_heatmap_base(self, heatmap_rec_list: pd.DataFrame) -> pn.pane.Vega: single = alt.selection_point(name="single") tooltip = [ - alt.Tooltip("Column", title="Column"), - alt.Tooltip("Recommendation", title="Recommendation"), - alt.Tooltip("Value", title="Similarity"), + alt.Tooltip("Column", title="Source Column"), + alt.Tooltip("Recommendation", title="Matching Candidate"), + alt.Tooltip("Value", title="Similarity Score", format=".4f"), ] # facet = alt.Facet(alt.Undefined) + if self.additional_sources: + source_transformation = alt.datum["DataFrame"] + ">" + alt.datum["Column"] + else: + source_transformation = alt.datum["Column"] + + size_expr = alt.expr(f"datam.value == {single.name} ? 20 : 10") + weight_expr = alt.expr(f"datam.value == {single.name} ? 800 : 300") + + search_input = alt.param( + value="", + bind=alt.binding( + input="search", + placeholder="Candidate search", + name="Search ", + ), + ) + base = ( alt.Chart(heatmap_rec_list) - .transform_calculate( - Column=alt.datum["DataFrame"] + ">" + alt.datum["Column"] - ) + .transform_calculate(Column=source_transformation) .encode( - y=alt.Y("Column:O", sort=None), - x=alt.X("Recommendation:O", sort=None).axis(labelAngle=-45), + y=alt.Y("Column:O", sort=None).axis( + labelFontSize=12, + titleFontSize=14, + title="Source Columns", + ), + x=alt.X( + "Recommendation:O", + sort=None, + ).axis( + labelAngle=-45, + labelFontSize=12, + titleFontSize=14, + title="Target Schemas", + ), color=alt.condition( single, alt.Color("Value:Q").scale(domainMax=1, domainMin=0), alt.value("lightgray"), ), # type: ignore + opacity=alt.condition( + alt.expr.test( + alt.expr.regexp(search_input, "i"), alt.datum.Recommendation + ), + alt.value(1), + alt.value(0.5), + ), # type: ignore tooltip=tooltip, ) - .add_params(single) + .add_params(single, search_input) ) background = base.mark_rect(size=100) @@ -602,9 +639,7 @@ def _plot_heatmap_base(self, heatmap_rec_list: pd.DataFrame) -> pn.pane.Vega: ] box_source_base = ( alt.Chart(y_source) - .transform_calculate( - Column=alt.datum["DataFrame"] + ">" + alt.datum["Column"] - ) + .transform_calculate(Column=source_transformation) .encode( text=alt.condition( alt.datum["DataFrame"] == self.source_prefix, @@ -627,9 +662,7 @@ def _plot_heatmap_base(self, heatmap_rec_list: pd.DataFrame) -> pn.pane.Vega: y_source = heatmap_rec_list[heatmap_rec_list["DataFrame"] == name] box_source_base = ( alt.Chart(y_source) - .transform_calculate( - Column=alt.datum["DataFrame"] + ">" + alt.datum["Column"] - ) + .transform_calculate(Column=source_transformation) .encode( text="DataFrame:O", y="Column:O", @@ -642,9 +675,9 @@ def _plot_heatmap_base(self, heatmap_rec_list: pd.DataFrame) -> pn.pane.Vega: box_sources.append(box_source) box_sources.append(box_source_text) - # rule1 = background.mark_rect(color="", stroke="orange", strokeWidth=2).transform_filter(alt.FieldEqualPredicate(field='DataFrame', equal='source')) - # rule2 = background.mark_rect(color="", stroke="yellow", strokeWidth=2).transform_filter(alt.FieldOneOfPredicate(field='DataFrame', oneOf=list(self.additional_sources.keys()))) - return pn.pane.Vega(alt.layer(background, *box_sources)) + return pn.pane.Vega(alt.layer(background, *box_sources)) + else: + return pn.pane.Vega(background) def _update_column_selection( self, heatmap_rec_list: pd.DataFrame, selection: List[int] @@ -828,7 +861,7 @@ def _plot_target_histogram( def _plot_value_comparisons( self, source_column: str, heatmap_rec_list: pd.DataFrame, selection: List[int] - ) -> "pn.widgets.Tabulator | pn.pane.Markdown": + ) -> "pn.Column | pn.pane.Markdown": if not selection: column = source_column rec = None @@ -851,7 +884,7 @@ def _plot_value_comparisons( if rec: frozen_columns.append(rec) - return pn.widgets.Tabulator( + tabulator = pn.widgets.Tabulator( pd.DataFrame( dict([(k, pd.Series(v)) for k, v in value_comparisons.items()]) ).fillna(""), @@ -861,6 +894,23 @@ def _plot_value_comparisons( height=200, ) + value_filter = pn.widgets.TextInput(name="Value filter", value="") + + def _filter_values(df: pd.DataFrame, pattern: str): + if not pattern or pattern == "": + return df + col_list = list(df.columns[:1]) + for col in df.columns[1:]: + for value in df[col].values: + if pattern.lower() in str(value).lower(): + col_list.append(col) + continue + print(col_list) + return df[col_list] + + tabulator.add_filter(pn.bind(_filter_values, pattern=value_filter)) + return pn.Column(value_filter, tabulator) + def _plot_pane( self, select_column: Optional[str] = None, @@ -873,6 +923,7 @@ def _plot_pane( discard_click: int = 0, undo_click: int = 0, redo_click: int = 0, + log_click: int = 0, ) -> pn.Column: if self.rec_list_df is None: raise ValueError("Heatmap rec_list_df not generated.") @@ -889,6 +940,12 @@ def _plot_pane( heatmap_rec_list["Column"].isin(clustered_cols) ] + sort_order = { + k: v for k, v in zip(clustered_cols, range(len(clustered_cols))) + } + sorted_indices = (heatmap_rec_list["Column"].map(lambda x: sort_order[x]) + (1 - heatmap_rec_list["Value"])).sort_values().index # type: ignore + heatmap_rec_list = heatmap_rec_list.loc[sorted_indices, :] + candidates_df = self.candidates_dfs[select_column] def _filter_datatype(heatmap_rec: pd.Series) -> bool: @@ -954,6 +1011,7 @@ def _filter_datatype(heatmap_rec: pd.Series) -> bool: align="end", theme="secondary", config=self.panel_floatpanel_config, + status=self.log_status, ), pn.Row( heatmap_pane, @@ -1062,22 +1120,27 @@ def callback(contents: str, user: Any, instance: Any) -> Optional[str]: def plot_heatmap(self) -> pn.Column: select_column = pn.widgets.Select( - name="Column", + name="Source Column", options=list(self.source.columns), width=120, ) select_candidate_type = pn.widgets.Select( - name="Candidate type", + name="Candidate Type", options=["All", "enum", "number", "string", "boolean"], width=120, ) n_similar_slider = pn.widgets.IntSlider( - name="N Similar", start=0, end=5, value=0, width=100 + name="Similar Sources", start=0, end=5, value=0, width=150 ) thresh_slider = pn.widgets.FloatSlider( - name="Threshold", start=0, end=1.0, step=0.01, value=0.1, width=100 + name="Candidate Threshold", + start=0, + end=1.0, + step=0.01, + value=0.1, + width=150, ) acc_button = pn.widgets.Button(name="Accept Match", button_type="success") @@ -1093,14 +1156,13 @@ def plot_heatmap(self) -> pn.Column: name="Redo", button_style="outline", button_type="primary" ) - if self.ai_assistant_status == "minimized": - ai_assistant_button = pn.widgets.Button( - name="Show AI Assistant", button_type="primary" - ) - else: - ai_assistant_button = pn.widgets.Button( - name="Hide AI Assistant", button_type="primary" - ) + ai_assistant_button = pn.widgets.Button( + name="Show/Hide AI Assistant", button_type="primary" + ) + + log_button = pn.widgets.Button( + name="Show/Hide Operation Log", button_type="primary" + ) # Subschemas if not isinstance(self.target, pd.DataFrame) and self.target == "gdc": @@ -1113,12 +1175,26 @@ def plot_heatmap(self) -> pn.Column: def on_click_accept_match(event: Any) -> None: self._accept_match() + if ( + select_column.value + and self.selected_row is not None + and n_similar_slider.value == 0 + ): + value_idx = select_column.options.index(select_column.value) + if value_idx < len(select_column.options) - 1: + select_column.value = select_column.options[value_idx + 1] + self._get_heatmap() def on_click_reject_match(event: Any) -> None: self._reject_match() def on_click_discard_column(event: Any) -> None: self._discard_column(select_column.value) + if select_column.value and n_similar_slider.value == 0: + value_idx = select_column.options.index(select_column.value) + if value_idx < len(select_column.options) - 1: + select_column.value = select_column.options[value_idx + 1] + self._get_heatmap() def on_click_undo(event: Any) -> None: self._undo_user_action() @@ -1132,12 +1208,19 @@ def on_click_ai_assistant(event: Any) -> None: else: self.ai_assistant_status = "minimized" + def on_click_log(event: Any) -> None: + if self.log_status == "minimized": + self.log_status = "normalized" + else: + self.log_status = "minimized" + acc_button.on_click(on_click_accept_match) rej_button.on_click(on_click_reject_match) discard_button.on_click(on_click_discard_column) undo_button.on_click(on_click_undo) redo_button.on_click(on_click_redo) ai_assistant_button.on_click(on_click_ai_assistant) + log_button.on_click(on_click_log) heatmap_bind = pn.bind( self._plot_pane, @@ -1155,20 +1238,21 @@ def on_click_ai_assistant(event: Any) -> None: discard_button.param.clicks, undo_button.param.clicks, redo_button.param.clicks, + log_button.param.clicks, ) buttons_down = pn.Column(acc_button, rej_button, discard_button) buttons_redo_undo = pn.Column(undo_button, redo_button) - buttons_floatpanel = pn.Column(ai_assistant_button) + buttons_floatpanel = pn.Column(ai_assistant_button, log_button) column_top = pn.Row( select_column, select_candidate_type, - ( - subschema_col # type: ignore - if (not isinstance(self.target, pd.DataFrame) and self.target == "gdc") - else None - ), + # ( + # subschema_col # type: ignore + # if (not isinstance(self.target, pd.DataFrame) and self.target == "gdc") + # else None + # ), n_similar_slider, thresh_slider, buttons_down,